diff --git a/ratis-common/dev-support/findbugsExcludeFile.xml b/ratis-common/dev-support/findbugsExcludeFile.xml index 882f08b7fa..0b397e3077 100644 --- a/ratis-common/dev-support/findbugsExcludeFile.xml +++ b/ratis-common/dev-support/findbugsExcludeFile.xml @@ -111,4 +111,8 @@ - \ No newline at end of file + + + + + diff --git a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientAsynchronousProtocol.java b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientAsynchronousProtocol.java index 1a9f83c823..02e5550c7b 100644 --- a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientAsynchronousProtocol.java +++ b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientAsynchronousProtocol.java @@ -17,12 +17,30 @@ */ package org.apache.ratis.protocol; +import org.apache.ratis.util.ReferenceCountedObject; + import java.io.IOException; import java.util.concurrent.CompletableFuture; /** Asynchronous version of {@link RaftClientProtocol}. */ public interface RaftClientAsynchronousProtocol { - CompletableFuture submitClientRequestAsync( - RaftClientRequest request) throws IOException; + /** + * A plain request is submitted from a client for processing. + * + * This default keeps older call sites compatible with implementations that operate on reference-counted requests. + */ + default CompletableFuture submitClientRequestAsync(RaftClientRequest request) throws IOException { + return submitClientRequestAsync(ReferenceCountedObject.wrap(request)); + } + /** + * A referenced counted request is submitted from a client for processing. + * Implementations of this method should retain the request, process it and then release it. + * The request may be retained even after the future returned by this method has completed. + * + * @return a future of the reply + * @see ReferenceCountedObject + */ + CompletableFuture submitClientRequestAsync( + ReferenceCountedObject requestRef) throws IOException; } \ No newline at end of file diff --git a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java index b04402fe15..85ede62e8c 100644 --- a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java +++ b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java @@ -488,7 +488,13 @@ public SpanContextProto getSpanContext() { @Override public String toString() { - return super.toString() + ", seq=" + ProtoUtils.toString(slidingWindowEntry) + ", " - + type + ", " + getMessage(); + return toStringShort() + ", " + getMessage(); + } + + /** + * @return a short string which does not include {@link #message}. + */ + public String toStringShort() { + return super.toString() + ", seq=" + ProtoUtils.toString(slidingWindowEntry) + ", " + type; } } diff --git a/ratis-common/src/main/java/org/apache/ratis/util/DataBlockingQueue.java b/ratis-common/src/main/java/org/apache/ratis/util/DataBlockingQueue.java index e905893e5b..fb0f0715c5 100644 --- a/ratis-common/src/main/java/org/apache/ratis/util/DataBlockingQueue.java +++ b/ratis-common/src/main/java/org/apache/ratis/util/DataBlockingQueue.java @@ -29,6 +29,7 @@ import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; import java.util.function.ToLongFunction; /** @@ -46,6 +47,8 @@ public class DataBlockingQueue extends DataQueue { private final Condition notFull = lock.newCondition(); private final Condition notEmpty = lock.newCondition(); + private boolean closed = false; + public DataBlockingQueue(Object name, SizeInBytes byteLimit, int elementLimit, ToLongFunction getNumBytes) { super(name, byteLimit, elementLimit, getNumBytes); } @@ -72,10 +75,34 @@ public void clear() { } } + /** Apply the given handler to each element and then {@link #clear()}. */ + public void clear(Consumer handler) { + try(AutoCloseableLock auto = AutoCloseableLock.acquire(lock)) { + for(E e : this) { + handler.accept(e); + } + super.clear(); + } + } + + /** + * Close this queue to stop accepting new elements, i.e. the offer(…) methods always return false. + * Note that closing the queue will not clear the existing elements. + * The existing elements can be peeked, polled or cleared after close. + */ + public void close() { + try(AutoCloseableLock ignored = AutoCloseableLock.acquire(lock)) { + closed = true; + } + } + @Override public boolean offer(E element) { Objects.requireNonNull(element, "element == null"); try(AutoCloseableLock auto = AutoCloseableLock.acquire(lock)) { + if (closed) { + return false; + } if (super.offer(element)) { notEmpty.signal(); return true; @@ -95,6 +122,9 @@ public boolean offer(E element, TimeDuration timeout) throws InterruptedExceptio long nanos = timeout.toLong(TimeUnit.NANOSECONDS); try(AutoCloseableLock auto = AutoCloseableLock.acquire(lock)) { for(;;) { + if (closed) { + return false; + } if (super.offer(element)) { notEmpty.signal(); return true; diff --git a/ratis-common/src/main/java/org/apache/ratis/util/LeakDetector.java b/ratis-common/src/main/java/org/apache/ratis/util/LeakDetector.java new file mode 100644 index 0000000000..19d1729acf --- /dev/null +++ b/ratis-common/src/main/java/org/apache/ratis/util/LeakDetector.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ratis.util; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** + * Simple general resource leak detector using {@link ReferenceQueue} and {@link java.lang.ref.WeakReference} to + * observe resource object life-cycle and assert proper resource closure before they are GCed. + * + *

+ * Example usage: + * + *

 {@code
+ * class MyResource implements AutoClosable {
+ *   static final LeakDetector LEAK_DETECTOR = new LeakDetector("MyResource");
+ *
+ *   private final UncheckedAutoCloseable leakTracker = LEAK_DETECTOR.track(this, () -> {
+ *      // report leaks, don't refer to the original object (MyResource) here.
+ *      System.out.println("MyResource is not closed before being discarded.");
+ *   });
+ *
+ *   @Override
+ *   public void close() {
+ *     // proper resources cleanup...
+ *     // inform tracker that this object is closed properly.
+ *     leakTracker.close();
+ *   }
+ * }
+ *
+ * }
+ */ +public class LeakDetector { + private static final Logger LOG = LoggerFactory.getLogger(LeakDetector.class); + + private static class LeakTrackerSet { + private final Set set = Collections.newSetFromMap(new HashMap<>()); + + synchronized boolean remove(LeakTracker tracker) { + return set.remove(tracker); + } + + synchronized void removeExisting(LeakTracker tracker) { + final boolean removed = set.remove(tracker); + Preconditions.assertTrue(removed, () -> "Failed to remove existing " + tracker); + } + + synchronized LeakTracker add(Object referent, ReferenceQueue queue, Supplier leakReporter) { + final LeakTracker tracker = new LeakTracker(referent, queue, this::removeExisting, leakReporter); + final boolean added = set.add(tracker); + Preconditions.assertTrue(added, () -> "Failed to add " + tracker + " for " + referent); + return tracker; + } + + synchronized int getNumLeaks(boolean throwException) { + if (set.isEmpty()) { + return 0; + } + + int n = 0; + for (LeakTracker tracker : set) { + if (tracker.reportLeak() != null) { + n++; + } + } + if (throwException) { + assertNoLeaks(n); + } + return n; + } + + synchronized void assertNoLeaks(int leaks) { + Preconditions.assertTrue(leaks == 0, () -> { + final int size = set.size(); + return "#leaks = " + leaks + " > 0, #leaks " + (leaks == size? "==" : "!=") + " set.size = " + size; + }); + } + } + + private static final AtomicLong COUNTER = new AtomicLong(); + + private final ReferenceQueue queue = new ReferenceQueue<>(); + /** All the {@link LeakTracker}s. */ + private final LeakTrackerSet trackers = new LeakTrackerSet(); + /** When a leak is discovered, a message is printed and added to this list. */ + private final List leakMessages = Collections.synchronizedList(new ArrayList<>()); + private final String name; + + LeakDetector(String name) { + this.name = name + COUNTER.getAndIncrement(); + } + + LeakDetector start() { + Thread t = new Thread(this::run); + t.setName(LeakDetector.class.getSimpleName() + "-" + name); + t.setDaemon(true); + LOG.info("Starting leak detector thread {}.", name); + t.start(); + return this; + } + + private void run() { + while (true) { + try { + LeakTracker tracker = (LeakTracker) queue.remove(); + // Original resource already been GCed, if tracker is not closed yet, + // report a leak. + if (trackers.remove(tracker)) { + final String leak = tracker.reportLeak(); + if (leak != null) { + leakMessages.add(leak); + } + } + } catch (InterruptedException e) { + LOG.warn("Thread interrupted, exiting.", e); + break; + } + } + + LOG.warn("Exiting leak detector {}.", name); + } + + Runnable track(Object leakable, Supplier reportLeak) { + // TODO: A rate filter can be put here to only track a subset of all objects, e.g. 5%, 10%, + // if we have proofs that leak tracking impacts performance, or a single LeakDetector + // thread can't keep up with the pace of object allocation. + // For now, it looks effective enough and let keep it simple. + return trackers.add(leakable, queue, reportLeak)::remove; + } + + public int getLeakCount() { + return trackers.getNumLeaks(false); + } + + public void assertNoLeaks(int maxRetries, TimeDuration retrySleep) throws InterruptedException { + synchronized (leakMessages) { + // leakMessages are all the leaks discovered so far. + Preconditions.assertTrue(leakMessages.isEmpty(), + () -> "#leaks = " + leakMessages.size() + "\n" + leakMessages); + } + + for(int i = 0; i < maxRetries; i++) { + final int numLeaks = trackers.getNumLeaks(false); + if (numLeaks == 0) { + return; + } + LOG.warn("{}/{}) numLeaks == {} > 0, will wait and retry ...", i, maxRetries, numLeaks); + retrySleep.sleep(); + } + trackers.getNumLeaks(true); + } + + private static final class LeakTracker extends WeakReference { + private final Consumer removeMethod; + private final Supplier getLeakMessage; + + LeakTracker(Object referent, ReferenceQueue referenceQueue, + Consumer removeMethod, Supplier getLeakMessage) { + super(referent, referenceQueue); + this.removeMethod = removeMethod; + this.getLeakMessage = getLeakMessage; + } + + /** Called by the tracked resource when the object is completely released. */ + void remove() { + removeMethod.accept(this); + } + + /** @return the leak message if there is a leak; return null if there is no leak. */ + String reportLeak() { + return getLeakMessage.get(); + } + } +} diff --git a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedLeakDetector.java b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedLeakDetector.java new file mode 100644 index 0000000000..ec99eee58e --- /dev/null +++ b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedLeakDetector.java @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ratis.util; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** + * A utility to detect leaks from @{@link ReferenceCountedObject}. + */ +public final class ReferenceCountedLeakDetector { + private static final Logger LOG = LoggerFactory.getLogger(ReferenceCountedLeakDetector.class); + // Leak detection is turned off by default. + + private static final AtomicReference FACTORY = new AtomicReference<>(Mode.NONE); + private static final Supplier SUPPLIER + = MemoizedSupplier.valueOf(() -> new LeakDetector(FACTORY.get().name()).start()); + + static Factory getFactory() { + return FACTORY.get(); + } + + public static LeakDetector getLeakDetector() { + return SUPPLIER.get(); + } + + private ReferenceCountedLeakDetector() { + } + + public static synchronized void enable(boolean advanced) { + FACTORY.set(advanced ? Mode.ADVANCED : Mode.SIMPLE); + } + + interface Factory { + ReferenceCountedObject create(V value, Runnable retainMethod, Consumer releaseMethod); + } + + private enum Mode implements Factory { + /** Leak detector is not enable in production to avoid performance impacts. */ + NONE { + @Override + public ReferenceCountedObject create(V value, Runnable retainMethod, Consumer releaseMethod) { + return new Impl<>(value, retainMethod, releaseMethod); + } + }, + /** Leak detector is enabled to detect leaks. This is intended to use in every tests. */ + SIMPLE { + @Override + public ReferenceCountedObject create(V value, Runnable retainMethod, Consumer releaseMethod) { + return new SimpleTracing<>(value, retainMethod, releaseMethod, getLeakDetector()); + } + }, + /** + * Leak detector is enabled to detect leaks and report object creation stacktrace as well as every retain and + * release stacktraces. This has severe impact in performance and only used to debug specific test cases. + */ + ADVANCED { + @Override + public ReferenceCountedObject create(V value, Runnable retainMethod, Consumer releaseMethod) { + return new AdvancedTracing<>(value, retainMethod, releaseMethod, getLeakDetector()); + } + } + } + + private static class Impl implements ReferenceCountedObject { + private final AtomicInteger count; + private final V value; + private final Runnable retainMethod; + private final Consumer releaseMethod; + + Impl(V value, Runnable retainMethod, Consumer releaseMethod) { + this.value = value; + this.retainMethod = retainMethod; + this.releaseMethod = releaseMethod; + count = new AtomicInteger(); + } + + @Override + public V get() { + final int previous = count.get(); + if (previous < 0) { + throw new IllegalStateException("Failed to get: object has already been completely released."); + } else if (previous == 0) { + throw new IllegalStateException("Failed to get: object has not yet been retained."); + } + return value; + } + + final int getCount() { + return count.get(); + } + + @Override + public V retain() { + // n < 0: exception + // n >= 0: n++ + if (count.getAndUpdate(n -> n < 0? n : n + 1) < 0) { + throw new IllegalStateException("Failed to retain: object has already been completely released."); + } + + retainMethod.run(); + return value; + } + + @Override + public boolean release() { + // n <= 0: exception + // n > 1: n-- + // n == 1: n = -1 + final int previous = count.getAndUpdate(n -> n <= 1? -1: n - 1); + if (previous < 0) { + throw new IllegalStateException("Failed to release: object has already been completely released."); + } else if (previous == 0) { + throw new IllegalStateException("Failed to release: object has not yet been retained."); + } + final boolean completedReleased = previous == 1; + releaseMethod.accept(completedReleased); + return completedReleased; + } + } + + private static class SimpleTracing extends Impl { + private final LeakDetector leakDetector; + private final Class valueClass; + private String valueString = null; + private Runnable removeMethod = null; + + SimpleTracing(T value, Runnable retainMethod, Consumer releaseMethod, LeakDetector leakDetector) { + super(value, retainMethod, releaseMethod); + this.valueClass = value.getClass(); + this.leakDetector = leakDetector; + } + + String getTraceString(int count) { + return "(" + valueClass + ", count=" + count + ", value=" + valueString + ")"; + } + + /** @return the leak message if there is a leak; return null if there is no leak. */ + String logLeakMessage() { + final int count = getCount(); + if (count == 0) { // never retain + return null; + } + final String message = "LEAK: " + getTraceString(count); + LOG.warn(message); + return message; + } + + @Override + public synchronized T get() { + try { + return super.get(); + } catch (Exception e) { + throw new IllegalStateException("Failed to get: " + getTraceString(getCount()), e); + } + } + + @Override + public synchronized T retain() { + final T value; + try { + value = super.retain(); + } catch (Exception e) { + throw new IllegalStateException("Failed to retain: " + getTraceString(getCount()), e); + } + if (getCount() == 1) { // this is the first retain + this.removeMethod = leakDetector.track(this, this::logLeakMessage); + this.valueString = value.toString(); + } + return value; + } + + @Override + public synchronized boolean release() { + final boolean released; + try { + released = super.release(); + } catch (Exception e) { + throw new IllegalStateException("Failed to release: " + getTraceString(getCount()), e); + } + + if (released) { + Preconditions.assertNotNull(removeMethod, () -> "Not yet retained (removeMethod == null): " + valueClass); + removeMethod.run(); + } + return released; + } + } + + private static class AdvancedTracing extends SimpleTracing { + enum Op {CREATION, RETAIN, RELEASE, CURRENT} + + static class Counts { + private final int refCount; + private final int retainCount; + private final int releaseCount; + + Counts() { + this.refCount = 0; + this.retainCount = 0; + this.releaseCount = 0; + } + + Counts(Op op, Counts previous) { + if (op == Op.RETAIN) { + this.refCount = previous.refCount + 1; + this.retainCount = previous.retainCount + 1; + this.releaseCount = previous.releaseCount; + } else if (op == Op.RELEASE) { + this.refCount = previous.refCount - 1; + this.retainCount = previous.retainCount; + this.releaseCount = previous.releaseCount + 1; + } else { + throw new IllegalStateException("Unexpected op: " + op); + } + } + + @Override + public String toString() { + return "refCount=" + refCount + + ", retainCount=" + retainCount + + ", releaseCount=" + releaseCount; + } + } + + static class TraceInfo { + private final int id; + private final Op op; + private final int previousRefCount; + private final Counts counts; + + private final String threadInfo; + private final StackTraceElement[] stackTraces; + private final int newTraceElementIndex; + + TraceInfo(int id, Op op, TraceInfo previous, int previousRefCount) { + this.id = id; + this.op = op; + this.previousRefCount = previousRefCount; + this.counts = previous == null? new Counts() + : op == Op.CURRENT ? previous.counts + : new Counts(op, previous.counts); + + final Thread thread = Thread.currentThread(); + this.threadInfo = "Thread_" + thread.getId() + ":" + thread.getName(); + this.stackTraces = thread.getStackTrace(); + this.newTraceElementIndex = previous == null? stackTraces.length - 1 + : findFirstUnequalFromTail(this.stackTraces, previous.stackTraces); + } + + static int findFirstUnequalFromTail(T[] current, T[] previous) { + int c = current.length == 0 ? 0 : current.length - 1; + for(int p = previous.length - 1; p >= 0; p--, c--) { + if (!previous[p].equals(current[c])) { + return c; + } + } + return -1; + } + + private StringBuilder appendTo(StringBuilder b) { + b.append(op).append("_").append(id) + .append(": previousRefCount=").append(previousRefCount) + .append(", ").append(counts) + .append(", ").append(threadInfo).append("\n"); + final int n = newTraceElementIndex + 1; + int line = 3; + for (; line <= n && line < stackTraces.length; line++) { + b.append(" ").append(stackTraces[line]).append("\n"); + } + if (line < stackTraces.length) { + b.append(" ...\n"); + } + return b; + } + + @Override + public String toString() { + return appendTo(new StringBuilder()).toString(); + } + } + + private final List traceInfos = new ArrayList<>(); + private TraceInfo previous; + + AdvancedTracing(T value, Runnable retainMethod, Consumer releaseMethod, LeakDetector leakDetector) { + super(value, retainMethod, releaseMethod, leakDetector); + addTraceInfo(Op.CREATION, -1); + } + + private synchronized TraceInfo addTraceInfo(Op op, int previousRefCount) { + final TraceInfo current = new TraceInfo(traceInfos.size(), op, previous, previousRefCount); + traceInfos.add(current); + previous = current; + return current; + } + + + @Override + public synchronized T retain() { + final int previousRefCount = getCount(); + final T retained = super.retain(); + final TraceInfo info = addTraceInfo(Op.RETAIN, previousRefCount); + Preconditions.assertSame(getCount(), info.counts.refCount, "refCount"); + return retained; + } + + @Override + public synchronized boolean release() { + final int previousRefCount = getCount(); + final boolean released = super.release(); + final TraceInfo info = addTraceInfo(Op.RELEASE, previousRefCount); + final int count = getCount(); + final int expected = count == -1? 0 : count; + Preconditions.assertSame(expected, info.counts.refCount, "refCount"); + return released; + } + + @Override + synchronized String getTraceString(int count) { + return super.getTraceString(count) + getTraceInfosString(); + } + + private String getTraceInfosString() { + final int n = traceInfos.size(); + final StringBuilder b = new StringBuilder(n << 10).append(" #TraceInfos=").append(n); + TraceInfo last = null; + for (TraceInfo info : traceInfos) { + info.appendTo(b.append("\n")); + last = info; + } + + // append current track info + final TraceInfo current = new TraceInfo(n, Op.CURRENT, last, getCount()); + current.appendTo(b.append("\n")); + + return b.toString(); + } + } +} diff --git a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java index 0dd378dc01..1fc72c3445 100644 --- a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java +++ b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java @@ -19,10 +19,11 @@ import org.apache.ratis.util.function.UncheckedAutoCloseableSupplier; +import java.util.Collection; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import java.util.function.Function; /** * A reference-counted object can be retained for later use @@ -44,6 +45,7 @@ * @param The object type. */ public interface ReferenceCountedObject { + /** @return the object. */ T get(); @@ -101,65 +103,81 @@ static ReferenceCountedObject wrap(V value) { return wrap(value, () -> {}, ignored -> {}); } - /** - * Wrap the given value as a {@link ReferenceCountedObject}. - * - * @param value the value being wrapped. - * @param retainMethod a method to run when {@link #retain()} is invoked. - * @param releaseMethod a method to run when {@link #release()} is invoked, - * where the method takes a boolean which is the same as the one returned by {@link #release()}. - * @param The value type. - * @return the wrapped reference-counted object. - */ - static ReferenceCountedObject wrap(V value, Runnable retainMethod, Consumer releaseMethod) { - Objects.requireNonNull(value, "value == null"); - Objects.requireNonNull(retainMethod, "retainMethod == null"); - Objects.requireNonNull(releaseMethod, "releaseMethod == null"); - + static ReferenceCountedObject delegateFrom(Collection> fromRefs, V value) { return new ReferenceCountedObject() { - private final AtomicInteger count = new AtomicInteger(); - @Override public V get() { - final int previous = count.get(); - if (previous < 0) { - throw new IllegalStateException("Failed to get: object has already been completely released."); - } else if (previous == 0) { - throw new IllegalStateException("Failed to get: object has not yet been retained."); - } return value; } @Override public V retain() { - // n < 0: exception - // n >= 0: n++ - if (count.getAndUpdate(n -> n < 0? n : n + 1) < 0) { - throw new IllegalStateException("Failed to retain: object has already been completely released."); + fromRefs.forEach(ReferenceCountedObject::retain); + return value; + } + + @Override + public boolean release() { + boolean allReleased = true; + for (ReferenceCountedObject ref : fromRefs) { + if (!ref.release()) { + allReleased = false; + } } + return allReleased; + } + }; + } + + /** + * @return a {@link ReferenceCountedObject} of the given value by delegating to this object. + */ + default ReferenceCountedObject delegate(V value) { + final ReferenceCountedObject delegated = this; + return new ReferenceCountedObject() { + @Override + public V get() { + return value; + } - retainMethod.run(); + @Override + public V retain() { + delegated.retain(); return value; } @Override public boolean release() { - // n <= 0: exception - // n > 1: n-- - // n == 1: n = -1 - final int previous = count.getAndUpdate(n -> n <= 1? -1: n - 1); - if (previous < 0) { - throw new IllegalStateException("Failed to release: object has already been completely released."); - } else if (previous == 0) { - throw new IllegalStateException("Failed to release: object has not yet been retained."); - } - final boolean completedReleased = previous == 1; - releaseMethod.accept(completedReleased); - return completedReleased; + return delegated.release(); } }; } + /** + * @return a {@link ReferenceCountedObject} by apply the given function to this object. + */ + default ReferenceCountedObject apply(Function function) { + return delegate(function.apply(get())); + } + + /** + * Wrap the given value as a {@link ReferenceCountedObject}. + * + * @param value the value being wrapped. + * @param retainMethod a method to run when {@link #retain()} is invoked. + * @param releaseMethod a method to run when {@link #release()} is invoked, + * where the method takes a boolean which is the same as the one returned by {@link #release()}. + * @param The value type. + * @return the wrapped reference-counted object. + */ + static ReferenceCountedObject wrap(V value, Runnable retainMethod, Consumer releaseMethod) { + Objects.requireNonNull(value, "value == null"); + Objects.requireNonNull(retainMethod, "retainMethod == null"); + Objects.requireNonNull(releaseMethod, "releaseMethod == null"); + + return ReferenceCountedLeakDetector.getFactory().create(value, retainMethod, releaseMethod); + } + /** The same as wrap(value, retainMethod, ignored -> releaseMethod.run()). */ static ReferenceCountedObject wrap(V value, Runnable retainMethod, Runnable releaseMethod) { return wrap(value, retainMethod, ignored -> releaseMethod.run()); diff --git a/ratis-common/src/main/java/org/apache/ratis/util/SlidingWindow.java b/ratis-common/src/main/java/org/apache/ratis/util/SlidingWindow.java index 7e37d81322..a41b4b136a 100644 --- a/ratis-common/src/main/java/org/apache/ratis/util/SlidingWindow.java +++ b/ratis-common/src/main/java/org/apache/ratis/util/SlidingWindow.java @@ -51,6 +51,9 @@ interface Request { boolean hasReply(); void fail(Throwable e); + + default void release() { + } } interface ClientSideRequest extends Request { @@ -160,14 +163,19 @@ void endOfRequests(long nextToProcess, REQUEST end, Consumer replyMetho + " will NEVER be processed; request = " + r); r.fail(e); replyMethod.accept(r); + r.release(); } tail.clear(); putNewRequest(end); } - void clear() { + void clear(long nextToProcess) { LOG.debug("close {}", this); + final SortedMap tail = requests.tailMap(nextToProcess); + for (REQUEST r : tail.values()) { + r.release(); + } requests.clear(); } @@ -444,19 +452,26 @@ public synchronized String toString() { /** A request (or a retry) arrives (may be out-of-order except for the first request). */ public synchronized void receivedRequest(REQUEST request, Consumer processingMethod) { final long seqNum = request.getSeqNum(); + final boolean accepted; if (nextToProcess == -1 && (request.isFirstRequest() || seqNum == 0)) { nextToProcess = seqNum; requests.putNewRequest(request); LOG.debug("Received seq={} (first request), {}", seqNum, this); + accepted = true; + } else if (request.getSeqNum() < nextToProcess) { + LOG.debug("Received seq={} < nextToProcess {}, {}", seqNum, nextToProcess, this); + accepted = false; } else { final boolean isRetry = requests.putIfAbsent(request); LOG.debug("Received seq={}, isRetry? {}, {}", seqNum, isRetry, this); - if (isRetry) { - return; - } + accepted = !isRetry; } - processRequestsFromHead(processingMethod); + if (accepted) { + processRequestsFromHead(processingMethod); + } else { + request.release(); + } } private void processRequestsFromHead(Consumer processingMethod) { @@ -465,6 +480,7 @@ private void processRequestsFromHead(Consumer processingMethod) { return; } else if (r.getSeqNum() == nextToProcess) { processingMethod.accept(r); + r.release(); nextToProcess++; } } @@ -510,7 +526,7 @@ public synchronized boolean endOfRequests(Consumer replyMethod) { @Override public void close() { - requests.clear(); + requests.clear(nextToProcess); } } } \ No newline at end of file diff --git a/ratis-common/src/test/java/org/apache/ratis/BaseTest.java b/ratis-common/src/test/java/org/apache/ratis/BaseTest.java index 52b986ab9b..45c86cad0b 100644 --- a/ratis-common/src/test/java/org/apache/ratis/BaseTest.java +++ b/ratis-common/src/test/java/org/apache/ratis/BaseTest.java @@ -21,6 +21,7 @@ import org.apache.ratis.util.ExitUtils; import org.apache.ratis.util.FileUtils; import org.apache.ratis.util.JavaUtils; +import org.apache.ratis.util.ReferenceCountedLeakDetector; import org.apache.ratis.util.Slf4jUtils; import org.apache.ratis.util.StringUtils; import org.apache.ratis.util.TimeDuration; @@ -73,20 +74,12 @@ public void setFirstException(Throwable e) { @BeforeEach public void setup(TestInfo testInfo) { - checkAssumptions(); + testCaseName = testInfo.getTestMethod() + .orElseThrow(() -> new RuntimeException("Exception while getting test name.")) + .getName(); - final Method method = testInfo.getTestMethod().orElse(null); - testCaseName = testInfo.getTestClass().orElse(getClass()).getSimpleName() - + "." + (method == null? null : method.getName()); - } - - @BeforeEach - public void checkAssumptions() { - final Throwable first = firstException.get(); - Assumptions.assumeTrue(first == null, () -> "Already failed with " + first); - - final Throwable exited = ExitUtils.getFirstExitException(); - Assumptions.assumeTrue(exited == null, () -> "Already exited with " + exited); + final int leaks = ReferenceCountedLeakDetector.getLeakDetector().getLeakCount(); + Assumptions.assumeFalse(0 < leaks, () -> "numLeaks " + leaks + " > 0"); } @AfterEach diff --git a/ratis-examples/src/main/java/org/apache/ratis/examples/arithmetic/ArithmeticStateMachine.java b/ratis-examples/src/main/java/org/apache/ratis/examples/arithmetic/ArithmeticStateMachine.java index fa0dc6d8e2..c4adff5987 100644 --- a/ratis-examples/src/main/java/org/apache/ratis/examples/arithmetic/ArithmeticStateMachine.java +++ b/ratis-examples/src/main/java/org/apache/ratis/examples/arithmetic/ArithmeticStateMachine.java @@ -164,7 +164,7 @@ public void close() { @Override public CompletableFuture applyTransaction(TransactionContext trx) { - final LogEntryProto entry = trx.getLogEntry(); + final LogEntryProto entry = trx.getLogEntryUnsafe(); final AssignmentMessage assignment = new AssignmentMessage(entry.getStateMachineLogEntry().getLogData()); final long index = entry.getIndex(); diff --git a/ratis-examples/src/main/java/org/apache/ratis/examples/counter/server/CounterStateMachine.java b/ratis-examples/src/main/java/org/apache/ratis/examples/counter/server/CounterStateMachine.java index 914180feb8..e8b09c77ae 100644 --- a/ratis-examples/src/main/java/org/apache/ratis/examples/counter/server/CounterStateMachine.java +++ b/ratis-examples/src/main/java/org/apache/ratis/examples/counter/server/CounterStateMachine.java @@ -261,7 +261,7 @@ public TransactionContext startTransaction(RaftClientRequest request) throws IOE */ @Override public CompletableFuture applyTransaction(TransactionContext trx) { - final LogEntryProto entry = trx.getLogEntry(); + final LogEntryProto entry = trx.getLogEntryUnsafe(); //increment the counter and update term-index final TermIndex termIndex = TermIndex.valueOf(entry); final int incremented = incrementCounter(termIndex); diff --git a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileInfo.java b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileInfo.java index c7d8cb7cd1..bba001002a 100644 --- a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileInfo.java +++ b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileInfo.java @@ -19,6 +19,7 @@ import org.apache.ratis.protocol.RaftPeerId; import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; +import org.apache.ratis.thirdparty.com.google.protobuf.UnsafeByteOperations; import org.apache.ratis.util.CollectionUtils; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.LogUtils; @@ -84,7 +85,7 @@ ByteString read(CheckedFunction resolver, long offset, final ByteBuffer buffer = ByteBuffer.allocateDirect(FileStoreCommon.getChunkSize(length)); in.position(offset).read(buffer); buffer.flip(); - return ByteString.copyFrom(buffer); + return UnsafeByteOperations.unsafeWrap(buffer); } } diff --git a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java index 5f258ee3b7..345831dfb2 100644 --- a/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java +++ b/ratis-examples/src/main/java/org/apache/ratis/examples/filestore/FileStoreStateMachine.java @@ -40,6 +40,7 @@ import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException; import org.apache.ratis.util.FileUtils; +import org.apache.ratis.util.ReferenceCountedObject; import java.io.IOException; import java.nio.file.Path; @@ -112,16 +113,18 @@ public TransactionContext startTransaction(RaftClientRequest request) throws IOE @Override public TransactionContext startTransaction(LogEntryProto entry, RaftProtos.RaftPeerRole role) { + ByteString copied = ByteString.copyFrom(entry.getStateMachineLogEntry().getLogData().asReadOnlyByteBuffer()); return TransactionContext.newBuilder() .setStateMachine(this) .setLogEntry(entry) .setServerRole(role) - .setStateMachineContext(getProto(entry)) + .setStateMachineContext(getProto(copied)) .build(); } @Override - public CompletableFuture write(LogEntryProto entry, TransactionContext context) { + public CompletableFuture write(ReferenceCountedObject entryRef, TransactionContext context) { + LogEntryProto entry = entryRef.retain(); final FileStoreRequestProto proto = getProto(context, entry); if (proto.getRequestCase() != FileStoreRequestProto.RequestCase.WRITEHEADER) { return null; @@ -130,9 +133,10 @@ public CompletableFuture write(LogEntryProto entry, TransactionContext final WriteRequestHeaderProto h = proto.getWriteHeader(); final CompletableFuture f = files.write(entry.getIndex(), h.getPath().toStringUtf8(), h.getClose(), h.getSync(), h.getOffset(), - entry.getStateMachineLogEntry().getStateMachineEntry().getStateMachineData()); + entry.getStateMachineLogEntry().getStateMachineEntry().getStateMachineData() + ).whenComplete((r, e) -> entryRef.release()); // sync only if closing the file - return h.getClose()? f: null; + return h.getClose() ? f: null; } static FileStoreRequestProto getProto(TransactionContext context, LogEntryProto entry) { @@ -142,14 +146,14 @@ static FileStoreRequestProto getProto(TransactionContext context, LogEntryProto return proto; } } - return getProto(entry); + return getProto(entry.getStateMachineLogEntry().getLogData()); } - static FileStoreRequestProto getProto(LogEntryProto entry) { + static FileStoreRequestProto getProto(ByteString bytes) { try { - return FileStoreRequestProto.parseFrom(entry.getStateMachineLogEntry().getLogData()); + return FileStoreRequestProto.parseFrom(bytes); } catch (InvalidProtocolBufferException e) { - throw new IllegalArgumentException("Failed to parse data, entry=" + entry, e); + throw new IllegalArgumentException("Failed to parse data", e); } } @@ -214,7 +218,7 @@ public CompletableFuture link(DataStream stream, LogEntryProto entry) { @Override public CompletableFuture applyTransaction(TransactionContext trx) { - final LogEntryProto entry = trx.getLogEntry(); + final LogEntryProto entry = trx.getLogEntryUnsafe(); final long index = entry.getIndex(); updateLastAppliedTermIndex(entry.getTerm(), index); diff --git a/ratis-examples/src/test/java/org/apache/ratis/TestMultiRaftGroup.java b/ratis-examples/src/test/java/org/apache/ratis/TestMultiRaftGroup.java index 7822b694df..ea3962c088 100644 --- a/ratis-examples/src/test/java/org/apache/ratis/TestMultiRaftGroup.java +++ b/ratis-examples/src/test/java/org/apache/ratis/TestMultiRaftGroup.java @@ -22,29 +22,20 @@ import org.apache.ratis.examples.arithmetic.ArithmeticStateMachine; import org.apache.ratis.examples.arithmetic.TestArithmetic; import org.apache.ratis.protocol.RaftGroup; -import org.apache.ratis.server.RaftServer; import org.apache.ratis.server.impl.GroupManagementBaseTest; import org.apache.ratis.server.impl.MiniRaftCluster; -import org.apache.ratis.util.Slf4jUtils; -import org.apache.ratis.test.tag.Flaky; import org.apache.ratis.util.function.CheckedBiConsumer; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import org.slf4j.event.Level; import java.io.IOException; import java.util.Collection; import java.util.concurrent.atomic.AtomicInteger; -@Flaky("RATIS-2218") @Timeout(value = 300) public class TestMultiRaftGroup extends BaseTest { - static { - Slf4jUtils.setLogLevel(RaftServer.Division.LOG, Level.DEBUG); - } - - public static Collection data() throws IOException { + public static Collection data() { return ParameterizedBaseTest.getMiniRaftClusters(ArithmeticStateMachine.class, 0); } diff --git a/ratis-grpc/dev-support/findbugsExcludeFile.xml b/ratis-grpc/dev-support/findbugsExcludeFile.xml index c13c34ade0..c318da58a0 100644 --- a/ratis-grpc/dev-support/findbugsExcludeFile.xml +++ b/ratis-grpc/dev-support/findbugsExcludeFile.xml @@ -23,4 +23,16 @@ - \ No newline at end of file + + + + + + + + + + + + + diff --git a/ratis-grpc/pom.xml b/ratis-grpc/pom.xml index 360131d55b..47e6e1bb40 100644 --- a/ratis-grpc/pom.xml +++ b/ratis-grpc/pom.xml @@ -28,6 +28,12 @@ org.apache.ratis ratis-thirdparty-misc + + com.google.errorprone + error_prone_annotations + 2.29.2 + provided + ratis-proto org.apache.ratis diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java index f31794ac36..2fcb9b6b0a 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java @@ -54,29 +54,50 @@ static Consumer getDefaultLog() { interface TLS { String PREFIX = GrpcConfigKeys.PREFIX + ".tls"; - @Deprecated + String ENABLED_KEY = PREFIX + ".enabled"; + boolean ENABLED_DEFAULT = false; + static boolean enabled(RaftProperties properties) { + return getBoolean(properties::getBoolean, ENABLED_KEY, ENABLED_DEFAULT, getDefaultLog()); + } static void setEnabled(RaftProperties properties, boolean enabled) { - LOG.warn("This method has no effect. Use setConf(Parameters, GrpcTlsConfig) instead."); + setBoolean(properties::setBoolean, ENABLED_KEY, enabled); } - @Deprecated + String MUTUAL_AUTHN_ENABLED_KEY = PREFIX + ".mutual_authn.enabled"; + boolean MUTUAL_AUTHN_ENABLED_DEFAULT = false; + static boolean mutualAuthnEnabled(RaftProperties properties) { + return getBoolean(properties::getBoolean, + MUTUAL_AUTHN_ENABLED_KEY, MUTUAL_AUTHN_ENABLED_DEFAULT, getDefaultLog()); + } static void setMutualAuthnEnabled(RaftProperties properties, boolean mutualAuthnEnabled) { - LOG.warn("This method has no effect. Use setConf(Parameters, GrpcTlsConfig) instead."); + setBoolean(properties::setBoolean, MUTUAL_AUTHN_ENABLED_KEY, mutualAuthnEnabled); } - @Deprecated + String PRIVATE_KEY_FILE_NAME_KEY = PREFIX + ".private.key.file.name"; + String PRIVATE_KEY_FILE_NAME_DEFAULT = "private.pem"; + static String privateKeyFileName(RaftProperties properties) { + return get(properties::get, PRIVATE_KEY_FILE_NAME_KEY, PRIVATE_KEY_FILE_NAME_DEFAULT, getDefaultLog()); + } static void setPrivateKeyFileName(RaftProperties properties, String privateKeyFileName) { - LOG.warn("This method has no effect. Use setConf(Parameters, GrpcTlsConfig) instead."); + set(properties::set, PRIVATE_KEY_FILE_NAME_KEY, privateKeyFileName); } - @Deprecated + String CERT_CHAIN_FILE_NAME_KEY = PREFIX + ".cert.chain.file.name"; + String CERT_CHAIN_FILE_NAME_DEFAULT = "certificate.crt"; + static String certChainFileName(RaftProperties properties) { + return get(properties::get, CERT_CHAIN_FILE_NAME_KEY, CERT_CHAIN_FILE_NAME_DEFAULT, getDefaultLog()); + } static void setCertChainFileName(RaftProperties properties, String certChainFileName) { - LOG.warn("This method has no effect. Use setConf(Parameters, GrpcTlsConfig) instead."); + set(properties::set, CERT_CHAIN_FILE_NAME_KEY, certChainFileName); } - @Deprecated + String TRUST_STORE_KEY = PREFIX + ".trust.store"; + String TRUST_STORE_DEFAULT = "ca.crt"; + static String trustStore(RaftProperties properties) { + return get(properties::get, TRUST_STORE_KEY, TRUST_STORE_DEFAULT, getDefaultLog()); + } static void setTrustStore(RaftProperties properties, String trustStore) { - LOG.warn("This method has no effect. Use setConf(Parameters, GrpcTlsConfig) instead."); + set(properties::set, TRUST_STORE_KEY, trustStore); } String CONF_PARAMETER = PREFIX + ".conf"; @@ -264,6 +285,15 @@ static void setLogMessageBatchDuration(RaftProperties properties, LOG_MESSAGE_BATCH_DURATION_KEY, logMessageBatchDuration); } + String ZERO_COPY_ENABLED_KEY = PREFIX + ".zerocopy.enabled"; + boolean ZERO_COPY_ENABLED_DEFAULT = false; + static boolean zeroCopyEnabled(RaftProperties properties) { + return getBoolean(properties::getBoolean, ZERO_COPY_ENABLED_KEY, ZERO_COPY_ENABLED_DEFAULT, getDefaultLog()); + } + static void setZeroCopyEnabled(RaftProperties properties, boolean enabled) { + setBoolean(properties::setBoolean, ZERO_COPY_ENABLED_KEY, enabled); + } + String SERVICES_CUSTOMIZER_PARAMETER = PREFIX + ".services.customizer"; Class SERVICES_CUSTOMIZER_CLASS = GrpcServices.Customizer.class; static GrpcServices.Customizer servicesCustomizer(Parameters parameters) { @@ -282,15 +312,6 @@ static GrpcTlsConfig tlsConf(Parameters parameters) { static void setTlsConf(Parameters parameters, GrpcTlsConfig conf) { parameters.put(TLS_CONF_PARAMETER, conf, TLS_CONF_CLASS); } - - String STUB_POOL_SIZE_KEY = PREFIX + ".stub.pool.size"; - int STUB_POOL_SIZE_DEFAULT = 1; - static int stubPoolSize(RaftProperties properties) { - return get(properties::getInt, STUB_POOL_SIZE_KEY, STUB_POOL_SIZE_DEFAULT, getDefaultLog()); - } - static void setStubPoolSize(RaftProperties properties, int size) { - setInt(properties::setInt, STUB_POOL_SIZE_KEY, size); - } } String MESSAGE_SIZE_MAX_KEY = PREFIX + ".message.size.max"; diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java index 1053cab80e..331d1a8585 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcFactory.java @@ -32,15 +32,11 @@ import org.apache.ratis.server.leader.FollowerInfo; import org.apache.ratis.server.leader.LeaderState; import org.apache.ratis.thirdparty.io.netty.buffer.PooledByteBufAllocator; -import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext; import org.apache.ratis.util.JavaUtils; -import org.apache.ratis.util.MemoizedSupplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.function.BiFunction; import java.util.function.Consumer; -import java.util.function.Supplier; public class GrpcFactory implements ServerFactory, ClientFactory { @@ -69,32 +65,19 @@ static boolean checkPooledByteBufAllocatorUseCacheForAllThreads(Consumer return value; } - static final BiFunction BUILD_SSL_CONTEXT_FOR_SERVER - = (tlsConf, defaultContext) -> tlsConf == null ? defaultContext : GrpcUtil.buildSslContextForServer(tlsConf); - - static final BiFunction BUILD_SSL_CONTEXT_FOR_CLIENT - = (tlsConf, defaultContext) -> tlsConf == null ? defaultContext : GrpcUtil.buildSslContextForClient(tlsConf); + private final GrpcServices.Customizer servicesCustomizer; - static final class SslContexts { - private final SslContext adminSslContext; - private final SslContext clientSslContext; - private final SslContext serverSslContext; + private final GrpcTlsConfig tlsConfig; + private final GrpcTlsConfig adminTlsConfig; + private final GrpcTlsConfig clientTlsConfig; + private final GrpcTlsConfig serverTlsConfig; - private SslContexts(GrpcTlsConfig tlsConfig, GrpcTlsConfig adminTlsConfig, - GrpcTlsConfig clientTlsConfig, GrpcTlsConfig serverTlsConfig, - BiFunction buildMethod) { - final SslContext defaultSslContext = buildMethod.apply(tlsConfig, null); - this.adminSslContext = buildMethod.apply(adminTlsConfig, defaultSslContext); - this.clientSslContext = buildMethod.apply(clientTlsConfig, defaultSslContext); - this.serverSslContext = buildMethod.apply(serverTlsConfig, defaultSslContext); - } + public static Parameters newRaftParameters(GrpcTlsConfig conf) { + final Parameters p = new Parameters(); + GrpcConfigKeys.TLS.setConf(p, conf); + return p; } - private final GrpcServices.Customizer servicesCustomizer; - - private final Supplier forServerSupplier; - private final Supplier forClientSupplier; - public GrpcFactory(Parameters parameters) { this(GrpcConfigKeys.Server.servicesCustomizer(parameters), GrpcConfigKeys.TLS.conf(parameters), @@ -104,15 +87,35 @@ public GrpcFactory(Parameters parameters) { ); } + public GrpcFactory(GrpcTlsConfig tlsConfig) { + this(null, tlsConfig, null, null, null); + } + private GrpcFactory(GrpcServices.Customizer servicesCustomizer, GrpcTlsConfig tlsConfig, GrpcTlsConfig adminTlsConfig, GrpcTlsConfig clientTlsConfig, GrpcTlsConfig serverTlsConfig) { this.servicesCustomizer = servicesCustomizer; - this.forServerSupplier = MemoizedSupplier.valueOf(() -> new SslContexts( - tlsConfig, adminTlsConfig, clientTlsConfig, serverTlsConfig, BUILD_SSL_CONTEXT_FOR_SERVER)); - this.forClientSupplier = MemoizedSupplier.valueOf(() -> new SslContexts( - tlsConfig, adminTlsConfig, clientTlsConfig, serverTlsConfig, BUILD_SSL_CONTEXT_FOR_CLIENT)); + this.tlsConfig = tlsConfig; + this.adminTlsConfig = adminTlsConfig; + this.clientTlsConfig = clientTlsConfig; + this.serverTlsConfig = serverTlsConfig; + } + + public GrpcTlsConfig getTlsConfig() { + return tlsConfig; + } + + public GrpcTlsConfig getAdminTlsConfig() { + return adminTlsConfig != null ? adminTlsConfig : tlsConfig; + } + + public GrpcTlsConfig getClientTlsConfig() { + return clientTlsConfig != null ? clientTlsConfig : tlsConfig; + } + + public GrpcTlsConfig getServerTlsConfig() { + return serverTlsConfig != null ? serverTlsConfig : tlsConfig; } @Override @@ -128,24 +131,19 @@ public LogAppender newLogAppender(RaftServer.Division server, LeaderState state, @Override public GrpcServices newRaftServerRpc(RaftServer server) { checkPooledByteBufAllocatorUseCacheForAllThreads(LOG::info); - - final SslContexts forServer = forServerSupplier.get(); - final SslContexts forClient = forClientSupplier.get(); return GrpcServicesImpl.newBuilder() .setServer(server) .setCustomizer(servicesCustomizer) - .setAdminSslContext(forServer.adminSslContext) - .setServerSslContextForServer(forServer.serverSslContext) - .setServerSslContextForClient(forClient.serverSslContext) - .setClientSslContext(forServer.clientSslContext) + .setAdminTlsConfig(getAdminTlsConfig()) + .setServerTlsConfig(getServerTlsConfig()) + .setClientTlsConfig(getClientTlsConfig()) .build(); } @Override public GrpcClientRpc newRaftClientRpc(ClientId clientId, RaftProperties properties) { checkPooledByteBufAllocatorUseCacheForAllThreads(LOG::debug); - - final SslContexts forClient = forClientSupplier.get(); - return new GrpcClientRpc(clientId, properties, forClient.adminSslContext, forClient.clientSslContext); + return new GrpcClientRpc(clientId, properties, + getAdminTlsConfig(), getClientTlsConfig()); } } diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java index df076875bf..311bcb8778 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcUtil.java @@ -24,14 +24,15 @@ import org.apache.ratis.security.TlsConf.CertificatesConf; import org.apache.ratis.security.TlsConf.PrivateKeyConf; import org.apache.ratis.security.TlsConf.KeyManagerConf; +import org.apache.ratis.thirdparty.com.google.protobuf.MessageLite; import org.apache.ratis.thirdparty.io.grpc.ManagedChannel; import org.apache.ratis.thirdparty.io.grpc.Metadata; +import org.apache.ratis.thirdparty.io.grpc.MethodDescriptor; +import org.apache.ratis.thirdparty.io.grpc.ServerCallHandler; +import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition; import org.apache.ratis.thirdparty.io.grpc.Status; import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException; -import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts; import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; -import org.apache.ratis.thirdparty.io.netty.handler.ssl.ClientAuth; -import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext; import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder; import org.apache.ratis.util.IOUtils; import org.apache.ratis.util.JavaUtils; @@ -42,7 +43,6 @@ import org.slf4j.LoggerFactory; import javax.net.ssl.KeyManager; -import javax.net.ssl.SSLException; import javax.net.ssl.TrustManager; import java.io.IOException; import java.util.concurrent.CompletableFuture; @@ -50,8 +50,6 @@ import java.util.function.Function; import java.util.function.Supplier; -import static org.apache.ratis.thirdparty.io.netty.handler.ssl.SslProvider.OPENSSL; - public interface GrpcUtil { Logger LOG = LoggerFactory.getLogger(GrpcUtil.class); @@ -65,14 +63,8 @@ public interface GrpcUtil { Metadata.Key.of("heartbeat", Metadata.ASCII_STRING_MARSHALLER); static StatusRuntimeException wrapException(Throwable t) { - return wrapException(t, -1); - } - - static StatusRuntimeException wrapException(Throwable t, long callId) { t = JavaUtils.unwrapCompletionException(t); - Metadata trailers = new StatusRuntimeExceptionMetadataBuilder(t) - .addCallId(callId) - .build(); + Metadata trailers = new StatusRuntimeExceptionMetadataBuilder(t).build(); return wrapException(t, trailers); } @@ -84,6 +76,10 @@ static StatusRuntimeException wrapException(Throwable t, long callId, boolean is .build(); return wrapException(t, trailers); } + static StatusRuntimeException wrapException(Throwable t, long callId) { + return wrapException(t, callId, false); + } + static StatusRuntimeException wrapException(Throwable t, Metadata trailers) { return new StatusRuntimeException( @@ -97,7 +93,7 @@ static Throwable unwrapThrowable(Throwable t) { return unwrapped; } } - return JavaUtils.unwrapCompletionException(t); + return t; } static IOException unwrapException(StatusRuntimeException se) { @@ -144,10 +140,8 @@ static Throwable tryUnwrapThrowable(StatusRuntimeException se) { static long getCallId(Throwable t) { if (t instanceof StatusRuntimeException) { final Metadata trailers = ((StatusRuntimeException)t).getTrailers(); - if (trailers != null) { - final String callId = trailers.get(CALL_ID); - return callId != null ? Long.parseUnsignedLong(callId) : -1; - } + String callId = trailers.get(CALL_ID); + return callId != null ? Integer.parseInt(callId) : -1; } return -1; } @@ -155,8 +149,8 @@ static long getCallId(Throwable t) { static boolean isHeartbeat(Throwable t) { if (t instanceof StatusRuntimeException) { final Metadata trailers = ((StatusRuntimeException)t).getTrailers(); - final String isHeartbeat = trailers != null ? trailers.get(HEARTBEAT) : null; - return Boolean.parseBoolean(isHeartbeat); + String isHeartbeat = trailers != null ? trailers.get(HEARTBEAT) : null; + return isHeartbeat != null && Boolean.valueOf(isHeartbeat); } return false; } @@ -164,7 +158,7 @@ static boolean isHeartbeat(Throwable t) { static IOException unwrapIOException(Throwable t) { final IOException e; if (t instanceof StatusRuntimeException) { - e = unwrapException((StatusRuntimeException) t); + e = GrpcUtil.unwrapException((StatusRuntimeException) t); } else { e = IOUtils.asIOException(t); } @@ -180,7 +174,7 @@ static void asyncCall( supplier.get().whenComplete((reply, exception) -> { if (exception != null) { warning.accept(exception); - responseObserver.onError(wrapException(exception)); + responseObserver.onError(GrpcUtil.wrapException(exception)); } else { responseObserver.onNext(toProto.apply(reply)); responseObserver.onCompleted(); @@ -188,7 +182,7 @@ static void asyncCall( }); } catch (Exception e) { warning.accept(e); - responseObserver.onError(wrapException(e)); + responseObserver.onError(GrpcUtil.wrapException(e)); } } @@ -197,7 +191,7 @@ static void warn(Logger log, Supplier message, Throwable t) { } class StatusRuntimeExceptionMetadataBuilder { - private final Metadata trailers = new Metadata(); + private Metadata trailers = new Metadata(); StatusRuntimeExceptionMetadataBuilder(Throwable t) { trailers.put(EXCEPTION_TYPE_KEY, t.getClass().getCanonicalName()); @@ -306,37 +300,25 @@ static void setKeyManager(SslContextBuilder b, KeyManagerConf keyManagerConfig) } } - static SslContext buildSslContextForServer(GrpcTlsConfig tlsConf) { - if (tlsConf == null) { - return null; - } - SslContextBuilder b = initSslContextBuilderForServer(tlsConf.getKeyManager()); - if (tlsConf.getMtlsEnabled()) { - b.clientAuth(ClientAuth.REQUIRE); - setTrustManager(b, tlsConf.getTrustManager()); - } - b = GrpcSslContexts.configure(b, OPENSSL); - try { - return b.build(); - } catch (Exception e) { - throw new IllegalArgumentException("Failed to buildSslContextForServer from tlsConfig " + tlsConf, e); - } - } - - static SslContext buildSslContextForClient(GrpcTlsConfig tlsConf) { - if (tlsConf == null) { - return null; - } - - final SslContextBuilder b = GrpcSslContexts.forClient(); - setTrustManager(b, tlsConf.getTrustManager()); - if (tlsConf.getMtlsEnabled()) { - setKeyManager(b, tlsConf.getKeyManager()); - } - try { - return b.build(); - } catch (SSLException e) { - throw new IllegalArgumentException("Failed to buildSslContextForClient from tlsConfig " + tlsConf, e); - } + /** + * Used to add a method to Service definition with a custom request marshaller. + * + * @param orig original service definition. + * @param newServiceBuilder builder of the new service definition. + * @param origMethod the original method definition. + * @param customMarshaller custom marshaller to be set for the method. + * @param + * @param + */ + static void addMethodWithCustomMarshaller( + ServerServiceDefinition orig, ServerServiceDefinition.Builder newServiceBuilder, + MethodDescriptor origMethod, MethodDescriptor.PrototypeMarshaller customMarshaller) { + MethodDescriptor newMethod = origMethod.toBuilder() + .setRequestMarshaller(customMarshaller) + .build(); + @SuppressWarnings("unchecked") + ServerCallHandler serverCallHandler = + (ServerCallHandler) orig.getMethod(newMethod.getFullMethodName()).getServerCallHandler(); + newServiceBuilder.addMethod(newMethod, serverCallHandler); } } diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java index 0eaec6b962..3b9d512683 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolClient.java @@ -21,6 +21,7 @@ import org.apache.ratis.client.impl.ClientProtoUtils; import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.grpc.GrpcConfigKeys; +import org.apache.ratis.grpc.GrpcTlsConfig; import org.apache.ratis.grpc.GrpcUtil; import org.apache.ratis.grpc.metrics.intercept.client.MetricClientInterceptor; import org.apache.ratis.proto.RaftProtos.GroupInfoReplyProto; @@ -48,10 +49,11 @@ import org.apache.ratis.protocol.exceptions.TimeoutIOException; import org.apache.ratis.thirdparty.io.grpc.ManagedChannel; import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException; +import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts; import org.apache.ratis.thirdparty.io.grpc.netty.NegotiationType; import org.apache.ratis.thirdparty.io.grpc.netty.NettyChannelBuilder; import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; -import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext; +import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder; import org.apache.ratis.util.CollectionUtils; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.SizeInBytes; @@ -82,7 +84,6 @@ public class GrpcClientProtocolClient implements Closeable { private final ManagedChannel clientChannel; private final ManagedChannel adminChannel; - private final SizeInBytes maxMessageSize; private final TimeDuration requestTimeoutDuration; private final TimeDuration watchRequestTimeoutDuration; private final TimeoutExecutor scheduler = TimeoutExecutor.getInstance(); @@ -96,11 +97,11 @@ public class GrpcClientProtocolClient implements Closeable { private final MetricClientInterceptor metricClientInterceptor; GrpcClientProtocolClient(ClientId id, RaftPeer target, RaftProperties properties, - SslContext adminSslContext, SslContext clientSslContext) { + GrpcTlsConfig adminTlsConfig, GrpcTlsConfig clientTlsConfig) { this.name = JavaUtils.memoize(() -> id + "->" + target.getId()); this.target = target; final SizeInBytes flowControlWindow = GrpcConfigKeys.flowControlWindow(properties, LOG::debug); - this.maxMessageSize = GrpcConfigKeys.messageSizeMax(properties, LOG::debug); + final SizeInBytes maxMessageSize = GrpcConfigKeys.messageSizeMax(properties, LOG::debug); metricClientInterceptor = new MetricClientInterceptor(getName()); final String clientAddress = Optional.ofNullable(target.getClientAddress()) @@ -109,9 +110,11 @@ public class GrpcClientProtocolClient implements Closeable { .filter(x -> !x.isEmpty()).orElse(target.getAddress()); final boolean separateAdminChannel = !Objects.equals(clientAddress, adminAddress); - clientChannel = buildChannel(clientAddress, clientSslContext, flowControlWindow); + clientChannel = buildChannel(clientAddress, clientTlsConfig, + flowControlWindow, maxMessageSize); adminChannel = separateAdminChannel - ? buildChannel(adminAddress, adminSslContext, flowControlWindow) + ? buildChannel(adminAddress, adminTlsConfig, + flowControlWindow, maxMessageSize) : clientChannel; asyncStub = RaftClientProtocolServiceGrpc.newStub(clientChannel); @@ -121,16 +124,26 @@ public class GrpcClientProtocolClient implements Closeable { RaftClientConfigKeys.Rpc.watchRequestTimeout(properties); } - private ManagedChannel buildChannel(String address, SslContext sslContext, - SizeInBytes flowControlWindow) { + private ManagedChannel buildChannel(String address, GrpcTlsConfig tlsConf, + SizeInBytes flowControlWindow, SizeInBytes maxMessageSize) { NettyChannelBuilder channelBuilder = NettyChannelBuilder.forTarget(address); // ignore any http proxy for grpc channelBuilder.proxyDetector(uri -> null); - if (sslContext != null) { + if (tlsConf != null) { LOG.debug("Setting TLS for {}", address); - channelBuilder.useTransportSecurity().sslContext(sslContext); + SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); + GrpcUtil.setTrustManager(sslContextBuilder, tlsConf.getTrustManager()); + if (tlsConf.getMtlsEnabled()) { + GrpcUtil.setKeyManager(sslContextBuilder, tlsConf.getKeyManager()); + } + try { + channelBuilder.useTransportSecurity().sslContext( + sslContextBuilder.build()); + } catch (Exception ex) { + throw new RuntimeException(ex); + } } else { channelBuilder.negotiationType(NegotiationType.PLAINTEXT); } @@ -333,20 +346,13 @@ public void onCompleted() { } CompletableFuture onNext(RaftClientRequest request) { - final RaftClientRequestProto proto = ClientProtoUtils.toRaftClientRequestProto(request); - if (proto.getSerializedSize() > maxMessageSize.getSizeInt()) { - return JavaUtils.completeExceptionally(new IllegalArgumentException(getName() - + ": request serialized size " + proto.getSerializedSize() - + " exceeds maximum " + maxMessageSize + " for " + request)); - } - final long callId = request.getCallId(); final CompletableFuture f = replies.putNew(callId); if (f == null) { return JavaUtils.completeExceptionally(new AlreadyClosedException(getName() + " is closed.")); } try { - if (!requestStreamer.onNext(proto)) { + if (!requestStreamer.onNext(ClientProtoUtils.toRaftClientRequestProto(request))) { return JavaUtils.completeExceptionally(new AlreadyClosedException(getName() + ": the stream is closed.")); } } catch(Exception t) { diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolProxy.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolProxy.java new file mode 100644 index 0000000000..95119ef7d7 --- /dev/null +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientProtocolProxy.java @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ratis.grpc.client; + +import org.apache.ratis.conf.RaftProperties; +import org.apache.ratis.grpc.GrpcTlsConfig; +import org.apache.ratis.protocol.ClientId; +import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; +import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto; +import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto; +import org.apache.ratis.protocol.RaftPeer; + +import java.io.Closeable; +import java.io.IOException; +import java.util.function.Function; + +public class GrpcClientProtocolProxy implements Closeable { + private final GrpcClientProtocolClient proxy; + private final Function responseHandlerCreation; + private RpcSession currentSession; + + public GrpcClientProtocolProxy(ClientId clientId, RaftPeer target, + Function responseHandlerCreation, + RaftProperties properties, GrpcTlsConfig tlsConfig) { + proxy = new GrpcClientProtocolClient(clientId, target, properties, tlsConfig, tlsConfig); + this.responseHandlerCreation = responseHandlerCreation; + } + + @Override + public void close() throws IOException { + closeCurrentSession(); + proxy.close(); + } + + @Override + public String toString() { + return "ProxyTo:" + proxy.getTarget(); + } + + public void closeCurrentSession() { + if (currentSession != null) { + currentSession.close(); + currentSession = null; + } + } + + public void onNext(RaftClientRequestProto request) { + if (currentSession == null) { + currentSession = new RpcSession( + responseHandlerCreation.apply(proxy.getTarget())); + } + currentSession.requestObserver.onNext(request); + } + + public void onError() { + if (currentSession != null) { + currentSession.onError(); + } + } + + public interface CloseableStreamObserver + extends StreamObserver, Closeable { + } + + class RpcSession implements Closeable { + private final StreamObserver requestObserver; + private final CloseableStreamObserver responseHandler; + private boolean hasError = false; + + RpcSession(CloseableStreamObserver responseHandler) { + this.responseHandler = responseHandler; + this.requestObserver = proxy.ordered(responseHandler); + } + + void onError() { + hasError = true; + } + + @Override + public void close() { + if (!hasError) { + try { + requestObserver.onCompleted(); + } catch (Exception ignored) { + } + } + try { + responseHandler.close(); + } catch (IOException ignored) { + } + } + } +} diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientRpc.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientRpc.java index 65175dc2a1..b825429ae4 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientRpc.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/client/GrpcClientRpc.java @@ -17,16 +17,14 @@ */ package org.apache.ratis.grpc.client; -import org.apache.ratis.client.RaftClientConfigKeys; import org.apache.ratis.client.impl.ClientProtoUtils; import org.apache.ratis.client.impl.RaftClientRpcWithProxy; import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.grpc.GrpcConfigKeys; +import org.apache.ratis.grpc.GrpcTlsConfig; import org.apache.ratis.grpc.GrpcUtil; import org.apache.ratis.protocol.*; import org.apache.ratis.protocol.exceptions.AlreadyClosedException; -import org.apache.ratis.protocol.exceptions.TimeoutIOException; -import org.apache.ratis.thirdparty.io.grpc.Status; import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException; import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; import org.apache.ratis.proto.RaftProtos.GroupInfoRequestProto; @@ -38,11 +36,9 @@ import org.apache.ratis.proto.RaftProtos.TransferLeadershipRequestProto; import org.apache.ratis.proto.RaftProtos.SnapshotManagementRequestProto; import org.apache.ratis.proto.RaftProtos.LeaderElectionManagementRequestProto; -import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext; import org.apache.ratis.util.IOUtils; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.PeerProxyMap; -import org.apache.ratis.util.TimeDuration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,25 +46,19 @@ import java.io.InterruptedIOException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; public class GrpcClientRpc extends RaftClientRpcWithProxy { public static final Logger LOG = LoggerFactory.getLogger(GrpcClientRpc.class); private final ClientId clientId; private final int maxMessageSize; - private final TimeDuration requestTimeoutDuration; - private final TimeDuration watchRequestTimeoutDuration; public GrpcClientRpc(ClientId clientId, RaftProperties properties, - SslContext adminSslContext, SslContext clientSslContext) { + GrpcTlsConfig adminTlsConfig, GrpcTlsConfig clientTlsConfig) { super(new PeerProxyMap<>(clientId.toString(), - p -> new GrpcClientProtocolClient(clientId, p, properties, adminSslContext, clientSslContext))); + p -> new GrpcClientProtocolClient(clientId, p, properties, adminTlsConfig, clientTlsConfig))); this.clientId = clientId; this.maxMessageSize = GrpcConfigKeys.messageSizeMax(properties, LOG::debug).getSizeInt(); - this.requestTimeoutDuration = RaftClientConfigKeys.Rpc.requestTimeout(properties); - this.watchRequestTimeoutDuration = RaftClientConfigKeys.Rpc.watchRequestTimeout(properties); } @Override @@ -131,11 +121,24 @@ public RaftClientReply sendRequest(RaftClientRequest request) ((LeaderElectionManagementRequest) request); return ClientProtoUtils.toRaftClientReply(proxy.leaderElectionManagement(proto)); } else { - return sendRequest(request, proxy); + final CompletableFuture f = sendRequest(request, proxy); + // TODO: timeout support + try { + return f.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new InterruptedIOException( + "Interrupted while waiting for response of request " + request); + } catch (ExecutionException e) { + if (LOG.isTraceEnabled()) { + LOG.trace(clientId + ": failed " + request, e); + } + throw IOUtils.toIOException(e); + } } } - private RaftClientReply sendRequest( + private CompletableFuture sendRequest( RaftClientRequest request, GrpcClientProtocolClient proxy) throws IOException { final RaftClientRequestProto requestProto = toRaftClientRequestProto(request); @@ -164,44 +167,7 @@ public void onCompleted() { requestObserver.onNext(requestProto); requestObserver.onCompleted(); - final TimeDuration timeout = getTimeoutDuration(request); - try { - return replyFuture.thenApply(ClientProtoUtils::toRaftClientReply) - .get(timeout.getDuration(), timeout.getUnit()); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - replyFuture.cancel(true); - final InterruptedIOException ioe = new InterruptedIOException(clientId + ": Interrupted " + request); - sendOnError(requestObserver, Status.CANCELLED, ioe.getMessage()); - throw ioe; - } catch (TimeoutException e) { - replyFuture.cancel(true); - final TimeoutIOException ioe = - new TimeoutIOException(clientId + ": Timed out " + timeout + " for " + request, e); - sendOnError(requestObserver, Status.DEADLINE_EXCEEDED, ioe.getMessage()); - throw ioe; - } catch (ExecutionException e) { - if (LOG.isTraceEnabled()) { - LOG.trace("{} : failed {}", clientId, request, e); - } - throw IOUtils.toIOException(e); - } - } - - private void sendOnError(StreamObserver requestObserver, Status status, String message) { - try { - requestObserver.onError(status.withDescription(message).asException()); - } catch (Exception ignored) { - // the stream already closed. - } - } - - private TimeDuration getTimeoutDuration(RaftClientRequest request) { - final long timeoutMs = request.getTimeoutMs(); - if (timeoutMs > 0) { - return TimeDuration.valueOf(timeoutMs, TimeUnit.MILLISECONDS); - } - return request.is(RaftClientRequestProto.TypeCase.WATCH) ? watchRequestTimeoutDuration : requestTimeoutDuration; + return replyFuture.thenApply(ClientProtoUtils::toRaftClientReply); } private RaftClientRequestProto toRaftClientRequestProto(RaftClientRequest request) throws IOException { diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/MessageMetrics.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/MessageMetrics.java index b152c67098..2a211aae80 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/MessageMetrics.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/MessageMetrics.java @@ -61,6 +61,14 @@ private void inc(String metricNamePrefix, Type t) { types.get(t) .computeIfAbsent(metricNamePrefix, prefix -> getRegistry().counter(prefix + t.getSuffix())) .inc(); + final Map counters = types.get(t); + LongCounter c = counters.get(metricNamePrefix); + if (c == null) { + synchronized (counters) { + c = counters.computeIfAbsent(metricNamePrefix, prefix -> getRegistry().counter(prefix + t.getSuffix())); + } + } + c.inc(); } /** diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java new file mode 100644 index 0000000000..1fcc317f9d --- /dev/null +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java @@ -0,0 +1,79 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ratis.grpc.metrics; + +import org.apache.ratis.metrics.LongCounter; +import org.apache.ratis.metrics.MetricRegistryInfo; +import org.apache.ratis.metrics.RatisMetricRegistry; +import org.apache.ratis.metrics.RatisMetrics; +import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting; +import org.apache.ratis.thirdparty.com.google.protobuf.AbstractMessage; + +import java.util.function.Supplier; + +public class ZeroCopyMetrics extends RatisMetrics { + private static final String RATIS_GRPC_METRICS_APP_NAME = "ratis_grpc"; + private static final String RATIS_GRPC_METRICS_COMP_NAME = "zero_copy"; + private static final String RATIS_GRPC_METRICS_DESC = "Metrics for Ratis Grpc Zero copy"; + + private final LongCounter zeroCopyMessages = getRegistry().counter("num_zero_copy_messages"); + private final LongCounter nonZeroCopyMessages = getRegistry().counter("num_non_zero_copy_messages"); + private final LongCounter releasedMessages = getRegistry().counter("num_released_messages"); + + public ZeroCopyMetrics() { + super(createRegistry()); + } + + private static RatisMetricRegistry createRegistry() { + return create(new MetricRegistryInfo("", + RATIS_GRPC_METRICS_APP_NAME, + RATIS_GRPC_METRICS_COMP_NAME, RATIS_GRPC_METRICS_DESC)); + } + + public void addUnreleased(String name, Supplier unreleased) { + getRegistry().gauge(name + "_num_unreleased_messages", () -> unreleased); + } + + + public void onZeroCopyMessage(AbstractMessage ignored) { + zeroCopyMessages.inc(); + } + + public void onNonZeroCopyMessage(AbstractMessage ignored) { + nonZeroCopyMessages.inc(); + } + + public void onReleasedMessage(AbstractMessage ignored) { + releasedMessages.inc(); + } + + @VisibleForTesting + public long zeroCopyMessages() { + return zeroCopyMessages.getCount(); + } + + @VisibleForTesting + public long nonZeroCopyMessages() { + return nonZeroCopyMessages.getCount(); + } + + @VisibleForTesting + public long releasedMessages() { + return releasedMessages.getCount(); + } +} \ No newline at end of file diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java index 1da3587e91..78ad8759fb 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java @@ -19,10 +19,13 @@ import org.apache.ratis.client.impl.ClientProtoUtils; import org.apache.ratis.grpc.GrpcUtil; +import org.apache.ratis.grpc.metrics.ZeroCopyMetrics; +import org.apache.ratis.grpc.util.ZeroCopyMessageMarshaller; import org.apache.ratis.protocol.*; import org.apache.ratis.protocol.exceptions.AlreadyClosedException; import org.apache.ratis.protocol.exceptions.GroupMismatchException; import org.apache.ratis.protocol.exceptions.RaftException; +import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition; import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto; import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto; @@ -30,6 +33,7 @@ import org.apache.ratis.util.CollectionUtils; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.SlidingWindow; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,7 +43,6 @@ import java.util.Iterator; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; @@ -48,15 +51,21 @@ import java.util.function.Consumer; import java.util.function.Supplier; +import static org.apache.ratis.grpc.GrpcUtil.addMethodWithCustomMarshaller; +import static org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.getOrderedMethod; +import static org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.getUnorderedMethod; + class GrpcClientProtocolService extends RaftClientProtocolServiceImplBase { private static final Logger LOG = LoggerFactory.getLogger(GrpcClientProtocolService.class); private static class PendingOrderedRequest implements SlidingWindow.ServerSideRequest { + private final ReferenceCountedObject requestRef; private final RaftClientRequest request; private final AtomicReference reply = new AtomicReference<>(); - PendingOrderedRequest(RaftClientRequest request) { - this.request = request; + PendingOrderedRequest(ReferenceCountedObject requestRef) { + this.requestRef = requestRef; + this.request = requestRef != null ? requestRef.retain() : null; } @Override @@ -76,15 +85,23 @@ public boolean hasReply() { @Override public void setReply(RaftClientReply r) { final boolean set = reply.compareAndSet(null, r); - Preconditions.assertTrue(set, () -> "Reply is already set: request=" + request + ", reply=" + reply); + Preconditions.assertTrue(set, () -> "Reply is already set: request=" + + request.toStringShort() + ", reply=" + reply); } RaftClientReply getReply() { return reply.get(); } - RaftClientRequest getRequest() { - return request; + ReferenceCountedObject getRequestRef() { + return requestRef; + } + + @Override + public void release() { + if (requestRef != null) { + requestRef.release(); + } } @Override @@ -135,18 +152,38 @@ void closeAllExisting(RaftGroupId groupId) { private final ExecutorService executor; private final OrderedStreamObservers orderedStreamObservers = new OrderedStreamObservers(); + private final boolean zeroCopyEnabled; + private final ZeroCopyMessageMarshaller zeroCopyRequestMarshaller; GrpcClientProtocolService(Supplier idSupplier, RaftClientAsynchronousProtocol protocol, - ExecutorService executor) { + ExecutorService executor, boolean zeroCopyEnabled, ZeroCopyMetrics zeroCopyMetrics) { this.idSupplier = idSupplier; this.protocol = protocol; this.executor = executor; + this.zeroCopyEnabled = zeroCopyEnabled; + this.zeroCopyRequestMarshaller = new ZeroCopyMessageMarshaller<>(RaftClientRequestProto.getDefaultInstance(), + zeroCopyMetrics::onZeroCopyMessage, zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage); + zeroCopyMetrics.addUnreleased("client_protocol", zeroCopyRequestMarshaller::getUnclosedCount); } RaftPeerId getId() { return idSupplier.get(); } + ServerServiceDefinition bindServiceWithZeroCopy() { + ServerServiceDefinition orig = super.bindService(); + if (!zeroCopyEnabled) { + LOG.info("{}: Zero copy is disabled.", getId()); + return orig; + } + ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(orig.getServiceDescriptor().getName()); + + addMethodWithCustomMarshaller(orig, builder, getOrderedMethod(), zeroCopyRequestMarshaller); + addMethodWithCustomMarshaller(orig, builder, getUnorderedMethod(), zeroCopyRequestMarshaller); + + return builder.build(); + } + @Override public StreamObserver ordered(StreamObserver responseObserver) { final OrderedRequestStreamObserver so = new OrderedRequestStreamObserver(responseObserver); @@ -220,31 +257,43 @@ boolean isClosed() { return isClosed.get(); } - CompletableFuture processClientRequest(RaftClientRequest request, Consumer replyHandler) { + CompletableFuture processClientRequest(ReferenceCountedObject requestRef, + Consumer replyHandler) { + final String errMsg = LOG.isDebugEnabled() ? "processClientRequest for " + requestRef.get() : ""; + CompletableFuture replyFuture; try { - final String errMsg = LOG.isDebugEnabled() ? "processClientRequest for " + request : ""; - return protocol.submitClientRequestAsync(request - ).thenAcceptAsync(replyHandler, executor - ).exceptionally(exception -> { - // TODO: the exception may be from either raft or state machine. - // Currently we skip all the following responses when getting an - // exception from the state machine. - responseError(exception, () -> errMsg); - return null; - }); + replyFuture = protocol.submitClientRequestAsync(requestRef); } catch (IOException e) { - throw new CompletionException("Failed processClientRequest for " + request + " in " + name, e); + replyFuture = new CompletableFuture<>(); + replyFuture.completeExceptionally(e); } + return replyFuture.thenAcceptAsync(replyHandler, executor).exceptionally(exception -> { + // TODO: the exception may be from either raft or state machine. + // Currently we skip all the following responses when getting an + // exception from the state machine. + responseError(exception, () -> errMsg); + return null; + }); } - abstract void processClientRequest(RaftClientRequest request); + abstract void processClientRequest(ReferenceCountedObject requestRef); @Override public void onNext(RaftClientRequestProto request) { + ReferenceCountedObject requestRef = null; try { final RaftClientRequest r = ClientProtoUtils.toRaftClientRequest(request); - processClientRequest(r); + requestRef = ReferenceCountedObject.wrap(r, () -> {}, released -> { + if (released) { + zeroCopyRequestMarshaller.release(request); + } + }); + + processClientRequest(requestRef); } catch (Exception e) { + if (requestRef == null) { + zeroCopyRequestMarshaller.release(request); + } responseError(e, () -> "onNext for " + ClientProtoUtils.toString(request) + " in " + name); } } @@ -278,15 +327,21 @@ private class UnorderedRequestStreamObserver extends RequestStreamObserver { } @Override - void processClientRequest(RaftClientRequest request) { - final CompletableFuture f = processClientRequest(request, reply -> { - if (!reply.isSuccess()) { - LOG.info("Failed " + request + ", reply=" + reply); - } - final RaftClientReplyProto proto = ClientProtoUtils.toRaftClientReplyProto(reply); - responseNext(proto); - }); - final long callId = request.getCallId(); + void processClientRequest(ReferenceCountedObject requestRef) { + final long callId = requestRef.retain().getCallId(); + final CompletableFuture f; + try { + f = processClientRequest(requestRef, reply -> { + if (!reply.isSuccess()) { + LOG.info("Failed request cid={}, reply={}", callId, reply); + } + final RaftClientReplyProto proto = ClientProtoUtils.toRaftClientReplyProto(reply); + responseNext(proto); + }); + } finally { + requestRef.release(); + } + put(callId, f); f.thenAccept(dummy -> remove(callId)); } @@ -329,32 +384,40 @@ RaftGroupId getGroupId() { void processClientRequest(PendingOrderedRequest pending) { final long seq = pending.getSeqNum(); - processClientRequest(pending.getRequest(), + processClientRequest(pending.getRequestRef(), reply -> slidingWindow.receiveReply(seq, reply, this::sendReply)); } @Override - void processClientRequest(RaftClientRequest r) { - if (isClosed()) { - final AlreadyClosedException exception = new AlreadyClosedException(getName() + ": the stream is closed"); - responseError(exception, () -> "processClientRequest (stream already closed) for " + r); - return; - } + void processClientRequest(ReferenceCountedObject requestRef) { + final RaftClientRequest request = requestRef.retain(); + try { + if (isClosed()) { + final AlreadyClosedException exception = new AlreadyClosedException(getName() + ": the stream is closed"); + responseError(exception, () -> "processClientRequest (stream already closed) for " + request); + } - final RaftGroupId requestGroupId = r.getRaftGroupId(); - // use the group id in the first request as the group id of this observer - final RaftGroupId updated = groupId.updateAndGet(g -> g != null ? g: requestGroupId); - final PendingOrderedRequest pending = new PendingOrderedRequest(r); - - if (!requestGroupId.equals(updated)) { - final GroupMismatchException exception = new GroupMismatchException(getId() - + ": The group (" + requestGroupId + ") of " + r.getClientId() - + " does not match the group (" + updated + ") of the " + JavaUtils.getClassSimpleName(getClass())); - responseError(exception, () -> "processClientRequest (Group mismatched) for " + r); - return; - } + final RaftGroupId requestGroupId = request.getRaftGroupId(); + // use the group id in the first request as the group id of this observer + final RaftGroupId updated = groupId.updateAndGet(g -> g != null ? g : requestGroupId); - slidingWindow.receivedRequest(pending, this::processClientRequest); + if (!requestGroupId.equals(updated)) { + final GroupMismatchException exception = new GroupMismatchException(getId() + + ": The group (" + requestGroupId + ") of " + request.getClientId() + + " does not match the group (" + updated + ") of the " + JavaUtils.getClassSimpleName(getClass())); + responseError(exception, () -> "processClientRequest (Group mismatched) for " + request); + return; + } + final PendingOrderedRequest pending = new PendingOrderedRequest(requestRef); + try { + slidingWindow.receivedRequest(pending, this::processClientRequest); + } catch (Exception e) { + pending.release(); + throw e; + } + } finally { + requestRef.release(); + } } private void sendReply(PendingOrderedRequest ready) { diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java index 053cc5c0f4..1e9519d24c 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java @@ -24,6 +24,7 @@ import org.apache.ratis.metrics.Timekeeper; import org.apache.ratis.proto.RaftProtos.InstallSnapshotResult; import org.apache.ratis.protocol.RaftPeerId; +import org.apache.ratis.retry.MultipleLinearRandomRetry; import org.apache.ratis.retry.RetryPolicy; import org.apache.ratis.server.RaftServer; import org.apache.ratis.server.RaftServerConfigKeys; @@ -63,8 +64,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; -import static org.apache.ratis.server.raftlog.LogProtoUtils.toLogEntryTermIndexString; - /** * A new log appender implementation using grpc bi-directional stream API. */ @@ -74,11 +73,7 @@ public class GrpcLogAppender extends LogAppenderBase { private enum BatchLogKey implements BatchLogger.Key { RESET_CLIENT, INCONSISTENCY_REPLY, - APPEND_LOG_RESPONSE_HANDLER_ON_ERROR, - INSTALL_SNAPSHOT_NOTIFY, - INSTALL_SNAPSHOT_REPLY, - INSTALL_SNAPSHOT_IN_PROGRESS, - SNAPSHOT_UNAVAILABLE + APPEND_LOG_RESPONSE_HANDLER_ON_ERROR } public static final int INSTALL_SNAPSHOT_NOTIFICATION_INDEX = 0; @@ -199,9 +194,8 @@ public GrpcLogAppender(RaftServer.Division server, LeaderState leaderState, Foll lock = new AutoCloseableReadWriteLock(this); caller = LOG.isTraceEnabled()? JavaUtils.getCallerStackTraceElement(): null; - errorRetryWaitPolicy = RetryPolicy.parse( - RaftServerConfigKeys.Log.Appender.retryPolicy(properties), - RaftServerConfigKeys.Log.Appender.RETRY_POLICY_KEY); + errorRetryWaitPolicy = MultipleLinearRandomRetry.parseCommaSeparated( + RaftServerConfigKeys.Log.Appender.retryPolicy(properties)); } @Override @@ -240,7 +234,7 @@ private void resetClient(AppendEntriesRequest request, Event event) { } getFollower().computeNextIndex(getNextIndexForError(nextIndex)); } catch (IOException ie) { - LOG.warn("{}: Failed to resetClient for {}", this, getFollowerId(), ie); + LOG.warn(this + ": Failed to getClient for " + getFollowerId(), ie); } } @@ -303,8 +297,8 @@ private void mayWait() { getEventAwaitForSignal().await(getWaitTimeMs() + errorWaitTimeMs(), TimeUnit.MILLISECONDS); } catch (InterruptedException ie) { + LOG.warn(this + ": Wait interrupted by " + ie); Thread.currentThread().interrupt(); - LOG.warn("{} is interrupted: {}", this, ie.toString()); } } @@ -315,8 +309,14 @@ private long errorWaitTimeMs() { @Override public CompletableFuture stopAsync() { - grpcServerMetrics.unregister(); - return super.stopAsync(); + try (AutoCloseableLock ignored = lock.writeLock(caller, LOG::trace)) { + if (appendLogRequestObserver != null) { + appendLogRequestObserver.stop(); + appendLogRequestObserver = null; + } + grpcServerMetrics.unregister(); + return super.stopAsync(); + } } @Override @@ -392,30 +392,42 @@ public Comparator getCallIdComparator() { } private void appendLog(boolean heartbeat) throws IOException { - final AppendEntriesRequestProto pending; + final ReferenceCountedObject pending; final AppendEntriesRequest request; try (AutoCloseableLock writeLock = lock.writeLock(caller, LOG::trace)) { + if (!isRunning()) { + return; + } // Prepare and send the append request. // Note changes on follower's nextIndex and ops on pendingRequests should always be done under the write-lock - pending = newAppendEntriesRequest(callId.getAndIncrement(), heartbeat); + pending = nextAppendEntriesRequest(callId.getAndIncrement(), heartbeat); if (pending == null) { return; } - request = new AppendEntriesRequest(pending, getFollowerId(), grpcServerMetrics); - pendingRequests.put(request); - increaseNextIndex(pending); - if (appendLogRequestObserver == null) { - appendLogRequestObserver = new StreamObservers( - getClient(), new AppendLogResponseHandler(), useSeparateHBChannel, getWaitTimeMin()); + try { + request = new AppendEntriesRequest(pending.get(), getFollowerId(), grpcServerMetrics); + pendingRequests.put(request); + increaseNextIndex(pending.get()); + if (appendLogRequestObserver == null) { + appendLogRequestObserver = new StreamObservers( + getClient(), new AppendLogResponseHandler(), useSeparateHBChannel, getWaitTimeMin()); + } + } catch (Exception e) { + pending.release(); + throw e; } } - final TimeDuration remaining = getRemainingWaitTime(); - if (remaining.isPositive()) { - sleep(remaining, heartbeat); - } - if (isRunning()) { - sendRequest(request, pending); + try { + final TimeDuration remaining = getRemainingWaitTime(); + if (remaining.isPositive()) { + sleep(remaining, heartbeat); + } + if (isRunning()) { + sendRequest(request, pending.get()); + } + } finally { + pending.release(); } } @@ -503,8 +515,8 @@ public void onNext(AppendEntriesReplyProto reply) { try { onNextImpl(request, reply); } catch(Exception t) { - LOG.error("Failed onNext(reply), request={}, reply={}", - request, ServerStringUtils.toAppendEntriesReplyString(reply), t); + LOG.error("Failed onNext request=" + request + + ", reply=" + ServerStringUtils.toAppendEntriesReplyString(reply), t); } } @@ -579,8 +591,8 @@ private void updateNextIndex(long replyNextIndex) { } private class InstallSnapshotResponseHandler implements StreamObserver { - private final String name; - private final Queue pending = new LinkedList<>(); + private final String name = getFollower().getName() + "-" + JavaUtils.getClassSimpleName(getClass()); + private final Queue pending; private final CompletableFuture done = new CompletableFuture<>(); private final boolean isNotificationOnly; @@ -589,8 +601,8 @@ private class InstallSnapshotResponseHandler implements StreamObserver(); this.isNotificationOnly = notifyOnly; - this.name = getFollower().getName() + "-InstallSnapshot" + (isNotificationOnly ? "Notification" : ""); } void addPending(InstallSnapshotRequestProto request) { @@ -632,8 +644,8 @@ void onFollowerCatchup(long followerSnapshotIndex) { final long leaderStartIndex = getRaftLog().getStartIndex(); final long followerNextIndex = followerSnapshotIndex + 1; if (followerNextIndex >= leaderStartIndex) { - LOG.info("{}: follower nextIndex = {} >= leader startIndex = {}", - this, followerNextIndex, leaderStartIndex); + LOG.info("{}: Follower can catch up leader after install the snapshot, as leader's start index is {}", + this, followerNextIndex); notifyInstallSnapshotFinished(InstallSnapshotResult.SUCCESS, followerSnapshotIndex); } } @@ -665,10 +677,10 @@ boolean hasAllResponse() { @Override public void onNext(InstallSnapshotReplyProto reply) { - BatchLogger.print(BatchLogKey.INSTALL_SNAPSHOT_REPLY, name, - suffix -> LOG.info("{}: received {} reply {} {}", this, - replyState.isFirstReplyReceived() ? "a" : "the first", - ServerStringUtils.toInstallSnapshotReplyString(reply), suffix)); + if (LOG.isInfoEnabled()) { + LOG.info("{}: received {} reply {}", this, replyState.isFirstReplyReceived()? "a" : "the first", + ServerStringUtils.toInstallSnapshotReplyString(reply)); + } // update the last rpc time getFollower().updateLastRpcResponseTime(); @@ -677,13 +689,12 @@ public void onNext(InstallSnapshotReplyProto reply) { final long followerSnapshotIndex; switch (reply.getResult()) { case SUCCESS: - LOG.info("{}: Completed", this); + LOG.info("{}: Completed InstallSnapshot. Reply: {}", this, reply); getFollower().setAttemptedToInstallSnapshot(); removePending(reply); break; case IN_PROGRESS: - BatchLogger.print(BatchLogKey.INSTALL_SNAPSHOT_IN_PROGRESS, name, - suffix -> LOG.info("{}: in progress, {}", this, suffix)); + LOG.info("{}: InstallSnapshot in progress.", this); removePending(reply); break; case ALREADY_INSTALLED: @@ -699,7 +710,7 @@ public void onNext(InstallSnapshotReplyProto reply) { onFollowerTerm(reply.getTerm()); break; case CONF_MISMATCH: - LOG.error("{}: CONF_MISMATCH ({}): Leader {} has it set to {} but follower {} has it set to {}", + LOG.error("{}: Configuration Mismatch ({}): Leader {} has it set to {} but follower {} has it set to {}", this, RaftServerConfigKeys.Log.Appender.INSTALL_SNAPSHOT_ENABLED_KEY, getServer().getId(), installSnapshotEnabled, getFollowerId(), !installSnapshotEnabled); break; @@ -714,19 +725,17 @@ public void onNext(InstallSnapshotReplyProto reply) { removePending(reply); break; case SNAPSHOT_UNAVAILABLE: - BatchLogger.print(BatchLogKey.SNAPSHOT_UNAVAILABLE, name, - suffix -> LOG.info("{}: Follower failed since the snapshot is unavailable {}", this, suffix)); + LOG.info("{}: Follower could not install snapshot as it is not available.", this); getFollower().setAttemptedToInstallSnapshot(); notifyInstallSnapshotFinished(InstallSnapshotResult.SNAPSHOT_UNAVAILABLE, RaftLog.INVALID_LOG_INDEX); removePending(reply); break; case UNRECOGNIZED: - LOG.error("{}: Reply result {}, {}", - name, reply.getResult(), ServerStringUtils.toInstallSnapshotReplyString(reply)); + LOG.error("Unrecognized the reply result {}: Leader is {}, follower is {}", + reply.getResult(), getServer().getId(), getFollowerId()); break; case SNAPSHOT_EXPIRED: - LOG.warn("{}: Follower failed since the request expired, {}", - name, ServerStringUtils.toInstallSnapshotReplyString(reply)); + LOG.warn("{}: Follower could not install snapshot as it is expired.", this); default: break; } @@ -805,9 +814,8 @@ private void installSnapshot(SnapshotInfo snapshot) { * @param firstAvailable the first available log's index on the Leader */ private void notifyInstallSnapshot(TermIndex firstAvailable) { - BatchLogger.print(BatchLogKey.INSTALL_SNAPSHOT_NOTIFY, getFollower().getName(), - suffix -> LOG.info("{}: notifyInstallSnapshot with firstAvailable={}, followerNextIndex={} {}", - this, firstAvailable, getFollower().getNextIndex(), suffix)); + LOG.info("{}: notifyInstallSnapshot with firstAvailable={}, followerNextIndex={}", + this, firstAvailable, getFollower().getNextIndex()); final InstallSnapshotResponseHandler responseHandler = new InstallSnapshotResponseHandler(true); StreamObserver snapshotRequestObserver = null; @@ -834,6 +842,45 @@ private void notifyInstallSnapshot(TermIndex firstAvailable) { responseHandler.waitForResponse(); } + /** + * Should the Leader notify the Follower to install the snapshot through + * its own State Machine. + * @return the first available log's start term index + */ + @Override + public TermIndex shouldNotifyToInstallSnapshot() { + final FollowerInfo follower = getFollower(); + final long leaderNextIndex = getRaftLog().getNextIndex(); + final boolean isFollowerBootstrapping = getLeaderState().isFollowerBootstrapping(follower); + final long leaderStartIndex = getRaftLog().getStartIndex(); + final TermIndex firstAvailable = Optional.ofNullable(getRaftLog().getTermIndex(leaderStartIndex)) + .orElseGet(() -> TermIndex.valueOf(getServer().getInfo().getCurrentTerm(), leaderNextIndex)); + if (isFollowerBootstrapping && !follower.hasAttemptedToInstallSnapshot()) { + // If the follower is bootstrapping and has not yet installed any snapshot from leader, then the follower should + // be notified to install a snapshot. Every follower should try to install at least one snapshot during + // bootstrapping, if available. + LOG.debug("{}: follower is bootstrapping, notify to install snapshot to {}.", this, firstAvailable); + return firstAvailable; + } + + final long followerNextIndex = follower.getNextIndex(); + if (followerNextIndex >= leaderNextIndex) { + return null; + } + + if (followerNextIndex < leaderStartIndex) { + // The Leader does not have the logs from the Follower's last log + // index onwards. And install snapshot is disabled. So the Follower + // should be notified to install the latest snapshot through its + // State Machine. + return firstAvailable; + } else if (leaderStartIndex == RaftLog.INVALID_LOG_INDEX) { + // Leader has no logs to check from, hence return next index. + return firstAvailable; + } + + return null; + } static class AppendEntriesRequest { private final Timekeeper timer; @@ -891,9 +938,13 @@ boolean isHeartbeat() { @Override public String toString() { + final String entries = entriesCount == 0? "" + : entriesCount == 1? ",entry=" + firstEntry + : ",entries=" + firstEntry + "..." + lastEntry; return JavaUtils.getClassSimpleName(getClass()) + ":cid=" + callId - + ":" + toLogEntryTermIndexString(entriesCount, firstEntry, lastEntry); + + ",entriesCount=" + entriesCount + + entries; } } diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java index d2748c7be2..4a280ab335 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java @@ -1,4 +1,4 @@ -/* +/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -17,11 +17,13 @@ */ package org.apache.ratis.grpc.server; +import org.apache.ratis.grpc.GrpcTlsConfig; import org.apache.ratis.grpc.GrpcUtil; import org.apache.ratis.grpc.util.StreamObserverWithTimeout; import org.apache.ratis.protocol.RaftPeerId; import org.apache.ratis.server.util.ServerStringUtils; import org.apache.ratis.thirdparty.io.grpc.ManagedChannel; +import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts; import org.apache.ratis.thirdparty.io.grpc.netty.NegotiationType; import org.apache.ratis.thirdparty.io.grpc.netty.NettyChannelBuilder; import org.apache.ratis.thirdparty.io.grpc.stub.CallStreamObserver; @@ -31,7 +33,7 @@ import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceBlockingStub; import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceStub; import org.apache.ratis.protocol.RaftPeer; -import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext; +import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder; import org.apache.ratis.util.TimeDuration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,10 +44,9 @@ * This is a RaftClient implementation that supports streaming data to the raft * ring. The stream implementation utilizes gRPC. */ -class GrpcServerProtocolClient implements Closeable { +public class GrpcServerProtocolClient implements Closeable { // Common channel private final ManagedChannel channel; - private final GrpcStubPool pool; // Channel and stub for heartbeat private ManagedChannel hbChannel; private RaftServerProtocolServiceStub hbAsyncStub; @@ -58,34 +59,40 @@ class GrpcServerProtocolClient implements Closeable { //visible for using in log / error messages AND to use in instrumented tests private final RaftPeerId raftPeerId; - GrpcServerProtocolClient(RaftPeer target, int connections, int flowControlWindow, - TimeDuration requestTimeout, SslContext sslContext, boolean separateHBChannel) { + public GrpcServerProtocolClient(RaftPeer target, int flowControlWindow, + TimeDuration requestTimeout, GrpcTlsConfig tlsConfig, boolean separateHBChannel) { raftPeerId = target.getId(); LOG.info("Build channel for {}", target); useSeparateHBChannel = separateHBChannel; - channel = buildChannel(target, flowControlWindow, sslContext); + channel = buildChannel(target, flowControlWindow, tlsConfig); blockingStub = RaftServerProtocolServiceGrpc.newBlockingStub(channel); asyncStub = RaftServerProtocolServiceGrpc.newStub(channel); if (useSeparateHBChannel) { - hbChannel = buildChannel(target, flowControlWindow, sslContext); + hbChannel = buildChannel(target, flowControlWindow, tlsConfig); hbAsyncStub = RaftServerProtocolServiceGrpc.newStub(hbChannel); } requestTimeoutDuration = requestTimeout; - this.pool = connections == 1? null : newGrpcStubPool(target.getAddress(), sslContext, connections); - } - - GrpcStubPool newGrpcStubPool(String address, SslContext sslContext, int connections) { - return new GrpcStubPool<>(connections, address, sslContext, RaftServerProtocolServiceGrpc::newStub, 16); } - private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow, SslContext sslContext) { + private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow, + GrpcTlsConfig tlsConfig) { NettyChannelBuilder channelBuilder = NettyChannelBuilder.forTarget(target.getAddress()); // ignore any http proxy for grpc channelBuilder.proxyDetector(uri -> null); - if (sslContext != null) { - channelBuilder.useTransportSecurity().sslContext(sslContext); + if (tlsConfig!= null) { + SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); + GrpcUtil.setTrustManager(sslContextBuilder, tlsConfig.getTrustManager()); + if (tlsConfig.getMtlsEnabled()) { + GrpcUtil.setKeyManager(sslContextBuilder, tlsConfig.getKeyManager()); + } + try { + channelBuilder.useTransportSecurity().sslContext(sslContextBuilder.build()); + } catch (Exception ex) { + throw new IllegalArgumentException("Failed to build SslContext, peerId=" + raftPeerId + + ", tlsConfig=" + tlsConfig, ex); + } } else { channelBuilder.negotiationType(NegotiationType.PLAINTEXT); } @@ -100,9 +107,6 @@ public void close() { GrpcUtil.shutdownManagedChannel(hbChannel); } GrpcUtil.shutdownManagedChannel(channel); - if (pool != null) { - pool.close(); - } } public RequestVoteReplyProto requestVote(RequestVoteRequestProto request) { @@ -121,44 +125,8 @@ public StartLeaderElectionReplyProto startLeaderElection(StartLeaderElectionRequ } void readIndex(ReadIndexRequestProto request, StreamObserver s) { - if (pool == null) { - asyncStub.withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit()) - .readIndex(request, s); - } else { - GrpcStubPool.Stub p; - try { - p = pool.acquire(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - s.onError(e); - return; - } - p.getStub().withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit()) - .readIndex(request, new StreamObserver() { - @Override - public void onNext(ReadIndexReplyProto v) { - s.onNext(v); - } - - @Override - public void onError(Throwable t) { - try { - s.onError(t); - } finally { - p.release(); - } - } - - @Override - public void onCompleted() { - try { - s.onCompleted(); - } finally { - p.release(); - } - } - }); - } + asyncStub.withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit()) + .readIndex(request, s); } CallStreamObserver appendEntries( diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java index a13e74b89d..7e17cb3cf4 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java @@ -20,19 +20,21 @@ import java.util.function.Consumer; import java.util.function.Function; import org.apache.ratis.grpc.GrpcUtil; +import org.apache.ratis.grpc.metrics.ZeroCopyMetrics; +import org.apache.ratis.grpc.util.ZeroCopyMessageMarshaller; import org.apache.ratis.protocol.RaftPeerId; import org.apache.ratis.server.RaftServer; import org.apache.ratis.server.protocol.RaftServerProtocol; import org.apache.ratis.server.util.ServerStringUtils; -import org.apache.ratis.thirdparty.com.google.protobuf.MessageOrBuilder; +import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition; import org.apache.ratis.thirdparty.io.grpc.Status; -import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException; import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; import org.apache.ratis.proto.RaftProtos.*; import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceImplBase; import org.apache.ratis.util.BatchLogger; import org.apache.ratis.util.MemoizedSupplier; import org.apache.ratis.util.ProtoUtils; +import org.apache.ratis.util.ReferenceCountedObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,6 +45,9 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import static org.apache.ratis.grpc.GrpcUtil.addMethodWithCustomMarshaller; +import static org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.getAppendEntriesMethod; + class GrpcServerProtocolService extends RaftServerProtocolServiceImplBase { public static final Logger LOG = LoggerFactory.getLogger(GrpcServerProtocolService.class); @@ -52,24 +57,31 @@ private enum BatchLogKey implements BatchLogger.Key { } static class PendingServerRequest { - private final REQUEST request; + private final AtomicReference> requestRef; private final CompletableFuture future = new CompletableFuture<>(); - PendingServerRequest(REQUEST request) { - this.request = request; + PendingServerRequest(ReferenceCountedObject requestRef) { + requestRef.retain(); + this.requestRef = new AtomicReference<>(requestRef); } REQUEST getRequest() { - return request; + return Optional.ofNullable(requestRef.get()) + .map(ReferenceCountedObject::get) + .orElse(null); } CompletableFuture getFuture() { return future; } + + void release() { + Optional.ofNullable(requestRef.getAndSet(null)) + .ifPresent(ReferenceCountedObject::release); + } } - abstract class ServerRequestStreamObserver - implements StreamObserver { + abstract class ServerRequestStreamObserver implements StreamObserver { private final RaftServer.Op op; private final Supplier nameSupplier; private final StreamObserver responseObserver; @@ -97,24 +109,39 @@ private String getPreviousRequestString() { .orElse(null); } - abstract CompletableFuture process(REQUEST request) throws IOException; + CompletableFuture process(REQUEST request) throws IOException { + throw new UnsupportedOperationException("This method is not supported."); + } + + CompletableFuture process(ReferenceCountedObject requestRef) + throws IOException { + try { + return process(requestRef.retain()); + } finally { + requestRef.release(); + } + } + + void release(REQUEST req) { + } abstract long getCallId(REQUEST request); + boolean isHeartbeat(REQUEST request) { + return false; + } + abstract String requestToString(REQUEST request); abstract String replyToString(REPLY reply); abstract boolean replyInOrder(REQUEST request); - StatusRuntimeException wrapException(Throwable e, REQUEST request) { - return GrpcUtil.wrapException(e, getCallId(request)); - } - - private void handleError(Throwable e, REQUEST request) { - GrpcUtil.warn(LOG, () -> getId() + ": Failed " + op + " request " + requestToString(request), e); + private synchronized void handleError(Throwable e, long callId, boolean isHeartbeat) { + GrpcUtil.warn(LOG, () -> getId() + ": Failed " + op + " request cid=" + callId + ", isHeartbeat? " + + isHeartbeat, e); if (isClosed.compareAndSet(false, true)) { - responseObserver.onError(wrapException(e, request)); + responseObserver.onError(GrpcUtil.wrapException(e, callId, isHeartbeat)); } } @@ -134,24 +161,32 @@ void composeRequest(CompletableFuture current) { @Override public void onNext(REQUEST request) { + ReferenceCountedObject requestRef = ReferenceCountedObject.wrap(request, () -> {}, released -> { + if (released) { + release(request); + } + }); + if (!replyInOrder(request)) { try { - composeRequest(process(request).thenApply(this::handleReply)); + composeRequest(process(requestRef).thenApply(this::handleReply)); } catch (Exception e) { - handleError(e, request); + handleError(e, getCallId(request), isHeartbeat(request)); + release(request); } return; } - final PendingServerRequest current = new PendingServerRequest<>(request); - final PendingServerRequest previous = previousOnNext.getAndSet(current); - final CompletableFuture previousFuture = Optional.ofNullable(previous) - .map(PendingServerRequest::getFuture) + final PendingServerRequest current = new PendingServerRequest<>(requestRef); + final long callId = getCallId(current.getRequest()); + final boolean isHeartbeat = isHeartbeat(current.getRequest()); + final Optional> previous = Optional.ofNullable(previousOnNext.getAndSet(current)); + final CompletableFuture previousFuture = previous.map(PendingServerRequest::getFuture) .orElse(CompletableFuture.completedFuture(null)); try { - final CompletableFuture f = process(request).exceptionally(e -> { + final CompletableFuture f = process(requestRef).exceptionally(e -> { // Handle cases, such as RaftServer is paused - handleError(e, request); + handleError(e, callId, isHeartbeat); current.getFuture().completeExceptionally(e); return null; }).thenCombine(previousFuture, (reply, v) -> { @@ -161,8 +196,14 @@ public void onNext(REQUEST request) { }); composeRequest(f); } catch (Exception e) { - handleError(e, request); + handleError(e, callId, isHeartbeat); current.getFuture().completeExceptionally(e); + } finally { + previous.ifPresent(PendingServerRequest::release); + if (isClosed.get()) { + // Some requests may come after onCompleted or onError, ensure they're released. + releaseLast(); + } } } @@ -174,12 +215,13 @@ public void onCompleted() { getId(), op, getPreviousRequestString(), suffix)); requestFuture.get().thenAccept(reply -> { BatchLogger.print(BatchLogKey.COMPLETED_REPLY, getName(), - suffix -> LOG.info("{}: Completed {}, lastReply: {} {}", - getId(), op, ProtoUtils.shortDebugString(reply), suffix)); + suffix -> LOG.info("{}: Completed {}, lastReply: {} {}", getId(), op, reply, suffix)); responseObserver.onCompleted(); }); + releaseLast(); } } + @Override public void onError(Throwable t) { GrpcUtil.warn(LOG, () -> getId() + ": "+ op + " onError, lastRequest: " + getPreviousRequestString(), t); @@ -188,22 +230,54 @@ public void onError(Throwable t) { if (status != null && status.getCode() != Status.Code.CANCELLED) { responseObserver.onCompleted(); } + releaseLast(); } } + + private void releaseLast() { + Optional.ofNullable(previousOnNext.get()).ifPresent(PendingServerRequest::release); + } } private final Supplier idSupplier; private final RaftServer server; + private final boolean zeroCopyEnabled; + private final ZeroCopyMessageMarshaller zeroCopyRequestMarshaller; - GrpcServerProtocolService(Supplier idSupplier, RaftServer server) { + GrpcServerProtocolService(Supplier idSupplier, RaftServer server, boolean zeroCopyEnabled, + ZeroCopyMetrics zeroCopyMetrics) { this.idSupplier = idSupplier; this.server = server; + this.zeroCopyEnabled = zeroCopyEnabled; + this.zeroCopyRequestMarshaller = new ZeroCopyMessageMarshaller<>(AppendEntriesRequestProto.getDefaultInstance(), + zeroCopyMetrics::onZeroCopyMessage, zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage); + zeroCopyMetrics.addUnreleased("server_protocol", zeroCopyRequestMarshaller::getUnclosedCount); } RaftPeerId getId() { return idSupplier.get(); } + ServerServiceDefinition bindServiceWithZeroCopy() { + ServerServiceDefinition orig = super.bindService(); + if (!zeroCopyEnabled) { + LOG.info("{}: Zero copy is disabled.", getId()); + return orig; + } + ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(orig.getServiceDescriptor().getName()); + + // Add appendEntries with zero copy marshaller. + addMethodWithCustomMarshaller(orig, builder, getAppendEntriesMethod(), zeroCopyRequestMarshaller); + // Add remaining methods as is. + orig.getMethods().stream().filter( + x -> !x.getMethodDescriptor().getFullMethodName().equals(getAppendEntriesMethod().getFullMethodName()) + ).forEach( + builder::addMethod + ); + + return builder.build(); + } + @Override public void requestVote(RequestVoteRequestProto request, StreamObserver responseObserver) { @@ -244,8 +318,14 @@ public StreamObserver appendEntries( return new ServerRequestStreamObserver( RaftServerProtocol.Op.APPEND_ENTRIES, responseObserver) { @Override - CompletableFuture process(AppendEntriesRequestProto request) throws IOException { - return server.appendEntriesAsync(request); + CompletableFuture process(ReferenceCountedObject requestRef) + throws IOException { + return server.appendEntriesAsync(requestRef); + } + + @Override + void release(AppendEntriesRequestProto req) { + zeroCopyRequestMarshaller.release(req); } @Override @@ -253,6 +333,11 @@ long getCallId(AppendEntriesRequestProto request) { return request.getServerRequest().getCallId(); } + @Override + boolean isHeartbeat(AppendEntriesRequestProto request) { + return request.getEntriesCount() == 0; + } + @Override String requestToString(AppendEntriesRequestProto request) { return ServerStringUtils.toAppendEntriesRequestString(request, null); @@ -267,11 +352,6 @@ String replyToString(AppendEntriesReplyProto reply) { boolean replyInOrder(AppendEntriesRequestProto request) { return request.getEntriesCount() != 0; } - - @Override - StatusRuntimeException wrapException(Throwable e, AppendEntriesRequestProto request) { - return GrpcUtil.wrapException(e, getCallId(request), request.getEntriesCount() == 0); - } }; } diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java new file mode 100644 index 0000000000..e3c0a5eddb --- /dev/null +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ratis.grpc.server; + +import org.apache.ratis.conf.RaftProperties; +import org.apache.ratis.grpc.GrpcConfigKeys; +import org.apache.ratis.grpc.GrpcTlsConfig; +import org.apache.ratis.grpc.GrpcUtil; +import org.apache.ratis.grpc.metrics.ZeroCopyMetrics; +import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor; +import org.apache.ratis.protocol.RaftGroupId; +import org.apache.ratis.protocol.RaftPeerId; +import org.apache.ratis.rpc.SupportedRpcType; +import org.apache.ratis.server.RaftServer; +import org.apache.ratis.server.RaftServerConfigKeys; +import org.apache.ratis.server.RaftServerRpcWithProxy; +import org.apache.ratis.server.protocol.RaftServerAsynchronousProtocol; +import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting; +import org.apache.ratis.thirdparty.io.grpc.ServerInterceptors; +import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts; +import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder; +import org.apache.ratis.thirdparty.io.grpc.Server; +import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; +import org.apache.ratis.thirdparty.io.netty.channel.ChannelOption; +import org.apache.ratis.thirdparty.io.netty.handler.ssl.ClientAuth; +import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder; + +import org.apache.ratis.proto.RaftProtos.*; +import org.apache.ratis.util.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.function.Supplier; + +import static org.apache.ratis.thirdparty.io.netty.handler.ssl.SslProvider.OPENSSL; + +/** A grpc implementation of {@link org.apache.ratis.server.RaftServerRpc}. */ +public final class GrpcService extends RaftServerRpcWithProxy> { + static final Logger LOG = LoggerFactory.getLogger(GrpcService.class); + public static final String GRPC_SEND_SERVER_REQUEST = + JavaUtils.getClassSimpleName(GrpcService.class) + ".sendRequest"; + + class AsyncService implements RaftServerAsynchronousProtocol { + + @Override + public CompletableFuture appendEntriesAsync(AppendEntriesRequestProto request) + throws IOException { + throw new UnsupportedOperationException("This method is not supported"); + } + + @Override + public CompletableFuture readIndexAsync(ReadIndexRequestProto request) throws IOException { + CodeInjectionForTesting.execute(GRPC_SEND_SERVER_REQUEST, getId(), null, request); + + final CompletableFuture f = new CompletableFuture<>(); + final StreamObserver s = new StreamObserver() { + @Override + public void onNext(ReadIndexReplyProto reply) { + f.complete(reply); + } + + @Override + public void onError(Throwable throwable) { + f.completeExceptionally(throwable); + } + + @Override + public void onCompleted() { + } + }; + + final RaftPeerId target = RaftPeerId.valueOf(request.getServerRequest().getReplyId()); + getProxies().getProxy(target).readIndex(request, s); + return f; + } + } + + public static final class Builder { + private RaftServer server; + private GrpcTlsConfig tlsConfig; + private GrpcTlsConfig adminTlsConfig; + private GrpcTlsConfig clientTlsConfig; + private GrpcTlsConfig serverTlsConfig; + + private Builder() {} + + public Builder setServer(RaftServer raftServer) { + this.server = raftServer; + return this; + } + + public GrpcService build() { + return new GrpcService(server, adminTlsConfig, clientTlsConfig, serverTlsConfig); + } + + public Builder setTlsConfig(GrpcTlsConfig tlsConfig) { + this.tlsConfig = tlsConfig; + return this; + } + + public Builder setAdminTlsConfig(GrpcTlsConfig config) { + this.adminTlsConfig = config; + return this; + } + + public Builder setClientTlsConfig(GrpcTlsConfig config) { + this.clientTlsConfig = config; + return this; + } + + public Builder setServerTlsConfig(GrpcTlsConfig config) { + this.serverTlsConfig = config; + return this; + } + + public GrpcTlsConfig getTlsConfig() { + return tlsConfig; + } + } + + public static Builder newBuilder() { + return new Builder(); + } + + private final Map servers = new HashMap<>(); + private final Supplier addressSupplier; + private final Supplier clientServerAddressSupplier; + private final Supplier adminServerAddressSupplier; + + private final AsyncService asyncService = new AsyncService(); + + private final ExecutorService executor; + private final GrpcClientProtocolService clientProtocolService; + + private final MetricServerInterceptor serverInterceptor; + private final ZeroCopyMetrics zeroCopyMetrics; + + public MetricServerInterceptor getServerInterceptor() { + return serverInterceptor; + } + + private GrpcService(RaftServer server, + GrpcTlsConfig adminTlsConfig, GrpcTlsConfig clientTlsConfig, GrpcTlsConfig serverTlsConfig) { + this(server, server::getId, + GrpcConfigKeys.Admin.host(server.getProperties()), + GrpcConfigKeys.Admin.port(server.getProperties()), + adminTlsConfig, + GrpcConfigKeys.Client.host(server.getProperties()), + GrpcConfigKeys.Client.port(server.getProperties()), + clientTlsConfig, + GrpcConfigKeys.Server.host(server.getProperties()), + GrpcConfigKeys.Server.port(server.getProperties()), + serverTlsConfig, + GrpcConfigKeys.messageSizeMax(server.getProperties(), LOG::info), + RaftServerConfigKeys.Log.Appender.bufferByteLimit(server.getProperties()), + GrpcConfigKeys.flowControlWindow(server.getProperties(), LOG::info), + RaftServerConfigKeys.Rpc.requestTimeout(server.getProperties()), + GrpcConfigKeys.Server.heartbeatChannel(server.getProperties()), + GrpcConfigKeys.Server.zeroCopyEnabled(server.getProperties())); + } + + @SuppressWarnings("checkstyle:ParameterNumber") // private constructor + private GrpcService(RaftServer raftServer, Supplier idSupplier, + String adminHost, int adminPort, GrpcTlsConfig adminTlsConfig, + String clientHost, int clientPort, GrpcTlsConfig clientTlsConfig, + String serverHost, int serverPort, GrpcTlsConfig serverTlsConfig, + SizeInBytes grpcMessageSizeMax, SizeInBytes appenderBufferSize, + SizeInBytes flowControlWindow,TimeDuration requestTimeoutDuration, + boolean useSeparateHBChannel, boolean zeroCopyEnabled) { + super(idSupplier, id -> new PeerProxyMap<>(id.toString(), + p -> new GrpcServerProtocolClient(p, flowControlWindow.getSizeInt(), + requestTimeoutDuration, serverTlsConfig, useSeparateHBChannel))); + if (appenderBufferSize.getSize() > grpcMessageSizeMax.getSize()) { + throw new IllegalArgumentException("Illegal configuration: " + + RaftServerConfigKeys.Log.Appender.BUFFER_BYTE_LIMIT_KEY + " = " + appenderBufferSize + + " > " + GrpcConfigKeys.MESSAGE_SIZE_MAX_KEY + " = " + grpcMessageSizeMax); + } + + final RaftProperties properties = raftServer.getProperties(); + this.executor = ConcurrentUtils.newThreadPoolWithMax( + GrpcConfigKeys.Server.asyncRequestThreadPoolCached(properties), + GrpcConfigKeys.Server.asyncRequestThreadPoolSize(properties), + getId() + "-request-"); + this.zeroCopyMetrics = new ZeroCopyMetrics(); + this.clientProtocolService = new GrpcClientProtocolService(idSupplier, raftServer, executor, + zeroCopyEnabled, zeroCopyMetrics); + + this.serverInterceptor = new MetricServerInterceptor( + idSupplier, + JavaUtils.getClassSimpleName(getClass()) + "_" + serverPort + ); + + final boolean separateAdminServer = adminPort != serverPort && adminPort > 0; + final boolean separateClientServer = clientPort != serverPort && clientPort > 0; + + final NettyServerBuilder serverBuilder = + startBuildingNettyServer(serverHost, serverPort, serverTlsConfig, grpcMessageSizeMax, flowControlWindow); + GrpcServerProtocolService serverProtocolService = new GrpcServerProtocolService(idSupplier, raftServer, + zeroCopyEnabled, zeroCopyMetrics); + serverBuilder.addService(ServerInterceptors.intercept( + serverProtocolService.bindServiceWithZeroCopy(), serverInterceptor)); + if (!separateAdminServer) { + addAdminService(raftServer, serverBuilder); + } + if (!separateClientServer) { + addClientService(serverBuilder); + } + + final Server server = serverBuilder.build(); + servers.put(GrpcServerProtocolService.class.getSimpleName(), server); + addressSupplier = newAddressSupplier(serverPort, server); + + if (separateAdminServer) { + final NettyServerBuilder builder = + startBuildingNettyServer(adminHost, adminPort, adminTlsConfig, grpcMessageSizeMax, flowControlWindow); + addAdminService(raftServer, builder); + final Server adminServer = builder.build(); + servers.put(GrpcAdminProtocolService.class.getName(), adminServer); + adminServerAddressSupplier = newAddressSupplier(adminPort, adminServer); + } else { + adminServerAddressSupplier = addressSupplier; + } + + if (separateClientServer) { + final NettyServerBuilder builder = + startBuildingNettyServer(clientHost, clientPort, clientTlsConfig, grpcMessageSizeMax, flowControlWindow); + addClientService(builder); + final Server clientServer = builder.build(); + servers.put(GrpcClientProtocolService.class.getName(), clientServer); + clientServerAddressSupplier = newAddressSupplier(clientPort, clientServer); + } else { + clientServerAddressSupplier = addressSupplier; + } + } + + private MemoizedSupplier newAddressSupplier(int port, Server server) { + return JavaUtils.memoize(() -> new InetSocketAddress(port != 0 ? port : server.getPort())); + } + + private void addClientService(NettyServerBuilder builder) { + builder.addService(ServerInterceptors.intercept(clientProtocolService.bindServiceWithZeroCopy(), + serverInterceptor)); + } + + private void addAdminService(RaftServer raftServer, NettyServerBuilder nettyServerBuilder) { + nettyServerBuilder.addService(ServerInterceptors.intercept( + new GrpcAdminProtocolService(raftServer), + serverInterceptor)); + } + + private static NettyServerBuilder startBuildingNettyServer(String hostname, int port, GrpcTlsConfig tlsConfig, + SizeInBytes grpcMessageSizeMax, SizeInBytes flowControlWindow) { + InetSocketAddress address = hostname == null || hostname.isEmpty() ? + new InetSocketAddress(port) : new InetSocketAddress(hostname, port); + NettyServerBuilder nettyServerBuilder = NettyServerBuilder.forAddress(address) + .withChildOption(ChannelOption.SO_REUSEADDR, true) + .maxInboundMessageSize(grpcMessageSizeMax.getSizeInt()) + .flowControlWindow(flowControlWindow.getSizeInt()); + + if (tlsConfig != null) { + SslContextBuilder sslContextBuilder = GrpcUtil.initSslContextBuilderForServer(tlsConfig.getKeyManager()); + if (tlsConfig.getMtlsEnabled()) { + sslContextBuilder.clientAuth(ClientAuth.REQUIRE); + GrpcUtil.setTrustManager(sslContextBuilder, tlsConfig.getTrustManager()); + } + sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder, OPENSSL); + try { + nettyServerBuilder.sslContext(sslContextBuilder.build()); + } catch (Exception ex) { + throw new IllegalArgumentException("Failed to build SslContext, tlsConfig=" + tlsConfig, ex); + } + } + return nettyServerBuilder; + } + + @Override + public SupportedRpcType getRpcType() { + return SupportedRpcType.GRPC; + } + + @Override + public void startImpl() { + for (Server server : servers.values()) { + try { + server.start(); + } catch (IOException e) { + ExitUtils.terminate(1, "Failed to start Grpc server", e, LOG); + } + LOG.info("{}: {} started, listening on {}", + getId(), JavaUtils.getClassSimpleName(getClass()), server.getPort()); + } + } + + @Override + public void closeImpl() throws IOException { + for (Map.Entry server : servers.entrySet()) { + final String name = getId() + ": shutdown server " + server.getKey(); + LOG.info("{} now", name); + final Server s = server.getValue().shutdownNow(); + super.closeImpl(); + try { + s.awaitTermination(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw IOUtils.toInterruptedIOException(name + " failed", e); + } + LOG.info("{} successfully", name); + } + + serverInterceptor.close(); + ConcurrentUtils.shutdownAndWait(executor); + zeroCopyMetrics.unregister(); + } + + @Override + public void notifyNotLeader(RaftGroupId groupId) { + clientProtocolService.closeAllOrderedRequestStreamObservers(groupId); + } + + @Override + public InetSocketAddress getInetSocketAddress() { + return addressSupplier.get(); + } + + @Override + public InetSocketAddress getClientServerAddress() { + return clientServerAddressSupplier.get(); + } + + @Override + public InetSocketAddress getAdminServerAddress() { + return adminServerAddressSupplier.get(); + } + + @Override + public RaftServerAsynchronousProtocol async() { + return asyncService; + } + + @Override + public AppendEntriesReplyProto appendEntries(AppendEntriesRequestProto request) { + throw new UnsupportedOperationException( + "Blocking " + JavaUtils.getCurrentStackTraceElement().getMethodName() + " call is not supported"); + } + + @Override + public InstallSnapshotReplyProto installSnapshot(InstallSnapshotRequestProto request) { + throw new UnsupportedOperationException( + "Blocking " + JavaUtils.getCurrentStackTraceElement().getMethodName() + " call is not supported"); + } + + @Override + public RequestVoteReplyProto requestVote(RequestVoteRequestProto request) + throws IOException { + CodeInjectionForTesting.execute(GRPC_SEND_SERVER_REQUEST, getId(), + null, request); + + final RaftPeerId target = RaftPeerId.valueOf(request.getServerRequest().getReplyId()); + return getProxies().getProxy(target).requestVote(request); + } + + @Override + public StartLeaderElectionReplyProto startLeaderElection(StartLeaderElectionRequestProto request) throws IOException { + CodeInjectionForTesting.execute(GRPC_SEND_SERVER_REQUEST, getId(), null, request); + + final RaftPeerId target = RaftPeerId.valueOf(request.getServerRequest().getReplyId()); + return getProxies().getProxy(target).startLeaderElection(request); + } + + @VisibleForTesting + public ZeroCopyMetrics getZeroCopyMetrics() { + return zeroCopyMetrics; + } +} diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java index d554ca583a..d6f6a0c866 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java @@ -19,7 +19,10 @@ import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.grpc.GrpcConfigKeys; +import org.apache.ratis.grpc.GrpcTlsConfig; +import org.apache.ratis.grpc.GrpcUtil; import org.apache.ratis.grpc.metrics.MessageMetrics; +import org.apache.ratis.grpc.metrics.ZeroCopyMetrics; import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor; import org.apache.ratis.protocol.AdminAsynchronousProtocol; import org.apache.ratis.protocol.RaftGroupId; @@ -30,13 +33,17 @@ import org.apache.ratis.server.RaftServerConfigKeys; import org.apache.ratis.server.RaftServerRpcWithProxy; import org.apache.ratis.server.protocol.RaftServerAsynchronousProtocol; +import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting; import org.apache.ratis.thirdparty.io.grpc.ServerInterceptor; import org.apache.ratis.thirdparty.io.grpc.ServerInterceptors; +import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition; +import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts; import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder; import org.apache.ratis.thirdparty.io.grpc.Server; import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver; import org.apache.ratis.thirdparty.io.netty.channel.ChannelOption; -import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext; +import org.apache.ratis.thirdparty.io.netty.handler.ssl.ClientAuth; +import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder; import org.apache.ratis.proto.RaftProtos.*; import org.apache.ratis.util.*; @@ -52,6 +59,8 @@ import java.util.concurrent.ExecutorService; import java.util.function.Supplier; +import static org.apache.ratis.thirdparty.io.netty.handler.ssl.SslProvider.OPENSSL; + /** A grpc implementation of {@link org.apache.ratis.server.RaftServerRpc}. */ public final class GrpcServicesImpl extends RaftServerRpcWithProxy> @@ -100,20 +109,19 @@ public static final class Builder { private String adminHost; private int adminPort; - private SslContext adminSslContext; + private GrpcTlsConfig adminTlsConfig; private String clientHost; private int clientPort; - private SslContext clientSslContext; + private GrpcTlsConfig clientTlsConfig; private String serverHost; private int serverPort; - private SslContext serverSslContextForServer; - private SslContext serverSslContextForClient; - private int serverStubPoolSize; + private GrpcTlsConfig serverTlsConfig; private SizeInBytes messageSizeMax; private SizeInBytes flowControlWindow; private TimeDuration requestTimeoutDuration; private boolean separateHeartbeatChannel; + private boolean zeroCopyEnabled; private Builder() {} @@ -131,7 +139,7 @@ public Builder setServer(RaftServer raftServer) { this.flowControlWindow = GrpcConfigKeys.flowControlWindow(properties, LOG::info); this.requestTimeoutDuration = RaftServerConfigKeys.Rpc.requestTimeout(properties); this.separateHeartbeatChannel = GrpcConfigKeys.Server.heartbeatChannel(properties); - this.serverStubPoolSize = GrpcConfigKeys.Server.stubPoolSize(properties); + this.zeroCopyEnabled = GrpcConfigKeys.Server.zeroCopyEnabled(properties); final SizeInBytes appenderBufferSize = RaftServerConfigKeys.Log.Appender.bufferByteLimit(properties); final SizeInBytes gap = SizeInBytes.ONE_MB; @@ -152,8 +160,8 @@ public Builder setCustomizer(Customizer customizer) { } private GrpcServerProtocolClient newGrpcServerProtocolClient(RaftPeer target) { - return new GrpcServerProtocolClient(target, serverStubPoolSize, flowControlWindow.getSizeInt(), - requestTimeoutDuration, serverSslContextForClient, separateHeartbeatChannel); + return new GrpcServerProtocolClient(target, flowControlWindow.getSizeInt(), + requestTimeoutDuration, serverTlsConfig, separateHeartbeatChannel); } private ExecutorService newExecutor() { @@ -165,12 +173,12 @@ private ExecutorService newExecutor() { } private GrpcClientProtocolService newGrpcClientProtocolService( - ExecutorService executor) { - return new GrpcClientProtocolService(server::getId, server, executor); + ExecutorService executor, ZeroCopyMetrics zeroCopyMetrics) { + return new GrpcClientProtocolService(server::getId, server, executor, zeroCopyEnabled, zeroCopyMetrics); } - private GrpcServerProtocolService newGrpcServerProtocolService() { - return new GrpcServerProtocolService(server::getId, server); + private GrpcServerProtocolService newGrpcServerProtocolService(ZeroCopyMetrics zeroCopyMetrics) { + return new GrpcServerProtocolService(server::getId, server, zeroCopyEnabled, zeroCopyMetrics); } private MetricServerInterceptor newMetricServerInterceptor() { @@ -183,18 +191,18 @@ Server buildServer(NettyServerBuilder builder, EnumSet types) } private NettyServerBuilder newNettyServerBuilderForServer() { - return newNettyServerBuilder(serverHost, serverPort, serverSslContextForServer); + return newNettyServerBuilder(serverHost, serverPort, serverTlsConfig); } private NettyServerBuilder newNettyServerBuilderForAdmin() { - return newNettyServerBuilder(adminHost, adminPort, adminSslContext); + return newNettyServerBuilder(adminHost, adminPort, adminTlsConfig); } private NettyServerBuilder newNettyServerBuilderForClient() { - return newNettyServerBuilder(clientHost, clientPort, clientSslContext); + return newNettyServerBuilder(clientHost, clientPort, clientTlsConfig); } - private NettyServerBuilder newNettyServerBuilder(String hostname, int port, SslContext sslContext) { + private NettyServerBuilder newNettyServerBuilder(String hostname, int port, GrpcTlsConfig tlsConfig) { final InetSocketAddress address = hostname == null || hostname.isEmpty() ? new InetSocketAddress(port) : new InetSocketAddress(hostname, port); final NettyServerBuilder nettyServerBuilder = NettyServerBuilder.forAddress(address) @@ -202,9 +210,19 @@ private NettyServerBuilder newNettyServerBuilder(String hostname, int port, SslC .maxInboundMessageSize(messageSizeMax.getSizeInt()) .flowControlWindow(flowControlWindow.getSizeInt()); - if (sslContext != null) { + if (tlsConfig != null) { LOG.info("Setting TLS for {}", address); - nettyServerBuilder.sslContext(sslContext); + SslContextBuilder sslContextBuilder = GrpcUtil.initSslContextBuilderForServer(tlsConfig.getKeyManager()); + if (tlsConfig.getMtlsEnabled()) { + sslContextBuilder.clientAuth(ClientAuth.REQUIRE); + GrpcUtil.setTrustManager(sslContextBuilder, tlsConfig.getTrustManager()); + } + sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder, OPENSSL); + try { + nettyServerBuilder.sslContext(sslContextBuilder.build()); + } catch (Exception ex) { + throw new IllegalArgumentException("Failed to build SslContext, tlsConfig=" + tlsConfig, ex); + } } return nettyServerBuilder; } @@ -217,10 +235,10 @@ private boolean separateClientServer() { return clientPort > 0 && clientPort != serverPort; } - Server newServer(GrpcClientProtocolService client, ServerInterceptor interceptor) { + Server newServer(GrpcClientProtocolService client, ZeroCopyMetrics zeroCopyMetrics, ServerInterceptor interceptor) { final EnumSet types = EnumSet.of(GrpcServices.Type.SERVER); final NettyServerBuilder serverBuilder = newNettyServerBuilderForServer(); - final GrpcServerProtocolService service = newGrpcServerProtocolService(); + final ServerServiceDefinition service = newGrpcServerProtocolService(zeroCopyMetrics).bindServiceWithZeroCopy(); serverBuilder.addService(ServerInterceptors.intercept(service, interceptor)); if (!separateAdminServer()) { @@ -238,23 +256,18 @@ public GrpcServicesImpl build() { return new GrpcServicesImpl(this); } - public Builder setAdminSslContext(SslContext adminSslContext) { - this.adminSslContext = adminSslContext; - return this; - } - - public Builder setClientSslContext(SslContext clientSslContext) { - this.clientSslContext = clientSslContext; + public Builder setAdminTlsConfig(GrpcTlsConfig config) { + this.adminTlsConfig = config; return this; } - public Builder setServerSslContextForServer(SslContext serverSslContextForServer) { - this.serverSslContextForServer = serverSslContextForServer; + public Builder setClientTlsConfig(GrpcTlsConfig config) { + this.clientTlsConfig = config; return this; } - public Builder setServerSslContextForClient(SslContext serverSslContextForClient) { - this.serverSslContextForClient = serverSslContextForClient; + public Builder setServerTlsConfig(GrpcTlsConfig config) { + this.serverTlsConfig = config; return this; } } @@ -274,14 +287,15 @@ public static Builder newBuilder() { private final GrpcClientProtocolService clientProtocolService; private final MetricServerInterceptor serverInterceptor; + private final ZeroCopyMetrics zeroCopyMetrics = new ZeroCopyMetrics(); private GrpcServicesImpl(Builder b) { super(b.server::getId, id -> new PeerProxyMap<>(id.toString(), b::newGrpcServerProtocolClient)); this.executor = b.newExecutor(); - this.clientProtocolService = b.newGrpcClientProtocolService(executor); + this.clientProtocolService = b.newGrpcClientProtocolService(executor, zeroCopyMetrics); this.serverInterceptor = b.newMetricServerInterceptor(); - final Server server = b.newServer(clientProtocolService, serverInterceptor); + final Server server = b.newServer(clientProtocolService, zeroCopyMetrics, serverInterceptor); servers.put(GrpcServerProtocolService.class.getSimpleName(), server); addressSupplier = newAddressSupplier(b.serverPort, server); @@ -313,7 +327,8 @@ static MemoizedSupplier newAddressSupplier(int port, Server s static void addClientService(NettyServerBuilder builder, GrpcClientProtocolService client, ServerInterceptor interceptor) { - builder.addService(ServerInterceptors.intercept(client, interceptor)); + final ServerServiceDefinition service = client.bindServiceWithZeroCopy(); + builder.addService(ServerInterceptors.intercept(service, interceptor)); } static void addAdminService(NettyServerBuilder builder, AdminAsynchronousProtocol admin, @@ -341,40 +356,24 @@ public void startImpl() { } @Override - public void closeImpl() { - for (Server server : servers.values()) { - server.shutdownNow(); - } - boolean interrupted = false; + public void closeImpl() throws IOException { for (Map.Entry server : servers.entrySet()) { + final String name = getId() + ": shutdown server " + server.getKey(); + LOG.info("{} now", name); + final Server s = server.getValue().shutdownNow(); + super.closeImpl(); try { - server.getValue().awaitTermination(); - LOG.info("{}: Shutdown {} successfully", getId(), server.getKey()); + s.awaitTermination(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); - LOG.warn("{}: Interrupted shutdown {}", getId(), server.getKey()); - interrupted = true; - break; + throw IOUtils.toInterruptedIOException(name + " failed", e); } + LOG.info("{} successfully", name); } - try { - serverInterceptor.close(); - } catch (Exception e) { - LOG.warn("{}: Failed to unregister metrics", getId(), e); - } - - if (interrupted) { - executor.shutdown(); // shutdown but not wait - } else { - ConcurrentUtils.shutdownAndWait(executor); - } - - try { - super.closeImpl(); - } catch (IOException e) { - LOG.warn("{}: Failed to close proxies", getId(), e); - } + serverInterceptor.close(); + ConcurrentUtils.shutdownAndWait(executor); + zeroCopyMetrics.unregister(); } @Override @@ -432,7 +431,13 @@ public StartLeaderElectionReplyProto startLeaderElection(StartLeaderElectionRequ return getProxies().getProxy(target).startLeaderElection(request); } + @VisibleForTesting MessageMetrics getMessageMetrics() { return serverInterceptor.getMetrics(); } + + @VisibleForTesting + public ZeroCopyMetrics getZeroCopyMetrics() { + return zeroCopyMetrics; + } } diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java index 3fb3f067be..eddf2495e4 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java @@ -62,12 +62,14 @@ public class ZeroCopyMessageMarshaller implements Prototy private final Consumer zeroCopyCount; private final Consumer nonZeroCopyCount; + private final Consumer releasedCount; public ZeroCopyMessageMarshaller(T defaultInstance) { - this(defaultInstance, m -> {}, m -> {}); + this(defaultInstance, m -> {}, m -> {}, m -> {}); } - public ZeroCopyMessageMarshaller(T defaultInstance, Consumer zeroCopyCount, Consumer nonZeroCopyCount) { + public ZeroCopyMessageMarshaller(T defaultInstance, Consumer zeroCopyCount, Consumer nonZeroCopyCount, + Consumer releasedCount) { this.name = JavaUtils.getClassSimpleName(defaultInstance.getClass()) + "-Marshaller"; @SuppressWarnings("unchecked") final Parser p = (Parser) defaultInstance.getParserForType(); @@ -76,6 +78,7 @@ public ZeroCopyMessageMarshaller(T defaultInstance, Consumer zeroCopyCount, C this.zeroCopyCount = zeroCopyCount; this.nonZeroCopyCount = nonZeroCopyCount; + this.releasedCount = releasedCount; } @Override @@ -124,6 +127,7 @@ public void release(T message) { } try { stream.close(); + releasedCount.accept(message); } catch (IOException e) { LOG.error(name + ": Failed to close stream.", e); } @@ -223,6 +227,10 @@ public InputStream popStream(T message) { return unclosedStreams.remove(message); } + public int getUnclosedCount() { + return unclosedStreams.size(); + } + void assertNoUnclosedStreams() { // Intended for tests/teardown to fail fast if callers forgot to release streams. final int size = unclosedStreams.size(); diff --git a/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java b/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java index 0e4eb55544..bd1c72b241 100644 --- a/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java +++ b/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java @@ -21,14 +21,19 @@ import org.apache.ratis.RaftTestUtil; import org.apache.ratis.conf.Parameters; import org.apache.ratis.conf.RaftProperties; +import org.apache.ratis.grpc.metrics.ZeroCopyMetrics; import org.apache.ratis.grpc.server.GrpcServicesImpl; import org.apache.ratis.protocol.RaftGroup; import org.apache.ratis.protocol.RaftPeer; import org.apache.ratis.protocol.RaftPeerId; import org.apache.ratis.rpc.SupportedRpcType; +import org.apache.ratis.server.RaftServer; import org.apache.ratis.server.impl.DelayLocalExecutionInjection; import org.apache.ratis.server.impl.MiniRaftCluster; +import org.apache.ratis.server.impl.RaftServerTestUtil; import org.apache.ratis.util.NetUtils; +import org.apache.ratis.util.ReferenceCountedLeakDetector; +import org.junit.jupiter.api.Assertions; import java.util.Optional; @@ -45,6 +50,11 @@ public MiniRaftClusterWithGrpc newCluster(String[] ids, String[] listenerIds, Ra } }; + static { + // TODO move it to MiniRaftCluster for detecting non-gRPC cases + ReferenceCountedLeakDetector.enable(false); + } + public interface FactoryGet extends Factory.Get { @Override default Factory getFactory() { @@ -54,7 +64,7 @@ default Factory getFactory() { public static final DelayLocalExecutionInjection SEND_SERVER_REQUEST_INJECTION = new DelayLocalExecutionInjection(GrpcServicesImpl.GRPC_SEND_SERVER_REQUEST); - + public MiniRaftClusterWithGrpc(String[] ids, RaftProperties properties, Parameters parameters) { this(ids, new String[0], properties, parameters); } @@ -70,6 +80,8 @@ protected Parameters setPropertiesAndInitParameters(RaftPeerId id, RaftGroup gro GrpcConfigKeys.Client.setPort(properties, NetUtils.createSocketAddr(address).getPort())); Optional.ofNullable(getAddress(id, group, RaftPeer::getAdminAddress)).ifPresent(address -> GrpcConfigKeys.Admin.setPort(properties, NetUtils.createSocketAddr(address).getPort())); + // Always run grpc integration tests with zero-copy enabled because the path of nonzero-copy is not risky. + GrpcConfigKeys.Server.setZeroCopyEnabled(properties, true); return parameters; } @@ -79,4 +91,22 @@ protected void blockQueueAndSetDelay(String leaderId, int delayMs) RaftTestUtil.blockQueueAndSetDelay(getServers(), SEND_SERVER_REQUEST_INJECTION, leaderId, delayMs, getTimeoutMax()); } + + @Override + public void shutdown() { + super.shutdown(); + assertZeroCopyMetrics(); + } + + public void assertZeroCopyMetrics() { + getServers().forEach(server -> server.getGroupIds().forEach(id -> { + LOG.info("Checking {}-{}", server.getId(), id); + RaftServer.Division division = RaftServerTestUtil.getDivision(server, id); + final GrpcServicesImpl service = (GrpcServicesImpl) RaftServerTestUtil.getServerRpc(division); + ZeroCopyMetrics zeroCopyMetrics = service.getZeroCopyMetrics(); + Assertions.assertEquals(0, zeroCopyMetrics.nonZeroCopyMessages()); + Assertions.assertEquals(zeroCopyMetrics.zeroCopyMessages(), zeroCopyMetrics.releasedMessages(), + "Unreleased zero copy messages: please check logs to find the leaks. "); + })); + } } diff --git a/ratis-server-api/src/main/java/org/apache/ratis/server/leader/LogAppender.java b/ratis-server-api/src/main/java/org/apache/ratis/server/leader/LogAppender.java index 33914fde7f..dc189a14aa 100644 --- a/ratis-server-api/src/main/java/org/apache/ratis/server/leader/LogAppender.java +++ b/ratis-server-api/src/main/java/org/apache/ratis/server/leader/LogAppender.java @@ -125,7 +125,9 @@ default RaftPeerId getFollowerId() { * @param heartbeat the returned request must be a heartbeat. * * @return a new {@link AppendEntriesRequestProto} object. + * @deprecated this is no longer a public API. */ + @Deprecated AppendEntriesRequestProto newAppendEntriesRequest(long callId, boolean heartbeat) throws RaftLogIOException; /** @return a new {@link InstallSnapshotRequestProto} object. */ diff --git a/ratis-server-api/src/main/java/org/apache/ratis/server/protocol/RaftServerAsynchronousProtocol.java b/ratis-server-api/src/main/java/org/apache/ratis/server/protocol/RaftServerAsynchronousProtocol.java index 8a904069ba..035e0a815f 100644 --- a/ratis-server-api/src/main/java/org/apache/ratis/server/protocol/RaftServerAsynchronousProtocol.java +++ b/ratis-server-api/src/main/java/org/apache/ratis/server/protocol/RaftServerAsynchronousProtocol.java @@ -22,14 +22,39 @@ import org.apache.ratis.proto.RaftProtos.ReadIndexReplyProto; import org.apache.ratis.proto.RaftProtos.AppendEntriesReplyProto; import org.apache.ratis.proto.RaftProtos.AppendEntriesRequestProto; +import org.apache.ratis.util.ReferenceCountedObject; import java.io.IOException; import java.util.concurrent.CompletableFuture; public interface RaftServerAsynchronousProtocol { - CompletableFuture appendEntriesAsync(AppendEntriesRequestProto request) - throws IOException; + /** + * It is recommended to override {@link #appendEntriesAsync(ReferenceCountedObject)} instead. + * Then, it does not have to override this method. + */ + default CompletableFuture appendEntriesAsync(AppendEntriesRequestProto request) + throws IOException { + throw new UnsupportedOperationException(); + } + + /** + * A referenced counted request is submitted from a client for processing. + * Implementations of this method should retain the request, process it and then release it. + * The request may be retained even after the future returned by this method has completed. + * + * @return a future of the reply + * @see ReferenceCountedObject + */ + default CompletableFuture appendEntriesAsync( + ReferenceCountedObject requestRef) throws IOException { + // Default implementation for backward compatibility. + try { + return appendEntriesAsync(requestRef.retain()); + } finally { + requestRef.release(); + } + } CompletableFuture readIndexAsync(ReadIndexRequestProto request) throws IOException; diff --git a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java index e194f865ed..07446282e7 100644 --- a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java +++ b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java @@ -21,6 +21,7 @@ import org.apache.ratis.server.metrics.RaftLogMetrics; import org.apache.ratis.server.protocol.TermIndex; import org.apache.ratis.server.storage.RaftStorageMetadata; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.TimeDuration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -60,16 +61,44 @@ default boolean contains(TermIndex ti) { /** * @return null if the log entry is not found in this log; - * otherwise, return the log entry corresponding to the given index. + * otherwise, return a copy of the log entry corresponding to the given index. + * @deprecated use {@link RaftLog#retainLog(long)} instead in order to avoid copying. */ + @Deprecated LogEntryProto get(long index) throws RaftLogIOException; + /** + * @return a retained {@link ReferenceCountedObject} to the log entry corresponding to the given index if it exists; + * otherwise, return null. + * Since the returned reference is retained, the caller must call {@link ReferenceCountedObject#release()}} + * after use. + */ + default ReferenceCountedObject retainLog(long index) throws RaftLogIOException { + ReferenceCountedObject wrap = ReferenceCountedObject.wrap(get(index)); + wrap.retain(); + return wrap; + } + /** * @return null if the log entry is not found in this log; * otherwise, return the {@link EntryWithData} corresponding to the given index. + * @deprecated use {@link #retainEntryWithData(long)}. */ + @Deprecated EntryWithData getEntryWithData(long index) throws RaftLogIOException; + /** + * @return null if the log entry is not found in this log; + * otherwise, return a retained reference of the {@link EntryWithData} corresponding to the given index. + * Since the returned reference is retained, the caller must call {@link ReferenceCountedObject#release()}} + * after use. + */ + default ReferenceCountedObject retainEntryWithData(long index) throws RaftLogIOException { + final ReferenceCountedObject wrap = ReferenceCountedObject.wrap(getEntryWithData(index)); + wrap.retain(); + return wrap; +} + /** * @param startIndex the starting log index (inclusive) * @param endIndex the ending log index (exclusive) @@ -160,6 +189,15 @@ default long getNextIndex() { * containing both the log entry and the state machine data. */ interface EntryWithData { + /** @return the index of this entry. */ + default long getIndex() { + try { + return getEntry(TimeDuration.ONE_MINUTE).getIndex(); + } catch (Exception e) { + throw new IllegalStateException("Failed to getIndex", e); + } + } + /** @return the serialized size including both log entry and state machine data. */ int getSerializedSize(); diff --git a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLogSequentialOps.java b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLogSequentialOps.java index 5e8bd6d784..5a25728830 100644 --- a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLogSequentialOps.java +++ b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLogSequentialOps.java @@ -22,8 +22,10 @@ import org.apache.ratis.server.RaftConfiguration; import org.apache.ratis.statemachine.TransactionContext; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.StringUtils; import org.apache.ratis.util.function.CheckedSupplier; +import org.apache.ratis.util.function.UncheckedAutoCloseableSupplier; import java.util.Arrays; import java.util.List; @@ -124,31 +126,56 @@ OUTPUT runSequentially( CompletableFuture appendEntry(LogEntryProto entry); /** - * Append asynchronously an entry. - * Used by the leader. + * @deprecated use {@link #appendEntry(ReferenceCountedObject, TransactionContext)}}. */ + @Deprecated default CompletableFuture appendEntry(LogEntryProto entry, TransactionContext context) { - return appendEntry(entry); + throw new UnsupportedOperationException(); + } + + /** + * Append asynchronously an entry. + * Used for scenarios that there is a ReferenceCountedObject context for resource cleanup when the given entry + * is no longer used/referenced by this log. + */ + default CompletableFuture appendEntry(ReferenceCountedObject entryRef, + TransactionContext context) { + return appendEntry(entryRef.get(), context); } /** * The same as append(Arrays.asList(entries)). * - * @deprecated use {@link #append(List)} + * @deprecated use {@link #append(ReferenceCountedObject)}. */ @Deprecated default List> append(LogEntryProto... entries) { return append(Arrays.asList(entries)); } + /** + * @deprecated use {@link #append(ReferenceCountedObject)}. + */ + @Deprecated + default List> append(List entries) { + throw new UnsupportedOperationException(); + } + /** * Append asynchronously all the given log entries. * Used by the followers. * * If an existing entry conflicts with a new one (same index but different terms), * delete the existing entry and all entries that follow it (§5.3). + * + * A reference counter is also submitted. + * For each entry, implementations of this method should retain the counter, process it and then release. */ - List> append(List entries); + default List> append(ReferenceCountedObject> entriesRef) { + try(UncheckedAutoCloseableSupplier> entries = entriesRef.retainAndReleaseOnClose()) { + return append(entries.get()); + } + } /** * Truncate asynchronously the log entries till the given index (inclusively). diff --git a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java index 98d4537847..3960ab8287 100644 --- a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java +++ b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/StateMachine.java @@ -88,11 +88,35 @@ default CompletableFuture read(LogEntryProto entry, TransactionConte return read(entry); } + /** + * Read asynchronously the state machine data from this state machine. + * StateMachines implement this method when the read result contains retained resources that should be released + * after use. + * + * @return a future for the read task. The result of the future is a {@link ReferenceCountedObject} wrapping the + * read result. Client code of this method must call {@link ReferenceCountedObject#release()} after + * use. + */ + default CompletableFuture> retainRead(LogEntryProto entry, + TransactionContext context) { + return read(entry, context).thenApply(r -> { + if (r == null) { + return null; + } + ReferenceCountedObject ref = ReferenceCountedObject.wrap(r); + ref.retain(); + return ref; + + }); + } + /** * Write asynchronously the state machine data in the given log entry to this state machine. * * @return a future for the write task + * @deprecated Applications should implement {@link #write(ReferenceCountedObject, TransactionContext)} instead. */ + @Deprecated default CompletableFuture write(LogEntryProto entry) { return CompletableFuture.completedFuture(null); } @@ -101,11 +125,36 @@ default CompletableFuture write(LogEntryProto entry) { * Write asynchronously the state machine data in the given log entry to this state machine. * * @return a future for the write task + * @deprecated Applications should implement {@link #write(ReferenceCountedObject, TransactionContext)} instead. */ + @Deprecated default CompletableFuture write(LogEntryProto entry, TransactionContext context) { return write(entry); } + /** + * Write asynchronously the state machine data in the given log entry to this state machine. + * + * @param entryRef Reference to a log entry. + * Implementations of this method may call {@link ReferenceCountedObject#get()} + * to access the log entry before this method returns. + * If the log entry is needed after this method returns, + * e.g. for asynchronous computation or caching, + * the implementation must invoke {@link ReferenceCountedObject#retain()} + * and {@link ReferenceCountedObject#release()}. + * @return a future for the write task + */ + default CompletableFuture write(ReferenceCountedObject entryRef, TransactionContext context) { + final LogEntryProto entry = entryRef.get(); + try { + final LogEntryProto copy = LogEntryProto.parseFrom(entry.toByteString()); + return write(copy, context); + } catch (InvalidProtocolBufferException e) { + return JavaUtils.completeExceptionally(new IllegalStateException( + "Failed to copy log entry " + TermIndex.valueOf(entry), e)); + } + } + /** * Create asynchronously a {@link DataStream} to stream state machine data. * The state machine may use the first message (i.e. request.getMessage()) as the header to create the stream. diff --git a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java index 35a40efb57..6b63624c81 100644 --- a/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java +++ b/ratis-server-api/src/main/java/org/apache/ratis/statemachine/TransactionContext.java @@ -24,9 +24,11 @@ import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; import org.apache.ratis.util.Preconditions; import org.apache.ratis.util.ReflectionUtils; +import org.apache.ratis.util.ReferenceCountedObject; import java.io.IOException; import java.util.Objects; +import java.util.Optional; /** * Context for a transaction. @@ -58,7 +60,10 @@ public interface TransactionContext { /** * Returns the data from the {@link StateMachine} * @return the data from the {@link StateMachine} + * @deprecated access StateMachineLogEntry via {@link TransactionContext#getLogEntryRef()} or + * {@link TransactionContext#getLogEntryUnsafe()} */ + @Deprecated StateMachineLogEntryProto getStateMachineLogEntry(); /** Set exception in case of failure. */ @@ -93,11 +98,47 @@ public interface TransactionContext { LogEntryProto initLogEntry(long term, long index); /** - * Returns the committed log entry - * @return the committed log entry + * @return a copy of the committed log entry if it exists; otherwise, returns null + * + * @deprecated Use {@link #getLogEntryRef()} or {@link #getLogEntryUnsafe()} to avoid copying. */ + @Deprecated LogEntryProto getLogEntry(); + /** + * @return the committed log entry if it exists; otherwise, returns null. + * The returned value is safe to use only before {@link StateMachine#applyTransaction} returns. + * Once {@link StateMachine#applyTransaction} has returned, it is unsafe to use the log entry + * since the underlying buffers can possiby be released. + */ + default LogEntryProto getLogEntryUnsafe() { + return getLogEntryRef().get(); + } + + /** + * Get a {@link ReferenceCountedObject} to the committed log entry. + * + * It is safe to access the log entry by calling {@link ReferenceCountedObject#get()} + * (without {@link ReferenceCountedObject#retain()}) + * inside the scope of {@link StateMachine#applyTransaction}. + * + * If the log entry is needed after {@link StateMachine#applyTransaction} returns, + * e.g. for asynchronous computation or caching, + * the caller must invoke {@link ReferenceCountedObject#retain()} and {@link ReferenceCountedObject#release()}. + * + * @return a reference to the committed log entry if it exists; otherwise, returns null. + */ + default ReferenceCountedObject getLogEntryRef() { + return Optional.ofNullable(getLogEntryUnsafe()).map(this::wrap).orElse(null); + } + + /** Wrap the given log entry as a {@link ReferenceCountedObject} for retaining it for later use. */ + default ReferenceCountedObject wrap(LogEntryProto entry) { + Preconditions.assertSame(getLogEntry().getTerm(), entry.getTerm(), "entry.term"); + Preconditions.assertSame(getLogEntry().getIndex(), entry.getIndex(), "entry.index"); + return ReferenceCountedObject.wrap(entry); + } + /** * Sets whether to commit the transaction to the RAFT log or not * @param shouldCommit true if the transaction is supposed to be committed to the RAFT log diff --git a/ratis-server/dev-support/findbugsExcludeFile.xml b/ratis-server/dev-support/findbugsExcludeFile.xml index 0161c226bb..384d00cb29 100644 --- a/ratis-server/dev-support/findbugsExcludeFile.xml +++ b/ratis-server/dev-support/findbugsExcludeFile.xml @@ -75,4 +75,8 @@ - \ No newline at end of file + + + + + diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java index 1c986ca638..dab660fc05 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/LeaderStateImpl.java @@ -39,7 +39,6 @@ import org.apache.ratis.protocol.exceptions.ReadIndexException; import org.apache.ratis.protocol.exceptions.ReconfigurationTimeoutException; import org.apache.ratis.server.RaftServerConfigKeys; -import org.apache.ratis.server.RaftServerConfigKeys.Read.ReadIndex.Type; import org.apache.ratis.server.impl.ReadIndexHeartbeats.AppendEntriesListener; import org.apache.ratis.server.leader.FollowerInfo; import org.apache.ratis.server.leader.LeaderState; @@ -58,6 +57,7 @@ import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.MemoizedSupplier; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.TimeDuration; import org.apache.ratis.util.Timestamp; @@ -556,20 +556,21 @@ PendingRequests.Permit tryAcquirePendingRequest(Message message) { PendingRequest addPendingRequest(PendingRequests.Permit permit, RaftClientRequest request, TransactionContext entry) { if (LOG.isDebugEnabled()) { LOG.debug("{}: addPendingRequest at {}, entry={}", this, request, - LogProtoUtils.toLogEntryString(entry.getLogEntry())); + LogProtoUtils.toLogEntryString(entry.getLogEntryUnsafe())); } return pendingRequests.add(permit, request, entry); } - CompletableFuture streamAsync(RaftClientRequest request) { - return messageStreamRequests.streamAsync(request) + CompletableFuture streamAsync(ReferenceCountedObject requestRef) { + RaftClientRequest request = requestRef.get(); + return messageStreamRequests.streamAsync(requestRef) .thenApply(dummy -> server.newSuccessReply(request)) .exceptionally(e -> exception2RaftClientReply(request, e)); } - CompletableFuture streamEndOfRequestAsync(RaftClientRequest request) { - return messageStreamRequests.streamEndOfRequestAsync(request) - .thenApply(bytes -> RaftClientRequest.toWriteRequest(request, Message.valueOf(bytes))); + CompletableFuture> streamEndOfRequestAsync( + ReferenceCountedObject requestRef) { + return messageStreamRequests.streamEndOfRequestAsync(requestRef); } CompletableFuture addWatchRequest(RaftClientRequest request) { @@ -1241,22 +1242,8 @@ private boolean checkLeaderLease() { && (server.getRaftConf().isSingleton() || lease.isValid()); } - void replyPendingRequest(TermIndex termIndex, RaftClientReply reply, RetryCacheImpl.CacheEntry cacheEntry) { - final PendingRequest pending = pendingRequests.remove(termIndex); - - final LongSupplier replyMethod = () -> { - cacheEntry.updateResult(reply); - if (pending != null) { - pending.setReply(reply); - } - return termIndex.getIndex(); - }; - - if (readIndexType == Type.REPLIED_INDEX) { - replyFlusher.hold(replyMethod); - } else { - replyMethod.getAsLong(); - } + void replyPendingRequest(TermIndex termIndex, RaftClientReply reply) { + pendingRequests.replyPendingRequest(termIndex, reply); } TransactionContext getTransactionContext(TermIndex termIndex) { @@ -1351,6 +1338,7 @@ private static boolean isCaughtUp(FollowerInfo follower) { } @Override + @SuppressWarnings("deprecation") public void checkHealth(FollowerInfo follower) { final TimeDuration elapsedTime = follower.getLastRpcResponseTime().elapsedTime(); if (elapsedTime.compareTo(server.properties().rpcSlownessTimeout()) > 0) { diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java index ac81b348bb..c00c57b364 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java @@ -25,12 +25,15 @@ import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; class MessageStreamRequests { public static final Logger LOG = LoggerFactory.getLogger(MessageStreamRequests.class); @@ -39,12 +42,14 @@ private static class PendingStream { private final ClientInvocationId key; private long nextId = -1; private ByteString bytes = ByteString.EMPTY; + private final List> pendingRefs = new LinkedList<>(); PendingStream(ClientInvocationId key) { this.key = key; } - synchronized CompletableFuture append(long messageId, Message message) { + synchronized CompletableFuture append(long messageId, + ReferenceCountedObject requestRef) { if (nextId == -1) { nextId = messageId; } else if (messageId != nextId) { @@ -52,27 +57,38 @@ synchronized CompletableFuture append(long messageId, Message messag "Unexpected message id in " + key + ": messageId = " + messageId + " != nextId = " + nextId)); } nextId++; + final Message message = requestRef.retain().getMessage(); + pendingRefs.add(requestRef); bytes = bytes.concat(message.getContent()); return CompletableFuture.completedFuture(bytes); } - synchronized CompletableFuture getBytes(long messageId, Message message) { - return append(messageId, message); + synchronized CompletableFuture> getWriteRequest(long messageId, + ReferenceCountedObject requestRef) { + return append(messageId, requestRef) + .thenApply(appended -> RaftClientRequest.toWriteRequest(requestRef.get(), () -> appended)) + .thenApply(request -> ReferenceCountedObject.delegateFrom(pendingRefs, request)); + } + + synchronized void clear() { + pendingRefs.forEach(ReferenceCountedObject::release); + pendingRefs.clear(); } } static class StreamMap { - private final ConcurrentMap map = new ConcurrentHashMap<>(); + private final Map map = new HashMap<>(); - PendingStream computeIfAbsent(ClientInvocationId key) { + synchronized PendingStream computeIfAbsent(ClientInvocationId key) { return map.computeIfAbsent(key, PendingStream::new); } - PendingStream remove(ClientInvocationId key) { + synchronized PendingStream remove(ClientInvocationId key) { return map.remove(key); } - void clear() { + synchronized void clear() { + map.values().forEach(PendingStream::clear); map.clear(); } } @@ -84,15 +100,18 @@ void clear() { this.name = name + "-" + JavaUtils.getClassSimpleName(getClass()); } - CompletableFuture streamAsync(RaftClientRequest request) { + CompletableFuture streamAsync(ReferenceCountedObject requestRef) { + final RaftClientRequest request = requestRef.get(); final MessageStreamRequestTypeProto stream = request.getType().getMessageStream(); Preconditions.assertTrue(!stream.getEndOfRequest()); final ClientInvocationId key = ClientInvocationId.valueOf(request.getClientId(), stream.getStreamId()); final PendingStream pending = streams.computeIfAbsent(key); - return pending.append(stream.getMessageId(), request.getMessage()); + return pending.append(stream.getMessageId(), requestRef); } - CompletableFuture streamEndOfRequestAsync(RaftClientRequest request) { + CompletableFuture> streamEndOfRequestAsync( + ReferenceCountedObject requestRef) { + final RaftClientRequest request = requestRef.get(); final MessageStreamRequestTypeProto stream = request.getType().getMessageStream(); Preconditions.assertTrue(stream.getEndOfRequest()); final ClientInvocationId key = ClientInvocationId.valueOf(request.getClientId(), stream.getStreamId()); @@ -101,7 +120,7 @@ CompletableFuture streamEndOfRequestAsync(RaftClientRequest request) if (pending == null) { return JavaUtils.completeExceptionally(new StreamException(name + ": " + key + " not found")); } - return pending.getBytes(stream.getMessageId(), request.getMessage()); + return pending.getWriteRequest(stream.getMessageId(), requestRef); } void clear() { diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java index ed13b10113..d72fcde90b 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequest.java @@ -38,7 +38,7 @@ class PendingRequest { private final CompletableFuture futureToReturn; PendingRequest(RaftClientRequest request, TransactionContext entry) { - this.termIndex = entry == null? null: TermIndex.valueOf(entry.getLogEntry()); + this.termIndex = entry == null? null: TermIndex.valueOf(entry.getLogEntryUnsafe()); this.request = request; this.entry = entry; if (request.is(TypeCase.FORWARD)) { diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequests.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequests.java index f89d354e6a..259695d5ed 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequests.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/PendingRequests.java @@ -117,7 +117,7 @@ private static class RequestMap { raftServerMetrics.addNumPendingRequestsMegaByteSize(resource::getMegaByteSize); } - synchronized Permit tryAcquire(Message message) { + Permit tryAcquire(Message message) { final int messageSize = Message.getSize(message); final int messageSizeMb = roundUpMb(messageSize ); final Acquired acquired = resource.tryAcquire(messageSizeMb); @@ -139,7 +139,13 @@ synchronized Permit tryAcquire(Message message) { if (messageSizeMb > diffMb) { resource.releaseExtraMb(messageSizeMb - diffMb); } + return putPermit(); + } + private synchronized Permit putPermit() { + if (resource.isClosed()) { + return null; + } final Permit permit = new Permit(); permits.put(permit, permit); return permit; @@ -151,9 +157,9 @@ synchronized PendingRequest put(Permit permit, PendingRequest p) { if (removed == null) { return null; } - Preconditions.assertSame(permit, removed, "permit"); + Preconditions.assertTrue(removed == permit); final PendingRequest previous = map.put(p.getTermIndex(), p); - Preconditions.assertNull(previous, "previous"); + Preconditions.assertTrue(previous == null); return p; } @@ -264,13 +270,12 @@ TransactionContext getTransactionContext(TermIndex termIndex) { return pendingRequest != null ? pendingRequest.getEntry() : null; } - /** @return the removed the {@link PendingRequest} for the given {@link TermIndex}. */ - PendingRequest remove(TermIndex termIndex) { + void replyPendingRequest(TermIndex termIndex, RaftClientReply reply) { final PendingRequest pending = pendingRequests.remove(termIndex); if (pending != null) { Preconditions.assertEquals(termIndex, pending.getTermIndex(), "termIndex"); + pending.setReply(reply); } - return pending; } /** diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java index 1c9cd3f658..043ba1ee71 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java @@ -18,7 +18,6 @@ package org.apache.ratis.server.impl; import org.apache.ratis.client.impl.ClientProtoUtils; -import org.apache.ratis.client.impl.OrderedAsync; import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.metrics.Timekeeper; import org.apache.ratis.proto.RaftProtos.AppendEntriesReplyProto; @@ -82,9 +81,9 @@ import org.apache.ratis.server.RaftServerRpc; import org.apache.ratis.server.impl.LeaderElection.Phase; import org.apache.ratis.server.impl.RetryCacheImpl.CacheEntry; +import org.apache.ratis.server.impl.ServerImplUtils.ConsecutiveIndices; import org.apache.ratis.server.impl.ServerImplUtils.NavigableIndices; -import org.apache.ratis.server.leader.LeaderState.StepDownReason; -import org.apache.ratis.server.leader.LogAppender; +import org.apache.ratis.server.leader.LeaderState; import org.apache.ratis.server.metrics.LeaderElectionMetrics; import org.apache.ratis.server.metrics.RaftServerMetricsImpl; import org.apache.ratis.server.protocol.RaftServerAsynchronousProtocol; @@ -101,8 +100,6 @@ import org.apache.ratis.statemachine.impl.TransactionContextImpl; import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting; import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException; -import org.apache.ratis.trace.TraceServer; -import org.apache.ratis.trace.TraceUtils; import org.apache.ratis.util.CodeInjectionForTesting; import org.apache.ratis.util.CollectionUtils; import org.apache.ratis.util.ConcurrentUtils; @@ -114,14 +111,17 @@ import org.apache.ratis.util.MemoizedSupplier; import org.apache.ratis.util.Preconditions; import org.apache.ratis.util.ProtoUtils; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.TimeDuration; import org.apache.ratis.util.function.CheckedSupplier; +import org.apache.ratis.util.function.UncheckedAutoCloseableSupplier; import java.io.File; import java.io.IOException; import java.nio.file.NoSuchFileException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; @@ -136,12 +136,12 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.apache.ratis.server.impl.LeaderElection.Result.NOT_IN_CONF; import static org.apache.ratis.server.impl.ServerImplUtils.assertEntries; import static org.apache.ratis.server.impl.ServerImplUtils.assertGroup; import static org.apache.ratis.server.impl.ServerImplUtils.effectiveCommitIndex; @@ -150,7 +150,6 @@ import static org.apache.ratis.server.impl.ServerProtoUtils.toReadIndexRequestProto; import static org.apache.ratis.server.impl.ServerProtoUtils.toRequestVoteReplyProto; import static org.apache.ratis.server.impl.ServerProtoUtils.toStartLeaderElectionReplyProto; -import static org.apache.ratis.server.raftlog.LogProtoUtils.toLogEntryTermIndexString; import static org.apache.ratis.server.util.ServerStringUtils.toAppendEntriesReplyString; import static org.apache.ratis.server.util.ServerStringUtils.toAppendEntriesRequestString; import static org.apache.ratis.server.util.ServerStringUtils.toRequestVoteReplyString; @@ -175,7 +174,7 @@ public RaftPeerRole getCurrentRole() { @Override public boolean isLeaderReady() { - return getRole().isLeaderReady(); + return isLeader() && getRole().isLeaderReady(); } @Override @@ -240,25 +239,32 @@ public long[] getFollowerMatchIndices() { private final RetryCacheImpl retryCache; private final CommitInfoCache commitInfoCache = new CommitInfoCache(); private final WriteIndexCache writeIndexCache; - private final NavigableIndices appendLogTermIndices; private final RaftServerJmxAdapter jmxAdapter = new RaftServerJmxAdapter(this); private final LeaderElectionMetrics leaderElectionMetrics; private final RaftServerMetricsImpl raftServerMetrics; - - // Disallow appendEntries before start() complete; otherwise, it could fail with illegal lifeCycle transition - private final AtomicBoolean startComplete = new AtomicBoolean(false); - private final AtomicBoolean firstElectionSinceStartup = new AtomicBoolean(true); private final CountDownLatch closeFinishedLatch = new CountDownLatch(1); + // To avoid append entry before complete start() method + // For example, if thread1 start(), but before thread1 startAsFollower(), thread2 receive append entry + // request, and change state to RUNNING by lifeCycle.compareAndTransition(STARTING, RUNNING), + // then thread1 execute lifeCycle.transition(RUNNING) in startAsFollower(), + // So happens IllegalStateException: ILLEGAL TRANSITION: RUNNING -> RUNNING, + private final AtomicBoolean startComplete; + private final TransferLeadership transferLeadership; private final SnapshotManagementRequestHandler snapshotRequestHandler; private final SnapshotInstallationHandler snapshotInstallationHandler; private final ExecutorService serverExecutor; private final ExecutorService clientExecutor; + + private final AtomicBoolean firstElectionSinceStartup = new AtomicBoolean(true); private final ThreadGroup threadGroup; + private final AtomicReference> appendLogFuture; + private final NavigableIndices appendLogTermIndices = new NavigableIndices(); + RaftServerImpl(RaftGroup group, StateMachine stateMachine, RaftServerProxy proxy, RaftStorage.StartupOption option) throws IOException { final RaftPeerId id = proxy.getId(); @@ -280,18 +286,19 @@ public long[] getFollowerMatchIndices() { this.readOption = RaftServerConfigKeys.Read.option(properties); this.writeIndexCache = new WriteIndexCache(properties); this.transactionManager = new TransactionManager(id); - TraceUtils.setTracerWhenEnabled(properties); this.leaderElectionMetrics = LeaderElectionMetrics.getLeaderElectionMetrics( getMemberId(), state::getLastLeaderElapsedTimeMs); this.raftServerMetrics = RaftServerMetricsImpl.computeIfAbsentRaftServerMetrics( getMemberId(), this::getCommitIndex, retryCache::getStatistics); + this.startComplete = new AtomicBoolean(false); + this.threadGroup = new ThreadGroup(proxy.getThreadGroup(), getMemberId().toString()); + this.transferLeadership = new TransferLeadership(this, properties); this.snapshotRequestHandler = new SnapshotManagementRequestHandler(this); this.snapshotInstallationHandler = new SnapshotInstallationHandler(this, properties); - this.appendLogTermIndices = RaftServerConfigKeys.Log.appendEntriesComposeEnabled(properties) ? - new NavigableIndices() : null; + this.appendLogFuture = new AtomicReference<>(CompletableFuture.completedFuture(null)); this.serverExecutor = ConcurrentUtils.newThreadPoolWithMax( RaftServerConfigKeys.ThreadPool.serverCached(properties), @@ -301,7 +308,6 @@ public long[] getFollowerMatchIndices() { RaftServerConfigKeys.ThreadPool.clientCached(properties), RaftServerConfigKeys.ThreadPool.clientSize(properties), id + "-client"); - this.threadGroup = new ThreadGroup(proxy.getThreadGroup(), getMemberId().toString()); } private long getCommitIndex(RaftPeerId id) { @@ -406,7 +412,7 @@ boolean start() throws IOException { startAsPeer(RaftPeerRole.LISTENER); } else { LOG.info("{}: start with initializing state, conf={}", getMemberId(), conf); - setRole(RaftPeerRole.FOLLOWER, NOT_IN_CONF); + setRole(RaftPeerRole.FOLLOWER, "start"); } jmxAdapter.registerMBean(); @@ -557,12 +563,12 @@ public void close() { try { ConcurrentUtils.shutdownAndWait(clientExecutor); } catch (Exception e) { - LOG.warn("{}: Failed to shutdown clientExecutor", getMemberId(), e); + LOG.warn(getMemberId() + ": Failed to shutdown clientExecutor", e); } try { ConcurrentUtils.shutdownAndWait(serverExecutor); } catch (Exception e) { - LOG.warn("{}: Failed to shutdown serverExecutor", getMemberId(), e); + LOG.warn(getMemberId() + ": Failed to shutdown serverExecutor", e); } closeFinishedLatch.countDown(); }); @@ -587,7 +593,7 @@ private synchronized CompletableFuture changeToFollower( throw new IllegalStateException("Unexpected role " + old); } CompletableFuture future = CompletableFuture.completedFuture(null); - if (shouldSetFollower(old, force)) { + if ((old != RaftPeerRole.FOLLOWER || force) && old != RaftPeerRole.LISTENER) { setRole(RaftPeerRole.FOLLOWER, reason); if (old == RaftPeerRole.LEADER) { future = role.shutdownLeaderState(false) @@ -603,7 +609,7 @@ private synchronized CompletableFuture changeToFollower( state.setLeader(null, reason); } else if (old == RaftPeerRole.CANDIDATE) { future = role.shutdownLeaderElection(); - } else if (old == RaftPeerRole.FOLLOWER || old == RaftPeerRole.LISTENER) { + } else if (old == RaftPeerRole.FOLLOWER) { future = role.shutdownFollowerState(); } @@ -616,14 +622,6 @@ private synchronized CompletableFuture changeToFollower( return future; } - private boolean shouldSetFollower(RaftPeerRole old, boolean force) { - if (old == RaftPeerRole.LISTENER) { - final RaftConfigurationImpl conf = state.getRaftConf(); - return conf.isStable() && conf.containsInConf(getId()); - } - return old != RaftPeerRole.FOLLOWER || force; - } - synchronized CompletableFuture changeToFollowerAndPersistMetadata( long newTerm, boolean allowListener, @@ -649,6 +647,15 @@ synchronized void changeToLeader() { @Override public Collection getCommitInfos() { + try { + return getCommitInfosImpl(); + } catch (Throwable t) { + LOG.warn("{} Failed to getCommitInfos", getMemberId(), t); + return Collections.emptyList(); + } + } + + private Collection getCommitInfosImpl() { final List infos = new ArrayList<>(); // add the commit info of this server final long commitIndex = updateCommitInfoCache(); @@ -757,23 +764,23 @@ RaftClientReply newExceptionReply(RaftClientRequest request, RaftException excep } private CompletableFuture checkLeaderState(RaftClientRequest request) { - try { - assertGroup(getMemberId(), request); - } catch (GroupMismatchException e) { - return JavaUtils.completeExceptionally(e); - } - return checkLeaderState(request, null, null); + return checkLeaderState(request, null); } /** * @return null if the server is in leader state. */ - private CompletableFuture checkLeaderState( - RaftClientRequest request, CacheEntry entry, TransactionContextImpl context) { + private CompletableFuture checkLeaderState(RaftClientRequest request, CacheEntry entry) { + try { + assertGroup(getMemberId(), request); + } catch (GroupMismatchException e) { + return RetryCacheImpl.failWithException(e, entry); + } + if (!getInfo().isLeader()) { NotLeaderException exception = generateNotLeaderException(); final RaftClientReply reply = newExceptionReply(request, exception); - return failWithReply(reply, entry, context); + return RetryCacheImpl.failWithReply(reply, entry); } if (!getInfo().isLeaderReady()) { final CacheEntry cacheEntry = retryCache.getIfPresent(ClientInvocationId.valueOf(request)); @@ -782,13 +789,13 @@ private CompletableFuture checkLeaderState( } final LeaderNotReadyException lnre = new LeaderNotReadyException(getMemberId()); final RaftClientReply reply = newExceptionReply(request, lnre); - return failWithReply(reply, entry, context); + return RetryCacheImpl.failWithReply(reply, entry); } if (!request.isReadOnly() && isSteppingDown()) { final LeaderSteppingDownException lsde = new LeaderSteppingDownException(getMemberId() + " is stepping down"); final RaftClientReply reply = newExceptionReply(request, lsde); - return failWithReply(reply, entry, context); + return RetryCacheImpl.failWithReply(reply, entry); } return null; @@ -814,102 +821,67 @@ void assertLifeCycleState(Set expected) throws ServerNotReadyEx getMemberId() + " is not in " + expected + ": current state is " + c), expected); } - private CompletableFuture getResourceUnavailableReply(String op, - RaftClientRequest request, CacheEntry entry, TransactionContextImpl context) { - final ResourceUnavailableException e = new ResourceUnavailableException(getMemberId() - + ": Failed to " + op + " for " + request); - cancelTransaction(context, e); - return entry.failWithException(e); - } - - private CompletableFuture failWithReply( - RaftClientReply reply, CacheEntry entry, TransactionContextImpl context) { - if (context != null) { - cancelTransaction(context, reply.getException()); - } - - if (entry == null) { - return CompletableFuture.completedFuture(reply); - } - entry.failWithReply(reply); - return entry.getReplyFuture(); - } - - /** Cancel a transaction and notify the state machine. Set exception if provided to the transaction context. */ - private void cancelTransaction(TransactionContextImpl context, Exception exception) { - if (context == null) { - return; - } - - if (exception != null) { - context.setException(exception); - } - - try { - context.cancelTransaction(); - } catch (IOException ioe) { - LOG.warn("{}: Failed to cancel transaction {}", getMemberId(), context, ioe); - } - } - /** - * Handle a normal update request from client. + * Append a transaction to the log for processing a client request. + * Note that the given request could be different from {@link TransactionContext#getClientRequest()} + * since the request could be converted; see {@link #convertRaftClientRequest(RaftClientRequest)}. + * + * @param request The client request. + * @param context The context of the transaction. + * @param cacheEntry the entry in the retry cache. + * @return a future of the reply. */ private CompletableFuture appendTransaction( - RaftClientRequest request, TransactionContextImpl context, CacheEntry cacheEntry) throws IOException { + RaftClientRequest request, TransactionContextImpl context, CacheEntry cacheEntry) { + Objects.requireNonNull(request, "request == null"); CodeInjectionForTesting.execute(APPEND_TRANSACTION, getId(), request.getClientId(), request, context, cacheEntry); - assertLifeCycleState(LifeCycle.States.RUNNING); - - final LeaderStateImpl unsyncedLeaderState = role.getLeaderState().orElse(null); - if (unsyncedLeaderState == null) { - final NotLeaderException nle = generateNotLeaderException(); - final RaftClientReply reply = newExceptionReply(request, nle); - return failWithReply(reply, cacheEntry, context); - } - final PendingRequests.Permit unsyncedPermit = unsyncedLeaderState.tryAcquirePendingRequest(request.getMessage()); - if (unsyncedPermit == null) { - return getResourceUnavailableReply("acquire a pending write request", request, cacheEntry, context); - } - - final LeaderStateImpl leaderState; final PendingRequest pending; synchronized (this) { - final CompletableFuture reply = checkLeaderState(request, cacheEntry, context); + final CompletableFuture reply = checkLeaderState(request, cacheEntry); if (reply != null) { return reply; } - leaderState = role.getLeaderStateNonNull(); - final PendingRequests.Permit permit = leaderState == unsyncedLeaderState ? unsyncedPermit - : leaderState.tryAcquirePendingRequest(request.getMessage()); - if (permit == null) { - return getResourceUnavailableReply("acquire a pending write request", request, cacheEntry, context); - } - // append the message to its local log + final LeaderStateImpl leaderState = role.getLeaderStateNonNull(); writeIndexCache.add(request.getClientId(), context.getLogIndexFuture()); + + final PendingRequests.Permit permit = leaderState.tryAcquirePendingRequest(request.getMessage()); + if (permit == null) { + cacheEntry.failWithException(new ResourceUnavailableException( + getMemberId() + ": Failed to acquire a pending write request for " + request)); + return cacheEntry.getReplyFuture(); + } try { + assertLifeCycleState(LifeCycle.States.RUNNING); state.appendLog(context); } catch (StateMachineException e) { // the StateMachineException is thrown by the SM in the preAppend stage. // Return the exception in a RaftClientReply. - final RaftClientReply exceptionReply = newExceptionReply(request, e); + RaftClientReply exceptionReply = newExceptionReply(request, e); + cacheEntry.failWithReply(exceptionReply); // leader will step down here if (e.leaderShouldStepDown() && getInfo().isLeader()) { - leaderState.submitStepDownEvent(StepDownReason.STATE_MACHINE_EXCEPTION); + leaderState.submitStepDownEvent(LeaderState.StepDownReason.STATE_MACHINE_EXCEPTION); } - return failWithReply(exceptionReply, cacheEntry, null); + return CompletableFuture.completedFuture(exceptionReply); + } catch (ServerNotReadyException e) { + final RaftClientReply exceptionReply = newExceptionReply(request, e); + return CompletableFuture.completedFuture(exceptionReply); } // put the request into the pending queue pending = leaderState.addPendingRequest(permit, request, context); if (pending == null) { - return getResourceUnavailableReply("add a pending write request", request, cacheEntry, context); + cacheEntry.failWithException(new ResourceUnavailableException( + getMemberId() + ": Failed to add a pending write request for " + request)); + return cacheEntry.getReplyFuture(); } + leaderState.notifySenders(); } - leaderState.notifySenders(); + return pending.getFuture(); } @@ -947,14 +919,16 @@ private RaftClientReply combineReplies(RaftClientReply reply, RaftClientReply wa } void stepDownOnJvmPause() { - role.getLeaderState().ifPresent(leader -> leader.submitStepDownEvent(StepDownReason.JVM_PAUSE)); + role.getLeaderState().ifPresent(leader -> leader.submitStepDownEvent(LeaderState.StepDownReason.JVM_PAUSE)); } - private RaftClientRequest filterDataStreamRaftClientRequest(RaftClientRequest request) - throws InvalidProtocolBufferException { - return !request.is(TypeCase.FORWARD) ? request : ClientProtoUtils.toRaftClientRequest( - RaftClientRequestProto.parseFrom( - request.getMessage().getContent().asReadOnlyByteBuffer())); + /** If the given request is {@link TypeCase#FORWARD}, convert it. */ + static RaftClientRequest convertRaftClientRequest(RaftClientRequest request) throws InvalidProtocolBufferException { + if (!request.is(TypeCase.FORWARD)) { + return request; + } + return ClientProtoUtils.toRaftClientRequest(RaftClientRequestProto.parseFrom( + request.getMessage().getContent().asReadOnlyByteBuffer())); } CompletableFuture executeSubmitServerRequestAsync( @@ -964,35 +938,40 @@ CompletableFuture executeSubmitServerRequestAsync( serverExecutor).join(); } - CompletableFuture executeSubmitClientRequestAsync(RaftClientRequest request) { - return CompletableFuture.supplyAsync( - () -> JavaUtils.callAsUnchecked(() -> submitClientRequestAsync(request), CompletionException::new), - clientExecutor).join(); + CompletableFuture executeSubmitClientRequestAsync( + ReferenceCountedObject request) { + return CompletableFuture.supplyAsync(() -> submitClientRequestAsync(request), clientExecutor).join(); } @Override public CompletableFuture submitClientRequestAsync( - RaftClientRequest request) throws IOException { - return TraceServer.traceAsyncMethod( - () -> submitClientRequestAsyncInternal(request), - request, getMemberId().toString(), "raft.server.submitClientRequestAsync"); - } + ReferenceCountedObject requestRef) { + final RaftClientRequest request = requestRef.retain(); + try { + LOG.debug("{}: receive client request({})", getMemberId(), request); + assertLifeCycleState(LifeCycle.States.RUNNING); - private CompletableFuture submitClientRequestAsyncInternal( - RaftClientRequest request) throws IOException { - assertLifeCycleState(LifeCycle.States.RUNNING); - LOG.debug("{}: receive client request({})", getMemberId(), request); - final Timekeeper timer = raftServerMetrics.getClientRequestTimer(request.getType()); - final Optional timerContext = Optional.ofNullable(timer).map(Timekeeper::time); - return replyFuture(request).whenComplete((clientReply, exception) -> { - timerContext.ifPresent(Timekeeper.Context::stop); - if (exception != null || clientReply.getException() != null) { - raftServerMetrics.incFailedRequestCount(request.getType()); - } - }); + RaftClientRequest.Type type = request.getType(); + final Timekeeper timer = raftServerMetrics.getClientRequestTimer(type); + final Optional timerContext = Optional.ofNullable(timer).map(Timekeeper::time); + return replyFuture(requestRef).whenComplete((clientReply, exception) -> { + timerContext.ifPresent(Timekeeper.Context::stop); + if (exception != null || clientReply.getException() != null) { + raftServerMetrics.incFailedRequestCount(type); + } + }); + } catch (RaftException e) { + return CompletableFuture.completedFuture(newExceptionReply(request, e)); + } catch (Throwable t) { + LOG.error("{} Failed to submitClientRequestAsync for {}", getMemberId(), request, t); + return CompletableFuture.completedFuture(newExceptionReply(request, new RaftException(t))); + } finally { + requestRef.release(); + } } - private CompletableFuture replyFuture(RaftClientRequest request) throws IOException { + private CompletableFuture replyFuture(ReferenceCountedObject requestRef) { + final RaftClientRequest request = requestRef.get(); retryCache.invalidateRepliedRequests(request); final TypeCase type = request.getType().getTypeCase(); @@ -1004,17 +983,18 @@ private CompletableFuture replyFuture(RaftClientRequest request case WATCH: return watchAsync(request); case MESSAGESTREAM: - return messageStreamAsync(request); + return messageStreamAsync(requestRef); case WRITE: case FORWARD: - return writeAsync(request); + return writeAsync(requestRef); default: throw new IllegalStateException("Unexpected request type: " + type + ", request=" + request); } } - private CompletableFuture writeAsync(RaftClientRequest request) throws IOException { - final CompletableFuture future = writeAsyncImpl(request); + private CompletableFuture writeAsync(ReferenceCountedObject requestRef) { + final RaftClientRequest request = requestRef.get(); + final CompletableFuture future = writeAsyncImpl(requestRef); if (request.is(TypeCase.WRITE)) { // check replication final ReplicationLevel replication = request.getType().getWrite().getReplication(); @@ -1025,7 +1005,8 @@ private CompletableFuture writeAsync(RaftClientRequest request) return future; } - private CompletableFuture writeAsyncImpl(RaftClientRequest request) throws IOException { + private CompletableFuture writeAsyncImpl(ReferenceCountedObject requestRef) { + final RaftClientRequest request = requestRef.get(); final CompletableFuture reply = checkLeaderState(request); if (reply != null) { return reply; @@ -1038,30 +1019,30 @@ private CompletableFuture writeAsyncImpl(RaftClientRequest requ // return the cached future. return cacheEntry.getReplyFuture(); } - // This request will be added to pending requests later in appendTransaction. - // Any failure in between must invoke cancelTransaction. - final TransactionContextImpl context = (TransactionContextImpl) stateMachine.startTransaction( - filterDataStreamRaftClientRequest(request)); + // TODO: this client request will not be added to pending requests until + // later which means that any failure in between will leave partial state in + // the state machine. We should call cancelTransaction() for failed requests + final TransactionContextImpl context; + try { + context = (TransactionContextImpl) stateMachine.startTransaction(convertRaftClientRequest(request)); + } catch (IOException e) { + final RaftClientReply exceptionReply = newExceptionReply(request, + new RaftException("Failed to startTransaction for " + request, e)); + cacheEntry.failWithReply(exceptionReply); + return CompletableFuture.completedFuture(exceptionReply); + } if (context.getException() != null) { - final Exception exception = context.getException(); - final StateMachineException e = new StateMachineException(getMemberId(), exception); + final StateMachineException e = new StateMachineException(getMemberId(), context.getException()); final RaftClientReply exceptionReply = newExceptionReply(request, e); - return failWithReply(exceptionReply, cacheEntry, context); + cacheEntry.failWithReply(exceptionReply); + return CompletableFuture.completedFuture(exceptionReply); } - try { - return appendTransaction(request, context, cacheEntry); - } catch (Exception e) { - cancelTransaction(context, e); - throw e; - } + context.setDelegatedRef(requestRef); + return appendTransaction(request, context, cacheEntry); } private CompletableFuture watchAsync(RaftClientRequest request) { - if (OrderedAsync.DUMMY.getContent().equals(request.getMessage().getContent())) { - return CompletableFuture.completedFuture(RaftClientReply.newBuilder().setRequest(request).build()); - } - final CompletableFuture reply = checkLeaderState(request); if (reply != null) { return reply; @@ -1076,7 +1057,7 @@ private CompletableFuture watchAsync(RaftClientRequest request) private CompletableFuture staleReadAsync(RaftClientRequest request) { final long minIndex = request.getType().getStaleRead().getMinIndex(); final long commitIndex = state.getLog().getLastCommittedIndex(); - LOG.debug("{}: minIndex={}, commitIndex={} from {}", getMemberId(), minIndex, commitIndex, request.getClientId()); + LOG.debug("{}: minIndex={}, commitIndex={}", getMemberId(), minIndex, commitIndex); if (commitIndex < minIndex) { final StaleReadException e = new StaleReadException( "Unable to serve stale-read due to server commit index = " + commitIndex + " < min = " + minIndex); @@ -1160,28 +1141,36 @@ private RaftClientReply readException2Reply(RaftClientRequest request, Throwable } } - private CompletableFuture messageStreamAsync(RaftClientRequest request) throws IOException { + private CompletableFuture messageStreamAsync(ReferenceCountedObject requestRef) { + final RaftClientRequest request = requestRef.get(); final CompletableFuture reply = checkLeaderState(request); if (reply != null) { return reply; } if (request.getType().getMessageStream().getEndOfRequest()) { - final CompletableFuture f = streamEndOfRequestAsync(request); + final CompletableFuture> f = streamEndOfRequestAsync(requestRef); if (f.isCompletedExceptionally()) { return f.thenApply(r -> null); } // the message stream has ended and the request become a WRITE request - return replyFuture(f.join()); + ReferenceCountedObject joinedRequest = f.join(); + try { + return replyFuture(joinedRequest); + } finally { + // Released pending streaming requests. + joinedRequest.release(); + } } return role.getLeaderState() - .map(ls -> ls.streamAsync(request)) + .map(ls -> ls.streamAsync(requestRef)) .orElseGet(() -> CompletableFuture.completedFuture( newExceptionReply(request, generateNotLeaderException()))); } - private CompletableFuture streamEndOfRequestAsync(RaftClientRequest request) { + private CompletableFuture> streamEndOfRequestAsync( + ReferenceCountedObject request) { return role.getLeaderState() .map(ls -> ls.streamEndOfRequestAsync(request)) .orElse(null); @@ -1475,13 +1464,12 @@ public RequestVoteReplyProto requestVote(RequestVoteRequestProto r) throws IOExc RaftPeerId.valueOf(request.getRequestorId()), ProtoUtils.toRaftGroupId(request.getRaftGroupId()), r.getCandidateTerm(), - TermIndex.valueOf(r.getCandidateLastEntry()), - request.getCallId()); + TermIndex.valueOf(r.getCandidateLastEntry())); } private RequestVoteReplyProto requestVote(Phase phase, RaftPeerId candidateId, RaftGroupId candidateGroupId, - long candidateTerm, TermIndex candidateLastEntry, long callId) throws IOException { + long candidateTerm, TermIndex candidateLastEntry) throws IOException { CodeInjectionForTesting.execute(REQUEST_VOTE, getId(), candidateId, candidateTerm, candidateLastEntry); LOG.info("{}: receive requestVote({}, {}, {}, {}, {})", @@ -1516,7 +1504,7 @@ private RequestVoteReplyProto requestVote(Phase phase, shouldShutdown = true; } reply = toRequestVoteReplyProto(candidateId, getMemberId(), - voteGranted, state.getCurrentTerm(), shouldShutdown, state.getLastEntry(), callId); + voteGranted, state.getCurrentTerm(), shouldShutdown, state.getLastEntry()); if (LOG.isInfoEnabled()) { LOG.info("{} replies to {} vote request: {}. Peer's state: {}", getMemberId(), phase, toRequestVoteReplyString(reply), state); @@ -1532,15 +1520,16 @@ private RequestVoteReplyProto requestVote(Phase phase, public AppendEntriesReplyProto appendEntries(AppendEntriesRequestProto r) throws IOException { try { - return appendEntriesAsync(r).join(); + return appendEntriesAsync(ReferenceCountedObject.wrap(r)).join(); } catch (CompletionException e) { throw IOUtils.asIOException(JavaUtils.unwrapCompletionException(e)); } } @Override - public CompletableFuture appendEntriesAsync(AppendEntriesRequestProto r) - throws IOException { + public CompletableFuture appendEntriesAsync( + ReferenceCountedObject requestRef) throws IOException { + final AppendEntriesRequestProto r = requestRef.retain(); final RaftRpcRequestProto request = r.getServerRequest(); final TermIndex previous = r.hasPreviousLog()? TermIndex.valueOf(r.getPreviousLog()) : null; try { @@ -1556,11 +1545,13 @@ public CompletableFuture appendEntriesAsync(AppendEntri assertGroup(getMemberId(), leaderId, leaderGroupId); assertEntries(r, previous, state); - return appendEntriesAsync(leaderId, request.getCallId(), previous, r); + return appendEntriesAsync(leaderId, request.getCallId(), previous, requestRef); } catch(Exception t) { LOG.error("{}: Failed appendEntries* {}", getMemberId(), toAppendEntriesRequestString(r, stateMachine::toStateMachineLogEntryString), t); throw IOUtils.asIOException(t); + } finally { + requestRef.release(); } } @@ -1577,26 +1568,13 @@ public CompletableFuture readIndexAsync(ReadIndexRequestPro return getReadIndex(ClientProtoUtils.toRaftClientRequest(request.getClientRequest()), leader) .thenApply(index -> toReadIndexReplyProto(peerId, getMemberId(), true, index)) - .whenComplete((reply, exception) -> { - if (exception == null) { - // Leader should try to trigger heartbeat immediately after leader replies the ReadIndex to the follower - // so that the follower's commitIndex can be updated to the leader's commitIndex and the follower - // can start applying the logs up until the leader's commitIndex (instead of waiting for the next - // AppendEntries to happen through heartbeat or new transactions (which might increase the latency - // considerably)). - // Note that if the follower commitIndex is already equal to the leader's commitIndex, no heartbeat - // will be triggered, see GrpcLogAppender#isFollowerCommitBehindLastCommitIndex. - RaftPeerId requestorId = RaftPeerId.valueOf(reply.getServerReply().getRequestorId()); - leader.getLogAppender(requestorId).ifPresent(LogAppender::triggerHeartbeat); - } - }) .exceptionally(throwable -> toReadIndexReplyProto(peerId, getMemberId())); } static void logAppendEntries(boolean isHeartbeat, Supplier message) { if (isHeartbeat) { if (LOG.isTraceEnabled()) { - LOG.trace("HEARTBEAT: {}", message.get()); + LOG.trace("HEARTBEAT: " + message.get()); } } else { if (LOG.isDebugEnabled()) { @@ -1624,7 +1602,8 @@ ExecutorService getServerExecutor() { } private CompletableFuture appendEntriesAsync(RaftPeerId leaderId, long callId, - TermIndex previous, AppendEntriesRequestProto proto) throws IOException { + TermIndex previous, ReferenceCountedObject requestRef) throws IOException { + final AppendEntriesRequestProto proto = requestRef.get(); final List entries = proto.getEntriesList(); final boolean isHeartbeat = entries.isEmpty(); logAppendEntries(isHeartbeat, () -> getMemberId() + ": appendEntries* " @@ -1647,11 +1626,11 @@ leaderId, getMemberId(), currentTerm, followerCommit, state.getNextIndex(), AppendResult.NOT_LEADER, callId, RaftLog.INVALID_LOG_INDEX, isHeartbeat)); } try { - future = changeToFollowerAndPersistMetadata(leaderTerm, true, Op.APPEND_ENTRIES); + future = changeToFollowerAndPersistMetadata(leaderTerm, true, "appendEntries"); } catch (IOException e) { return JavaUtils.completeExceptionally(e); } - state.setLeader(leaderId, Op.APPEND_ENTRIES); + state.setLeader(leaderId, "appendEntries"); if (!proto.getInitializing() && lifeCycle.compareAndTransition(State.STARTING, State.RUNNING)) { role.startFollowerState(this, Op.APPEND_ENTRIES); @@ -1678,9 +1657,8 @@ leaderId, getMemberId(), currentTerm, followerCommit, inconsistencyReplyNextInde state.updateConfiguration(entries); } future.join(); - final CompletableFuture appendFuture = entries.isEmpty()? CompletableFuture.completedFuture(null) - : appendLogTermIndices != null ? appendLogTermIndices.append(entries, this::appendLog) - : JavaUtils.allOf(state.getLog().append(entries)); + final CompletableFuture appendLog = entries.isEmpty()? CompletableFuture.completedFuture(null) + : appendLog(requestRef.delegate(entries)); proto.getCommitInfosList().forEach(commitInfoCache::update); @@ -1695,12 +1673,7 @@ leaderId, getMemberId(), currentTerm, followerCommit, inconsistencyReplyNextInde final long commitIndex = effectiveCommitIndex(proto.getLeaderCommit(), previous, entries.size()); final long matchIndex = isHeartbeat? RaftLog.INVALID_LOG_INDEX: entries.get(entries.size() - 1).getIndex(); - return appendFuture.whenCompleteAsync((r, t) -> { - if (t != null) { - LOG.warn("{}: appendEntries* failed: {}", getMemberId(), toLogEntryTermIndexString(entries), t); - } else if (LOG.isDebugEnabled()) { - LOG.debug("{}: appendEntries* succeeded: {}", getMemberId(), toLogEntryTermIndexString(entries)); - } + return appendLog.whenCompleteAsync((r, t) -> { followerState.ifPresent(fs -> fs.updateLastRpcTime(FollowerState.UpdateType.APPEND_COMPLETE)); timer.stop(); }, getServerExecutor()).thenApply(v -> { @@ -1717,10 +1690,23 @@ leaderId, getMemberId(), currentTerm, followerCommit, inconsistencyReplyNextInde return reply; }); } + private CompletableFuture appendLog(ReferenceCountedObject> entriesRef) { + final List entriesTermIndices; + try(UncheckedAutoCloseableSupplier> entries = entriesRef.retainAndReleaseOnClose()) { + entriesTermIndices = ConsecutiveIndices.convert(entries.get()); + if (!appendLogTermIndices.append(entriesTermIndices)) { + // index already exists, return the last future + return appendLogFuture.get(); + } + } - private CompletableFuture appendLog(List entries) { - return CompletableFuture.completedFuture(null) - .thenComposeAsync(dummy -> JavaUtils.allOf(state.getLog().append(entries)), serverExecutor); + entriesRef.retain(); + return appendLogFuture.updateAndGet(f -> f.thenCompose( + ignored -> JavaUtils.allOf(state.getLog().append(entriesRef)))) + .whenComplete((v, e) -> { + entriesRef.release(); + appendLogTermIndices.removeExisting(entriesTermIndices); + }); } private long checkInconsistentAppendEntries(TermIndex previous, List entries) { @@ -1747,11 +1733,9 @@ private long checkInconsistentAppendEntries(TermIndex previous, List replyPendingRequest( } // update pending request - final LeaderStateImpl leader = role.getLeaderState().orElse(null); - if (leader != null) { - leader.replyPendingRequest(termIndex, r, cacheEntry); - } else { - cacheEntry.updateResult(r); - } + role.getLeaderState().ifPresent(leader -> leader.replyPendingRequest(termIndex, r)); + cacheEntry.updateResult(r); }); } @@ -1918,7 +1898,9 @@ TransactionContext getTransactionContext(LogEntryProto entry, Boolean createNew) MemoizedSupplier.valueOf(() -> stateMachine.startTransaction(entry, getInfo().getCurrentRole()))); } - CompletableFuture applyLogToStateMachine(LogEntryProto next) throws RaftLogIOException { + CompletableFuture applyLogToStateMachine(ReferenceCountedObject nextRef) + throws RaftLogIOException { + LogEntryProto next = nextRef.get(); CompletableFuture messageFuture = null; switch (next.getLogEntryBodyCase()) { @@ -1935,7 +1917,7 @@ CompletableFuture applyLogToStateMachine(LogEntryProto next) throws Raf Objects.requireNonNull(trx, "trx == null"); final ClientInvocationId invocationId = ClientInvocationId.valueOf(next.getStateMachineLogEntry()); writeIndexCache.add(invocationId.getClientId(), ((TransactionContextImpl) trx).getLogIndexFuture()); - + ((TransactionContextImpl) trx).setDelegatedRef(nextRef); try { // Let the StateMachine inject logic for committed transactions in sequential order. trx = stateMachine.applyTransactionSerial(trx); diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerProxy.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerProxy.java index 8539fa99ec..2914c434f1 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerProxy.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerProxy.java @@ -52,8 +52,8 @@ import org.apache.ratis.util.MemoizedSupplier; import org.apache.ratis.util.Preconditions; import org.apache.ratis.util.ProtoUtils; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.TimeDuration; -import org.apache.ratis.util.VersionInfo; import java.io.Closeable; import java.io.File; @@ -208,8 +208,6 @@ String toString(RaftGroupId groupId, CompletableFuture f) { RaftServerProxy(RaftPeerId id, StateMachine.Registry stateMachineRegistry, RaftProperties properties, Parameters parameters, ThreadGroup threadGroup) { - VersionInfo.load(getClass()).printStartupMessages(id, LOG::info); - this.properties = properties; this.stateMachineRegistry = stateMachineRegistry; @@ -454,9 +452,15 @@ public void close() { } @Override - public CompletableFuture submitClientRequestAsync(RaftClientRequest request) { - return getImplFuture(request.getRaftGroupId()) - .thenCompose(impl -> impl.executeSubmitClientRequestAsync(request)); + public CompletableFuture submitClientRequestAsync( + ReferenceCountedObject requestRef) { + final RaftClientRequest request = requestRef.retain(); + try { + return getImplFuture(request.getRaftGroupId()) + .thenCompose(impl -> impl.executeSubmitClientRequestAsync(requestRef)); + } finally { + requestRef.release(); + } } @Override @@ -648,11 +652,17 @@ public StartLeaderElectionReplyProto startLeaderElection(StartLeaderElectionRequ } @Override - public CompletableFuture appendEntriesAsync(AppendEntriesRequestProto request) { - final RaftGroupId groupId = ProtoUtils.toRaftGroupId(request.getServerRequest().getRaftGroupId()); - return getImplFuture(groupId) - .thenCompose(impl -> JavaUtils.callAsUnchecked( - () -> impl.appendEntriesAsync(request), CompletionException::new)); + public CompletableFuture appendEntriesAsync( + ReferenceCountedObject requestRef) { + AppendEntriesRequestProto request = requestRef.retain(); + try { + final RaftGroupId groupId = ProtoUtils.toRaftGroupId(request.getServerRequest().getRaftGroupId()); + return getImplFuture(groupId) + .thenCompose(impl -> JavaUtils.callAsUnchecked( + () -> impl.appendEntriesAsync(requestRef), CompletionException::new)); + } finally { + requestRef.release(); + } } @Override diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadRequests.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadRequests.java index 6112a46009..e63a23a0b8 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadRequests.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/ReadRequests.java @@ -20,7 +20,7 @@ import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.protocol.exceptions.ReadException; import org.apache.ratis.server.RaftServerConfigKeys; -import org.apache.ratis.util.Preconditions; +import org.apache.ratis.statemachine.StateMachine; import org.apache.ratis.util.TimeDuration; import org.apache.ratis.util.TimeoutExecutor; import org.slf4j.Logger; @@ -29,7 +29,7 @@ import java.util.NavigableMap; import java.util.TreeMap; import java.util.concurrent.CompletableFuture; -import java.util.function.LongConsumer; +import java.util.function.Consumer; /** For supporting linearizable read. */ class ReadRequests { @@ -37,18 +37,10 @@ class ReadRequests { static class ReadIndexQueue { private final TimeoutExecutor scheduler = TimeoutExecutor.getInstance(); - /** The log index known to be applied. */ - private long lastAppliedIndex; - /** - * Map : readIndex -> appliedIndexFuture (when completes, readIndex <= appliedIndex). - * Invariant: all keys > lastAppliedIndex. - */ private final NavigableMap> sorted = new TreeMap<>(); - private final TimeDuration readTimeout; - ReadIndexQueue(long lastAppliedIndex, TimeDuration readTimeout) { - this.lastAppliedIndex = lastAppliedIndex; + ReadIndexQueue(TimeDuration readTimeout) { this.readTimeout = readTimeout; } @@ -56,9 +48,6 @@ CompletableFuture add(long readIndex) { final CompletableFuture returned; final boolean create; synchronized (this) { - if (readIndex <= lastAppliedIndex) { - return CompletableFuture.completedFuture(lastAppliedIndex); - } // The same as computeIfAbsent except that it also tells if a new value is created. final CompletableFuture existing = sorted.get(readIndex); create = existing == null; @@ -90,19 +79,7 @@ private void handleTimeout(long readIndex) { /** Complete all the entries less than or equal to the given applied index. */ - synchronized void complete(long appliedIndex) { - if (appliedIndex > lastAppliedIndex) { - lastAppliedIndex = appliedIndex; - } else { - // appliedIndex <= lastAppliedIndex: nothing to do - if (!sorted.isEmpty()) { - // Assert: all keys > lastAppliedIndex. - final long first = sorted.firstKey(); - Preconditions.assertTrue(first > lastAppliedIndex, - () -> "first = " + first + " <= lastAppliedIndex = " + lastAppliedIndex); - } - return; - } + synchronized void complete(Long appliedIndex) { final NavigableMap> headMap = sorted.headMap(appliedIndex, true); headMap.values().forEach(f -> f.complete(appliedIndex)); headMap.clear(); @@ -110,16 +87,27 @@ synchronized void complete(long appliedIndex) { } private final ReadIndexQueue readIndexQueue; + private final StateMachine stateMachine; - ReadRequests(long appliedIndex, RaftProperties properties) { - this.readIndexQueue = new ReadIndexQueue(appliedIndex, RaftServerConfigKeys.Read.timeout(properties)); + ReadRequests(RaftProperties properties, StateMachine stateMachine) { + this.readIndexQueue = new ReadIndexQueue(RaftServerConfigKeys.Read.timeout(properties)); + this.stateMachine = stateMachine; } - LongConsumer getAppliedIndexConsumer() { + Consumer getAppliedIndexConsumer() { return readIndexQueue::complete; } CompletableFuture waitToAdvance(long readIndex) { - return readIndexQueue.add(readIndex); + final long lastApplied = stateMachine.getLastAppliedTermIndex().getIndex(); + if (lastApplied >= readIndex) { + return CompletableFuture.completedFuture(lastApplied); + } + final CompletableFuture f = readIndexQueue.add(readIndex); + final long current = stateMachine.getLastAppliedTermIndex().getIndex(); + if (current > lastApplied) { + readIndexQueue.complete(current); + } + return f; } } diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java index 96ad62a531..50d238b07a 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java @@ -84,10 +84,9 @@ void failWithReply(RaftClientReply reply) { replyFuture.complete(reply); } - CompletableFuture failWithException(Throwable t) { + void failWithException(Throwable t) { failed = true; replyFuture.completeExceptionally(t); - return replyFuture; } @Override @@ -257,4 +256,24 @@ public synchronized void close() { cache.invalidateAll(); statistics.set(null); } + + static CompletableFuture failWithReply( + RaftClientReply reply, CacheEntry entry) { + if (entry != null) { + entry.failWithReply(reply); + return entry.getReplyFuture(); + } else { + return CompletableFuture.completedFuture(reply); + } + } + + static CompletableFuture failWithException( + Throwable t, CacheEntry entry) { + if (entry != null) { + entry.failWithException(t); + return entry.getReplyFuture(); + } else { + return JavaUtils.completeExceptionally(t); + } + } } diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerImplUtils.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerImplUtils.java index 434f98d683..864b402a23 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerImplUtils.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerImplUtils.java @@ -47,10 +47,7 @@ import java.util.NavigableMap; import java.util.TreeMap; import java.util.Objects; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; /** Server utilities for internal use. */ public final class ServerImplUtils { @@ -122,8 +119,6 @@ Long getTerm(long index) { /** A data structure to support the {@link #contains(TermIndex)} method. */ static class NavigableIndices { private final NavigableMap map = new TreeMap<>(); - private final AtomicReference> future - = new AtomicReference<>(CompletableFuture.completedFuture(null)); boolean contains(TermIndex ti) { final Long term = getTerm(ti.getIndex()); @@ -142,15 +137,7 @@ synchronized Long getTerm(long index) { return floorEntry.getValue().getTerm(index); } - CompletableFuture append(List entries, - Function, CompletableFuture> appendLog) { - final List entriesTermIndices = ConsecutiveIndices.convert(entries); - return alreadyExists(entriesTermIndices) ? future.get() - : future.updateAndGet(f -> f.thenComposeAsync(ignored -> appendLog.apply(entries))) - .whenComplete((v, e) -> removeExisting(entriesTermIndices)); - } - - private synchronized boolean alreadyExists(List entriesTermIndices) { + synchronized boolean append(List entriesTermIndices) { for(int i = 0; i < entriesTermIndices.size(); i++) { final ConsecutiveIndices indices = entriesTermIndices.get(i); final ConsecutiveIndices previous = map.put(indices.startIndex, indices); @@ -160,10 +147,10 @@ private synchronized boolean alreadyExists(List entriesTermI for(int j = 0; j < i; j++) { map.remove(entriesTermIndices.get(j).startIndex); } - return true; + return false; } } - return false; + return true; } synchronized void removeExisting(List entriesTermIndices) { @@ -183,8 +170,8 @@ public static RaftServerProxy newRaftServer( RaftPeerId id, RaftGroup group, RaftStorage.StartupOption option, StateMachine.Registry stateMachineRegistry, ThreadGroup threadGroup, RaftProperties properties, Parameters parameters) throws IOException { RaftServer.LOG.debug("newRaftServer: {}, {}", id, group); - Objects.requireNonNull(id, "id == null"); if (group != null && !group.getPeers().isEmpty()) { + Objects.requireNonNull(id, () -> "RaftPeerId " + id + " is not in RaftGroup " + group); Objects.requireNonNull(group.getPeer(id), () -> "RaftPeerId " + id + " is not in RaftGroup " + group); } final RaftServerProxy proxy = newRaftServer(id, stateMachineRegistry, threadGroup, properties, parameters); diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerProtoUtils.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerProtoUtils.java index 19d4ce6a75..494037f373 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerProtoUtils.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerProtoUtils.java @@ -24,7 +24,6 @@ import org.apache.ratis.protocol.RaftGroupMemberId; import org.apache.ratis.protocol.RaftPeer; import org.apache.ratis.protocol.RaftPeerId; -import org.apache.ratis.rpc.CallId; import org.apache.ratis.server.protocol.TermIndex; import org.apache.ratis.server.raftlog.RaftLog; import org.apache.ratis.util.Preconditions; @@ -43,6 +42,17 @@ private static RaftRpcReplyProto.Builder toRaftRpcReplyProtoBuilder( requestorId.toByteString(), replyId.getPeerId().toByteString(), replyId.getGroupId(), null, success); } + static RequestVoteReplyProto toRequestVoteReplyProto( + RaftPeerId requestorId, RaftGroupMemberId replyId, boolean success, long term, boolean shouldShutdown, + TermIndex lastEntry) { + return RequestVoteReplyProto.newBuilder() + .setServerReply(toRaftRpcReplyProtoBuilder(requestorId, replyId, success)) + .setTerm(term) + .setShouldShutdown(shouldShutdown) + .setLastEntry((lastEntry != null? lastEntry : TermIndex.INITIAL_VALUE).toProto()) + .build(); + } + static RequestVoteReplyProto toRequestVoteReplyProto( RaftPeerId requestorId, RaftGroupMemberId replyId, boolean success, long term, boolean shouldShutdown, TermIndex lastEntry, long callId) { @@ -57,8 +67,7 @@ static RequestVoteReplyProto toRequestVoteReplyProto( static RequestVoteRequestProto toRequestVoteRequestProto( RaftGroupMemberId requestorId, RaftPeerId replyId, long term, TermIndex lastEntry, boolean preVote) { final RequestVoteRequestProto.Builder b = RequestVoteRequestProto.newBuilder() - .setServerRequest(ClientProtoUtils.toRaftRpcRequestProtoBuilder(requestorId, replyId) - .setCallId(CallId.getAndIncrement())) + .setServerRequest(ClientProtoUtils.toRaftRpcRequestProtoBuilder(requestorId, replyId)) .setCandidateTerm(term) .setPreVote(preVote); Optional.ofNullable(lastEntry).map(TermIndex::toProto).ifPresent(b::setCandidateLastEntry); @@ -180,12 +189,8 @@ static ServerRpcProto toServerRpcProto(RaftPeer peer, long delay) { // if no peer information return empty return ServerRpcProto.getDefaultInstance(); } - return toServerRpcProto(peer.getRaftPeerProto(), delay); - } - - static ServerRpcProto toServerRpcProto(RaftPeerProto peer, long delay) { return ServerRpcProto.newBuilder() - .setId(peer) + .setId(peer.getRaftPeerProto()) .setLastRpcElapsedTimeMs(delay) .build(); } diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java index bcf11baf7a..c49e9554f0 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/ServerState.java @@ -24,7 +24,6 @@ import org.apache.ratis.server.RaftConfiguration; import org.apache.ratis.server.RaftServerConfigKeys; import org.apache.ratis.server.impl.LeaderElection.Phase; -import org.apache.ratis.server.protocol.RaftServerProtocol.Op; import org.apache.ratis.server.protocol.TermIndex; import org.apache.ratis.server.raftlog.LogProtoUtils; import org.apache.ratis.server.raftlog.RaftLog; @@ -124,8 +123,13 @@ class ServerState { // On start the leader is null, start the clock now this.lastNoLeaderTime = new AtomicReference<>(Timestamp.currentTime()); this.noLeaderTimeout = RaftServerConfigKeys.Notification.noLeaderTimeout(prop); - this.log = JavaUtils.memoize(() -> initRaftLog(() -> getSnapshotIndexFromStateMachine(stateMachine), prop)); - this.readRequests = new ReadRequests(stateMachine.getLastAppliedTermIndex().getIndex(), prop); + + final LongSupplier getSnapshotIndexFromStateMachine = () -> Optional.ofNullable(stateMachine.getLatestSnapshot()) + .map(SnapshotInfo::getIndex) + .filter(i -> i >= 0) + .orElse(RaftLog.INVALID_LOG_INDEX); + this.log = JavaUtils.memoize(() -> initRaftLog(getSnapshotIndexFromStateMachine, prop)); + this.readRequests = new ReadRequests(prop, stateMachine); this.stateMachineUpdater = JavaUtils.memoize(() -> new StateMachineUpdater( stateMachine, server, this, getLog().getSnapshotIndex(), prop, this.readRequests.getAppliedIndexConsumer())); @@ -150,16 +154,6 @@ RaftGroupMemberId getMemberId() { return memberId; } - private long getSnapshotIndexFromStateMachine(StateMachine stateMachine) { - final SnapshotInfo latest = stateMachine.getLatestSnapshot(); - LOG.info("{}: getLatestSnapshot({}) returns {}", getMemberId(), stateMachine, latest); - if (latest == null) { - return RaftLog.INVALID_LOG_INDEX; - } - final long index = latest.getIndex(); - return index >= 0 ? index : RaftLog.INVALID_LOG_INDEX; - } - void writeRaftConfiguration(LogEntryProto conf) { getStorage().writeRaftConfiguration(conf); } @@ -258,7 +252,7 @@ RaftPeerId getVotedFor() { */ void grantVote(RaftPeerId candidateId) { votedFor = candidateId; - setLeader(null, Op.REQUEST_VOTE); + setLeader(null, "grantVote"); } void setLeader(RaftPeerId newLeaderId, Object op) { @@ -271,7 +265,7 @@ void setLeader(RaftPeerId newLeaderId, Object op) { suffix = ""; } else { final Timestamp previous = lastNoLeaderTime.getAndSet(null); - suffix = ", leader elected after " + (previous != null ? previous.elapsedTimeMs() : 0) + "ms"; + suffix = ", leader elected after " + previous.elapsedTimeMs() + "ms"; server.setFirstElection(op); server.getStateMachine().event().notifyLeaderChanged(getMemberId(), newLeaderId); } @@ -327,7 +321,7 @@ TermIndex getLastEntry() { void appendLog(TransactionContext operation) throws StateMachineException { getLog().append(currentTerm.get(), operation); - Objects.requireNonNull(operation.getLogEntry()); + Objects.requireNonNull(operation.getLogEntryUnsafe(), "transaction-logEntry"); } /** @return true iff the given peer id is recognized as the leader. */ @@ -376,12 +370,10 @@ boolean isConfCommitted() { return getLog().getLastCommittedIndex() >= getRaftConf().getLogEntryIndex(); } - private boolean setRaftConf(LogEntryProto entry) { + void setRaftConf(LogEntryProto entry) { if (entry.hasConfigurationEntry()) { setRaftConf(LogProtoUtils.toRaftConfiguration(entry)); - return true; } - return false; } void setRaftConf(RaftConfiguration conf) { @@ -399,19 +391,10 @@ void truncate(long logIndex) { configurationManager.removeConfigurations(logIndex); } - void updateConfiguration(List entries) throws IOException { - if (entries == null || entries.isEmpty()) { - return; - } - configurationManager.removeConfigurations(entries.get(0).getIndex()); - - boolean changed = false; - for(LogEntryProto entry : entries) { - changed |= setRaftConf(entry); - } - - if (changed && server.getRole().getCurrentRole() == RaftPeerRole.LISTENER) { - server.changeToFollowerAndPersistMetadata(getCurrentTerm(), true, "setRaftConf").join(); + void updateConfiguration(List entries) { + if (entries != null && !entries.isEmpty()) { + configurationManager.removeConfigurations(entries.get(0).getIndex()); + entries.forEach(this::setRaftConf); } } @@ -443,7 +426,7 @@ void close() { if (e instanceof InterruptedException) { Thread.currentThread().interrupt(); } - LOG.warn("{}: Failed to join {}", getMemberId(), getStateMachineUpdater(), e); + LOG.warn(getMemberId() + ": Failed to join " + getStateMachineUpdater(), e); } try { @@ -451,7 +434,7 @@ void close() { getLog().close(); } } catch (Throwable e) { - LOG.warn("{}: Failed to close raft log {}", getMemberId(), getLog(), e); + LOG.warn(getMemberId() + ": Failed to close raft log " + getLog(), e); } try { @@ -459,7 +442,7 @@ void close() { getStorage().close(); } } catch (Throwable e) { - LOG.warn("{}: Failed to close raft storage {}", getMemberId(), getStorage(), e); + LOG.warn(getMemberId() + ": Failed to close raft storage " + getStorage(), e); } } diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/SnapshotInstallationHandler.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/SnapshotInstallationHandler.java index 46b6aaf87f..4f1ac4177f 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/SnapshotInstallationHandler.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/SnapshotInstallationHandler.java @@ -29,10 +29,10 @@ import org.apache.ratis.proto.RaftProtos.ServerRpcProto; import org.apache.ratis.protocol.RaftGroupId; import org.apache.ratis.protocol.RaftGroupMemberId; +import org.apache.ratis.protocol.RaftPeer; import org.apache.ratis.protocol.RaftPeerId; import org.apache.ratis.server.RaftServerConfigKeys; -import org.apache.ratis.server.impl.FollowerState.UpdateType; -import org.apache.ratis.server.protocol.RaftServerProtocol.Op; +import org.apache.ratis.server.protocol.RaftServerProtocol; import org.apache.ratis.server.protocol.TermIndex; import org.apache.ratis.server.raftlog.LogProtoUtils; import org.apache.ratis.server.util.ServerStringUtils; @@ -46,7 +46,6 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.Collections; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; @@ -145,8 +144,8 @@ private InstallSnapshotReplyProto installSnapshotImpl(InstallSnapshotRequestProt final LogEntryProto proto = request.getLastRaftConfigurationLogEntryProto(); state.truncate(proto.getIndex()); if (!state.getRaftConf().equals(LogProtoUtils.toRaftConfiguration(proto))) { - LOG.info("{}: set new configuration {} from snapshot", getMemberId(), ProtoUtils.shortDebugString(proto)); - state.updateConfiguration(Collections.singletonList(proto)); + LOG.info("{}: set new configuration {} from snapshot", getMemberId(), proto); + state.setRaftConf(proto); state.writeRaftConfiguration(proto); server.getStateMachine().event().notifyConfigurationChanged( proto.getTerm(), proto.getIndex(), proto.getConfigurationEntry()); @@ -173,16 +172,16 @@ private CompletableFuture checkAndInstallSnapshot(Ins final long lastIncludedIndex = lastIncluded.getIndex(); final CompletableFuture future; synchronized (server) { - final boolean recognized = state.recognizeLeader(Op.INSTALL_SNAPSHOT, leaderId, leaderTerm); + final boolean recognized = state.recognizeLeader(RaftServerProtocol.Op.INSTALL_SNAPSHOT, leaderId, leaderTerm); currentTerm = state.getCurrentTerm(); if (!recognized) { return CompletableFuture.completedFuture(toInstallSnapshotReplyProto(leaderId, getMemberId(), currentTerm, snapshotChunkRequest.getRequestIndex(), InstallSnapshotResult.NOT_LEADER)); } - future = server.changeToFollowerAndPersistMetadata(leaderTerm, true, Op.INSTALL_SNAPSHOT); - state.setLeader(leaderId, Op.INSTALL_SNAPSHOT); + future = server.changeToFollowerAndPersistMetadata(leaderTerm, true, "installSnapshot"); + state.setLeader(leaderId, "installSnapshot"); - server.updateLastRpcTime(UpdateType.INSTALL_SNAPSHOT_START); + server.updateLastRpcTime(FollowerState.UpdateType.INSTALL_SNAPSHOT_START); long callId = chunk0CallId.get(); // 1. leaderTerm < currentTerm will never come here // 2. leaderTerm == currentTerm && callId == request.getCallId() @@ -229,7 +228,7 @@ private CompletableFuture checkAndInstallSnapshot(Ins chunk0CallId.set(-1); } } finally { - server.updateLastRpcTime(UpdateType.INSTALL_SNAPSHOT_COMPLETE); + server.updateLastRpcTime(FollowerState.UpdateType.INSTALL_SNAPSHOT_COMPLETE); } } if (snapshotChunkRequest.getDone()) { @@ -249,15 +248,15 @@ private CompletableFuture notifyStateMachineToInstall final long firstAvailableLogIndex = firstAvailableLogTermIndex.getIndex(); final CompletableFuture future; synchronized (server) { - final boolean recognized = state.recognizeLeader(UpdateType.INSTALL_SNAPSHOT_NOTIFICATION, leaderId, leaderTerm); + final boolean recognized = state.recognizeLeader("notifyInstallSnapshot", leaderId, leaderTerm); currentTerm = state.getCurrentTerm(); if (!recognized) { return CompletableFuture.completedFuture(toInstallSnapshotReplyProto(leaderId, getMemberId(), currentTerm, InstallSnapshotResult.NOT_LEADER)); } - future = server.changeToFollowerAndPersistMetadata(leaderTerm, true, UpdateType.INSTALL_SNAPSHOT_NOTIFICATION); - state.setLeader(leaderId, UpdateType.INSTALL_SNAPSHOT_NOTIFICATION); - server.updateLastRpcTime(UpdateType.INSTALL_SNAPSHOT_NOTIFICATION); + future = server.changeToFollowerAndPersistMetadata(leaderTerm, true, "installSnapshot"); + state.setLeader(leaderId, "installSnapshot"); + server.updateLastRpcTime(FollowerState.UpdateType.INSTALL_SNAPSHOT_NOTIFICATION); if (inProgressInstallSnapshotIndex.compareAndSet(INVALID_LOG_INDEX, firstAvailableLogIndex)) { LOG.info("{}: Received notification to install snapshot at index {}", getMemberId(), firstAvailableLogIndex); @@ -292,7 +291,7 @@ private CompletableFuture notifyStateMachineToInstall // For the cases where RaftConf is empty on newly started peer with empty peer list, // we retrieve leader info from installSnapShotRequestProto. final RoleInfoProto proto = leaderProto == null || server.getRaftConf().getPeer(state.getLeaderId()) != null? - server.getRoleInfoProto(): getRoleInfoProto(leaderProto); + server.getRoleInfoProto(): getRoleInfoProto(ProtoUtils.toRaftPeer(leaderProto)); // This is the first installSnapshot notify request for this term and // index. Notify the state machine to install the snapshot. LOG.info("{}: notifyInstallSnapshot: nextIndex is {} but the leader's first available index is {}.", @@ -386,7 +385,7 @@ private CompletableFuture notifyStateMachineToInstall } } - private RoleInfoProto getRoleInfoProto(RaftPeerProto leader) { + private RoleInfoProto getRoleInfoProto(RaftPeer leader) { final RoleInfo role = server.getRole(); final Optional fs = role.getFollowerState(); final ServerRpcProto leaderInfo = toServerRpcProto(leader, diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java index 041693195f..9c5290efe4 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/StateMachineUpdater.java @@ -45,7 +45,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import java.util.function.LongConsumer; import java.util.stream.LongStream; /** @@ -91,12 +90,12 @@ enum State { private final MemoizedSupplier stateMachineMetrics; - private final LongConsumer appliedIndexConsumer; + private final Consumer appliedIndexConsumer; private volatile boolean isRemoving; StateMachineUpdater(StateMachine stateMachine, RaftServerImpl server, - ServerState serverState, long lastAppliedIndex, RaftProperties properties, LongConsumer appliedIndexConsumer) { + ServerState serverState, long lastAppliedIndex, RaftProperties properties, Consumer appliedIndexConsumer) { this.name = ServerStringUtils.generateUnifiedName(serverState.getMemberId(), getClass()); this.appliedIndexConsumer = appliedIndexConsumer; this.infoIndexChange = s -> LOG.info("{}: {}", name, s); @@ -116,7 +115,8 @@ enum State { final int numSnapshotFilesRetained = RaftServerConfigKeys.Snapshot.retentionFileNum(properties); this.snapshotRetentionPolicy = new SnapshotRetentionPolicy() { @Override - public int getNumSnapshotsRetained() { + @SuppressWarnings({"deprecation", "try"}) +public int getNumSnapshotsRetained() { return numSnapshotFilesRetained; } }; @@ -244,10 +244,17 @@ private CompletableFuture applyLog(CompletableFuture applyLogFutures final long committed = raftLog.getLastCommittedIndex(); for(long applied; (applied = getLastAppliedIndex()) < committed && state == State.RUNNING && !shouldStop(); ) { final long nextIndex = applied + 1; - final LogEntryProto next = raftLog.get(nextIndex); - if (next != null) { + final ReferenceCountedObject next = raftLog.retainLog(nextIndex); + if (next == null) { + LOG.debug("{}: logEntry {} is null. There may be snapshot to load. state:{}", + this, nextIndex, state); + break; + } + + try { + final LogEntryProto entry = next.get(); if (LOG.isTraceEnabled()) { - LOG.trace("{}: applying nextIndex={}, nextLog={}", this, nextIndex, LogProtoUtils.toLogEntryString(next)); + LOG.trace("{}: applying nextIndex={}, nextLog={}", this, nextIndex, LogProtoUtils.toLogEntryString(entry)); } else { LOG.debug("{}: applying nextIndex={}", this, nextIndex); } @@ -258,7 +265,7 @@ private CompletableFuture applyLog(CompletableFuture applyLogFutures if (f != null) { CompletableFuture exceptionHandledFuture = f.exceptionally(ex -> { LOG.error("Exception while {}: applying txn index={}, nextLog={}", this, nextIndex, - LogProtoUtils.toLogEntryString(next), ex); + LogProtoUtils.toLogEntryString(entry), ex); return null; }); applyLogFutures = applyLogFutures.thenCombine(exceptionHandledFuture, (v, message) -> null); @@ -266,10 +273,8 @@ private CompletableFuture applyLog(CompletableFuture applyLogFutures } else { notifyAppliedIndex(incremented); } - } else { - LOG.debug("{}: logEntry {} is null. There may be snapshot to load. state:{}", - this, nextIndex, state); - break; + } finally { + next.release(); } } return applyLogFutures; @@ -283,6 +288,7 @@ private void checkAndTakeSnapshot(CompletableFuture futures) } } + @SuppressWarnings("try") private void takeSnapshot(CompletableFuture applyLogFutures) throws ExecutionException, InterruptedException { final long i; applyLogFutures.get(); diff --git a/ratis-server/src/main/java/org/apache/ratis/server/leader/LogAppenderBase.java b/ratis-server/src/main/java/org/apache/ratis/server/leader/LogAppenderBase.java index f65ac1863c..874d390e1d 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/leader/LogAppenderBase.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/leader/LogAppenderBase.java @@ -33,27 +33,109 @@ import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.LifeCycle; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.SizeInBytes; import org.apache.ratis.util.TimeDuration; +import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.LongUnaryOperator; /** * An abstract implementation of {@link LogAppender}. */ +@SuppressWarnings({"deprecation", "try"}) public abstract class LogAppenderBase implements LogAppender { + /** For buffering log entries to create an {@link EntryList}. */ + private static class EntryBuffer { + /** A queue for limiting the byte size, number of elements and poll time. */ + private final DataQueue queue; + /** A map for releasing {@link ReferenceCountedObject}s. */ + private final Map> references = new HashMap<>(); + + EntryBuffer(Object name, RaftProperties properties) { + final SizeInBytes bufferByteLimit = RaftServerConfigKeys.Log.Appender.bufferByteLimit(properties); + final int bufferElementLimit = RaftServerConfigKeys.Log.Appender.bufferElementLimit(properties); + this.queue = new DataQueue<>(name, bufferByteLimit, bufferElementLimit, EntryWithData::getSerializedSize); + } + + boolean putNew(long index, ReferenceCountedObject retained) { + if (!queue.offer(retained.get())) { + retained.release(); + return false; + } + final ReferenceCountedObject previous = references.put(index, retained); + Preconditions.assertNull(previous, () -> "previous with index " + index); + return true; + } + + void releaseAllAndClear() { + for (ReferenceCountedObject ref : references.values()) { + ref.release(); + } + references.clear(); + queue.clear(); + } + + EntryList pollList(long heartbeatWaitTimeMs) throws RaftLogIOException { + final List protos; + try { + protos = queue.pollList(heartbeatWaitTimeMs, EntryWithData::getEntry, null); + } catch (Exception e) { + releaseAllAndClear(); + throw e; + } finally { + for (EntryWithData entry : queue) { + // Remove and release remaining entries. + final ReferenceCountedObject removed = references.remove(entry.getIndex()); + Objects.requireNonNull(removed, "removed == null"); + removed.release(); + } + queue.clear(); + } + return new EntryList(protos, references); + } + } + + /** Storing log entries and their references. */ + private static class EntryList { + private final List protos; + private final Collection> references; + + EntryList(List protos, Map> references) { + Preconditions.assertSame(references.size(), protos.size(), "#entries"); + this.protos = Collections.unmodifiableList(protos); + this.references = Collections.unmodifiableCollection(references.values()); + } + + List getProtos() { + return protos; + } + + void retain() { + for (ReferenceCountedObject ref : references) { + ref.retain(); + } + } + + void release() { + for (ReferenceCountedObject ref : references) { + ref.release(); + } + } + } + private final String name; private final RaftServer.Division server; private final LeaderState leaderState; private final FollowerInfo follower; - private final DataQueue buffer; private final int snapshotChunkMaxSize; private final LogAppenderDaemon daemon; @@ -71,9 +153,6 @@ protected LogAppenderBase(RaftServer.Division server, LeaderState leaderState, F final RaftProperties properties = server.getRaftServer().getProperties(); this.snapshotChunkMaxSize = RaftServerConfigKeys.Log.Appender.snapshotChunkSizeMax(properties).getSizeInt(); - final SizeInBytes bufferByteLimit = RaftServerConfigKeys.Log.Appender.bufferByteLimit(properties); - final int bufferElementLimit = RaftServerConfigKeys.Log.Appender.bufferElementLimit(properties); - this.buffer = new DataQueue<>(this, bufferByteLimit, bufferElementLimit, EntryWithData::getSerializedSize); this.daemon = new LogAppenderDaemon(this); this.eventAwaitForSignal = new AwaitForSignal(name); @@ -124,14 +203,7 @@ public void start() { @Override public boolean isRunning() { - return daemon.isWorking() - && isLeaderAlive(); - } - - private boolean isLeaderAlive() { - return server.getInfo().isAlive() - && server.getInfo().isLeader() - && getRaftLog().isOpened(); + return daemon.isWorking() && server.getInfo().isLeader(); } @Override @@ -140,12 +212,8 @@ public CompletableFuture stopAsync() { } void restart() { - if (daemon.isClosingOrClosed()) { - LOG.warn("{}: daemon is closing or closed, skipping restart", this); - return; - } - if (!isLeaderAlive()) { - LOG.warn("{}: leader is not ready, skipping restart", this); + if (!server.getInfo().isAlive()) { + LOG.warn("Failed to restart {}: server {} is not alive", this, server.getMemberId()); return; } getLeaderState().restart(this); @@ -173,6 +241,28 @@ public boolean hasPendingDataRequests() { return false; } + @Override + public TermIndex getPrevious(long nextIndex) { + if (nextIndex == RaftLog.LEAST_VALID_LOG_INDEX) { + return null; + } + + final long previousIndex = nextIndex - 1; + final TermIndex previous = getRaftLog().getTermIndex(previousIndex); + if (previous != null) { + return previous; + } + + final SnapshotInfo snapshot = server.getStateMachine().getLatestSnapshot(); + if (snapshot != null) { + final TermIndex snapshotTermIndex = snapshot.getTermIndex(); + if (snapshotTermIndex.getIndex() == previousIndex) { + return snapshotTermIndex; + } + } + + return null; + } protected long getNextIndexForInconsistency(long requestFirstIndex, long replyNextIndex) { long next = replyNextIndex; @@ -196,59 +286,66 @@ protected LongUnaryOperator getNextIndexForError(long newNextIndex) { final long n = oldNextIndex <= 0L ? oldNextIndex : Math.min(oldNextIndex - 1, newNextIndex); if (m > n) { if (m > newNextIndex) { - LOG.info("Set nextIndex to matchIndex + 1 (= " + m + ")"); + LOG.info("{}: Set nextIndex to matchIndex + 1 (= {})", name, m); } return m; } else if (oldNextIndex <= 0L) { return oldNextIndex; // no change. } else { - LOG.info("Decrease nextIndex to " + n); + LOG.info("{}: Decrease nextIndex to {}", name, n); return n; } }; } - @Override - public AppendEntriesRequestProto newAppendEntriesRequest(long callId, boolean heartbeat) + public AppendEntriesRequestProto newAppendEntriesRequest(long callId, boolean heartbeat) { + throw new UnsupportedOperationException("Use nextAppendEntriesRequest(" + callId + ", " + heartbeat +") instead."); + } + + /** + * Create a {@link AppendEntriesRequestProto} object using the {@link FollowerInfo} of this {@link LogAppender}. + * The {@link AppendEntriesRequestProto} object may contain zero or more log entries. + * When there is zero log entries, the {@link AppendEntriesRequestProto} object is a heartbeat. + * + * @param callId The call id of the returned request. + * @param heartbeat the returned request must be a heartbeat. + * + * @return a retained reference of {@link AppendEntriesRequestProto} object. + * Since the returned reference is retained, + * the caller must call {@link ReferenceCountedObject#release()}} after use. + */ + protected ReferenceCountedObject nextAppendEntriesRequest(long callId, boolean heartbeat) throws RaftLogIOException { final long heartbeatWaitTimeMs = getHeartbeatWaitTimeMs(); final TermIndex previous = getPrevious(follower.getNextIndex()); if (heartbeatWaitTimeMs <= 0L || heartbeat) { // heartbeat - return leaderState.newAppendEntriesRequestProto(follower, Collections.emptyList(), - hasPendingDataRequests()? null : previous, callId); + AppendEntriesRequestProto heartbeatRequest = + leaderState.newAppendEntriesRequestProto(follower, Collections.emptyList(), + hasPendingDataRequests() ? null : previous, callId); + ReferenceCountedObject ref = ReferenceCountedObject.wrap(heartbeatRequest); + ref.retain(); + return ref; } - Preconditions.assertTrue(buffer.isEmpty(), () -> "buffer has " + buffer.getNumElements() + " elements."); - final long snapshotIndex = follower.getSnapshotIndex(); - final long leaderNext = getRaftLog().getNextIndex(); final long followerNext = follower.getNextIndex(); - - if (previous == null && followerNext > RaftLog.LEAST_VALID_LOG_INDEX && followerNext != snapshotIndex + 1) { - LOG.info("{}: Skipping appendEntries since the previous log entry is unavailable:" + - " follower {} nextIndex={} and snapshotIndex={} but leader startIndex={}", - this, follower.getName(), followerNext, snapshotIndex, getRaftLog().getStartIndex()); + final EntryBuffer entryBuffer = readLogEntries(followerNext, heartbeatWaitTimeMs); + if (entryBuffer == null) { return null; } - final long halfMs = heartbeatWaitTimeMs/2; - for (long next = followerNext; leaderNext > next && getHeartbeatWaitTimeMs() - halfMs > 0; ) { - if (!buffer.offer(getRaftLog().getEntryWithData(next++))) { - break; - } - } - if (buffer.isEmpty()) { - return null; - } - - final List protos = buffer.pollList(getHeartbeatWaitTimeMs(), EntryWithData::getEntry, - (entry, time, exception) -> LOG.warn("Failed to get " + entry - + " in " + time.toString(TimeUnit.MILLISECONDS, 3), exception)); - buffer.clear(); + final EntryList entryList = entryBuffer.pollList(heartbeatWaitTimeMs); + final List protos = entryList.getProtos(); assertProtos(protos, followerNext, previous, snapshotIndex); - return leaderState.newAppendEntriesRequestProto(follower, protos, previous, callId); + AppendEntriesRequestProto appendEntriesProto = + leaderState.newAppendEntriesRequestProto(follower, protos, previous, callId); + final ReferenceCountedObject ref = ReferenceCountedObject.wrap( + appendEntriesProto, entryList::retain, entryList::release); + ref.retain(); + entryList.release(); + return ref; } private void assertProtos(List protos, long nextIndex, TermIndex previous, long snapshotIndex) { @@ -270,6 +367,31 @@ private void assertProtos(List protos, long nextIndex, TermIndex } } + private EntryBuffer readLogEntries(long followerNext, long heartbeatWaitTimeMs) throws RaftLogIOException { + final RaftLog raftLog = getRaftLog(); + final long leaderNext = raftLog.getNextIndex(); + final long halfMs = heartbeatWaitTimeMs/2; + EntryBuffer entryBuffer = null; + for (long next = followerNext; leaderNext > next && getHeartbeatWaitTimeMs() - halfMs > 0; next++) { + final ReferenceCountedObject retained; + try { + retained = raftLog.retainEntryWithData(next); + if (entryBuffer == null) { + entryBuffer = new EntryBuffer(name, server.getRaftServer().getProperties()); + } + if (!entryBuffer.putNew(next, retained)) { + break; + } + } catch (Exception e) { + if (entryBuffer != null) { + entryBuffer.releaseAllAndClear(); + } + throw e; + } + } + return entryBuffer; + } + @Override public InstallSnapshotRequestProto newInstallSnapshotNotificationRequest(TermIndex firstAvailableLogTermIndex) { Preconditions.assertTrue(firstAvailableLogTermIndex.getIndex() >= 0); diff --git a/ratis-server/src/main/java/org/apache/ratis/server/leader/LogAppenderDefault.java b/ratis-server/src/main/java/org/apache/ratis/server/leader/LogAppenderDefault.java index 9d1edd4695..8c1675c7c3 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/leader/LogAppenderDefault.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/leader/LogAppenderDefault.java @@ -27,6 +27,7 @@ import org.apache.ratis.server.raftlog.RaftLogIOException; import org.apache.ratis.server.util.ServerStringUtils; import org.apache.ratis.statemachine.SnapshotInfo; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.Timestamp; import java.io.IOException; @@ -58,44 +59,36 @@ public Comparator getCallIdComparator() { /** Send an appendEntries RPC; retry indefinitely. */ private AppendEntriesReplyProto sendAppendEntriesWithRetries(AtomicLong requestFirstIndex) throws InterruptedException, InterruptedIOException, RaftLogIOException { - int retry = 0; - - AppendEntriesRequestProto request = newAppendEntriesRequest(CallId.getAndIncrement(), false); - while (isRunning()) { // keep retrying for IOException + for(int retry = 0; isRunning(); retry++) { + final ReferenceCountedObject request = nextAppendEntriesRequest( + CallId.getAndIncrement(), false); + if (request == null) { + LOG.trace("{} no entries to send now, wait ...", this); + return null; + } try { - if (request == null || request.getEntriesCount() == 0) { - request = newAppendEntriesRequest(CallId.getAndIncrement(), false); - } - - if (request == null) { - LOG.trace("{} no entries to send now, wait ...", this); - return null; - } else if (!isRunning()) { + if (!isRunning()) { LOG.info("{} is stopped. Skip appendEntries.", this); return null; } - resetHeartbeatTrigger(); - final Timestamp sendTime = Timestamp.currentTime(); - getFollower().updateLastRpcSendTime(request.getEntriesCount() == 0); - final AppendEntriesRequestProto proto = request; - final AppendEntriesReplyProto reply = getServerRpc().appendEntries(proto); + final AppendEntriesRequestProto proto = request.get(); + final AppendEntriesReplyProto reply = sendAppendEntries(proto); final long first = proto.getEntriesCount() > 0 ? proto.getEntries(0).getIndex() : RaftLog.INVALID_LOG_INDEX; requestFirstIndex.set(first); - getFollower().updateLastRpcResponseTime(); - getFollower().updateLastRespondedAppendEntriesSendTime(sendTime); - - getLeaderState().onFollowerCommitIndex(getFollower(), reply.getFollowerCommit()); return reply; } catch (InterruptedIOException | RaftLogIOException e) { throw e; } catch (IOException ioe) { // TODO should have more detailed retry policy here. - if (retry++ % 10 == 0) { // to reduce the number of messages + if (retry % 10 == 0) { // to reduce the number of messages LOG.warn("{}: Failed to appendEntries (retry={})", this, retry, ioe); } handleException(ioe); + } finally { + request.release(); } + if (isRunning()) { getServer().properties().rpcSleepTime().sleep(); } @@ -103,6 +96,18 @@ private AppendEntriesReplyProto sendAppendEntriesWithRetries(AtomicLong requestF return null; } + private AppendEntriesReplyProto sendAppendEntries(AppendEntriesRequestProto request) throws IOException { + resetHeartbeatTrigger(); + final Timestamp sendTime = Timestamp.currentTime(); + getFollower().updateLastRpcSendTime(request.getEntriesCount() == 0); + final AppendEntriesReplyProto r = getServerRpc().appendEntries(request); + getFollower().updateLastRpcResponseTime(); + getFollower().updateLastRespondedAppendEntriesSendTime(sendTime); + + getLeaderState().onFollowerCommitIndex(getFollower(), r.getFollowerCommit()); + return r; + } + private InstallSnapshotReplyProto installSnapshot(SnapshotInfo snapshot) throws InterruptedIOException { String requestId = UUID.randomUUID().toString(); InstallSnapshotReplyProto reply = null; diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/LogProtoUtils.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/LogProtoUtils.java index 3705c3bd4b..c3943f7d9f 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/LogProtoUtils.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/LogProtoUtils.java @@ -26,6 +26,7 @@ import org.apache.ratis.server.protocol.TermIndex; import org.apache.ratis.thirdparty.com.google.protobuf.AbstractMessage; import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; +import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException; import org.apache.ratis.util.Preconditions; import org.apache.ratis.util.ProtoUtils; @@ -155,8 +156,9 @@ public static LogEntryProto removeStateMachineData(LogEntryProto entry) { } private static LogEntryProto replaceStateMachineDataWithSerializedSize(LogEntryProto entry) { - return replaceStateMachineEntry(entry, + LogEntryProto replaced = replaceStateMachineEntry(entry, StateMachineEntryProto.newBuilder().setLogEntryProtoSerializedSize(entry.getSerializedSize())); + return copy(replaced); } private static LogEntryProto replaceStateMachineEntry(LogEntryProto proto, StateMachineEntryProto.Builder newEntry) { @@ -178,6 +180,13 @@ static LogEntryProto addStateMachineData(ByteString stateMachineData, LogEntryPr return replaceStateMachineEntry(entry, StateMachineEntryProto.newBuilder().setStateMachineData(stateMachineData)); } + public static boolean hasStateMachineData(LogEntryProto entry) { + return getStateMachineEntry(entry) + .map(StateMachineEntryProto::getStateMachineData) + .map(data -> !data.isEmpty()) + .orElse(false); + } + public static boolean isStateMachineDataEmpty(LogEntryProto entry) { return getStateMachineEntry(entry) .map(StateMachineEntryProto::getStateMachineData) @@ -240,4 +249,21 @@ public static RaftConfiguration toRaftConfiguration(LogEntryProto entry) { final List oldListener = ProtoUtils.toRaftPeers(proto.getOldListenersList()); return ServerImplUtils.newRaftConfiguration(conf, listener, entry.getIndex(), oldConf, oldListener); } + + public static LogEntryProto copy(LogEntryProto proto) { + if (proto == null) { + return null; + } + + if (!proto.hasStateMachineLogEntry() && !proto.hasMetadataEntry() && !proto.hasConfigurationEntry()) { + // empty entry, just return as is. + return proto; + } + + try { + return LogEntryProto.parseFrom(proto.toByteString()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException("Failed to copy log entry " + TermIndex.valueOf(proto), e); + } + } } diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/RaftLogBase.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/RaftLogBase.java index 48b410147c..024845fac4 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/RaftLogBase.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/RaftLogBase.java @@ -17,6 +17,7 @@ */ package org.apache.ratis.server.raftlog; +import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.proto.RaftProtos.LogEntryProto; @@ -31,7 +32,9 @@ import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.OpenCloseState; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.TimeDuration; +import org.apache.ratis.util.function.UncheckedAutoCloseableSupplier; import java.io.IOException; import java.util.List; @@ -193,7 +196,8 @@ private long appendImpl(long term, TransactionContext operation) throws StateMac throw new StateMachineException(memberId, new RaftLogIOException( "Log entry size " + entrySize + " exceeds the max buffer limit of " + maxBufferSize)); } - appendEntry(e, operation).whenComplete((returned, t) -> { + + appendEntry(operation.wrap(e), operation).whenComplete((returned, t) -> { if (t != null) { LOG.error(name + ": Failed to write log entry " + toLogEntryString(e), t); } else if (returned != nextIndex) { @@ -240,9 +244,21 @@ private boolean shouldAppendMetadata(long newCommitIndex) { // do not log the first conf entry return false; } - final LogEntryProto last = lastMetadataEntry.get(); - // do not log entries with a smaller commit index. - return last == null || newCommitIndex > last.getMetadataEntry().getCommitIndex(); + ReferenceCountedObject ref = null; + try { + ref = retainLog(newCommitIndex); + if (ref.get().hasMetadataEntry()) { + // do not log the metadata entry + return false; + } + } catch(RaftLogIOException e) { + LOG.error("Failed to get log entry for index " + newCommitIndex, e); + } finally { + if (ref != null) { + ref.release(); + } + } + return true; } @Override @@ -346,22 +362,32 @@ public final CompletableFuture purge(long suggestedIndex) { @Override public final CompletableFuture appendEntry(LogEntryProto entry) { - return appendEntry(entry, null); + return appendEntry(ReferenceCountedObject.wrap(entry), null); } @Override - public final CompletableFuture appendEntry(LogEntryProto entry, TransactionContext context) { + public final CompletableFuture appendEntry(ReferenceCountedObject entry, + TransactionContext context) { return runner.runSequentially(() -> appendEntryImpl(entry, context)); } - protected abstract CompletableFuture appendEntryImpl(LogEntryProto entry, TransactionContext context); + protected abstract CompletableFuture appendEntryImpl(ReferenceCountedObject entry, + TransactionContext context); @Override - public final List> append(List entries) { + public final List> append(ReferenceCountedObject> entries) { return runner.runSequentially(() -> appendImpl(entries)); } - protected abstract List> appendImpl(List entries); + protected List> appendImpl(List entries) { + throw new UnsupportedOperationException(); + } + + protected List> appendImpl(ReferenceCountedObject> entriesRef) { + try(UncheckedAutoCloseableSupplier> entries = entriesRef.retainAndReleaseOnClose()) { + return appendImpl(entries.get()); + } + } @Override public String toString() { @@ -398,8 +424,43 @@ public String getName() { return name; } - protected EntryWithData newEntryWithData(LogEntryProto logEntry, CompletableFuture future) { - return new EntryWithDataImpl(logEntry, future); + protected ReferenceCountedObject newEntryWithData(ReferenceCountedObject retained) { + return retained.delegate(new EntryWithDataImpl(retained.get(), null)); + } + + protected ReferenceCountedObject newEntryWithData(ReferenceCountedObject retained, + CompletableFuture> stateMachineDataFuture) { + final EntryWithDataImpl impl = new EntryWithDataImpl(retained.get(), stateMachineDataFuture); + return new ReferenceCountedObject() { + private CompletableFuture> future + = Objects.requireNonNull(stateMachineDataFuture, "stateMachineDataFuture == null"); + + @Override + public EntryWithData get() { + return impl; + } + + synchronized void updateFuture(Consumer> action) { + future = future.whenComplete((ref, e) -> { + if (ref != null) { + action.accept(ref); + } + }); + } + + @Override + public EntryWithData retain() { + retained.retain(); + updateFuture(ReferenceCountedObject::retain); + return impl; + } + + @Override + public boolean release() { + updateFuture(ReferenceCountedObject::release); + return retained.release(); + } + }; } /** @@ -407,20 +468,25 @@ protected EntryWithData newEntryWithData(LogEntryProto logEntry, CompletableFutu */ class EntryWithDataImpl implements EntryWithData { private final LogEntryProto logEntry; - private final CompletableFuture future; + private final CompletableFuture> future; - EntryWithDataImpl(LogEntryProto logEntry, CompletableFuture future) { + EntryWithDataImpl(LogEntryProto logEntry, CompletableFuture> future) { this.logEntry = logEntry; this.future = future == null? null: future.thenApply(this::checkStateMachineData); } - private ByteString checkStateMachineData(ByteString data) { + private ReferenceCountedObject checkStateMachineData(ReferenceCountedObject data) { if (data == null) { - throw new IllegalStateException("State machine data is null for log entry " + logEntry); + throw new IllegalStateException("State machine data is null for log entry " + this); } return data; } + @Override + public long getIndex() { + return logEntry.getIndex(); + } + @Override public int getSerializedSize() { return LogProtoUtils.getSerializedSize(logEntry); @@ -428,14 +494,15 @@ public int getSerializedSize() { @Override public LogEntryProto getEntry(TimeDuration timeout) throws RaftLogIOException, TimeoutException { - LogEntryProto entryProto; if (future == null) { return logEntry; } + final LogEntryProto entryProto; + ReferenceCountedObject data; try { - entryProto = future.thenApply(data -> LogProtoUtils.addStateMachineData(data, logEntry)) - .get(timeout.getDuration(), timeout.getUnit()); + data = future.get(timeout.getDuration(), timeout.getUnit()); + entryProto = LogProtoUtils.addStateMachineData(data.get(), logEntry); } catch (TimeoutException t) { if (timeout.compareTo(stateMachineDataReadTimeout) > 0) { getRaftLogMetrics().onStateMachineDataReadTimeout(); @@ -445,14 +512,14 @@ public LogEntryProto getEntry(TimeDuration timeout) throws RaftLogIOException, T if (e instanceof InterruptedException) { Thread.currentThread().interrupt(); } - final String err = getName() + ": Failed readStateMachineData for " + toLogEntryString(logEntry); + final String err = getName() + ": Failed readStateMachineData for " + this; LOG.error(err, e); throw new RaftLogIOException(err, JavaUtils.unwrapCompletionException(e)); } // by this time we have already read the state machine data, // so the log entry data should be set now if (LogProtoUtils.isStateMachineDataEmpty(entryProto)) { - final String err = getName() + ": State machine data not set for " + toLogEntryString(logEntry); + final String err = getName() + ": State machine data not set for " + this; LOG.error(err); throw new RaftLogIOException(err); } diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLog.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLog.java index ebb1e27d77..3579bb1f37 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLog.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLog.java @@ -22,12 +22,15 @@ import org.apache.ratis.server.metrics.RaftLogMetricsBase; import org.apache.ratis.server.protocol.TermIndex; import org.apache.ratis.proto.RaftProtos.LogEntryProto; +import org.apache.ratis.server.raftlog.LogProtoUtils; import org.apache.ratis.server.raftlog.RaftLogBase; import org.apache.ratis.server.raftlog.LogEntryHeader; +import org.apache.ratis.server.raftlog.RaftLogIOException; import org.apache.ratis.server.storage.RaftStorageMetadata; import org.apache.ratis.statemachine.TransactionContext; import org.apache.ratis.util.AutoCloseableLock; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import java.io.IOException; import java.util.ArrayList; @@ -40,6 +43,7 @@ /** * A simple RaftLog implementation in memory. Used only for testing. */ +@SuppressWarnings({"deprecation", "try"}) public class MemoryRaftLog extends RaftLogBase { static class EntryList { private final List entries = new ArrayList<>(); @@ -62,18 +66,22 @@ int size() { void truncate(int index) { if (entries.size() > index) { - entries.subList(index, entries.size()).clear(); + clear(index, entries.size()); } } void purge(int index) { if (entries.size() > index) { - entries.subList(0, index).clear(); + clear(0, index); } } - void add(LogEntryProto entry) { - entries.add(entry); + void clear(int from, int to) { + entries.subList(from, to).clear(); + } + + void add(LogEntryProto entryRef) { + entries.add(entryRef); } } @@ -100,16 +108,35 @@ public RaftLogMetricsBase getRaftLogMetrics() { } @Override - public LogEntryProto get(long index) { + public LogEntryProto get(long index) throws RaftLogIOException { + final ReferenceCountedObject ref = retainLog(index); + try { + return LogProtoUtils.copy(ref.get()); + } finally { + ref.release(); + } + } + + @Override + public ReferenceCountedObject retainLog(long index) { checkLogState(); - try(AutoCloseableLock readLock = readLock()) { - return entries.get(Math.toIntExact(index)); + try (AutoCloseableLock readLock = readLock()) { + final LogEntryProto entry = entries.get(Math.toIntExact(index)); + final ReferenceCountedObject ref = ReferenceCountedObject.wrap(entry); + ref.retain(); + return ref; } } @Override - public EntryWithData getEntryWithData(long index) { - return newEntryWithData(get(index), null); + public EntryWithData getEntryWithData(long index) throws RaftLogIOException { + throw new UnsupportedOperationException("Use retainEntryWithData(" + index + ") instead."); + } + + @Override + public ReferenceCountedObject retainEntryWithData(long index) { + final ReferenceCountedObject ref = retainLog(index); + return newEntryWithData(ref); } @Override @@ -166,11 +193,15 @@ public TermIndex getLastEntryTermIndex() { } @Override - protected CompletableFuture appendEntryImpl(LogEntryProto entry, TransactionContext context) { + protected CompletableFuture appendEntryImpl(ReferenceCountedObject entryRef, + TransactionContext context) { checkLogState(); - try(AutoCloseableLock writeLock = writeLock()) { + LogEntryProto entry = entryRef.retain(); + try (AutoCloseableLock writeLock = writeLock()) { validateLogEntry(entry); entries.add(entry); + } finally { + entryRef.release(); } return CompletableFuture.completedFuture(entry.getIndex()); } @@ -181,12 +212,14 @@ public long getStartIndex() { } @Override - public List> appendImpl(List logEntryProtos) { + public List> appendImpl(ReferenceCountedObject> entriesRef) { checkLogState(); + final List logEntryProtos = entriesRef.retain(); if (logEntryProtos == null || logEntryProtos.isEmpty()) { + entriesRef.release(); return Collections.emptyList(); } - try(AutoCloseableLock writeLock = writeLock()) { + try (AutoCloseableLock writeLock = writeLock()) { // Before truncating the entries, we first need to check if some // entries are duplicated. If the leader sends entry 6, entry 7, then // entry 6 again, without this check the follower may truncate entry 7 @@ -214,10 +247,12 @@ public List> appendImpl(List logEntryProt } for (int i = index; i < logEntryProtos.size(); i++) { LogEntryProto logEntryProto = logEntryProtos.get(i); - this.entries.add(logEntryProto); + entries.add(LogProtoUtils.copy(logEntryProto)); futures.add(CompletableFuture.completedFuture(logEntryProto.getIndex())); } return futures; + } finally { + entriesRef.release(); } } diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java index bb2bde7edb..6b1696b960 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/LogSegment.java @@ -26,10 +26,12 @@ import org.apache.ratis.server.raftlog.RaftLogIOException; import org.apache.ratis.server.storage.RaftStorage; import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting; -import org.apache.ratis.thirdparty.com.google.common.cache.CacheLoader; import org.apache.ratis.thirdparty.com.google.protobuf.CodedOutputStream; +import org.apache.ratis.util.CodeInjectionForTesting; import org.apache.ratis.util.FileUtils; +import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.SizeInBytes; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,16 +40,16 @@ import java.io.IOException; import java.nio.file.Path; import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; import java.util.Map; import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentNavigableMap; import java.util.concurrent.ConcurrentSkipListMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import org.apache.ratis.util.CodeInjectionForTesting; /** @@ -60,6 +62,8 @@ public final class LogSegment { static final Logger LOG = LoggerFactory.getLogger(LogSegment.class); + public static final String APPEND_RECORD = LogSegment.class.getSimpleName() + ".append"; + enum Op { LOAD_SEGMENT_FILE, REMOVE_CACHE, @@ -69,17 +73,20 @@ enum Op { } static long getEntrySize(LogEntryProto entry, Op op) { - LogEntryProto e = entry; - if (op == Op.CHECK_SEGMENT_FILE_FULL) { - e = LogProtoUtils.removeStateMachineData(entry); - } else if (op == Op.LOAD_SEGMENT_FILE || op == Op.WRITE_CACHE_WITH_STATE_MACHINE_CACHE) { - Preconditions.assertTrue(entry == LogProtoUtils.removeStateMachineData(entry), - () -> "Unexpected LogEntryProto with StateMachine data: op=" + op + ", entry=" + entry); - } else { - Preconditions.assertTrue(op == Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE || op == Op.REMOVE_CACHE, - () -> "Unexpected op " + op + ", entry=" + entry); + switch (op) { + case CHECK_SEGMENT_FILE_FULL: + case LOAD_SEGMENT_FILE: + case WRITE_CACHE_WITH_STATE_MACHINE_CACHE: + Preconditions.assertTrue(!LogProtoUtils.hasStateMachineData(entry), + () -> "Unexpected LogEntryProto with StateMachine data: op=" + op + ", entry=" + entry); + break; + case WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE: + case REMOVE_CACHE: + break; + default: + throw new IllegalStateException("Unexpected op " + op + ", entry=" + entry); } - final int serialized = e.getSerializedSize(); + final int serialized = entry.getSerializedSize(); return serialized + CodedOutputStream.computeUInt32SizeNoTag(serialized) + 4L; } @@ -164,7 +171,8 @@ static LogSegment newLogSegment(RaftStorage storage, LogSegmentStartEnd startEnd } public static int readSegmentFile(File file, LogSegmentStartEnd startEnd, SizeInBytes maxOpSize, - CorruptionPolicy corruptionPolicy, SegmentedRaftLogMetrics raftLogMetrics, Consumer entryConsumer) + CorruptionPolicy corruptionPolicy, SegmentedRaftLogMetrics raftLogMetrics, + Consumer> entryConsumer) throws IOException { int count = 0; try(SegmentedRaftLogInputStream in = new SegmentedRaftLogInputStream(file, startEnd, maxOpSize, raftLogMetrics)) { @@ -175,7 +183,8 @@ public static int readSegmentFile(File file, LogSegmentStartEnd startEnd, SizeIn } if (entryConsumer != null) { - entryConsumer.accept(next); + // TODO: use reference count to support zero buffer copying for readSegmentFile + entryConsumer.accept(ReferenceCountedObject.wrap(next)); } count++; } @@ -202,10 +211,7 @@ static LogSegment loadSegment(RaftStorage storage, File file, LogSegmentStartEnd final CorruptionPolicy corruptionPolicy = CorruptionPolicy.get(storage, RaftStorage::getLogCorruptionPolicy); final boolean isOpen = startEnd.isOpen(); final int entryCount = readSegmentFile(file, startEnd, maxOpSize, corruptionPolicy, raftLogMetrics, entry -> { - segment.append(keepEntryInCache || isOpen, entry, Op.LOAD_SEGMENT_FILE, true); - if (logConsumer != null) { - logConsumer.accept(entry); - } + segment.append(Op.LOAD_SEGMENT_FILE, entry, keepEntryInCache || isOpen, logConsumer); }); LOG.info("Successfully read {} entries from segment file {}", entryCount, file); @@ -262,30 +268,34 @@ private void assertSegment(long expectedStart, int expectedEntryCount, boolean c * * In the future we can make the cache loader configurable if necessary. */ - class LogEntryLoader extends CacheLoader { + class LogEntryLoader { private final SegmentedRaftLogMetrics raftLogMetrics; LogEntryLoader(SegmentedRaftLogMetrics raftLogMetrics) { this.raftLogMetrics = raftLogMetrics; } - @Override - public LogEntryProto load(LogRecord key) throws IOException { + ReferenceCountedObject load(TermIndex key) throws IOException { final File file = getFile(); // note the loading should not exceed the endIndex: it is possible that // the on-disk log file should be truncated but has not been done yet. - final AtomicReference toReturn = new AtomicReference<>(); + final AtomicReference> toReturn = new AtomicReference<>(); final LogSegmentStartEnd startEnd = LogSegmentStartEnd.valueOf(startIndex, endIndex, isOpen); - readSegmentFile(file, startEnd, maxOpSize, - getLogCorruptionPolicy(), raftLogMetrics, entry -> { - final TermIndex ti = TermIndex.valueOf(entry); - putEntryCache(ti, entry, Op.LOAD_SEGMENT_FILE); - if (ti.equals(key.getTermIndex())) { - toReturn.set(entry); + readSegmentFile(file, startEnd, maxOpSize, getLogCorruptionPolicy(), raftLogMetrics, entryRef -> { + final LogEntryProto entry = entryRef.retain(); + try { + final TermIndex ti = TermIndex.valueOf(entry); + putEntryCache(ti, entryRef, Op.LOAD_SEGMENT_FILE); + if (ti.equals(key)) { + entryRef.retain(); + toReturn.set(entryRef); + } + } finally { + entryRef.release(); } }); loadingTimes.incrementAndGet(); - final LogEntryProto proto = toReturn.get(); + final ReferenceCountedObject proto = toReturn.get(); if (proto == null) { throw new RaftLogIOException("Failed to load log entry " + key); } @@ -293,13 +303,102 @@ public LogEntryProto load(LogRecord key) throws IOException { } } + private static class Item { + private final AtomicReference> ref; + private final long serializedSize; + + Item(ReferenceCountedObject obj, long serializedSize) { + this.ref = new AtomicReference<>(obj); + this.serializedSize = serializedSize; + } + + ReferenceCountedObject get() { + return ref.get(); + } + + long release() { + final ReferenceCountedObject entry = ref.getAndSet(null); + if (entry == null) { + return 0; + } + entry.release(); + return serializedSize; + } + } + + class EntryCache { + private Map map = new HashMap<>(); + private final AtomicLong size = new AtomicLong(); + + @Override + public String toString() { + return JavaUtils.getClassSimpleName(getClass()) + "-" + LogSegment.this; + } + + long size() { + return size.get(); + } + + synchronized ReferenceCountedObject get(TermIndex ti) { + if (map == null) { + return null; + } + final Item ref = map.get(ti); + return ref == null? null: ref.get(); + } + + /** After close(), the cache CANNOT be used again. */ + synchronized void close() { + if (map == null) { + return; + } + evict(); + map = null; + LOG.info("Successfully closed {}", this); + } + + /** After evict(), the cache can be used again. */ + synchronized void evict() { + if (map == null) { + return; + } + for (Iterator> i = map.entrySet().iterator(); i.hasNext(); i.remove()) { + release(i.next().getValue()); + } + } + + synchronized void put(TermIndex key, ReferenceCountedObject valueRef, Op op) { + if (map == null) { + return; + } + valueRef.retain(); + final long serializedSize = getEntrySize(valueRef.get(), op); + release(map.put(key, new Item(valueRef, serializedSize))); + size.getAndAdd(serializedSize); + } + + private void release(Item ref) { + if (ref == null) { + return; + } + final long serializedSize = ref.release(); + size.getAndAdd(-serializedSize); + } + + synchronized void remove(TermIndex key) { + if (map == null) { + return; + } + release(map.remove(key)); + } + } + File getFile() { return LogSegmentStartEnd.valueOf(startIndex, endIndex, isOpen).getFile(storage); } private volatile boolean isOpen; private long totalFileSize = SegmentedRaftLogFormat.getHeaderLength(); - private AtomicLong totalCacheSize = new AtomicLong(0); /** Segment start index, inclusive. */ private final long startIndex; /** Segment end index, inclusive. */ @@ -317,7 +416,7 @@ File getFile() { /** * the entryCache caches the content of log entries. */ - private final Map entryCache = new ConcurrentHashMap<>(); + private final EntryCache entryCache = new EntryCache(); private LogSegment(RaftStorage storage, boolean isOpen, long start, long end, SizeInBytes maxOpSize, SegmentedRaftLogMetrics raftLogMetrics) { @@ -334,11 +433,7 @@ long getStartIndex() { } long getEndIndex() { - if (!isOpen) { - return endIndex; - } - final LogRecord last = records.getLast(); - return last == null ? getStartIndex() - 1 : last.getTermIndex().getIndex(); + return endIndex; } boolean isOpen() { @@ -346,76 +441,85 @@ boolean isOpen() { } int numOfEntries() { - return Math.toIntExact(getEndIndex() - startIndex + 1); + return Math.toIntExact(endIndex - startIndex + 1); } CorruptionPolicy getLogCorruptionPolicy() { return CorruptionPolicy.get(storage, RaftStorage::getLogCorruptionPolicy); } - void appendToOpenSegment(LogEntryProto entry, Op op, boolean verifyEntryIndex) { + void appendToOpenSegment(Op op, ReferenceCountedObject entryRef) { Preconditions.assertTrue(isOpen(), "The log segment %s is not open for append", this); - append(true, entry, op, verifyEntryIndex); + append(op, entryRef, true, null); } - public static final String APPEND_RECORD = LogSegment.class.getSimpleName() + ".append"; - private void append(boolean keepEntryInCache, LogEntryProto entry, Op op, boolean verifyEntryIndex) { - Objects.requireNonNull(entry, "entry == null"); - if (verifyEntryIndex) { - verifyEntryIndex(entry.getIndex()); - } - final LogRecord record = new LogRecord(totalFileSize, entry); - if (keepEntryInCache) { - // It is important to put the entry into the cache before appending the - // record to the record list. Otherwise, a reader thread may get the - // record from the list but not the entry from the cache. - putEntryCache(record.getTermIndex(), entry, op); - CodeInjectionForTesting.execute(APPEND_RECORD, this, record.getTermIndex()); - } - records.append(record); + private void append(Op op, ReferenceCountedObject entryRef, + boolean keepEntryInCache, Consumer logConsumer) { + final LogEntryProto entry = entryRef.retain(); + try { + final LogRecord record = new LogRecord(totalFileSize, entry); + if (keepEntryInCache) { + putEntryCache(record.getTermIndex(), entryRef, op); + CodeInjectionForTesting.execute(APPEND_RECORD, this, record.getTermIndex()); + } + appendLogRecord(op, record); + totalFileSize += getEntrySize(entry, op); - totalFileSize += getEntrySize(entry, op); - endIndex = entry.getIndex(); + if (logConsumer != null) { + logConsumer.accept(entry); + } + } finally { + entryRef.release(); + } } - void verifyEntryIndex(long entryIndex) { + private void appendLogRecord(Op op, LogRecord record) { + Objects.requireNonNull(record, "record == null"); final LogRecord currentLast = records.getLast(); + + final long index = record.getTermIndex().getIndex(); if (currentLast == null) { - Preconditions.assertTrue(entryIndex == startIndex, - "gap between start index %s and first entry to append %s", - startIndex, entryIndex); + Preconditions.assertTrue(index == startIndex, + "%s: gap between start index %s and the entry to append %s", op, startIndex, index); } else { - Preconditions.assertTrue(entryIndex == currentLast.getTermIndex().getIndex() + 1, - "gap between entries %s and %s", entryIndex, currentLast.getTermIndex().getIndex()); + final long currentLastIndex = currentLast.getTermIndex().getIndex(); + Preconditions.assertTrue(index == currentLastIndex + 1, + "%s: gap between last entry %s and the entry to append %s", op, currentLastIndex, index); } + + records.append(record); + endIndex = index; } - LogEntryProto getEntryFromCache(TermIndex ti) { + ReferenceCountedObject getEntryFromCache(TermIndex ti) { return entryCache.get(ti); } /** * Acquire LogSegment's monitor so that there is no concurrent loading. */ - synchronized LogEntryProto loadCache(LogRecord record) throws RaftLogIOException { - LogEntryProto entry = entryCache.get(record.getTermIndex()); + synchronized ReferenceCountedObject loadCache(TermIndex ti) throws RaftLogIOException { + final ReferenceCountedObject entry = entryCache.get(ti); if (entry != null) { - return entry; + try { + entry.retain(); + return entry; + } catch (IllegalStateException ignored) { + // The entry could be removed from the cache and released. + // The exception can be safely ignored since it is the same as cache miss. + } } try { - return cacheLoader.load(record); + return cacheLoader.load(ti); } catch (RaftLogIOException e) { throw e; } catch (Exception e) { - throw new RaftLogIOException("Failed to loadCache for log entry " + record, e); + throw new RaftLogIOException("Failed to loadCache for log entry " + ti, e); } } LogRecord getLogRecord(long index) { - if (index >= startIndex && index <= getEndIndex()) { - return records.get(index); - } - return null; + return records.get(index); } TermIndex getLastTermIndex() { @@ -428,7 +532,7 @@ long getTotalFileSize() { } long getTotalCacheSize() { - return totalCacheSize.get(); + return entryCache.size(); } /** @@ -439,7 +543,7 @@ synchronized void truncate(long fromIndex) { for (long index = endIndex; index >= fromIndex; index--) { final LogRecord removed = records.removeLast(); Preconditions.assertSame(index, removed.getTermIndex().getIndex(), "removedIndex"); - removeEntryCache(removed.getTermIndex(), Op.REMOVE_CACHE); + removeEntryCache(removed.getTermIndex()); totalFileSize = removed.offset; } isOpen = false; @@ -474,7 +578,7 @@ private int compareTo(Long l) { synchronized void clear() { records.clear(); - evictCache(); + entryCache.close(); endIndex = startIndex - 1; } @@ -483,33 +587,23 @@ int getLoadingTimes() { } void evictCache() { - entryCache.clear(); - totalCacheSize.set(0); + entryCache.evict(); } - void putEntryCache(TermIndex key, LogEntryProto value, Op op) { - final LogEntryProto previous = entryCache.put(key, value); - long previousSize = 0; - if (previous != null) { - // Different threads maybe load LogSegment file into cache at the same time, so duplicate maybe happen - previousSize = getEntrySize(value, Op.REMOVE_CACHE); - } - totalCacheSize.getAndAdd(getEntrySize(value, op) - previousSize); + void putEntryCache(TermIndex key, ReferenceCountedObject valueRef, Op op) { + entryCache.put(key, valueRef, op); } - void removeEntryCache(TermIndex key, Op op) { - LogEntryProto value = entryCache.remove(key); - if (value != null) { - totalCacheSize.getAndAdd(-getEntrySize(value, op)); - } + void removeEntryCache(TermIndex key) { + entryCache.remove(key); } boolean hasCache() { - return isOpen || !entryCache.isEmpty(); // open segment always has cache. + return isOpen || entryCache.size() > 0; // open segment always has cache. } boolean containsIndex(long index) { - return startIndex <= index && getEndIndex() >= index; + return startIndex <= index && endIndex >= index; } boolean hasEntries() { diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java index a6ea6e3caf..895531c897 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLog.java @@ -36,11 +36,13 @@ import org.apache.ratis.proto.RaftProtos.LogEntryProto; import org.apache.ratis.statemachine.StateMachine; import org.apache.ratis.statemachine.TransactionContext; +import org.apache.ratis.statemachine.impl.TransactionContextImpl; import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; import org.apache.ratis.util.AutoCloseableLock; import org.apache.ratis.util.AwaitToRun; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.StringUtils; import java.io.File; @@ -99,6 +101,9 @@ void done() { completeFuture(); } + void discard() { + } + final void completeFuture() { final boolean completed = future.complete(getEndIndex()); Preconditions.assertTrue(completed, @@ -178,11 +183,17 @@ public long getLastAppliedIndex() { @Override public void notifyTruncatedLogEntry(TermIndex ti) { + ReferenceCountedObject ref = null; try { - final LogEntryProto entry = get(ti.getIndex()); + ref = retainLog(ti.getIndex()); + final LogEntryProto entry = ref != null ? ref.get() : null; notifyTruncatedLogEntry.accept(entry); } catch (RaftLogIOException e) { LOG.error("{}: Failed to read log {}", getName(), ti, e); + } finally { + if (ref != null) { + ref.release(); + } } } @@ -202,7 +213,6 @@ public TransactionContext getTransactionContext(LogEntryProto entry, boolean cre private final long segmentMaxSize; private final boolean stateMachineCachingEnabled; private final SegmentedRaftLogMetrics metrics; - private final boolean readLockEnabled; @SuppressWarnings({"squid:S2095"}) // Suppress closeable warning private SegmentedRaftLog(Builder b) { @@ -218,12 +228,6 @@ private SegmentedRaftLog(Builder b) { this.fileLogWorker = new SegmentedRaftLogWorker(b.memberId, stateMachine, b.submitUpdateCommitEvent, b.server, storage, b.properties, getRaftLogMetrics()); stateMachineCachingEnabled = RaftServerConfigKeys.Log.StateMachineData.cachingEnabled(b.properties); - this.readLockEnabled = RaftServerConfigKeys.Log.readLockEnabled(b.properties); - } - - @Override - public AutoCloseableLock readLock() { - return readLockEnabled ? super.readLock() : null; } @Override @@ -277,54 +281,80 @@ private void loadLogSegments(long lastIndexInSnapshot, } @Override + @SuppressWarnings("deprecation") public LogEntryProto get(long index) throws RaftLogIOException { + final ReferenceCountedObject ref = retainLog(index); + if (ref == null) { + return null; + } + try { + return LogProtoUtils.copy(ref.get()); + } finally { + ref.release(); + } + } + + @Override + public ReferenceCountedObject retainLog(long index) throws RaftLogIOException { checkLogState(); - final LogSegment segment; - final LogRecord record; - try (AutoCloseableLock readLock = readLock()) { - segment = cache.getSegment(index); - if (segment == null) { - return null; - } - record = segment.getLogRecord(index); - if (record == null) { - return null; - } - final LogEntryProto entry = segment.getEntryFromCache(record.getTermIndex()); - if (entry != null) { + final LogSegment segment = cache.getSegment(index); + if (segment == null) { + return null; + } + final LogRecord record = segment.getLogRecord(index); + if (record == null) { + return null; + } + final TermIndex ti = record.getTermIndex(); + final ReferenceCountedObject entry = segment.getEntryFromCache(ti); + if (entry != null) { + try { + entry.retain(); getRaftLogMetrics().onRaftLogCacheHit(); return entry; + } catch (IllegalStateException ignored) { + // The entry could be removed from the cache and released. + // The exception can be safely ignored since it is the same as cache miss. } } // the entry is not in the segment's cache. Load the cache without holding the lock. getRaftLogMetrics().onRaftLogCacheMiss(); cacheEviction.signal(); - return segment.loadCache(record); + return segment.loadCache(ti); } @Override + @SuppressWarnings("deprecation") public EntryWithData getEntryWithData(long index) throws RaftLogIOException { - final LogEntryProto entry = get(index); - if (entry == null) { + throw new UnsupportedOperationException("Use retainEntryWithData(" + index + ") instead."); + } + + @Override + public ReferenceCountedObject retainEntryWithData(long index) throws RaftLogIOException { + final ReferenceCountedObject entryRef = retainLog(index); + if (entryRef == null) { throw new RaftLogIOException("Log entry not found: index = " + index); } + + final LogEntryProto entry = entryRef.get(); if (!LogProtoUtils.isStateMachineDataEmpty(entry)) { - return newEntryWithData(entry, null); + return newEntryWithData(entryRef); } try { - CompletableFuture future = null; + CompletableFuture> future = null; if (stateMachine != null) { - future = stateMachine.data().read(entry, server.getTransactionContext(entry, false)).exceptionally(ex -> { + future = stateMachine.data().retainRead(entry, server.getTransactionContext(entry, false)).exceptionally(ex -> { stateMachine.event().notifyLogFailed(ex, entry); throw new CompletionException("Failed to read state machine data for log entry " + entry, ex); }); } - return newEntryWithData(entry, future); + return future != null? newEntryWithData(entryRef, future): newEntryWithData(entryRef); } catch (Exception e) { final String err = getName() + ": Failed readStateMachineData for " + toLogEntryString(entry); LOG.error(err, e); + entryRef.release(); throw new RaftLogIOException(err, JavaUtils.unwrapCompletionException(e)); } } @@ -344,25 +374,19 @@ private void checkAndEvictCache() { @Override public TermIndex getTermIndex(long index) { checkLogState(); - try(AutoCloseableLock readLock = readLock()) { - return cache.getTermIndex(index); - } + return cache.getTermIndex(index); } @Override public LogEntryHeader[] getEntries(long startIndex, long endIndex) { checkLogState(); - try(AutoCloseableLock readLock = readLock()) { - return cache.getTermIndices(startIndex, endIndex); - } + return cache.getTermIndices(startIndex, endIndex); } @Override public TermIndex getLastEntryTermIndex() { checkLogState(); - try(AutoCloseableLock readLock = readLock()) { - return cache.getLastTermIndex(); - } + return cache.getLastTermIndex(); } @Override @@ -395,11 +419,14 @@ protected CompletableFuture purgeImpl(long index) { } @Override - protected CompletableFuture appendEntryImpl(LogEntryProto entry, TransactionContext context) { + protected CompletableFuture appendEntryImpl(ReferenceCountedObject entryRef, + TransactionContext context) { checkLogState(); + LogEntryProto entry = entryRef.retain(); if (LOG.isTraceEnabled()) { LOG.trace("{}: appendEntry {}", getName(), LogProtoUtils.toLogEntryString(entry)); } + final LogEntryProto removedStateMachineData = LogProtoUtils.removeStateMachineData(entry); try(AutoCloseableLock writeLock = writeLock()) { final Timekeeper.Context appendEntryTimerContext = getRaftLogMetrics().startAppendEntryTimer(); validateLogEntry(entry); @@ -408,7 +435,7 @@ protected CompletableFuture appendEntryImpl(LogEntryProto entry, Transacti if (currentOpenSegment == null) { cache.addOpenSegment(entry.getIndex()); fileLogWorker.startLogSegment(entry.getIndex()); - } else if (isSegmentFull(currentOpenSegment, entry)) { + } else if (isSegmentFull(currentOpenSegment, removedStateMachineData)) { rollOpenSegment = true; } else { final TermIndex last = currentOpenSegment.getLastTermIndex(); @@ -430,21 +457,25 @@ protected CompletableFuture appendEntryImpl(LogEntryProto entry, Transacti // If the entry has state machine data, then the entry should be inserted // to statemachine first and then to the cache. Not following the order // will leave a spurious entry in the cache. - cache.verifyAppendEntryIndex(entry); - CompletableFuture writeFuture = - fileLogWorker.writeLogEntry(entry, context).getFuture(); + final Task write = fileLogWorker.writeLogEntry(entryRef, removedStateMachineData, context); if (stateMachineCachingEnabled) { // The stateMachineData will be cached inside the StateMachine itself. - cache.appendEntry(LogProtoUtils.removeStateMachineData(entry), - LogSegment.Op.WRITE_CACHE_WITH_STATE_MACHINE_CACHE); + if (removedStateMachineData != entry) { + cache.appendEntry(LogSegment.Op.WRITE_CACHE_WITH_STATE_MACHINE_CACHE, + ReferenceCountedObject.wrap(removedStateMachineData)); + } else { + cache.appendEntry(LogSegment.Op.WRITE_CACHE_WITH_STATE_MACHINE_CACHE, + ReferenceCountedObject.wrap(LogProtoUtils.copy(removedStateMachineData))); + } } else { - cache.appendEntry(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE); + cache.appendEntry(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, entryRef); } - writeFuture.whenComplete((clientReply, exception) -> appendEntryTimerContext.stop()); - return writeFuture; + return write.getFuture().whenComplete((clientReply, exception) -> appendEntryTimerContext.stop()); } catch (Exception e) { LOG.error("{}: Failed to append {}", getName(), toLogEntryString(entry), e); throw e; + } finally { + entryRef.release(); } } @@ -461,12 +492,14 @@ private boolean isSegmentFull(LogSegment segment, LogEntryProto entry) { } @Override - public List> appendImpl(List entries) { + protected List> appendImpl(ReferenceCountedObject> entriesRef) { checkLogState(); + final List entries = entriesRef.retain(); if (entries == null || entries.isEmpty()) { + entriesRef.release(); return Collections.emptyList(); } - try(AutoCloseableLock writeLock = writeLock()) { + try (AutoCloseableLock writeLock = writeLock()) { final TruncateIndices ti = cache.computeTruncateIndices(server::notifyTruncatedLogEntry, entries); final long truncateIndex = ti.getTruncateIndex(); final int index = ti.getArrayIndex(); @@ -481,9 +514,12 @@ public List> appendImpl(List entries) { } for (int i = index; i < entries.size(); i++) { final LogEntryProto entry = entries.get(i); - futures.add(appendEntry(entry, server.getTransactionContext(entry, true))); + TransactionContextImpl transactionContext = (TransactionContextImpl) server.getTransactionContext(entry, true); + futures.add(appendEntry(entriesRef.delegate(entry), transactionContext)); } return futures; + } finally { + entriesRef.release(); } } @@ -532,6 +568,7 @@ public CompletableFuture onSnapshotInstalled(long lastSnapshotIndex) { @Override public void close() throws IOException { try(AutoCloseableLock writeLock = writeLock()) { + LOG.info("Start closing {}", this); super.close(); cacheEviction.close(); cache.close(); @@ -539,6 +576,7 @@ public void close() throws IOException { fileLogWorker.close(); storage.close(); getRaftLogMetrics().unregister(); + LOG.info("Successfully closed {}", this); } SegmentedRaftLogCache getRaftLogCache() { diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogCache.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogCache.java index 714943c49c..8d79c58d37 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogCache.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogCache.java @@ -32,6 +32,7 @@ import org.apache.ratis.util.AutoCloseableReadWriteLock; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.SizeInBytes; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,12 +44,14 @@ import java.util.Collections; import java.util.Comparator; import java.util.Iterator; +import java.util.LinkedList; import java.util.List; import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.Consumer; +import java.util.stream.Collectors; /** * In-memory RaftLog Cache. Currently we provide a simple implementation that @@ -364,16 +367,14 @@ TruncationSegments truncate(long index, LogSegment openSegment, Runnable clearOp TruncationSegments purge(long index) { try (AutoCloseableLock writeLock = writeLock()) { int segmentIndex = binarySearch(index); - List list = new ArrayList<>(); if (segmentIndex == -1) { // nothing to purge return null; } + List list = new LinkedList<>(); if (segmentIndex == -segments.size() - 1) { - for (LogSegment ls : segments) { - list.add(SegmentFileInfo.newClosedSegmentFileInfo(ls)); - } + list.addAll(segments); segments.clear(); sizeInBytes = 0; } else if (segmentIndex >= 0) { @@ -386,13 +387,16 @@ TruncationSegments purge(long index) { for (int i = 0; i <= startIndex; i++) { LogSegment segment = segments.remove(0); // must remove the first segment to avoid gaps. sizeInBytes -= segment.getTotalFileSize(); - list.add(SegmentFileInfo.newClosedSegmentFileInfo(segment)); + list.add(segment); } } else { throw new IllegalStateException("Unexpected gap in segments: binarySearch(" + index + ") returns " + segmentIndex + ", segments=" + segments); } - return list.isEmpty() ? null : new TruncationSegments("purge(" + index + ")", null, list); + list.forEach(LogSegment::evictCache); + List toDelete = list.stream().map(SegmentFileInfo::newClosedSegmentFileInfo) + .collect(Collectors.toList()); + return list.isEmpty() ? null : new TruncationSegments("purge(" + index + ")", null, toDelete); } } @@ -614,25 +618,18 @@ long getLastIndexInClosedSegments() { TermIndex getLastTermIndex() { try (AutoCloseableLock readLock = closedSegments.readLock()) { - LogSegment tmpSegment = openSegment; - return (tmpSegment != null && tmpSegment.getLastTermIndex() != null) ? - tmpSegment.getLastTermIndex() : + return (openSegment != null && openSegment.numOfEntries() > 0) ? + openSegment.getLastTermIndex() : (closedSegments.isEmpty() ? null : closedSegments.get(closedSegments.size() - 1).getLastTermIndex()); } } - void verifyAppendEntryIndex(LogEntryProto entry) { - // SegmentedRaftLog does the segment creation/rolling work. - Objects.requireNonNull(openSegment, "openSegment == null"); - openSegment.verifyEntryIndex(entry.getIndex()); - } - - void appendEntry(LogEntryProto entry, LogSegment.Op op) { + void appendEntry(LogSegment.Op op, ReferenceCountedObject entry) { // SegmentedRaftLog does the segment creation/rolling work. Here we just // simply append the entry into the open segment. Objects.requireNonNull(openSegment, "openSegment == null"); - openSegment.appendToOpenSegment(entry, op, false); + openSegment.appendToOpenSegment(op, entry); } /** diff --git a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java index cac35f55b3..98b20bade3 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/raftlog/segmented/SegmentedRaftLogWorker.java @@ -51,9 +51,8 @@ import java.util.Objects; import java.util.Optional; import java.util.Queue; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Supplier; @@ -245,10 +244,11 @@ void start(long latestIndex, long evictIndex, File openSegmentFile) throws IOExc } void close() { + queue.close(); this.running = false; + ConcurrentUtils.shutdownAndWait(TimeDuration.ONE_MINUTE, workerThreadExecutor, + timeout -> LOG.warn("{}: shutdown timeout in {}", name, timeout)); Optional.ofNullable(flushExecutor).ifPresent(ExecutorService::shutdown); - ConcurrentUtils.shutdownAndWait(TimeDuration.ONE_SECOND.multiply(3), - workerThreadExecutor, timeout -> LOG.warn("{}: shutdown timeout in " + timeout, name)); IOUtils.cleanup(LOG, out); PlatformDependent.freeDirectBuffer(writeBuffer); LOG.info("{} close()", name); @@ -290,6 +290,7 @@ private Task addIOTask(Task task) { LOG.error("Failed to add IO task {}", task, e); Optional.ofNullable(server).ifPresent(RaftServer.Division::close); } + task.discard(); } task.startTimerOnEnqueue(raftLogMetrics.getEnqueuedTimer()); return task; @@ -343,7 +344,7 @@ private void run() { LOG.info(Thread.currentThread().getName() + " was interrupted, exiting. There are " + queue.getNumElements() + " tasks remaining in the queue."); - return; + break; } catch (Exception e) { if (!running) { LOG.info("{} got closed and hit exception", @@ -354,6 +355,8 @@ private void run() { } } } + + queue.clear(Task::discard); } private boolean shouldFlush() { @@ -448,8 +451,9 @@ void rollLogSegment(LogSegment segmentToClose) { addIOTask(new StartLogSegment(segmentToClose.getEndIndex() + 1)); } - Task writeLogEntry(LogEntryProto entry, TransactionContext context) { - return addIOTask(new WriteLog(entry, context)); + Task writeLogEntry(ReferenceCountedObject entry, + LogEntryProto removedStateMachineData, TransactionContext context) { + return addIOTask(new WriteLog(entry, removedStateMachineData, context)); } Task truncate(TruncationSegments ts, long index) { @@ -497,26 +501,32 @@ private class WriteLog extends Task { private final LogEntryProto entry; private final CompletableFuture stateMachineFuture; private final CompletableFuture combined; - - WriteLog(LogEntryProto entry, TransactionContext context) { - this.entry = LogProtoUtils.removeStateMachineData(entry); - if (this.entry == entry) { - final StateMachineLogEntryProto proto = entry.hasStateMachineLogEntry()? entry.getStateMachineLogEntry(): null; + private final AtomicReference> ref = new AtomicReference<>(); + + WriteLog(ReferenceCountedObject entryRef, LogEntryProto removedStateMachineData, + TransactionContext context) { + LogEntryProto origEntry = entryRef.get(); + this.entry = removedStateMachineData; + if (this.entry == origEntry) { + final StateMachineLogEntryProto proto = origEntry.hasStateMachineLogEntry() ? + origEntry.getStateMachineLogEntry(): null; if (stateMachine != null && proto != null && proto.getType() == StateMachineLogEntryProto.Type.DATASTREAM) { final ClientInvocationId invocationId = ClientInvocationId.valueOf(proto); final CompletableFuture removed = server.getDataStreamMap().remove(invocationId); - this.stateMachineFuture = removed == null? stateMachine.data().link(null, entry) - : removed.thenApply(stream -> stateMachine.data().link(stream, entry)); + this.stateMachineFuture = removed == null? stateMachine.data().link(null, origEntry) + : removed.thenApply(stream -> stateMachine.data().link(stream, origEntry)); } else { this.stateMachineFuture = null; } + entryRef.retain(); + this.ref.set(entryRef); } else { try { - // this.entry != entry iff the entry has state machine data - this.stateMachineFuture = stateMachine.data().write(entry, context); + // this.entry != origEntry if it has state machine data + this.stateMachineFuture = stateMachine.data().write(entryRef, context); } catch (Exception e) { - LOG.error(name + ": writeStateMachineData failed for index " + entry.getIndex() - + ", entry=" + LogProtoUtils.toLogEntryString(entry, stateMachine::toStateMachineLogEntryString), e); + LOG.error(name + ": writeStateMachineData failed for index " + origEntry.getIndex() + + ", entry=" + LogProtoUtils.toLogEntryString(origEntry, stateMachine::toStateMachineLogEntryString), e); throw e; } } @@ -528,6 +538,7 @@ private class WriteLog extends Task { void failed(IOException e) { stateMachine.event().notifyLogFailed(e, entry); super.failed(e); + discard(); } @Override @@ -543,6 +554,15 @@ CompletableFuture getFuture() { @Override void done() { writeTasks.offerOrCompleteFuture(this); + discard(); + } + + @Override + void discard() { + final ReferenceCountedObject entryRef = ref.getAndSet(null); + if (entryRef != null) { + entryRef.release(); + } } @Override diff --git a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/BaseStateMachine.java b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/BaseStateMachine.java index 3f18ee538b..b97749f262 100644 --- a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/BaseStateMachine.java +++ b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/BaseStateMachine.java @@ -18,7 +18,7 @@ package org.apache.ratis.statemachine.impl; -import org.apache.ratis.proto.RaftProtos; +import org.apache.ratis.proto.RaftProtos.LogEntryProto; import org.apache.ratis.protocol.Message; import org.apache.ratis.protocol.RaftClientRequest; import org.apache.ratis.protocol.RaftGroupId; @@ -115,10 +115,10 @@ public TransactionContext applyTransactionSerial(TransactionContext trx) throws @Override public CompletableFuture applyTransaction(TransactionContext trx) { // return the same message contained in the entry - RaftProtos.LogEntryProto entry = Objects.requireNonNull(trx.getLogEntry()); + final LogEntryProto entry = Objects.requireNonNull(trx.getLogEntryUnsafe()); updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex()); return CompletableFuture.completedFuture( - Message.valueOf(trx.getLogEntry().getStateMachineLogEntry().getLogData())); + Message.valueOf(entry.getStateMachineLogEntry().getLogData())); } @Override diff --git a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java index 8497b12f4d..58869f5edc 100644 --- a/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java +++ b/ratis-server/src/main/java/org/apache/ratis/statemachine/impl/TransactionContextImpl.java @@ -25,11 +25,14 @@ import org.apache.ratis.statemachine.StateMachine; import org.apache.ratis.statemachine.TransactionContext; import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; +import org.apache.ratis.util.MemoizedSupplier; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import java.io.IOException; import java.util.Objects; import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; /** * Implementation of {@link TransactionContext} @@ -71,6 +74,13 @@ public class TransactionContextImpl implements TransactionContext { /** Committed LogEntry. */ @SuppressWarnings({"squid:S3077"}) // Suppress volatile for generic type private volatile LogEntryProto logEntry; + /** Committed LogEntry copy. */ + @SuppressWarnings({"squid:S3077"}) // Suppress volatile for generic type + private volatile Supplier logEntryCopy; + + /** For wrapping {@link #logEntry} in order to release the underlying buffer. */ + @SuppressWarnings({"squid:S3077"}) // Suppress volatile for generic type + private volatile ReferenceCountedObject delegatedRef; private final CompletableFuture logIndexFuture = new CompletableFuture<>(); @@ -112,7 +122,7 @@ private static StateMachineLogEntryProto get(StateMachineLogEntryProto stateMach */ TransactionContextImpl(RaftPeerRole serverRole, StateMachine stateMachine, LogEntryProto logEntry) { this(serverRole, null, stateMachine, logEntry.getStateMachineLogEntry()); - this.logEntry = logEntry; + setLogEntry(logEntry); this.logIndexFuture.complete(logEntry.getIndex()); } @@ -126,7 +136,24 @@ public RaftClientRequest getClientRequest() { return clientRequest; } + public void setDelegatedRef(ReferenceCountedObject ref) { + this.delegatedRef = ref; + } + + @Override + public ReferenceCountedObject wrap(LogEntryProto entry) { + if (delegatedRef == null) { + return TransactionContext.super.wrap(entry); + } + final LogEntryProto expected = getLogEntryUnsafe(); + Objects.requireNonNull(expected, "logEntry == null"); + Preconditions.assertSame(expected.getTerm(), entry.getTerm(), "entry.term"); + Preconditions.assertSame(expected.getIndex(), entry.getIndex(), "entry.index"); + return delegatedRef.delegate(entry); + } + @Override + @SuppressWarnings("deprecation") public StateMachineLogEntryProto getStateMachineLogEntry() { return stateMachineLogEntry; } @@ -154,18 +181,32 @@ public LogEntryProto initLogEntry(long term, long index) { Objects.requireNonNull(stateMachineLogEntry, "stateMachineLogEntry == null"); logIndexFuture.complete(index); - return logEntry = LogProtoUtils.toLogEntryProto(stateMachineLogEntry, term, index); + return setLogEntry(LogProtoUtils.toLogEntryProto(stateMachineLogEntry, term, index)); } public CompletableFuture getLogIndexFuture() { return logIndexFuture; } + private LogEntryProto setLogEntry(LogEntryProto entry) { + this.logEntry = entry; + this.logEntryCopy = MemoizedSupplier.valueOf(() -> LogProtoUtils.copy(entry)); + return entry; + } + + @Override + @SuppressWarnings("deprecation") public LogEntryProto getLogEntry() { + return logEntryCopy == null ? null : logEntryCopy.get(); + } + + @Override + public LogEntryProto getLogEntryUnsafe() { return logEntry; } + @Override public TransactionContext setException(Exception ioe) { this.exception = ioe; @@ -191,6 +232,12 @@ public TransactionContext preAppendTransaction() throws IOException { @Override public TransactionContext cancelTransaction() throws IOException { + // TODO: This is not called from Raft server / log yet. When an IOException happens, we should + // call this to let the SM know that Transaction cannot be synced return stateMachine.cancelTransaction(this); } + + public static LogEntryProto getLogEntry(TransactionContext context) { + return ((TransactionContextImpl) context).logEntry; + } } diff --git a/ratis-test/pom.xml b/ratis-test/pom.xml index 577262d84c..d2db083989 100644 --- a/ratis-test/pom.xml +++ b/ratis-test/pom.xml @@ -28,6 +28,12 @@ + + com.google.errorprone + error_prone_annotations + 2.29.2 + provided + ratis-common org.apache.ratis diff --git a/ratis-test/src/test/java/org/apache/ratis/TestRaftServerSlownessDetection.java b/ratis-test/src/test/java/org/apache/ratis/TestRaftServerSlownessDetection.java index fd4aa1e068..895acccfbf 100644 --- a/ratis-test/src/test/java/org/apache/ratis/TestRaftServerSlownessDetection.java +++ b/ratis-test/src/test/java/org/apache/ratis/TestRaftServerSlownessDetection.java @@ -49,6 +49,7 @@ */ //TODO: fix StateMachine.notifySlowness(..); see RATIS-370 @Disabled +@SuppressWarnings({"deprecation", "rawtypes"}) public class TestRaftServerSlownessDetection extends BaseTest { static { Slf4jUtils.setLogLevel(RaftServer.Division.LOG, Level.DEBUG); diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamDisabled.java b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamDisabled.java index 697e746877..613bb69752 100644 --- a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamDisabled.java +++ b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamDisabled.java @@ -29,6 +29,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +@SuppressWarnings({"try"}) public class TestDataStreamDisabled extends BaseTest { @Test public void testDataStreamDisabled() throws Exception { diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java b/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java index a8cd6138ec..737325d72f 100644 --- a/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java +++ b/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcMessageMetrics.java @@ -19,7 +19,6 @@ import org.apache.ratis.BaseTest; import org.apache.ratis.grpc.MiniRaftClusterWithGrpc; -import org.apache.ratis.grpc.metrics.MessageMetrics; import org.apache.ratis.server.impl.MiniRaftCluster; import org.apache.ratis.RaftTestUtil; import org.apache.ratis.client.RaftClient; @@ -70,10 +69,6 @@ static void assertMessageCount(RaftServer.Division server) { final GrpcServicesImpl services = (GrpcServicesImpl) RaftServerTestUtil.getServerRpc(server); final RatisMetricRegistry registry = services.getMessageMetrics().getRegistry(); String counter_prefix = serverId + "_" + "ratis.grpc.RaftServerProtocolService"; - final String metricPrefix = counter_prefix + "_" + "requestVote" + "_OK"; - final long before = registry.counter(metricPrefix + "_completed_total").getCount(); - services.getMessageMetrics().rpcCompleted(metricPrefix); - final long after = registry.counter(metricPrefix + "_completed_total").getCount(); - Assertions.assertEquals(before + 1, after); + Assertions.assertTrue(registry.counter(counter_prefix + "_" + "requestVote" + "_OK_completed_total").getCount() > 0); } } diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcServerMetrics.java b/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcServerMetrics.java index 1b57730594..3e6257683c 100644 --- a/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcServerMetrics.java +++ b/ratis-test/src/test/java/org/apache/ratis/grpc/server/TestGrpcServerMetrics.java @@ -43,6 +43,7 @@ import org.junit.jupiter.api.Test; import org.mockito.Mockito; +@SuppressWarnings({"rawtypes"}) public class TestGrpcServerMetrics { private static GrpcServerMetrics grpcServerMetrics; private static RatisMetricRegistry ratisMetricRegistry; diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java index ca2709270b..af7991a416 100644 --- a/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java +++ b/ratis-test/src/test/java/org/apache/ratis/grpc/util/GrpcZeroCopyTestServer.java @@ -75,6 +75,7 @@ public synchronized String toString() { private final Count zeroCopyCount = new Count(); private final Count nonZeroCopyCount = new Count(); + private final Count releasedCount = new Count(); private final Server server; // Allow tests to disable release to validate leak detection. @@ -82,7 +83,8 @@ public synchronized String toString() { private final ZeroCopyMessageMarshaller marshaller = new ZeroCopyMessageMarshaller<>( BinaryRequest.getDefaultInstance(), zeroCopyCount::inc, - nonZeroCopyCount::inc); + nonZeroCopyCount::inc, + releasedCount::inc); GrpcZeroCopyTestServer(int port) { this(port, true); diff --git a/ratis-test/src/test/java/org/apache/ratis/netty/TestTlsConfWithNetty.java b/ratis-test/src/test/java/org/apache/ratis/netty/TestTlsConfWithNetty.java index dbdcf1ebd8..abbc56934d 100644 --- a/ratis-test/src/test/java/org/apache/ratis/netty/TestTlsConfWithNetty.java +++ b/ratis-test/src/test/java/org/apache/ratis/netty/TestTlsConfWithNetty.java @@ -57,6 +57,7 @@ /** * Testing {@link TlsConf} and the security related utility methods in {@link NettyUtils}. */ +@SuppressWarnings({"try"}) public class TestTlsConfWithNetty { private static final Logger LOG = LoggerFactory.getLogger(TestTlsConfWithNetty.class); diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLogTest.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLogTest.java index 5a41f9ed9a..17c309f0bd 100644 --- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLogTest.java +++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/memory/MemoryRaftLogTest.java @@ -37,6 +37,7 @@ import org.junit.jupiter.api.Test; import org.slf4j.event.Level; +@SuppressWarnings({"deprecation"}) public class MemoryRaftLogTest extends BaseTest { static { Slf4jUtils.setLogLevel(MemoryRaftLog.LOG, Level.DEBUG); diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestCacheEviction.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestCacheEviction.java index 6ad429249b..163c25da90 100644 --- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestCacheEviction.java +++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestCacheEviction.java @@ -51,6 +51,7 @@ import static org.apache.ratis.server.raftlog.segmented.SegmentedRaftLogTestUtils.MAX_OP_SIZE; +@SuppressWarnings({"deprecation"}) public class TestCacheEviction extends BaseTest { private static final CacheInvalidationPolicy POLICY = new CacheInvalidationPolicyDefault(); diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java index 6a75dfb36e..259f163070 100644 --- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java +++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestLogSegment.java @@ -33,6 +33,7 @@ import org.apache.ratis.thirdparty.com.google.protobuf.CodedOutputStream; import org.apache.ratis.util.FileUtils; import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.SizeInBytes; import org.apache.ratis.util.TraditionalBinaryPrefix; import org.junit.jupiter.api.AfterEach; @@ -57,6 +58,7 @@ /** * Test basic functionality of {@link LogSegment} */ +@SuppressWarnings({"try"}) public class TestLogSegment extends BaseTest { public static final LogSegmentStartEnd ZERO_START_NULL_END = LogSegmentStartEnd.valueOf(0); @@ -142,11 +144,11 @@ static void checkLogSegment(LogSegment segment, long start, long end, Assertions.assertEquals(term, ti.getTerm()); Assertions.assertEquals(offset, record.getOffset()); - LogEntryProto entry = segment.getEntryFromCache(ti); - if (entry == null) { - entry = segment.loadCache(record); + ReferenceCountedObject entryRef = segment.getEntryFromCache(ti); + if (entryRef == null) { + entryRef = segment.loadCache(ti); } - offset += getEntrySize(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE); + offset += getEntrySize(entryRef.get(), LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE); } } @@ -206,7 +208,7 @@ public void testAppendEntries() throws Exception { SimpleOperation op = new SimpleOperation("m" + i); LogEntryProto entry = LogProtoUtils.toLogEntryProto(op.getLogEntryContent(), term, i++ + start); size += getEntrySize(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE); - segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, true); + segment.appendToOpenSegment(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)); } Assertions.assertTrue(segment.getTotalFileSize() >= max); @@ -238,18 +240,18 @@ public void testAppendWithGap() throws Exception { final StateMachineLogEntryProto m = op.getLogEntryContent(); try { LogEntryProto entry = LogProtoUtils.toLogEntryProto(m, 0, 1001); - segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, true); + segment.appendToOpenSegment(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)); Assertions.fail("should fail since the entry's index needs to be 1000"); } catch (IllegalStateException e) { // the exception is expected. } LogEntryProto entry = LogProtoUtils.toLogEntryProto(m, 0, 1000); - segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, true); + segment.appendToOpenSegment(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)); try { entry = LogProtoUtils.toLogEntryProto(m, 0, 1002); - segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, true); + segment.appendToOpenSegment(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)); Assertions.fail("should fail since the entry's index needs to be 1001"); } catch (IllegalStateException e) { // the exception is expected. @@ -264,7 +266,7 @@ public void testTruncate() throws Exception { for (int i = 0; i < 100; i++) { LogEntryProto entry = LogProtoUtils.toLogEntryProto( new SimpleOperation("m" + i).getLogEntryContent(), term, i + start); - segment.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, true); + segment.appendToOpenSegment(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)); } // truncate an open segment (remove 1080~1099) diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java index 181d1fa430..1f04982823 100644 --- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java +++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLog.java @@ -87,6 +87,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.params.provider.Arguments.arguments; +@SuppressWarnings({"deprecation"}) public class TestSegmentedRaftLog extends BaseTest { static { Slf4jUtils.setLogLevel(SegmentedRaftLogWorker.LOG, Level.INFO); diff --git a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java index 3133fb36f6..d50b2d8a54 100644 --- a/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java +++ b/ratis-test/src/test/java/org/apache/ratis/server/raftlog/segmented/TestSegmentedRaftLogCache.java @@ -34,6 +34,7 @@ import org.apache.ratis.server.raftlog.segmented.SegmentedRaftLogCache.TruncationSegments; import org.apache.ratis.server.raftlog.segmented.LogSegment.LogRecord; import org.apache.ratis.proto.RaftProtos.LogEntryProto; +import org.apache.ratis.util.ReferenceCountedObject; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -63,7 +64,7 @@ private LogSegment prepareLogSegment(long start, long end, boolean isOpen) { for (long i = start; i <= end; i++) { SimpleOperation m = new SimpleOperation("m" + i); LogEntryProto entry = LogProtoUtils.toLogEntryProto(m.getLogEntryContent(), 0, i); - s.appendToOpenSegment(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, true); + s.appendToOpenSegment(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)); } if (!isOpen) { s.close(); @@ -78,8 +79,8 @@ private void checkCache(long start, long end, int segmentSize) { for (long index = start; index <= end; index++) { final LogSegment segment = cache.getSegment(index); final LogRecord record = segment.getLogRecord(index); - final LogEntryProto entry = segment.getEntryFromCache(record.getTermIndex()); - Assertions.assertEquals(index, entry.getIndex()); + final ReferenceCountedObject entry = segment.getEntryFromCache(record.getTermIndex()); + Assertions.assertEquals(index, entry.get().getIndex()); } long[] offsets = new long[]{start, start + 1, start + (end - start) / 2, @@ -154,7 +155,7 @@ public void testAppendEntry() throws Exception { final SimpleOperation m = new SimpleOperation("m"); try { LogEntryProto entry = LogProtoUtils.toLogEntryProto(m.getLogEntryContent(), 0, 0); - cache.appendEntry(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE); + cache.appendEntry(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)); Assertions.fail("the open segment is null"); } catch (IllegalStateException | NullPointerException ignored) { } @@ -163,7 +164,7 @@ public void testAppendEntry() throws Exception { cache.addSegment(openSegment); for (long index = 101; index < 200; index++) { LogEntryProto entry = LogProtoUtils.toLogEntryProto(m.getLogEntryContent(), 0, index); - cache.appendEntry(entry, LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE); + cache.appendEntry(LogSegment.Op.WRITE_CACHE_WITHOUT_STATE_MACHINE_CACHE, ReferenceCountedObject.wrap(entry)); } Assertions.assertNotNull(cache.getOpenSegment()); diff --git a/ratis-test/src/test/java/org/apache/ratis/shell/cli/sh/TestRatisShell.java b/ratis-test/src/test/java/org/apache/ratis/shell/cli/sh/TestRatisShell.java index 7c42b69076..ccf9702f7f 100644 --- a/ratis-test/src/test/java/org/apache/ratis/shell/cli/sh/TestRatisShell.java +++ b/ratis-test/src/test/java/org/apache/ratis/shell/cli/sh/TestRatisShell.java @@ -44,6 +44,7 @@ /** * Test {@link RatisShell} */ +@SuppressWarnings({"rawtypes"}) public class TestRatisShell extends BaseTest { static final PrintStream OUT = System.out; static final Class[] ARG_CLASSES = new Class[] {Context.class}; diff --git a/ratis-test/src/test/java/org/apache/ratis/statemachine/TestStateMachine.java b/ratis-test/src/test/java/org/apache/ratis/statemachine/TestStateMachine.java index 094189827d..79c204e292 100644 --- a/ratis-test/src/test/java/org/apache/ratis/statemachine/TestStateMachine.java +++ b/ratis-test/src/test/java/org/apache/ratis/statemachine/TestStateMachine.java @@ -57,6 +57,7 @@ /** * Test StateMachine related functionality */ +@SuppressWarnings({"deprecation"}) public class TestStateMachine extends BaseTest implements MiniRaftClusterWithSimulatedRpc.FactoryGet { static { Slf4jUtils.setLogLevel(RaftServer.Division.LOG, Level.DEBUG); diff --git a/ratis-tools/src/main/java/org/apache/ratis/tools/ParseRatisLog.java b/ratis-tools/src/main/java/org/apache/ratis/tools/ParseRatisLog.java index 564ce0bf07..7107977fbb 100644 --- a/ratis-tools/src/main/java/org/apache/ratis/tools/ParseRatisLog.java +++ b/ratis-tools/src/main/java/org/apache/ratis/tools/ParseRatisLog.java @@ -60,7 +60,7 @@ public void dumpSegmentFile() throws IOException { System.out.println("Processing Raft Log file: " + file.getAbsolutePath() + " size:" + file.length()); final int entryCount = LogSegment.readSegmentFile(file, pi.getStartEnd(), maxOpSize, - RaftServerConfigKeys.Log.CorruptionPolicy.EXCEPTION, null, this::processLogEntry); + RaftServerConfigKeys.Log.CorruptionPolicy.EXCEPTION, null, entry -> processLogEntry(entry.get())); System.out.println("Num Total Entries: " + entryCount); System.out.println("Num Conf Entries: " + numConfEntries); System.out.println("Num Metadata Entries: " + numMetadataEntries);