Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -317,9 +318,17 @@ public class McpAsyncClient {
};

this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, transport.protocolVersions(),
initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers,
notificationHandlers, con -> con.contextWrite(ctx)),
postInitializationHook);
initializationTimeout, ctx -> {
Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook;
if (features.connectHook() != null) {
connectHook = con -> features.connectHook().apply(con.contextWrite(ctx));
}
else {
connectHook = con -> con.contextWrite(ctx);
}
return new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers,
connectHook);
}, postInitializationHook);

this.transport.setExceptionHandler(this.initializer::handleException);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.spec.McpTransport;
import io.modelcontextprotocol.util.Assert;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;

import java.time.Duration;
Expand Down Expand Up @@ -195,6 +196,8 @@ class SyncSpec {

private boolean enableCallToolSchemaCaching = false; // Default to false

private Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook;

private SyncSpec(McpClientTransport transport) {
Assert.notNull(transport, "Transport must not be null");
this.transport = transport;
Expand Down Expand Up @@ -479,6 +482,17 @@ public SyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching)
return this;
}

/**
* Allows to add a reactive hook to the connection lifecycle. This hook can be
* used to intercept connection events, add retry logic, or handle errors.
* @param connectHook the connection hook.
* @return this builder instance for method chaining
*/
public SyncSpec connectHook(Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
this.connectHook = connectHook;
return this;
}

/**
* Create an instance of {@link McpSyncClient} with the provided configurations or
* sensible defaults.
Expand All @@ -488,7 +502,7 @@ public McpSyncClient build() {
McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities,
this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler,
this.elicitationHandler, this.enableCallToolSchemaCaching);
this.elicitationHandler, this.enableCallToolSchemaCaching, this.connectHook);

McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);

Expand Down Expand Up @@ -549,6 +563,8 @@ class AsyncSpec {

private boolean enableCallToolSchemaCaching = false; // Default to false

private Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook;

private AsyncSpec(McpClientTransport transport) {
Assert.notNull(transport, "Transport must not be null");
this.transport = transport;
Expand Down Expand Up @@ -820,6 +836,17 @@ public AsyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching
return this;
}

/**
* Allows to add a reactive hook to the connection lifecycle. This hook can be
* used to intercept connection events, add retry logic, or handle errors.
* @param connectHook the connection hook.
* @return this builder instance for method chaining
*/
public AsyncSpec connectHook(Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
this.connectHook = connectHook;
return this;
}

/**
* Create an instance of {@link McpAsyncClient} with the provided configurations
* or sensible defaults.
Expand All @@ -833,7 +860,8 @@ public McpAsyncClient build() {
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers,
this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching));
this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching,
this.connectHook));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

Expand Down Expand Up @@ -57,12 +58,14 @@ class McpClientFeatures {
* @param roots the roots.
* @param toolsChangeConsumers the tools change consumers.
* @param resourcesChangeConsumers the resources change consumers.
* @param resourcesUpdateConsumers the resources update consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
* @param connectHook the connection hook.
*/
record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
Map<String, McpSchema.Root> roots, List<Function<List<McpSchema.Tool>, Mono<Void>>> toolsChangeConsumers,
Expand All @@ -73,20 +76,23 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler,
boolean enableCallToolSchemaCaching) {
boolean enableCallToolSchemaCaching, Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {

/**
* Create an instance and validate the arguments.
* @param clientInfo the client implementation information.
* @param clientCapabilities the client capabilities.
* @param roots the roots.
* @param toolsChangeConsumers the tools change consumers.
* @param resourcesChangeConsumers the resources change consumers.
* @param resourcesUpdateConsumers the resources update consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
* @param connectHook the connection hook.
*/
public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
Map<String, McpSchema.Root> roots,
Expand All @@ -98,7 +104,8 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler,
boolean enableCallToolSchemaCaching) {
boolean enableCallToolSchemaCaching,
Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {

Assert.notNull(clientInfo, "Client info must not be null");
this.clientInfo = clientInfo;
Expand All @@ -118,6 +125,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
this.samplingHandler = samplingHandler;
this.elicitationHandler = elicitationHandler;
this.enableCallToolSchemaCaching = enableCallToolSchemaCaching;
this.connectHook = connectHook;
}

/**
Expand All @@ -134,7 +142,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler) {
this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers,
resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler,
elicitationHandler, false);
elicitationHandler, false, null);
}

/**
Expand Down Expand Up @@ -193,7 +201,7 @@ public static Async fromSync(Sync syncSpec) {
return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(),
toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers,
loggingConsumers, progressConsumers, samplingHandler, elicitationHandler,
syncSpec.enableCallToolSchemaCaching);
syncSpec.enableCallToolSchemaCaching, syncSpec.connectHook());
}
}

Expand All @@ -206,12 +214,14 @@ public static Async fromSync(Sync syncSpec) {
* @param roots the roots.
* @param toolsChangeConsumers the tools change consumers.
* @param resourcesChangeConsumers the resources change consumers.
* @param resourcesUpdateConsumers the resources update consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
* @param connectHook the connection hook.
*/
public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
Map<String, McpSchema.Root> roots, List<Consumer<List<McpSchema.Tool>>> toolsChangeConsumers,
Expand All @@ -222,7 +232,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler,
boolean enableCallToolSchemaCaching) {
boolean enableCallToolSchemaCaching, Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {

/**
* Create an instance and validate the arguments.
Expand All @@ -238,6 +248,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
* @param connectHook the connection hook.
*/
public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
Map<String, McpSchema.Root> roots, List<Consumer<List<McpSchema.Tool>>> toolsChangeConsumers,
Expand All @@ -248,7 +259,8 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler,
boolean enableCallToolSchemaCaching) {
boolean enableCallToolSchemaCaching,
Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {

Assert.notNull(clientInfo, "Client info must not be null");
this.clientInfo = clientInfo;
Expand All @@ -268,6 +280,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
this.samplingHandler = samplingHandler;
this.elicitationHandler = elicitationHandler;
this.enableCallToolSchemaCaching = enableCallToolSchemaCaching;
this.connectHook = connectHook;
}

/**
Expand All @@ -283,7 +296,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler) {
this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers,
resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler,
elicitationHandler, false);
elicitationHandler, false, null);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,15 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport,
this.requestHandlers.putAll(requestHandlers);
this.notificationHandlers.putAll(notificationHandlers);

this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe();
this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe(v -> {
}, error -> {
logger.error("MCP session connection error", error);
this.pendingResponses.forEach((id, sink) -> {
logger.warn("Terminating exchange for request {} due to connection error", id);
sink.error(new RuntimeException("MCP session connection error", error));
});
this.pendingResponses.clear();
});
}

private void dismissPendingResponses() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package io.modelcontextprotocol.client;

import java.time.Duration;
import java.util.concurrent.atomic.AtomicReference;

import io.modelcontextprotocol.client.transport.ServerParameters;
import io.modelcontextprotocol.client.transport.StdioClientTransport;
import io.modelcontextprotocol.json.McpJsonDefaults;
import io.modelcontextprotocol.json.McpJsonMapper;
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class McpClientInitializationTests {

private static final McpJsonMapper JSON_MAPPER = McpJsonDefaults.getMapper();

@Test
void reproduceInitializeErrorShouldNotBeDropped() {
ServerParameters stdioParams = ServerParameters.builder("non-existent-command").build();
StdioClientTransport transport = new StdioClientTransport(stdioParams, JSON_MAPPER);

McpSyncClient client = McpClient.sync(transport).requestTimeout(Duration.ofSeconds(2)).build();

assertThatExceptionOfType(RuntimeException.class).isThrownBy(client::initialize)
.withMessageContaining("Client failed to initialize")
.satisfies(ex -> {
assertThat(ex.getCause().getMessage()).contains("MCP session connection error");
assertThat(ex.getCause().getCause().getMessage()).contains("Failed to start process");
});
}

@Test
void verifyConnectHook() {
ServerParameters stdioParams = ServerParameters.builder("non-existent-command").build();
StdioClientTransport transport = new StdioClientTransport(stdioParams, JSON_MAPPER);

AtomicReference<Throwable> hookError = new AtomicReference<>();

McpSyncClient client = McpClient.sync(transport)
.requestTimeout(Duration.ofSeconds(2))
.connectHook(mono -> mono.doOnError(hookError::set))
.build();

try {
client.initialize();
}
catch (Exception e) {
// ignore
}

assertThat(hookError.get()).isNotNull();
assertThat(hookError.get().getMessage()).contains("Failed to start process");
}

}