diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 93fcc332a..a93d0981a 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -38,6 +38,7 @@ import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -317,9 +318,17 @@ public class McpAsyncClient { }; this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, transport.protocolVersions(), - initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, - notificationHandlers, con -> con.contextWrite(ctx)), - postInitializationHook); + initializationTimeout, ctx -> { + Function, ? extends Publisher> connectHook; + if (features.connectHook() != null) { + connectHook = con -> features.connectHook().apply(con.contextWrite(ctx)); + } + else { + connectHook = con -> con.contextWrite(ctx); + } + return new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers, + connectHook); + }, postInitializationHook); this.transport.setExceptionHandler(this.initializer::handleException); } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java index 12f34e60a..1300a111b 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -18,6 +18,7 @@ import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.util.Assert; +import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import java.time.Duration; @@ -195,6 +196,8 @@ class SyncSpec { private boolean enableCallToolSchemaCaching = false; // Default to false + private Function, ? extends Publisher> connectHook; + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -479,6 +482,17 @@ public SyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) return this; } + /** + * Allows to add a reactive hook to the connection lifecycle. This hook can be + * used to intercept connection events, add retry logic, or handle errors. + * @param connectHook the connection hook. + * @return this builder instance for method chaining + */ + public SyncSpec connectHook(Function, ? extends Publisher> connectHook) { + this.connectHook = connectHook; + return this; + } + /** * Create an instance of {@link McpSyncClient} with the provided configurations or * sensible defaults. @@ -488,7 +502,7 @@ public McpSyncClient build() { McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler, - this.elicitationHandler, this.enableCallToolSchemaCaching); + this.elicitationHandler, this.enableCallToolSchemaCaching, this.connectHook); McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); @@ -549,6 +563,8 @@ class AsyncSpec { private boolean enableCallToolSchemaCaching = false; // Default to false + private Function, ? extends Publisher> connectHook; + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -820,6 +836,17 @@ public AsyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching return this; } + /** + * Allows to add a reactive hook to the connection lifecycle. This hook can be + * used to intercept connection events, add retry logic, or handle errors. + * @param connectHook the connection hook. + * @return this builder instance for method chaining + */ + public AsyncSpec connectHook(Function, ? extends Publisher> connectHook) { + this.connectHook = connectHook; + return this; + } + /** * Create an instance of {@link McpAsyncClient} with the provided configurations * or sensible defaults. @@ -833,7 +860,8 @@ public McpAsyncClient build() { new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, - this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching)); + this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching, + this.connectHook)); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 127d53337..dcd49e0fd 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -15,6 +15,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -57,12 +58,14 @@ class McpClientFeatures { * @param roots the roots. * @param toolsChangeConsumers the tools change consumers. * @param resourcesChangeConsumers the resources change consumers. + * @param resourcesUpdateConsumers the resources update consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. * @param enableCallToolSchemaCaching whether to enable call tool schema caching. + * @param connectHook the connection hook. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, @@ -73,20 +76,23 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List>> progressConsumers, Function> samplingHandler, Function> elicitationHandler, - boolean enableCallToolSchemaCaching) { + boolean enableCallToolSchemaCaching, Function, ? extends Publisher> connectHook) { /** * Create an instance and validate the arguments. + * @param clientInfo the client implementation information. * @param clientCapabilities the client capabilities. * @param roots the roots. * @param toolsChangeConsumers the tools change consumers. * @param resourcesChangeConsumers the resources change consumers. + * @param resourcesUpdateConsumers the resources update consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. * @param enableCallToolSchemaCaching whether to enable call tool schema caching. + * @param connectHook the connection hook. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, @@ -98,7 +104,8 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List>> progressConsumers, Function> samplingHandler, Function> elicitationHandler, - boolean enableCallToolSchemaCaching) { + boolean enableCallToolSchemaCaching, + Function, ? extends Publisher> connectHook) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -118,6 +125,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + this.connectHook = connectHook; } /** @@ -134,7 +142,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c Function> elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler, false); + elicitationHandler, false, null); } /** @@ -193,7 +201,7 @@ public static Async fromSync(Sync syncSpec) { return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, progressConsumers, samplingHandler, elicitationHandler, - syncSpec.enableCallToolSchemaCaching); + syncSpec.enableCallToolSchemaCaching, syncSpec.connectHook()); } } @@ -206,12 +214,14 @@ public static Async fromSync(Sync syncSpec) { * @param roots the roots. * @param toolsChangeConsumers the tools change consumers. * @param resourcesChangeConsumers the resources change consumers. + * @param resourcesUpdateConsumers the resources update consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. * @param enableCallToolSchemaCaching whether to enable call tool schema caching. + * @param connectHook the connection hook. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, @@ -222,7 +232,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili List> progressConsumers, Function samplingHandler, Function elicitationHandler, - boolean enableCallToolSchemaCaching) { + boolean enableCallToolSchemaCaching, Function, ? extends Publisher> connectHook) { /** * Create an instance and validate the arguments. @@ -238,6 +248,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. * @param enableCallToolSchemaCaching whether to enable call tool schema caching. + * @param connectHook the connection hook. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, @@ -248,7 +259,8 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl List> progressConsumers, Function samplingHandler, Function elicitationHandler, - boolean enableCallToolSchemaCaching) { + boolean enableCallToolSchemaCaching, + Function, ? extends Publisher> connectHook) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -268,6 +280,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + this.connectHook = connectHook; } /** @@ -283,7 +296,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl Function elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler, false); + elicitationHandler, false, null); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 80b5ae246..f1392bbbd 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -119,7 +119,15 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); - this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe(); + this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe(v -> { + }, error -> { + logger.error("MCP session connection error", error); + this.pendingResponses.forEach((id, sink) -> { + logger.warn("Terminating exchange for request {} due to connection error", id); + sink.error(new RuntimeException("MCP session connection error", error)); + }); + this.pendingResponses.clear(); + }); } private void dismissPendingResponses() { diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/McpClientInitializationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpClientInitializationTests.java new file mode 100644 index 000000000..af5eaddb3 --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpClientInitializationTests.java @@ -0,0 +1,57 @@ +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; + +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +public class McpClientInitializationTests { + + private static final McpJsonMapper JSON_MAPPER = McpJsonDefaults.getMapper(); + + @Test + void reproduceInitializeErrorShouldNotBeDropped() { + ServerParameters stdioParams = ServerParameters.builder("non-existent-command").build(); + StdioClientTransport transport = new StdioClientTransport(stdioParams, JSON_MAPPER); + + McpSyncClient client = McpClient.sync(transport).requestTimeout(Duration.ofSeconds(2)).build(); + + assertThatExceptionOfType(RuntimeException.class).isThrownBy(client::initialize) + .withMessageContaining("Client failed to initialize") + .satisfies(ex -> { + assertThat(ex.getCause().getMessage()).contains("MCP session connection error"); + assertThat(ex.getCause().getCause().getMessage()).contains("Failed to start process"); + }); + } + + @Test + void verifyConnectHook() { + ServerParameters stdioParams = ServerParameters.builder("non-existent-command").build(); + StdioClientTransport transport = new StdioClientTransport(stdioParams, JSON_MAPPER); + + AtomicReference hookError = new AtomicReference<>(); + + McpSyncClient client = McpClient.sync(transport) + .requestTimeout(Duration.ofSeconds(2)) + .connectHook(mono -> mono.doOnError(hookError::set)) + .build(); + + try { + client.initialize(); + } + catch (Exception e) { + // ignore + } + + assertThat(hookError.get()).isNotNull(); + assertThat(hookError.get().getMessage()).contains("Failed to start process"); + } + +}