counters = types.get(t);
+ LongCounter c = counters.get(metricNamePrefix);
+ if (c == null) {
+ synchronized (counters) {
+ c = counters.computeIfAbsent(metricNamePrefix, prefix -> getRegistry().counter(prefix + t.getSuffix()));
+ }
+ }
+ c.inc();
}
/**
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java
new file mode 100644
index 0000000000..1fcc317f9d
--- /dev/null
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java
@@ -0,0 +1,79 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.metrics;
+
+import org.apache.ratis.metrics.LongCounter;
+import org.apache.ratis.metrics.MetricRegistryInfo;
+import org.apache.ratis.metrics.RatisMetricRegistry;
+import org.apache.ratis.metrics.RatisMetrics;
+import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting;
+import org.apache.ratis.thirdparty.com.google.protobuf.AbstractMessage;
+
+import java.util.function.Supplier;
+
+public class ZeroCopyMetrics extends RatisMetrics {
+ private static final String RATIS_GRPC_METRICS_APP_NAME = "ratis_grpc";
+ private static final String RATIS_GRPC_METRICS_COMP_NAME = "zero_copy";
+ private static final String RATIS_GRPC_METRICS_DESC = "Metrics for Ratis Grpc Zero copy";
+
+ private final LongCounter zeroCopyMessages = getRegistry().counter("num_zero_copy_messages");
+ private final LongCounter nonZeroCopyMessages = getRegistry().counter("num_non_zero_copy_messages");
+ private final LongCounter releasedMessages = getRegistry().counter("num_released_messages");
+
+ public ZeroCopyMetrics() {
+ super(createRegistry());
+ }
+
+ private static RatisMetricRegistry createRegistry() {
+ return create(new MetricRegistryInfo("",
+ RATIS_GRPC_METRICS_APP_NAME,
+ RATIS_GRPC_METRICS_COMP_NAME, RATIS_GRPC_METRICS_DESC));
+ }
+
+ public void addUnreleased(String name, Supplier unreleased) {
+ getRegistry().gauge(name + "_num_unreleased_messages", () -> unreleased);
+ }
+
+
+ public void onZeroCopyMessage(AbstractMessage ignored) {
+ zeroCopyMessages.inc();
+ }
+
+ public void onNonZeroCopyMessage(AbstractMessage ignored) {
+ nonZeroCopyMessages.inc();
+ }
+
+ public void onReleasedMessage(AbstractMessage ignored) {
+ releasedMessages.inc();
+ }
+
+ @VisibleForTesting
+ public long zeroCopyMessages() {
+ return zeroCopyMessages.getCount();
+ }
+
+ @VisibleForTesting
+ public long nonZeroCopyMessages() {
+ return nonZeroCopyMessages.getCount();
+ }
+
+ @VisibleForTesting
+ public long releasedMessages() {
+ return releasedMessages.getCount();
+ }
+}
\ No newline at end of file
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
index 1da3587e91..78ad8759fb 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java
@@ -19,10 +19,13 @@
import org.apache.ratis.client.impl.ClientProtoUtils;
import org.apache.ratis.grpc.GrpcUtil;
+import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
+import org.apache.ratis.grpc.util.ZeroCopyMessageMarshaller;
import org.apache.ratis.protocol.*;
import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
import org.apache.ratis.protocol.exceptions.GroupMismatchException;
import org.apache.ratis.protocol.exceptions.RaftException;
+import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto;
@@ -30,6 +33,7 @@
import org.apache.ratis.util.CollectionUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
import org.apache.ratis.util.SlidingWindow;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -39,7 +43,6 @@
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -48,15 +51,21 @@
import java.util.function.Consumer;
import java.util.function.Supplier;
+import static org.apache.ratis.grpc.GrpcUtil.addMethodWithCustomMarshaller;
+import static org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.getOrderedMethod;
+import static org.apache.ratis.proto.grpc.RaftClientProtocolServiceGrpc.getUnorderedMethod;
+
class GrpcClientProtocolService extends RaftClientProtocolServiceImplBase {
private static final Logger LOG = LoggerFactory.getLogger(GrpcClientProtocolService.class);
private static class PendingOrderedRequest implements SlidingWindow.ServerSideRequest {
+ private final ReferenceCountedObject requestRef;
private final RaftClientRequest request;
private final AtomicReference reply = new AtomicReference<>();
- PendingOrderedRequest(RaftClientRequest request) {
- this.request = request;
+ PendingOrderedRequest(ReferenceCountedObject requestRef) {
+ this.requestRef = requestRef;
+ this.request = requestRef != null ? requestRef.retain() : null;
}
@Override
@@ -76,15 +85,23 @@ public boolean hasReply() {
@Override
public void setReply(RaftClientReply r) {
final boolean set = reply.compareAndSet(null, r);
- Preconditions.assertTrue(set, () -> "Reply is already set: request=" + request + ", reply=" + reply);
+ Preconditions.assertTrue(set, () -> "Reply is already set: request=" +
+ request.toStringShort() + ", reply=" + reply);
}
RaftClientReply getReply() {
return reply.get();
}
- RaftClientRequest getRequest() {
- return request;
+ ReferenceCountedObject getRequestRef() {
+ return requestRef;
+ }
+
+ @Override
+ public void release() {
+ if (requestRef != null) {
+ requestRef.release();
+ }
}
@Override
@@ -135,18 +152,38 @@ void closeAllExisting(RaftGroupId groupId) {
private final ExecutorService executor;
private final OrderedStreamObservers orderedStreamObservers = new OrderedStreamObservers();
+ private final boolean zeroCopyEnabled;
+ private final ZeroCopyMessageMarshaller zeroCopyRequestMarshaller;
GrpcClientProtocolService(Supplier idSupplier, RaftClientAsynchronousProtocol protocol,
- ExecutorService executor) {
+ ExecutorService executor, boolean zeroCopyEnabled, ZeroCopyMetrics zeroCopyMetrics) {
this.idSupplier = idSupplier;
this.protocol = protocol;
this.executor = executor;
+ this.zeroCopyEnabled = zeroCopyEnabled;
+ this.zeroCopyRequestMarshaller = new ZeroCopyMessageMarshaller<>(RaftClientRequestProto.getDefaultInstance(),
+ zeroCopyMetrics::onZeroCopyMessage, zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage);
+ zeroCopyMetrics.addUnreleased("client_protocol", zeroCopyRequestMarshaller::getUnclosedCount);
}
RaftPeerId getId() {
return idSupplier.get();
}
+ ServerServiceDefinition bindServiceWithZeroCopy() {
+ ServerServiceDefinition orig = super.bindService();
+ if (!zeroCopyEnabled) {
+ LOG.info("{}: Zero copy is disabled.", getId());
+ return orig;
+ }
+ ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(orig.getServiceDescriptor().getName());
+
+ addMethodWithCustomMarshaller(orig, builder, getOrderedMethod(), zeroCopyRequestMarshaller);
+ addMethodWithCustomMarshaller(orig, builder, getUnorderedMethod(), zeroCopyRequestMarshaller);
+
+ return builder.build();
+ }
+
@Override
public StreamObserver ordered(StreamObserver responseObserver) {
final OrderedRequestStreamObserver so = new OrderedRequestStreamObserver(responseObserver);
@@ -220,31 +257,43 @@ boolean isClosed() {
return isClosed.get();
}
- CompletableFuture processClientRequest(RaftClientRequest request, Consumer replyHandler) {
+ CompletableFuture processClientRequest(ReferenceCountedObject requestRef,
+ Consumer replyHandler) {
+ final String errMsg = LOG.isDebugEnabled() ? "processClientRequest for " + requestRef.get() : "";
+ CompletableFuture replyFuture;
try {
- final String errMsg = LOG.isDebugEnabled() ? "processClientRequest for " + request : "";
- return protocol.submitClientRequestAsync(request
- ).thenAcceptAsync(replyHandler, executor
- ).exceptionally(exception -> {
- // TODO: the exception may be from either raft or state machine.
- // Currently we skip all the following responses when getting an
- // exception from the state machine.
- responseError(exception, () -> errMsg);
- return null;
- });
+ replyFuture = protocol.submitClientRequestAsync(requestRef);
} catch (IOException e) {
- throw new CompletionException("Failed processClientRequest for " + request + " in " + name, e);
+ replyFuture = new CompletableFuture<>();
+ replyFuture.completeExceptionally(e);
}
+ return replyFuture.thenAcceptAsync(replyHandler, executor).exceptionally(exception -> {
+ // TODO: the exception may be from either raft or state machine.
+ // Currently we skip all the following responses when getting an
+ // exception from the state machine.
+ responseError(exception, () -> errMsg);
+ return null;
+ });
}
- abstract void processClientRequest(RaftClientRequest request);
+ abstract void processClientRequest(ReferenceCountedObject requestRef);
@Override
public void onNext(RaftClientRequestProto request) {
+ ReferenceCountedObject requestRef = null;
try {
final RaftClientRequest r = ClientProtoUtils.toRaftClientRequest(request);
- processClientRequest(r);
+ requestRef = ReferenceCountedObject.wrap(r, () -> {}, released -> {
+ if (released) {
+ zeroCopyRequestMarshaller.release(request);
+ }
+ });
+
+ processClientRequest(requestRef);
} catch (Exception e) {
+ if (requestRef == null) {
+ zeroCopyRequestMarshaller.release(request);
+ }
responseError(e, () -> "onNext for " + ClientProtoUtils.toString(request) + " in " + name);
}
}
@@ -278,15 +327,21 @@ private class UnorderedRequestStreamObserver extends RequestStreamObserver {
}
@Override
- void processClientRequest(RaftClientRequest request) {
- final CompletableFuture f = processClientRequest(request, reply -> {
- if (!reply.isSuccess()) {
- LOG.info("Failed " + request + ", reply=" + reply);
- }
- final RaftClientReplyProto proto = ClientProtoUtils.toRaftClientReplyProto(reply);
- responseNext(proto);
- });
- final long callId = request.getCallId();
+ void processClientRequest(ReferenceCountedObject requestRef) {
+ final long callId = requestRef.retain().getCallId();
+ final CompletableFuture f;
+ try {
+ f = processClientRequest(requestRef, reply -> {
+ if (!reply.isSuccess()) {
+ LOG.info("Failed request cid={}, reply={}", callId, reply);
+ }
+ final RaftClientReplyProto proto = ClientProtoUtils.toRaftClientReplyProto(reply);
+ responseNext(proto);
+ });
+ } finally {
+ requestRef.release();
+ }
+
put(callId, f);
f.thenAccept(dummy -> remove(callId));
}
@@ -329,32 +384,40 @@ RaftGroupId getGroupId() {
void processClientRequest(PendingOrderedRequest pending) {
final long seq = pending.getSeqNum();
- processClientRequest(pending.getRequest(),
+ processClientRequest(pending.getRequestRef(),
reply -> slidingWindow.receiveReply(seq, reply, this::sendReply));
}
@Override
- void processClientRequest(RaftClientRequest r) {
- if (isClosed()) {
- final AlreadyClosedException exception = new AlreadyClosedException(getName() + ": the stream is closed");
- responseError(exception, () -> "processClientRequest (stream already closed) for " + r);
- return;
- }
+ void processClientRequest(ReferenceCountedObject requestRef) {
+ final RaftClientRequest request = requestRef.retain();
+ try {
+ if (isClosed()) {
+ final AlreadyClosedException exception = new AlreadyClosedException(getName() + ": the stream is closed");
+ responseError(exception, () -> "processClientRequest (stream already closed) for " + request);
+ }
- final RaftGroupId requestGroupId = r.getRaftGroupId();
- // use the group id in the first request as the group id of this observer
- final RaftGroupId updated = groupId.updateAndGet(g -> g != null ? g: requestGroupId);
- final PendingOrderedRequest pending = new PendingOrderedRequest(r);
-
- if (!requestGroupId.equals(updated)) {
- final GroupMismatchException exception = new GroupMismatchException(getId()
- + ": The group (" + requestGroupId + ") of " + r.getClientId()
- + " does not match the group (" + updated + ") of the " + JavaUtils.getClassSimpleName(getClass()));
- responseError(exception, () -> "processClientRequest (Group mismatched) for " + r);
- return;
- }
+ final RaftGroupId requestGroupId = request.getRaftGroupId();
+ // use the group id in the first request as the group id of this observer
+ final RaftGroupId updated = groupId.updateAndGet(g -> g != null ? g : requestGroupId);
- slidingWindow.receivedRequest(pending, this::processClientRequest);
+ if (!requestGroupId.equals(updated)) {
+ final GroupMismatchException exception = new GroupMismatchException(getId()
+ + ": The group (" + requestGroupId + ") of " + request.getClientId()
+ + " does not match the group (" + updated + ") of the " + JavaUtils.getClassSimpleName(getClass()));
+ responseError(exception, () -> "processClientRequest (Group mismatched) for " + request);
+ return;
+ }
+ final PendingOrderedRequest pending = new PendingOrderedRequest(requestRef);
+ try {
+ slidingWindow.receivedRequest(pending, this::processClientRequest);
+ } catch (Exception e) {
+ pending.release();
+ throw e;
+ }
+ } finally {
+ requestRef.release();
+ }
}
private void sendReply(PendingOrderedRequest ready) {
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
index 053cc5c0f4..1e9519d24c 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcLogAppender.java
@@ -24,6 +24,7 @@
import org.apache.ratis.metrics.Timekeeper;
import org.apache.ratis.proto.RaftProtos.InstallSnapshotResult;
import org.apache.ratis.protocol.RaftPeerId;
+import org.apache.ratis.retry.MultipleLinearRandomRetry;
import org.apache.ratis.retry.RetryPolicy;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.RaftServerConfigKeys;
@@ -63,8 +64,6 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
-import static org.apache.ratis.server.raftlog.LogProtoUtils.toLogEntryTermIndexString;
-
/**
* A new log appender implementation using grpc bi-directional stream API.
*/
@@ -74,11 +73,7 @@ public class GrpcLogAppender extends LogAppenderBase {
private enum BatchLogKey implements BatchLogger.Key {
RESET_CLIENT,
INCONSISTENCY_REPLY,
- APPEND_LOG_RESPONSE_HANDLER_ON_ERROR,
- INSTALL_SNAPSHOT_NOTIFY,
- INSTALL_SNAPSHOT_REPLY,
- INSTALL_SNAPSHOT_IN_PROGRESS,
- SNAPSHOT_UNAVAILABLE
+ APPEND_LOG_RESPONSE_HANDLER_ON_ERROR
}
public static final int INSTALL_SNAPSHOT_NOTIFICATION_INDEX = 0;
@@ -199,9 +194,8 @@ public GrpcLogAppender(RaftServer.Division server, LeaderState leaderState, Foll
lock = new AutoCloseableReadWriteLock(this);
caller = LOG.isTraceEnabled()? JavaUtils.getCallerStackTraceElement(): null;
- errorRetryWaitPolicy = RetryPolicy.parse(
- RaftServerConfigKeys.Log.Appender.retryPolicy(properties),
- RaftServerConfigKeys.Log.Appender.RETRY_POLICY_KEY);
+ errorRetryWaitPolicy = MultipleLinearRandomRetry.parseCommaSeparated(
+ RaftServerConfigKeys.Log.Appender.retryPolicy(properties));
}
@Override
@@ -240,7 +234,7 @@ private void resetClient(AppendEntriesRequest request, Event event) {
}
getFollower().computeNextIndex(getNextIndexForError(nextIndex));
} catch (IOException ie) {
- LOG.warn("{}: Failed to resetClient for {}", this, getFollowerId(), ie);
+ LOG.warn(this + ": Failed to getClient for " + getFollowerId(), ie);
}
}
@@ -303,8 +297,8 @@ private void mayWait() {
getEventAwaitForSignal().await(getWaitTimeMs() + errorWaitTimeMs(),
TimeUnit.MILLISECONDS);
} catch (InterruptedException ie) {
+ LOG.warn(this + ": Wait interrupted by " + ie);
Thread.currentThread().interrupt();
- LOG.warn("{} is interrupted: {}", this, ie.toString());
}
}
@@ -315,8 +309,14 @@ private long errorWaitTimeMs() {
@Override
public CompletableFuture stopAsync() {
- grpcServerMetrics.unregister();
- return super.stopAsync();
+ try (AutoCloseableLock ignored = lock.writeLock(caller, LOG::trace)) {
+ if (appendLogRequestObserver != null) {
+ appendLogRequestObserver.stop();
+ appendLogRequestObserver = null;
+ }
+ grpcServerMetrics.unregister();
+ return super.stopAsync();
+ }
}
@Override
@@ -392,30 +392,42 @@ public Comparator getCallIdComparator() {
}
private void appendLog(boolean heartbeat) throws IOException {
- final AppendEntriesRequestProto pending;
+ final ReferenceCountedObject pending;
final AppendEntriesRequest request;
try (AutoCloseableLock writeLock = lock.writeLock(caller, LOG::trace)) {
+ if (!isRunning()) {
+ return;
+ }
// Prepare and send the append request.
// Note changes on follower's nextIndex and ops on pendingRequests should always be done under the write-lock
- pending = newAppendEntriesRequest(callId.getAndIncrement(), heartbeat);
+ pending = nextAppendEntriesRequest(callId.getAndIncrement(), heartbeat);
if (pending == null) {
return;
}
- request = new AppendEntriesRequest(pending, getFollowerId(), grpcServerMetrics);
- pendingRequests.put(request);
- increaseNextIndex(pending);
- if (appendLogRequestObserver == null) {
- appendLogRequestObserver = new StreamObservers(
- getClient(), new AppendLogResponseHandler(), useSeparateHBChannel, getWaitTimeMin());
+ try {
+ request = new AppendEntriesRequest(pending.get(), getFollowerId(), grpcServerMetrics);
+ pendingRequests.put(request);
+ increaseNextIndex(pending.get());
+ if (appendLogRequestObserver == null) {
+ appendLogRequestObserver = new StreamObservers(
+ getClient(), new AppendLogResponseHandler(), useSeparateHBChannel, getWaitTimeMin());
+ }
+ } catch (Exception e) {
+ pending.release();
+ throw e;
}
}
- final TimeDuration remaining = getRemainingWaitTime();
- if (remaining.isPositive()) {
- sleep(remaining, heartbeat);
- }
- if (isRunning()) {
- sendRequest(request, pending);
+ try {
+ final TimeDuration remaining = getRemainingWaitTime();
+ if (remaining.isPositive()) {
+ sleep(remaining, heartbeat);
+ }
+ if (isRunning()) {
+ sendRequest(request, pending.get());
+ }
+ } finally {
+ pending.release();
}
}
@@ -503,8 +515,8 @@ public void onNext(AppendEntriesReplyProto reply) {
try {
onNextImpl(request, reply);
} catch(Exception t) {
- LOG.error("Failed onNext(reply), request={}, reply={}",
- request, ServerStringUtils.toAppendEntriesReplyString(reply), t);
+ LOG.error("Failed onNext request=" + request
+ + ", reply=" + ServerStringUtils.toAppendEntriesReplyString(reply), t);
}
}
@@ -579,8 +591,8 @@ private void updateNextIndex(long replyNextIndex) {
}
private class InstallSnapshotResponseHandler implements StreamObserver {
- private final String name;
- private final Queue pending = new LinkedList<>();
+ private final String name = getFollower().getName() + "-" + JavaUtils.getClassSimpleName(getClass());
+ private final Queue pending;
private final CompletableFuture done = new CompletableFuture<>();
private final boolean isNotificationOnly;
@@ -589,8 +601,8 @@ private class InstallSnapshotResponseHandler implements StreamObserver();
this.isNotificationOnly = notifyOnly;
- this.name = getFollower().getName() + "-InstallSnapshot" + (isNotificationOnly ? "Notification" : "");
}
void addPending(InstallSnapshotRequestProto request) {
@@ -632,8 +644,8 @@ void onFollowerCatchup(long followerSnapshotIndex) {
final long leaderStartIndex = getRaftLog().getStartIndex();
final long followerNextIndex = followerSnapshotIndex + 1;
if (followerNextIndex >= leaderStartIndex) {
- LOG.info("{}: follower nextIndex = {} >= leader startIndex = {}",
- this, followerNextIndex, leaderStartIndex);
+ LOG.info("{}: Follower can catch up leader after install the snapshot, as leader's start index is {}",
+ this, followerNextIndex);
notifyInstallSnapshotFinished(InstallSnapshotResult.SUCCESS, followerSnapshotIndex);
}
}
@@ -665,10 +677,10 @@ boolean hasAllResponse() {
@Override
public void onNext(InstallSnapshotReplyProto reply) {
- BatchLogger.print(BatchLogKey.INSTALL_SNAPSHOT_REPLY, name,
- suffix -> LOG.info("{}: received {} reply {} {}", this,
- replyState.isFirstReplyReceived() ? "a" : "the first",
- ServerStringUtils.toInstallSnapshotReplyString(reply), suffix));
+ if (LOG.isInfoEnabled()) {
+ LOG.info("{}: received {} reply {}", this, replyState.isFirstReplyReceived()? "a" : "the first",
+ ServerStringUtils.toInstallSnapshotReplyString(reply));
+ }
// update the last rpc time
getFollower().updateLastRpcResponseTime();
@@ -677,13 +689,12 @@ public void onNext(InstallSnapshotReplyProto reply) {
final long followerSnapshotIndex;
switch (reply.getResult()) {
case SUCCESS:
- LOG.info("{}: Completed", this);
+ LOG.info("{}: Completed InstallSnapshot. Reply: {}", this, reply);
getFollower().setAttemptedToInstallSnapshot();
removePending(reply);
break;
case IN_PROGRESS:
- BatchLogger.print(BatchLogKey.INSTALL_SNAPSHOT_IN_PROGRESS, name,
- suffix -> LOG.info("{}: in progress, {}", this, suffix));
+ LOG.info("{}: InstallSnapshot in progress.", this);
removePending(reply);
break;
case ALREADY_INSTALLED:
@@ -699,7 +710,7 @@ public void onNext(InstallSnapshotReplyProto reply) {
onFollowerTerm(reply.getTerm());
break;
case CONF_MISMATCH:
- LOG.error("{}: CONF_MISMATCH ({}): Leader {} has it set to {} but follower {} has it set to {}",
+ LOG.error("{}: Configuration Mismatch ({}): Leader {} has it set to {} but follower {} has it set to {}",
this, RaftServerConfigKeys.Log.Appender.INSTALL_SNAPSHOT_ENABLED_KEY,
getServer().getId(), installSnapshotEnabled, getFollowerId(), !installSnapshotEnabled);
break;
@@ -714,19 +725,17 @@ public void onNext(InstallSnapshotReplyProto reply) {
removePending(reply);
break;
case SNAPSHOT_UNAVAILABLE:
- BatchLogger.print(BatchLogKey.SNAPSHOT_UNAVAILABLE, name,
- suffix -> LOG.info("{}: Follower failed since the snapshot is unavailable {}", this, suffix));
+ LOG.info("{}: Follower could not install snapshot as it is not available.", this);
getFollower().setAttemptedToInstallSnapshot();
notifyInstallSnapshotFinished(InstallSnapshotResult.SNAPSHOT_UNAVAILABLE, RaftLog.INVALID_LOG_INDEX);
removePending(reply);
break;
case UNRECOGNIZED:
- LOG.error("{}: Reply result {}, {}",
- name, reply.getResult(), ServerStringUtils.toInstallSnapshotReplyString(reply));
+ LOG.error("Unrecognized the reply result {}: Leader is {}, follower is {}",
+ reply.getResult(), getServer().getId(), getFollowerId());
break;
case SNAPSHOT_EXPIRED:
- LOG.warn("{}: Follower failed since the request expired, {}",
- name, ServerStringUtils.toInstallSnapshotReplyString(reply));
+ LOG.warn("{}: Follower could not install snapshot as it is expired.", this);
default:
break;
}
@@ -805,9 +814,8 @@ private void installSnapshot(SnapshotInfo snapshot) {
* @param firstAvailable the first available log's index on the Leader
*/
private void notifyInstallSnapshot(TermIndex firstAvailable) {
- BatchLogger.print(BatchLogKey.INSTALL_SNAPSHOT_NOTIFY, getFollower().getName(),
- suffix -> LOG.info("{}: notifyInstallSnapshot with firstAvailable={}, followerNextIndex={} {}",
- this, firstAvailable, getFollower().getNextIndex(), suffix));
+ LOG.info("{}: notifyInstallSnapshot with firstAvailable={}, followerNextIndex={}",
+ this, firstAvailable, getFollower().getNextIndex());
final InstallSnapshotResponseHandler responseHandler = new InstallSnapshotResponseHandler(true);
StreamObserver snapshotRequestObserver = null;
@@ -834,6 +842,45 @@ private void notifyInstallSnapshot(TermIndex firstAvailable) {
responseHandler.waitForResponse();
}
+ /**
+ * Should the Leader notify the Follower to install the snapshot through
+ * its own State Machine.
+ * @return the first available log's start term index
+ */
+ @Override
+ public TermIndex shouldNotifyToInstallSnapshot() {
+ final FollowerInfo follower = getFollower();
+ final long leaderNextIndex = getRaftLog().getNextIndex();
+ final boolean isFollowerBootstrapping = getLeaderState().isFollowerBootstrapping(follower);
+ final long leaderStartIndex = getRaftLog().getStartIndex();
+ final TermIndex firstAvailable = Optional.ofNullable(getRaftLog().getTermIndex(leaderStartIndex))
+ .orElseGet(() -> TermIndex.valueOf(getServer().getInfo().getCurrentTerm(), leaderNextIndex));
+ if (isFollowerBootstrapping && !follower.hasAttemptedToInstallSnapshot()) {
+ // If the follower is bootstrapping and has not yet installed any snapshot from leader, then the follower should
+ // be notified to install a snapshot. Every follower should try to install at least one snapshot during
+ // bootstrapping, if available.
+ LOG.debug("{}: follower is bootstrapping, notify to install snapshot to {}.", this, firstAvailable);
+ return firstAvailable;
+ }
+
+ final long followerNextIndex = follower.getNextIndex();
+ if (followerNextIndex >= leaderNextIndex) {
+ return null;
+ }
+
+ if (followerNextIndex < leaderStartIndex) {
+ // The Leader does not have the logs from the Follower's last log
+ // index onwards. And install snapshot is disabled. So the Follower
+ // should be notified to install the latest snapshot through its
+ // State Machine.
+ return firstAvailable;
+ } else if (leaderStartIndex == RaftLog.INVALID_LOG_INDEX) {
+ // Leader has no logs to check from, hence return next index.
+ return firstAvailable;
+ }
+
+ return null;
+ }
static class AppendEntriesRequest {
private final Timekeeper timer;
@@ -891,9 +938,13 @@ boolean isHeartbeat() {
@Override
public String toString() {
+ final String entries = entriesCount == 0? ""
+ : entriesCount == 1? ",entry=" + firstEntry
+ : ",entries=" + firstEntry + "..." + lastEntry;
return JavaUtils.getClassSimpleName(getClass())
+ ":cid=" + callId
- + ":" + toLogEntryTermIndexString(entriesCount, firstEntry, lastEntry);
+ + ",entriesCount=" + entriesCount
+ + entries;
}
}
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
index d2748c7be2..4a280ab335 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java
@@ -1,4 +1,4 @@
-/*
+/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
@@ -17,11 +17,13 @@
*/
package org.apache.ratis.grpc.server;
+import org.apache.ratis.grpc.GrpcTlsConfig;
import org.apache.ratis.grpc.GrpcUtil;
import org.apache.ratis.grpc.util.StreamObserverWithTimeout;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.server.util.ServerStringUtils;
import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
+import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts;
import org.apache.ratis.thirdparty.io.grpc.netty.NegotiationType;
import org.apache.ratis.thirdparty.io.grpc.netty.NettyChannelBuilder;
import org.apache.ratis.thirdparty.io.grpc.stub.CallStreamObserver;
@@ -31,7 +33,7 @@
import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceBlockingStub;
import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceStub;
import org.apache.ratis.protocol.RaftPeer;
-import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
import org.apache.ratis.util.TimeDuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -42,10 +44,9 @@
* This is a RaftClient implementation that supports streaming data to the raft
* ring. The stream implementation utilizes gRPC.
*/
-class GrpcServerProtocolClient implements Closeable {
+public class GrpcServerProtocolClient implements Closeable {
// Common channel
private final ManagedChannel channel;
- private final GrpcStubPool pool;
// Channel and stub for heartbeat
private ManagedChannel hbChannel;
private RaftServerProtocolServiceStub hbAsyncStub;
@@ -58,34 +59,40 @@ class GrpcServerProtocolClient implements Closeable {
//visible for using in log / error messages AND to use in instrumented tests
private final RaftPeerId raftPeerId;
- GrpcServerProtocolClient(RaftPeer target, int connections, int flowControlWindow,
- TimeDuration requestTimeout, SslContext sslContext, boolean separateHBChannel) {
+ public GrpcServerProtocolClient(RaftPeer target, int flowControlWindow,
+ TimeDuration requestTimeout, GrpcTlsConfig tlsConfig, boolean separateHBChannel) {
raftPeerId = target.getId();
LOG.info("Build channel for {}", target);
useSeparateHBChannel = separateHBChannel;
- channel = buildChannel(target, flowControlWindow, sslContext);
+ channel = buildChannel(target, flowControlWindow, tlsConfig);
blockingStub = RaftServerProtocolServiceGrpc.newBlockingStub(channel);
asyncStub = RaftServerProtocolServiceGrpc.newStub(channel);
if (useSeparateHBChannel) {
- hbChannel = buildChannel(target, flowControlWindow, sslContext);
+ hbChannel = buildChannel(target, flowControlWindow, tlsConfig);
hbAsyncStub = RaftServerProtocolServiceGrpc.newStub(hbChannel);
}
requestTimeoutDuration = requestTimeout;
- this.pool = connections == 1? null : newGrpcStubPool(target.getAddress(), sslContext, connections);
- }
-
- GrpcStubPool newGrpcStubPool(String address, SslContext sslContext, int connections) {
- return new GrpcStubPool<>(connections, address, sslContext, RaftServerProtocolServiceGrpc::newStub, 16);
}
- private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow, SslContext sslContext) {
+ private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow,
+ GrpcTlsConfig tlsConfig) {
NettyChannelBuilder channelBuilder =
NettyChannelBuilder.forTarget(target.getAddress());
// ignore any http proxy for grpc
channelBuilder.proxyDetector(uri -> null);
- if (sslContext != null) {
- channelBuilder.useTransportSecurity().sslContext(sslContext);
+ if (tlsConfig!= null) {
+ SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
+ GrpcUtil.setTrustManager(sslContextBuilder, tlsConfig.getTrustManager());
+ if (tlsConfig.getMtlsEnabled()) {
+ GrpcUtil.setKeyManager(sslContextBuilder, tlsConfig.getKeyManager());
+ }
+ try {
+ channelBuilder.useTransportSecurity().sslContext(sslContextBuilder.build());
+ } catch (Exception ex) {
+ throw new IllegalArgumentException("Failed to build SslContext, peerId=" + raftPeerId
+ + ", tlsConfig=" + tlsConfig, ex);
+ }
} else {
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
}
@@ -100,9 +107,6 @@ public void close() {
GrpcUtil.shutdownManagedChannel(hbChannel);
}
GrpcUtil.shutdownManagedChannel(channel);
- if (pool != null) {
- pool.close();
- }
}
public RequestVoteReplyProto requestVote(RequestVoteRequestProto request) {
@@ -121,44 +125,8 @@ public StartLeaderElectionReplyProto startLeaderElection(StartLeaderElectionRequ
}
void readIndex(ReadIndexRequestProto request, StreamObserver s) {
- if (pool == null) {
- asyncStub.withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit())
- .readIndex(request, s);
- } else {
- GrpcStubPool.Stub p;
- try {
- p = pool.acquire();
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- s.onError(e);
- return;
- }
- p.getStub().withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit())
- .readIndex(request, new StreamObserver() {
- @Override
- public void onNext(ReadIndexReplyProto v) {
- s.onNext(v);
- }
-
- @Override
- public void onError(Throwable t) {
- try {
- s.onError(t);
- } finally {
- p.release();
- }
- }
-
- @Override
- public void onCompleted() {
- try {
- s.onCompleted();
- } finally {
- p.release();
- }
- }
- });
- }
+ asyncStub.withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit())
+ .readIndex(request, s);
}
CallStreamObserver appendEntries(
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java
index a13e74b89d..7e17cb3cf4 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java
@@ -20,19 +20,21 @@
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.ratis.grpc.GrpcUtil;
+import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
+import org.apache.ratis.grpc.util.ZeroCopyMessageMarshaller;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.protocol.RaftServerProtocol;
import org.apache.ratis.server.util.ServerStringUtils;
-import org.apache.ratis.thirdparty.com.google.protobuf.MessageOrBuilder;
+import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
import org.apache.ratis.thirdparty.io.grpc.Status;
-import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
import org.apache.ratis.proto.RaftProtos.*;
import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceImplBase;
import org.apache.ratis.util.BatchLogger;
import org.apache.ratis.util.MemoizedSupplier;
import org.apache.ratis.util.ProtoUtils;
+import org.apache.ratis.util.ReferenceCountedObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -43,6 +45,9 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
+import static org.apache.ratis.grpc.GrpcUtil.addMethodWithCustomMarshaller;
+import static org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.getAppendEntriesMethod;
+
class GrpcServerProtocolService extends RaftServerProtocolServiceImplBase {
public static final Logger LOG = LoggerFactory.getLogger(GrpcServerProtocolService.class);
@@ -52,24 +57,31 @@ private enum BatchLogKey implements BatchLogger.Key {
}
static class PendingServerRequest {
- private final REQUEST request;
+ private final AtomicReference> requestRef;
private final CompletableFuture future = new CompletableFuture<>();
- PendingServerRequest(REQUEST request) {
- this.request = request;
+ PendingServerRequest(ReferenceCountedObject requestRef) {
+ requestRef.retain();
+ this.requestRef = new AtomicReference<>(requestRef);
}
REQUEST getRequest() {
- return request;
+ return Optional.ofNullable(requestRef.get())
+ .map(ReferenceCountedObject::get)
+ .orElse(null);
}
CompletableFuture getFuture() {
return future;
}
+
+ void release() {
+ Optional.ofNullable(requestRef.getAndSet(null))
+ .ifPresent(ReferenceCountedObject::release);
+ }
}
- abstract class ServerRequestStreamObserver
- implements StreamObserver {
+ abstract class ServerRequestStreamObserver implements StreamObserver {
private final RaftServer.Op op;
private final Supplier nameSupplier;
private final StreamObserver responseObserver;
@@ -97,24 +109,39 @@ private String getPreviousRequestString() {
.orElse(null);
}
- abstract CompletableFuture process(REQUEST request) throws IOException;
+ CompletableFuture process(REQUEST request) throws IOException {
+ throw new UnsupportedOperationException("This method is not supported.");
+ }
+
+ CompletableFuture process(ReferenceCountedObject requestRef)
+ throws IOException {
+ try {
+ return process(requestRef.retain());
+ } finally {
+ requestRef.release();
+ }
+ }
+
+ void release(REQUEST req) {
+ }
abstract long getCallId(REQUEST request);
+ boolean isHeartbeat(REQUEST request) {
+ return false;
+ }
+
abstract String requestToString(REQUEST request);
abstract String replyToString(REPLY reply);
abstract boolean replyInOrder(REQUEST request);
- StatusRuntimeException wrapException(Throwable e, REQUEST request) {
- return GrpcUtil.wrapException(e, getCallId(request));
- }
-
- private void handleError(Throwable e, REQUEST request) {
- GrpcUtil.warn(LOG, () -> getId() + ": Failed " + op + " request " + requestToString(request), e);
+ private synchronized void handleError(Throwable e, long callId, boolean isHeartbeat) {
+ GrpcUtil.warn(LOG, () -> getId() + ": Failed " + op + " request cid=" + callId + ", isHeartbeat? "
+ + isHeartbeat, e);
if (isClosed.compareAndSet(false, true)) {
- responseObserver.onError(wrapException(e, request));
+ responseObserver.onError(GrpcUtil.wrapException(e, callId, isHeartbeat));
}
}
@@ -134,24 +161,32 @@ void composeRequest(CompletableFuture current) {
@Override
public void onNext(REQUEST request) {
+ ReferenceCountedObject requestRef = ReferenceCountedObject.wrap(request, () -> {}, released -> {
+ if (released) {
+ release(request);
+ }
+ });
+
if (!replyInOrder(request)) {
try {
- composeRequest(process(request).thenApply(this::handleReply));
+ composeRequest(process(requestRef).thenApply(this::handleReply));
} catch (Exception e) {
- handleError(e, request);
+ handleError(e, getCallId(request), isHeartbeat(request));
+ release(request);
}
return;
}
- final PendingServerRequest current = new PendingServerRequest<>(request);
- final PendingServerRequest previous = previousOnNext.getAndSet(current);
- final CompletableFuture previousFuture = Optional.ofNullable(previous)
- .map(PendingServerRequest::getFuture)
+ final PendingServerRequest current = new PendingServerRequest<>(requestRef);
+ final long callId = getCallId(current.getRequest());
+ final boolean isHeartbeat = isHeartbeat(current.getRequest());
+ final Optional> previous = Optional.ofNullable(previousOnNext.getAndSet(current));
+ final CompletableFuture previousFuture = previous.map(PendingServerRequest::getFuture)
.orElse(CompletableFuture.completedFuture(null));
try {
- final CompletableFuture f = process(request).exceptionally(e -> {
+ final CompletableFuture f = process(requestRef).exceptionally(e -> {
// Handle cases, such as RaftServer is paused
- handleError(e, request);
+ handleError(e, callId, isHeartbeat);
current.getFuture().completeExceptionally(e);
return null;
}).thenCombine(previousFuture, (reply, v) -> {
@@ -161,8 +196,14 @@ public void onNext(REQUEST request) {
});
composeRequest(f);
} catch (Exception e) {
- handleError(e, request);
+ handleError(e, callId, isHeartbeat);
current.getFuture().completeExceptionally(e);
+ } finally {
+ previous.ifPresent(PendingServerRequest::release);
+ if (isClosed.get()) {
+ // Some requests may come after onCompleted or onError, ensure they're released.
+ releaseLast();
+ }
}
}
@@ -174,12 +215,13 @@ public void onCompleted() {
getId(), op, getPreviousRequestString(), suffix));
requestFuture.get().thenAccept(reply -> {
BatchLogger.print(BatchLogKey.COMPLETED_REPLY, getName(),
- suffix -> LOG.info("{}: Completed {}, lastReply: {} {}",
- getId(), op, ProtoUtils.shortDebugString(reply), suffix));
+ suffix -> LOG.info("{}: Completed {}, lastReply: {} {}", getId(), op, reply, suffix));
responseObserver.onCompleted();
});
+ releaseLast();
}
}
+
@Override
public void onError(Throwable t) {
GrpcUtil.warn(LOG, () -> getId() + ": "+ op + " onError, lastRequest: " + getPreviousRequestString(), t);
@@ -188,22 +230,54 @@ public void onError(Throwable t) {
if (status != null && status.getCode() != Status.Code.CANCELLED) {
responseObserver.onCompleted();
}
+ releaseLast();
}
}
+
+ private void releaseLast() {
+ Optional.ofNullable(previousOnNext.get()).ifPresent(PendingServerRequest::release);
+ }
}
private final Supplier idSupplier;
private final RaftServer server;
+ private final boolean zeroCopyEnabled;
+ private final ZeroCopyMessageMarshaller zeroCopyRequestMarshaller;
- GrpcServerProtocolService(Supplier idSupplier, RaftServer server) {
+ GrpcServerProtocolService(Supplier idSupplier, RaftServer server, boolean zeroCopyEnabled,
+ ZeroCopyMetrics zeroCopyMetrics) {
this.idSupplier = idSupplier;
this.server = server;
+ this.zeroCopyEnabled = zeroCopyEnabled;
+ this.zeroCopyRequestMarshaller = new ZeroCopyMessageMarshaller<>(AppendEntriesRequestProto.getDefaultInstance(),
+ zeroCopyMetrics::onZeroCopyMessage, zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage);
+ zeroCopyMetrics.addUnreleased("server_protocol", zeroCopyRequestMarshaller::getUnclosedCount);
}
RaftPeerId getId() {
return idSupplier.get();
}
+ ServerServiceDefinition bindServiceWithZeroCopy() {
+ ServerServiceDefinition orig = super.bindService();
+ if (!zeroCopyEnabled) {
+ LOG.info("{}: Zero copy is disabled.", getId());
+ return orig;
+ }
+ ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(orig.getServiceDescriptor().getName());
+
+ // Add appendEntries with zero copy marshaller.
+ addMethodWithCustomMarshaller(orig, builder, getAppendEntriesMethod(), zeroCopyRequestMarshaller);
+ // Add remaining methods as is.
+ orig.getMethods().stream().filter(
+ x -> !x.getMethodDescriptor().getFullMethodName().equals(getAppendEntriesMethod().getFullMethodName())
+ ).forEach(
+ builder::addMethod
+ );
+
+ return builder.build();
+ }
+
@Override
public void requestVote(RequestVoteRequestProto request,
StreamObserver responseObserver) {
@@ -244,8 +318,14 @@ public StreamObserver appendEntries(
return new ServerRequestStreamObserver(
RaftServerProtocol.Op.APPEND_ENTRIES, responseObserver) {
@Override
- CompletableFuture process(AppendEntriesRequestProto request) throws IOException {
- return server.appendEntriesAsync(request);
+ CompletableFuture process(ReferenceCountedObject requestRef)
+ throws IOException {
+ return server.appendEntriesAsync(requestRef);
+ }
+
+ @Override
+ void release(AppendEntriesRequestProto req) {
+ zeroCopyRequestMarshaller.release(req);
}
@Override
@@ -253,6 +333,11 @@ long getCallId(AppendEntriesRequestProto request) {
return request.getServerRequest().getCallId();
}
+ @Override
+ boolean isHeartbeat(AppendEntriesRequestProto request) {
+ return request.getEntriesCount() == 0;
+ }
+
@Override
String requestToString(AppendEntriesRequestProto request) {
return ServerStringUtils.toAppendEntriesRequestString(request, null);
@@ -267,11 +352,6 @@ String replyToString(AppendEntriesReplyProto reply) {
boolean replyInOrder(AppendEntriesRequestProto request) {
return request.getEntriesCount() != 0;
}
-
- @Override
- StatusRuntimeException wrapException(Throwable e, AppendEntriesRequestProto request) {
- return GrpcUtil.wrapException(e, getCallId(request), request.getEntriesCount() == 0);
- }
};
}
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
new file mode 100644
index 0000000000..e3c0a5eddb
--- /dev/null
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcService.java
@@ -0,0 +1,396 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.grpc.server;
+
+import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.grpc.GrpcConfigKeys;
+import org.apache.ratis.grpc.GrpcTlsConfig;
+import org.apache.ratis.grpc.GrpcUtil;
+import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
+import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor;
+import org.apache.ratis.protocol.RaftGroupId;
+import org.apache.ratis.protocol.RaftPeerId;
+import org.apache.ratis.rpc.SupportedRpcType;
+import org.apache.ratis.server.RaftServer;
+import org.apache.ratis.server.RaftServerConfigKeys;
+import org.apache.ratis.server.RaftServerRpcWithProxy;
+import org.apache.ratis.server.protocol.RaftServerAsynchronousProtocol;
+import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting;
+import org.apache.ratis.thirdparty.io.grpc.ServerInterceptors;
+import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts;
+import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder;
+import org.apache.ratis.thirdparty.io.grpc.Server;
+import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
+import org.apache.ratis.thirdparty.io.netty.channel.ChannelOption;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.ClientAuth;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
+
+import org.apache.ratis.proto.RaftProtos.*;
+import org.apache.ratis.util.*;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Supplier;
+
+import static org.apache.ratis.thirdparty.io.netty.handler.ssl.SslProvider.OPENSSL;
+
+/** A grpc implementation of {@link org.apache.ratis.server.RaftServerRpc}. */
+public final class GrpcService extends RaftServerRpcWithProxy> {
+ static final Logger LOG = LoggerFactory.getLogger(GrpcService.class);
+ public static final String GRPC_SEND_SERVER_REQUEST =
+ JavaUtils.getClassSimpleName(GrpcService.class) + ".sendRequest";
+
+ class AsyncService implements RaftServerAsynchronousProtocol {
+
+ @Override
+ public CompletableFuture appendEntriesAsync(AppendEntriesRequestProto request)
+ throws IOException {
+ throw new UnsupportedOperationException("This method is not supported");
+ }
+
+ @Override
+ public CompletableFuture readIndexAsync(ReadIndexRequestProto request) throws IOException {
+ CodeInjectionForTesting.execute(GRPC_SEND_SERVER_REQUEST, getId(), null, request);
+
+ final CompletableFuture f = new CompletableFuture<>();
+ final StreamObserver s = new StreamObserver() {
+ @Override
+ public void onNext(ReadIndexReplyProto reply) {
+ f.complete(reply);
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ f.completeExceptionally(throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ }
+ };
+
+ final RaftPeerId target = RaftPeerId.valueOf(request.getServerRequest().getReplyId());
+ getProxies().getProxy(target).readIndex(request, s);
+ return f;
+ }
+ }
+
+ public static final class Builder {
+ private RaftServer server;
+ private GrpcTlsConfig tlsConfig;
+ private GrpcTlsConfig adminTlsConfig;
+ private GrpcTlsConfig clientTlsConfig;
+ private GrpcTlsConfig serverTlsConfig;
+
+ private Builder() {}
+
+ public Builder setServer(RaftServer raftServer) {
+ this.server = raftServer;
+ return this;
+ }
+
+ public GrpcService build() {
+ return new GrpcService(server, adminTlsConfig, clientTlsConfig, serverTlsConfig);
+ }
+
+ public Builder setTlsConfig(GrpcTlsConfig tlsConfig) {
+ this.tlsConfig = tlsConfig;
+ return this;
+ }
+
+ public Builder setAdminTlsConfig(GrpcTlsConfig config) {
+ this.adminTlsConfig = config;
+ return this;
+ }
+
+ public Builder setClientTlsConfig(GrpcTlsConfig config) {
+ this.clientTlsConfig = config;
+ return this;
+ }
+
+ public Builder setServerTlsConfig(GrpcTlsConfig config) {
+ this.serverTlsConfig = config;
+ return this;
+ }
+
+ public GrpcTlsConfig getTlsConfig() {
+ return tlsConfig;
+ }
+ }
+
+ public static Builder newBuilder() {
+ return new Builder();
+ }
+
+ private final Map servers = new HashMap<>();
+ private final Supplier addressSupplier;
+ private final Supplier clientServerAddressSupplier;
+ private final Supplier adminServerAddressSupplier;
+
+ private final AsyncService asyncService = new AsyncService();
+
+ private final ExecutorService executor;
+ private final GrpcClientProtocolService clientProtocolService;
+
+ private final MetricServerInterceptor serverInterceptor;
+ private final ZeroCopyMetrics zeroCopyMetrics;
+
+ public MetricServerInterceptor getServerInterceptor() {
+ return serverInterceptor;
+ }
+
+ private GrpcService(RaftServer server,
+ GrpcTlsConfig adminTlsConfig, GrpcTlsConfig clientTlsConfig, GrpcTlsConfig serverTlsConfig) {
+ this(server, server::getId,
+ GrpcConfigKeys.Admin.host(server.getProperties()),
+ GrpcConfigKeys.Admin.port(server.getProperties()),
+ adminTlsConfig,
+ GrpcConfigKeys.Client.host(server.getProperties()),
+ GrpcConfigKeys.Client.port(server.getProperties()),
+ clientTlsConfig,
+ GrpcConfigKeys.Server.host(server.getProperties()),
+ GrpcConfigKeys.Server.port(server.getProperties()),
+ serverTlsConfig,
+ GrpcConfigKeys.messageSizeMax(server.getProperties(), LOG::info),
+ RaftServerConfigKeys.Log.Appender.bufferByteLimit(server.getProperties()),
+ GrpcConfigKeys.flowControlWindow(server.getProperties(), LOG::info),
+ RaftServerConfigKeys.Rpc.requestTimeout(server.getProperties()),
+ GrpcConfigKeys.Server.heartbeatChannel(server.getProperties()),
+ GrpcConfigKeys.Server.zeroCopyEnabled(server.getProperties()));
+ }
+
+ @SuppressWarnings("checkstyle:ParameterNumber") // private constructor
+ private GrpcService(RaftServer raftServer, Supplier idSupplier,
+ String adminHost, int adminPort, GrpcTlsConfig adminTlsConfig,
+ String clientHost, int clientPort, GrpcTlsConfig clientTlsConfig,
+ String serverHost, int serverPort, GrpcTlsConfig serverTlsConfig,
+ SizeInBytes grpcMessageSizeMax, SizeInBytes appenderBufferSize,
+ SizeInBytes flowControlWindow,TimeDuration requestTimeoutDuration,
+ boolean useSeparateHBChannel, boolean zeroCopyEnabled) {
+ super(idSupplier, id -> new PeerProxyMap<>(id.toString(),
+ p -> new GrpcServerProtocolClient(p, flowControlWindow.getSizeInt(),
+ requestTimeoutDuration, serverTlsConfig, useSeparateHBChannel)));
+ if (appenderBufferSize.getSize() > grpcMessageSizeMax.getSize()) {
+ throw new IllegalArgumentException("Illegal configuration: "
+ + RaftServerConfigKeys.Log.Appender.BUFFER_BYTE_LIMIT_KEY + " = " + appenderBufferSize
+ + " > " + GrpcConfigKeys.MESSAGE_SIZE_MAX_KEY + " = " + grpcMessageSizeMax);
+ }
+
+ final RaftProperties properties = raftServer.getProperties();
+ this.executor = ConcurrentUtils.newThreadPoolWithMax(
+ GrpcConfigKeys.Server.asyncRequestThreadPoolCached(properties),
+ GrpcConfigKeys.Server.asyncRequestThreadPoolSize(properties),
+ getId() + "-request-");
+ this.zeroCopyMetrics = new ZeroCopyMetrics();
+ this.clientProtocolService = new GrpcClientProtocolService(idSupplier, raftServer, executor,
+ zeroCopyEnabled, zeroCopyMetrics);
+
+ this.serverInterceptor = new MetricServerInterceptor(
+ idSupplier,
+ JavaUtils.getClassSimpleName(getClass()) + "_" + serverPort
+ );
+
+ final boolean separateAdminServer = adminPort != serverPort && adminPort > 0;
+ final boolean separateClientServer = clientPort != serverPort && clientPort > 0;
+
+ final NettyServerBuilder serverBuilder =
+ startBuildingNettyServer(serverHost, serverPort, serverTlsConfig, grpcMessageSizeMax, flowControlWindow);
+ GrpcServerProtocolService serverProtocolService = new GrpcServerProtocolService(idSupplier, raftServer,
+ zeroCopyEnabled, zeroCopyMetrics);
+ serverBuilder.addService(ServerInterceptors.intercept(
+ serverProtocolService.bindServiceWithZeroCopy(), serverInterceptor));
+ if (!separateAdminServer) {
+ addAdminService(raftServer, serverBuilder);
+ }
+ if (!separateClientServer) {
+ addClientService(serverBuilder);
+ }
+
+ final Server server = serverBuilder.build();
+ servers.put(GrpcServerProtocolService.class.getSimpleName(), server);
+ addressSupplier = newAddressSupplier(serverPort, server);
+
+ if (separateAdminServer) {
+ final NettyServerBuilder builder =
+ startBuildingNettyServer(adminHost, adminPort, adminTlsConfig, grpcMessageSizeMax, flowControlWindow);
+ addAdminService(raftServer, builder);
+ final Server adminServer = builder.build();
+ servers.put(GrpcAdminProtocolService.class.getName(), adminServer);
+ adminServerAddressSupplier = newAddressSupplier(adminPort, adminServer);
+ } else {
+ adminServerAddressSupplier = addressSupplier;
+ }
+
+ if (separateClientServer) {
+ final NettyServerBuilder builder =
+ startBuildingNettyServer(clientHost, clientPort, clientTlsConfig, grpcMessageSizeMax, flowControlWindow);
+ addClientService(builder);
+ final Server clientServer = builder.build();
+ servers.put(GrpcClientProtocolService.class.getName(), clientServer);
+ clientServerAddressSupplier = newAddressSupplier(clientPort, clientServer);
+ } else {
+ clientServerAddressSupplier = addressSupplier;
+ }
+ }
+
+ private MemoizedSupplier newAddressSupplier(int port, Server server) {
+ return JavaUtils.memoize(() -> new InetSocketAddress(port != 0 ? port : server.getPort()));
+ }
+
+ private void addClientService(NettyServerBuilder builder) {
+ builder.addService(ServerInterceptors.intercept(clientProtocolService.bindServiceWithZeroCopy(),
+ serverInterceptor));
+ }
+
+ private void addAdminService(RaftServer raftServer, NettyServerBuilder nettyServerBuilder) {
+ nettyServerBuilder.addService(ServerInterceptors.intercept(
+ new GrpcAdminProtocolService(raftServer),
+ serverInterceptor));
+ }
+
+ private static NettyServerBuilder startBuildingNettyServer(String hostname, int port, GrpcTlsConfig tlsConfig,
+ SizeInBytes grpcMessageSizeMax, SizeInBytes flowControlWindow) {
+ InetSocketAddress address = hostname == null || hostname.isEmpty() ?
+ new InetSocketAddress(port) : new InetSocketAddress(hostname, port);
+ NettyServerBuilder nettyServerBuilder = NettyServerBuilder.forAddress(address)
+ .withChildOption(ChannelOption.SO_REUSEADDR, true)
+ .maxInboundMessageSize(grpcMessageSizeMax.getSizeInt())
+ .flowControlWindow(flowControlWindow.getSizeInt());
+
+ if (tlsConfig != null) {
+ SslContextBuilder sslContextBuilder = GrpcUtil.initSslContextBuilderForServer(tlsConfig.getKeyManager());
+ if (tlsConfig.getMtlsEnabled()) {
+ sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
+ GrpcUtil.setTrustManager(sslContextBuilder, tlsConfig.getTrustManager());
+ }
+ sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder, OPENSSL);
+ try {
+ nettyServerBuilder.sslContext(sslContextBuilder.build());
+ } catch (Exception ex) {
+ throw new IllegalArgumentException("Failed to build SslContext, tlsConfig=" + tlsConfig, ex);
+ }
+ }
+ return nettyServerBuilder;
+ }
+
+ @Override
+ public SupportedRpcType getRpcType() {
+ return SupportedRpcType.GRPC;
+ }
+
+ @Override
+ public void startImpl() {
+ for (Server server : servers.values()) {
+ try {
+ server.start();
+ } catch (IOException e) {
+ ExitUtils.terminate(1, "Failed to start Grpc server", e, LOG);
+ }
+ LOG.info("{}: {} started, listening on {}",
+ getId(), JavaUtils.getClassSimpleName(getClass()), server.getPort());
+ }
+ }
+
+ @Override
+ public void closeImpl() throws IOException {
+ for (Map.Entry server : servers.entrySet()) {
+ final String name = getId() + ": shutdown server " + server.getKey();
+ LOG.info("{} now", name);
+ final Server s = server.getValue().shutdownNow();
+ super.closeImpl();
+ try {
+ s.awaitTermination();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw IOUtils.toInterruptedIOException(name + " failed", e);
+ }
+ LOG.info("{} successfully", name);
+ }
+
+ serverInterceptor.close();
+ ConcurrentUtils.shutdownAndWait(executor);
+ zeroCopyMetrics.unregister();
+ }
+
+ @Override
+ public void notifyNotLeader(RaftGroupId groupId) {
+ clientProtocolService.closeAllOrderedRequestStreamObservers(groupId);
+ }
+
+ @Override
+ public InetSocketAddress getInetSocketAddress() {
+ return addressSupplier.get();
+ }
+
+ @Override
+ public InetSocketAddress getClientServerAddress() {
+ return clientServerAddressSupplier.get();
+ }
+
+ @Override
+ public InetSocketAddress getAdminServerAddress() {
+ return adminServerAddressSupplier.get();
+ }
+
+ @Override
+ public RaftServerAsynchronousProtocol async() {
+ return asyncService;
+ }
+
+ @Override
+ public AppendEntriesReplyProto appendEntries(AppendEntriesRequestProto request) {
+ throw new UnsupportedOperationException(
+ "Blocking " + JavaUtils.getCurrentStackTraceElement().getMethodName() + " call is not supported");
+ }
+
+ @Override
+ public InstallSnapshotReplyProto installSnapshot(InstallSnapshotRequestProto request) {
+ throw new UnsupportedOperationException(
+ "Blocking " + JavaUtils.getCurrentStackTraceElement().getMethodName() + " call is not supported");
+ }
+
+ @Override
+ public RequestVoteReplyProto requestVote(RequestVoteRequestProto request)
+ throws IOException {
+ CodeInjectionForTesting.execute(GRPC_SEND_SERVER_REQUEST, getId(),
+ null, request);
+
+ final RaftPeerId target = RaftPeerId.valueOf(request.getServerRequest().getReplyId());
+ return getProxies().getProxy(target).requestVote(request);
+ }
+
+ @Override
+ public StartLeaderElectionReplyProto startLeaderElection(StartLeaderElectionRequestProto request) throws IOException {
+ CodeInjectionForTesting.execute(GRPC_SEND_SERVER_REQUEST, getId(), null, request);
+
+ final RaftPeerId target = RaftPeerId.valueOf(request.getServerRequest().getReplyId());
+ return getProxies().getProxy(target).startLeaderElection(request);
+ }
+
+ @VisibleForTesting
+ public ZeroCopyMetrics getZeroCopyMetrics() {
+ return zeroCopyMetrics;
+ }
+}
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
index d554ca583a..d6f6a0c866 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java
@@ -19,7 +19,10 @@
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.grpc.GrpcConfigKeys;
+import org.apache.ratis.grpc.GrpcTlsConfig;
+import org.apache.ratis.grpc.GrpcUtil;
import org.apache.ratis.grpc.metrics.MessageMetrics;
+import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
import org.apache.ratis.grpc.metrics.intercept.server.MetricServerInterceptor;
import org.apache.ratis.protocol.AdminAsynchronousProtocol;
import org.apache.ratis.protocol.RaftGroupId;
@@ -30,13 +33,17 @@
import org.apache.ratis.server.RaftServerConfigKeys;
import org.apache.ratis.server.RaftServerRpcWithProxy;
import org.apache.ratis.server.protocol.RaftServerAsynchronousProtocol;
+import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting;
import org.apache.ratis.thirdparty.io.grpc.ServerInterceptor;
import org.apache.ratis.thirdparty.io.grpc.ServerInterceptors;
+import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
+import org.apache.ratis.thirdparty.io.grpc.netty.GrpcSslContexts;
import org.apache.ratis.thirdparty.io.grpc.netty.NettyServerBuilder;
import org.apache.ratis.thirdparty.io.grpc.Server;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelOption;
-import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.ClientAuth;
+import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
import org.apache.ratis.proto.RaftProtos.*;
import org.apache.ratis.util.*;
@@ -52,6 +59,8 @@
import java.util.concurrent.ExecutorService;
import java.util.function.Supplier;
+import static org.apache.ratis.thirdparty.io.netty.handler.ssl.SslProvider.OPENSSL;
+
/** A grpc implementation of {@link org.apache.ratis.server.RaftServerRpc}. */
public final class GrpcServicesImpl
extends RaftServerRpcWithProxy>
@@ -100,20 +109,19 @@ public static final class Builder {
private String adminHost;
private int adminPort;
- private SslContext adminSslContext;
+ private GrpcTlsConfig adminTlsConfig;
private String clientHost;
private int clientPort;
- private SslContext clientSslContext;
+ private GrpcTlsConfig clientTlsConfig;
private String serverHost;
private int serverPort;
- private SslContext serverSslContextForServer;
- private SslContext serverSslContextForClient;
- private int serverStubPoolSize;
+ private GrpcTlsConfig serverTlsConfig;
private SizeInBytes messageSizeMax;
private SizeInBytes flowControlWindow;
private TimeDuration requestTimeoutDuration;
private boolean separateHeartbeatChannel;
+ private boolean zeroCopyEnabled;
private Builder() {}
@@ -131,7 +139,7 @@ public Builder setServer(RaftServer raftServer) {
this.flowControlWindow = GrpcConfigKeys.flowControlWindow(properties, LOG::info);
this.requestTimeoutDuration = RaftServerConfigKeys.Rpc.requestTimeout(properties);
this.separateHeartbeatChannel = GrpcConfigKeys.Server.heartbeatChannel(properties);
- this.serverStubPoolSize = GrpcConfigKeys.Server.stubPoolSize(properties);
+ this.zeroCopyEnabled = GrpcConfigKeys.Server.zeroCopyEnabled(properties);
final SizeInBytes appenderBufferSize = RaftServerConfigKeys.Log.Appender.bufferByteLimit(properties);
final SizeInBytes gap = SizeInBytes.ONE_MB;
@@ -152,8 +160,8 @@ public Builder setCustomizer(Customizer customizer) {
}
private GrpcServerProtocolClient newGrpcServerProtocolClient(RaftPeer target) {
- return new GrpcServerProtocolClient(target, serverStubPoolSize, flowControlWindow.getSizeInt(),
- requestTimeoutDuration, serverSslContextForClient, separateHeartbeatChannel);
+ return new GrpcServerProtocolClient(target, flowControlWindow.getSizeInt(),
+ requestTimeoutDuration, serverTlsConfig, separateHeartbeatChannel);
}
private ExecutorService newExecutor() {
@@ -165,12 +173,12 @@ private ExecutorService newExecutor() {
}
private GrpcClientProtocolService newGrpcClientProtocolService(
- ExecutorService executor) {
- return new GrpcClientProtocolService(server::getId, server, executor);
+ ExecutorService executor, ZeroCopyMetrics zeroCopyMetrics) {
+ return new GrpcClientProtocolService(server::getId, server, executor, zeroCopyEnabled, zeroCopyMetrics);
}
- private GrpcServerProtocolService newGrpcServerProtocolService() {
- return new GrpcServerProtocolService(server::getId, server);
+ private GrpcServerProtocolService newGrpcServerProtocolService(ZeroCopyMetrics zeroCopyMetrics) {
+ return new GrpcServerProtocolService(server::getId, server, zeroCopyEnabled, zeroCopyMetrics);
}
private MetricServerInterceptor newMetricServerInterceptor() {
@@ -183,18 +191,18 @@ Server buildServer(NettyServerBuilder builder, EnumSet types)
}
private NettyServerBuilder newNettyServerBuilderForServer() {
- return newNettyServerBuilder(serverHost, serverPort, serverSslContextForServer);
+ return newNettyServerBuilder(serverHost, serverPort, serverTlsConfig);
}
private NettyServerBuilder newNettyServerBuilderForAdmin() {
- return newNettyServerBuilder(adminHost, adminPort, adminSslContext);
+ return newNettyServerBuilder(adminHost, adminPort, adminTlsConfig);
}
private NettyServerBuilder newNettyServerBuilderForClient() {
- return newNettyServerBuilder(clientHost, clientPort, clientSslContext);
+ return newNettyServerBuilder(clientHost, clientPort, clientTlsConfig);
}
- private NettyServerBuilder newNettyServerBuilder(String hostname, int port, SslContext sslContext) {
+ private NettyServerBuilder newNettyServerBuilder(String hostname, int port, GrpcTlsConfig tlsConfig) {
final InetSocketAddress address = hostname == null || hostname.isEmpty() ?
new InetSocketAddress(port) : new InetSocketAddress(hostname, port);
final NettyServerBuilder nettyServerBuilder = NettyServerBuilder.forAddress(address)
@@ -202,9 +210,19 @@ private NettyServerBuilder newNettyServerBuilder(String hostname, int port, SslC
.maxInboundMessageSize(messageSizeMax.getSizeInt())
.flowControlWindow(flowControlWindow.getSizeInt());
- if (sslContext != null) {
+ if (tlsConfig != null) {
LOG.info("Setting TLS for {}", address);
- nettyServerBuilder.sslContext(sslContext);
+ SslContextBuilder sslContextBuilder = GrpcUtil.initSslContextBuilderForServer(tlsConfig.getKeyManager());
+ if (tlsConfig.getMtlsEnabled()) {
+ sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
+ GrpcUtil.setTrustManager(sslContextBuilder, tlsConfig.getTrustManager());
+ }
+ sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder, OPENSSL);
+ try {
+ nettyServerBuilder.sslContext(sslContextBuilder.build());
+ } catch (Exception ex) {
+ throw new IllegalArgumentException("Failed to build SslContext, tlsConfig=" + tlsConfig, ex);
+ }
}
return nettyServerBuilder;
}
@@ -217,10 +235,10 @@ private boolean separateClientServer() {
return clientPort > 0 && clientPort != serverPort;
}
- Server newServer(GrpcClientProtocolService client, ServerInterceptor interceptor) {
+ Server newServer(GrpcClientProtocolService client, ZeroCopyMetrics zeroCopyMetrics, ServerInterceptor interceptor) {
final EnumSet types = EnumSet.of(GrpcServices.Type.SERVER);
final NettyServerBuilder serverBuilder = newNettyServerBuilderForServer();
- final GrpcServerProtocolService service = newGrpcServerProtocolService();
+ final ServerServiceDefinition service = newGrpcServerProtocolService(zeroCopyMetrics).bindServiceWithZeroCopy();
serverBuilder.addService(ServerInterceptors.intercept(service, interceptor));
if (!separateAdminServer()) {
@@ -238,23 +256,18 @@ public GrpcServicesImpl build() {
return new GrpcServicesImpl(this);
}
- public Builder setAdminSslContext(SslContext adminSslContext) {
- this.adminSslContext = adminSslContext;
- return this;
- }
-
- public Builder setClientSslContext(SslContext clientSslContext) {
- this.clientSslContext = clientSslContext;
+ public Builder setAdminTlsConfig(GrpcTlsConfig config) {
+ this.adminTlsConfig = config;
return this;
}
- public Builder setServerSslContextForServer(SslContext serverSslContextForServer) {
- this.serverSslContextForServer = serverSslContextForServer;
+ public Builder setClientTlsConfig(GrpcTlsConfig config) {
+ this.clientTlsConfig = config;
return this;
}
- public Builder setServerSslContextForClient(SslContext serverSslContextForClient) {
- this.serverSslContextForClient = serverSslContextForClient;
+ public Builder setServerTlsConfig(GrpcTlsConfig config) {
+ this.serverTlsConfig = config;
return this;
}
}
@@ -274,14 +287,15 @@ public static Builder newBuilder() {
private final GrpcClientProtocolService clientProtocolService;
private final MetricServerInterceptor serverInterceptor;
+ private final ZeroCopyMetrics zeroCopyMetrics = new ZeroCopyMetrics();
private GrpcServicesImpl(Builder b) {
super(b.server::getId, id -> new PeerProxyMap<>(id.toString(), b::newGrpcServerProtocolClient));
this.executor = b.newExecutor();
- this.clientProtocolService = b.newGrpcClientProtocolService(executor);
+ this.clientProtocolService = b.newGrpcClientProtocolService(executor, zeroCopyMetrics);
this.serverInterceptor = b.newMetricServerInterceptor();
- final Server server = b.newServer(clientProtocolService, serverInterceptor);
+ final Server server = b.newServer(clientProtocolService, zeroCopyMetrics, serverInterceptor);
servers.put(GrpcServerProtocolService.class.getSimpleName(), server);
addressSupplier = newAddressSupplier(b.serverPort, server);
@@ -313,7 +327,8 @@ static MemoizedSupplier newAddressSupplier(int port, Server s
static void addClientService(NettyServerBuilder builder, GrpcClientProtocolService client,
ServerInterceptor interceptor) {
- builder.addService(ServerInterceptors.intercept(client, interceptor));
+ final ServerServiceDefinition service = client.bindServiceWithZeroCopy();
+ builder.addService(ServerInterceptors.intercept(service, interceptor));
}
static void addAdminService(NettyServerBuilder builder, AdminAsynchronousProtocol admin,
@@ -341,40 +356,24 @@ public void startImpl() {
}
@Override
- public void closeImpl() {
- for (Server server : servers.values()) {
- server.shutdownNow();
- }
- boolean interrupted = false;
+ public void closeImpl() throws IOException {
for (Map.Entry server : servers.entrySet()) {
+ final String name = getId() + ": shutdown server " + server.getKey();
+ LOG.info("{} now", name);
+ final Server s = server.getValue().shutdownNow();
+ super.closeImpl();
try {
- server.getValue().awaitTermination();
- LOG.info("{}: Shutdown {} successfully", getId(), server.getKey());
+ s.awaitTermination();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
- LOG.warn("{}: Interrupted shutdown {}", getId(), server.getKey());
- interrupted = true;
- break;
+ throw IOUtils.toInterruptedIOException(name + " failed", e);
}
+ LOG.info("{} successfully", name);
}
- try {
- serverInterceptor.close();
- } catch (Exception e) {
- LOG.warn("{}: Failed to unregister metrics", getId(), e);
- }
-
- if (interrupted) {
- executor.shutdown(); // shutdown but not wait
- } else {
- ConcurrentUtils.shutdownAndWait(executor);
- }
-
- try {
- super.closeImpl();
- } catch (IOException e) {
- LOG.warn("{}: Failed to close proxies", getId(), e);
- }
+ serverInterceptor.close();
+ ConcurrentUtils.shutdownAndWait(executor);
+ zeroCopyMetrics.unregister();
}
@Override
@@ -432,7 +431,13 @@ public StartLeaderElectionReplyProto startLeaderElection(StartLeaderElectionRequ
return getProxies().getProxy(target).startLeaderElection(request);
}
+ @VisibleForTesting
MessageMetrics getMessageMetrics() {
return serverInterceptor.getMetrics();
}
+
+ @VisibleForTesting
+ public ZeroCopyMetrics getZeroCopyMetrics() {
+ return zeroCopyMetrics;
+ }
}
diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
index 3fb3f067be..eddf2495e4 100644
--- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
+++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java
@@ -62,12 +62,14 @@ public class ZeroCopyMessageMarshaller implements Prototy
private final Consumer zeroCopyCount;
private final Consumer nonZeroCopyCount;
+ private final Consumer releasedCount;
public ZeroCopyMessageMarshaller(T defaultInstance) {
- this(defaultInstance, m -> {}, m -> {});
+ this(defaultInstance, m -> {}, m -> {}, m -> {});
}
- public ZeroCopyMessageMarshaller(T defaultInstance, Consumer zeroCopyCount, Consumer nonZeroCopyCount) {
+ public ZeroCopyMessageMarshaller(T defaultInstance, Consumer zeroCopyCount, Consumer nonZeroCopyCount,
+ Consumer releasedCount) {
this.name = JavaUtils.getClassSimpleName(defaultInstance.getClass()) + "-Marshaller";
@SuppressWarnings("unchecked")
final Parser p = (Parser) defaultInstance.getParserForType();
@@ -76,6 +78,7 @@ public ZeroCopyMessageMarshaller(T defaultInstance, Consumer zeroCopyCount, C
this.zeroCopyCount = zeroCopyCount;
this.nonZeroCopyCount = nonZeroCopyCount;
+ this.releasedCount = releasedCount;
}
@Override
@@ -124,6 +127,7 @@ public void release(T message) {
}
try {
stream.close();
+ releasedCount.accept(message);
} catch (IOException e) {
LOG.error(name + ": Failed to close stream.", e);
}
@@ -223,6 +227,10 @@ public InputStream popStream(T message) {
return unclosedStreams.remove(message);
}
+ public int getUnclosedCount() {
+ return unclosedStreams.size();
+ }
+
void assertNoUnclosedStreams() {
// Intended for tests/teardown to fail fast if callers forgot to release streams.
final int size = unclosedStreams.size();
diff --git a/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java b/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
index 0e4eb55544..bd1c72b241 100644
--- a/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
+++ b/ratis-grpc/src/test/java/org/apache/ratis/grpc/MiniRaftClusterWithGrpc.java
@@ -21,14 +21,19 @@
import org.apache.ratis.RaftTestUtil;
import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
import org.apache.ratis.grpc.server.GrpcServicesImpl;
import org.apache.ratis.protocol.RaftGroup;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.rpc.SupportedRpcType;
+import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.impl.DelayLocalExecutionInjection;
import org.apache.ratis.server.impl.MiniRaftCluster;
+import org.apache.ratis.server.impl.RaftServerTestUtil;
import org.apache.ratis.util.NetUtils;
+import org.apache.ratis.util.ReferenceCountedLeakDetector;
+import org.junit.jupiter.api.Assertions;
import java.util.Optional;
@@ -45,6 +50,11 @@ public MiniRaftClusterWithGrpc newCluster(String[] ids, String[] listenerIds, Ra
}
};
+ static {
+ // TODO move it to MiniRaftCluster for detecting non-gRPC cases
+ ReferenceCountedLeakDetector.enable(false);
+ }
+
public interface FactoryGet extends Factory.Get {
@Override
default Factory getFactory() {
@@ -54,7 +64,7 @@ default Factory getFactory() {
public static final DelayLocalExecutionInjection SEND_SERVER_REQUEST_INJECTION =
new DelayLocalExecutionInjection(GrpcServicesImpl.GRPC_SEND_SERVER_REQUEST);
-
+
public MiniRaftClusterWithGrpc(String[] ids, RaftProperties properties, Parameters parameters) {
this(ids, new String[0], properties, parameters);
}
@@ -70,6 +80,8 @@ protected Parameters setPropertiesAndInitParameters(RaftPeerId id, RaftGroup gro
GrpcConfigKeys.Client.setPort(properties, NetUtils.createSocketAddr(address).getPort()));
Optional.ofNullable(getAddress(id, group, RaftPeer::getAdminAddress)).ifPresent(address ->
GrpcConfigKeys.Admin.setPort(properties, NetUtils.createSocketAddr(address).getPort()));
+ // Always run grpc integration tests with zero-copy enabled because the path of nonzero-copy is not risky.
+ GrpcConfigKeys.Server.setZeroCopyEnabled(properties, true);
return parameters;
}
@@ -79,4 +91,22 @@ protected void blockQueueAndSetDelay(String leaderId, int delayMs)
RaftTestUtil.blockQueueAndSetDelay(getServers(), SEND_SERVER_REQUEST_INJECTION,
leaderId, delayMs, getTimeoutMax());
}
+
+ @Override
+ public void shutdown() {
+ super.shutdown();
+ assertZeroCopyMetrics();
+ }
+
+ public void assertZeroCopyMetrics() {
+ getServers().forEach(server -> server.getGroupIds().forEach(id -> {
+ LOG.info("Checking {}-{}", server.getId(), id);
+ RaftServer.Division division = RaftServerTestUtil.getDivision(server, id);
+ final GrpcServicesImpl service = (GrpcServicesImpl) RaftServerTestUtil.getServerRpc(division);
+ ZeroCopyMetrics zeroCopyMetrics = service.getZeroCopyMetrics();
+ Assertions.assertEquals(0, zeroCopyMetrics.nonZeroCopyMessages());
+ Assertions.assertEquals(zeroCopyMetrics.zeroCopyMessages(), zeroCopyMetrics.releasedMessages(),
+ "Unreleased zero copy messages: please check logs to find the leaks. ");
+ }));
+ }
}
diff --git a/ratis-server-api/src/main/java/org/apache/ratis/server/leader/LogAppender.java b/ratis-server-api/src/main/java/org/apache/ratis/server/leader/LogAppender.java
index 33914fde7f..dc189a14aa 100644
--- a/ratis-server-api/src/main/java/org/apache/ratis/server/leader/LogAppender.java
+++ b/ratis-server-api/src/main/java/org/apache/ratis/server/leader/LogAppender.java
@@ -125,7 +125,9 @@ default RaftPeerId getFollowerId() {
* @param heartbeat the returned request must be a heartbeat.
*
* @return a new {@link AppendEntriesRequestProto} object.
+ * @deprecated this is no longer a public API.
*/
+ @Deprecated
AppendEntriesRequestProto newAppendEntriesRequest(long callId, boolean heartbeat) throws RaftLogIOException;
/** @return a new {@link InstallSnapshotRequestProto} object. */
diff --git a/ratis-server-api/src/main/java/org/apache/ratis/server/protocol/RaftServerAsynchronousProtocol.java b/ratis-server-api/src/main/java/org/apache/ratis/server/protocol/RaftServerAsynchronousProtocol.java
index 8a904069ba..035e0a815f 100644
--- a/ratis-server-api/src/main/java/org/apache/ratis/server/protocol/RaftServerAsynchronousProtocol.java
+++ b/ratis-server-api/src/main/java/org/apache/ratis/server/protocol/RaftServerAsynchronousProtocol.java
@@ -22,14 +22,39 @@
import org.apache.ratis.proto.RaftProtos.ReadIndexReplyProto;
import org.apache.ratis.proto.RaftProtos.AppendEntriesReplyProto;
import org.apache.ratis.proto.RaftProtos.AppendEntriesRequestProto;
+import org.apache.ratis.util.ReferenceCountedObject;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
public interface RaftServerAsynchronousProtocol {
- CompletableFuture appendEntriesAsync(AppendEntriesRequestProto request)
- throws IOException;
+ /**
+ * It is recommended to override {@link #appendEntriesAsync(ReferenceCountedObject)} instead.
+ * Then, it does not have to override this method.
+ */
+ default CompletableFuture appendEntriesAsync(AppendEntriesRequestProto request)
+ throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * A referenced counted request is submitted from a client for processing.
+ * Implementations of this method should retain the request, process it and then release it.
+ * The request may be retained even after the future returned by this method has completed.
+ *
+ * @return a future of the reply
+ * @see ReferenceCountedObject
+ */
+ default CompletableFuture appendEntriesAsync(
+ ReferenceCountedObject requestRef) throws IOException {
+ // Default implementation for backward compatibility.
+ try {
+ return appendEntriesAsync(requestRef.retain());
+ } finally {
+ requestRef.release();
+ }
+ }
CompletableFuture readIndexAsync(ReadIndexRequestProto request)
throws IOException;
diff --git a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java
index e194f865ed..07446282e7 100644
--- a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java
+++ b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLog.java
@@ -21,6 +21,7 @@
import org.apache.ratis.server.metrics.RaftLogMetrics;
import org.apache.ratis.server.protocol.TermIndex;
import org.apache.ratis.server.storage.RaftStorageMetadata;
+import org.apache.ratis.util.ReferenceCountedObject;
import org.apache.ratis.util.TimeDuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -60,16 +61,44 @@ default boolean contains(TermIndex ti) {
/**
* @return null if the log entry is not found in this log;
- * otherwise, return the log entry corresponding to the given index.
+ * otherwise, return a copy of the log entry corresponding to the given index.
+ * @deprecated use {@link RaftLog#retainLog(long)} instead in order to avoid copying.
*/
+ @Deprecated
LogEntryProto get(long index) throws RaftLogIOException;
+ /**
+ * @return a retained {@link ReferenceCountedObject} to the log entry corresponding to the given index if it exists;
+ * otherwise, return null.
+ * Since the returned reference is retained, the caller must call {@link ReferenceCountedObject#release()}}
+ * after use.
+ */
+ default ReferenceCountedObject retainLog(long index) throws RaftLogIOException {
+ ReferenceCountedObject wrap = ReferenceCountedObject.wrap(get(index));
+ wrap.retain();
+ return wrap;
+ }
+
/**
* @return null if the log entry is not found in this log;
* otherwise, return the {@link EntryWithData} corresponding to the given index.
+ * @deprecated use {@link #retainEntryWithData(long)}.
*/
+ @Deprecated
EntryWithData getEntryWithData(long index) throws RaftLogIOException;
+ /**
+ * @return null if the log entry is not found in this log;
+ * otherwise, return a retained reference of the {@link EntryWithData} corresponding to the given index.
+ * Since the returned reference is retained, the caller must call {@link ReferenceCountedObject#release()}}
+ * after use.
+ */
+ default ReferenceCountedObject retainEntryWithData(long index) throws RaftLogIOException {
+ final ReferenceCountedObject wrap = ReferenceCountedObject.wrap(getEntryWithData(index));
+ wrap.retain();
+ return wrap;
+}
+
/**
* @param startIndex the starting log index (inclusive)
* @param endIndex the ending log index (exclusive)
@@ -160,6 +189,15 @@ default long getNextIndex() {
* containing both the log entry and the state machine data.
*/
interface EntryWithData {
+ /** @return the index of this entry. */
+ default long getIndex() {
+ try {
+ return getEntry(TimeDuration.ONE_MINUTE).getIndex();
+ } catch (Exception e) {
+ throw new IllegalStateException("Failed to getIndex", e);
+ }
+ }
+
/** @return the serialized size including both log entry and state machine data. */
int getSerializedSize();
diff --git a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLogSequentialOps.java b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLogSequentialOps.java
index 5e8bd6d784..5a25728830 100644
--- a/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLogSequentialOps.java
+++ b/ratis-server-api/src/main/java/org/apache/ratis/server/raftlog/RaftLogSequentialOps.java
@@ -22,8 +22,10 @@
import org.apache.ratis.server.RaftConfiguration;
import org.apache.ratis.statemachine.TransactionContext;
import org.apache.ratis.util.Preconditions;
+import org.apache.ratis.util.ReferenceCountedObject;
import org.apache.ratis.util.StringUtils;
import org.apache.ratis.util.function.CheckedSupplier;
+import org.apache.ratis.util.function.UncheckedAutoCloseableSupplier;
import java.util.Arrays;
import java.util.List;
@@ -124,31 +126,56 @@