Skip to content

Commit f6c7145

Browse files
committed
fix: preserve cors_origin backward compat, use runtime checks for protected headers
- Add cors_origin convenience parameter back to all 3 server constructors (merged into response_headers_ internally, response_headers takes priority) - Replace assert() with std::invalid_argument for protected header validation so it works in release builds - Remove unused <cassert> includes - Add tests for backward-compat cors_origin, response_headers map, protected header rejection, and cors_origin vs response_headers priority
1 parent afd5131 commit f6c7145

7 files changed

Lines changed: 154 additions & 28 deletions

File tree

include/fastmcpp/server/http_server.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ class HttpServerWrapper
2727
* @param port Port to listen on (default: 18080)
2828
* To bind to any random available port provided by the OS use port number 0.
2929
* @param auth_token Optional auth token for Bearer authentication (empty = no auth required)
30-
* @param response_headers Additional HTTP headers added to responses (e.g.
31-
* "Access-Control-Allow-Origin"...)
30+
* @param cors_origin Optional CORS origin (shorthand for Access-Control-Allow-Origin header)
31+
* @param response_headers Additional HTTP headers added to responses
3232
*/
3333
HttpServerWrapper(std::shared_ptr<Server> core, std::string host = "127.0.0.1",
3434
int port = 18080, std::string auth_token = "",
35+
std::string cors_origin = "",
3536
std::unordered_map<std::string, std::string> response_headers = {});
3637
~HttpServerWrapper();
3738

include/fastmcpp/server/sse_server.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ class SseServerWrapper
5454
* @param sse_path Path for SSE GET endpoint (default: "/sse")
5555
* @param message_path Path for POST message endpoint (default: "/messages")
5656
* @param auth_token Optional auth token for Bearer authentication (empty = no auth required)
57-
* @param response_headers Additional HTTP headers added to responses (e.g.
58-
* "Access-Control-Allow-Origin"...)
57+
* @param cors_origin Optional CORS origin (shorthand for Access-Control-Allow-Origin header)
58+
* @param response_headers Additional HTTP headers added to responses
5959
*/
6060
explicit SseServerWrapper(McpHandler handler, std::string host = "127.0.0.1", int port = 18080,
6161
std::string sse_path = "/sse", std::string message_path = "/messages",
62-
std::string auth_token = "",
62+
std::string auth_token = "", std::string cors_origin = "",
6363
std::unordered_map<std::string, std::string> response_headers = {});
6464

6565
~SseServerWrapper();

include/fastmcpp/server/streamable_http_server.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,13 @@ class StreamableHttpServerWrapper
5252
* To bind to any random available port provided by the OS use port number 0.
5353
* @param mcp_path Path for the MCP POST endpoint (default: "/mcp")
5454
* @param auth_token Optional auth token for Bearer authentication (empty = no auth required)
55-
* @param response_headers Additional HTTP headers added to responses (e.g.
56-
* "Access-Control-Allow-Origin"...)
55+
* @param cors_origin Optional CORS origin (shorthand for Access-Control-Allow-Origin header)
56+
* @param response_headers Additional HTTP headers added to responses
5757
*/
5858
explicit StreamableHttpServerWrapper(
5959
McpHandler handler, std::string host = "127.0.0.1", int port = 18080,
6060
std::string mcp_path = "/mcp", std::string auth_token = "",
61+
std::string cors_origin = "",
6162
std::unordered_map<std::string, std::string> response_headers = {});
6263

6364
~StreamableHttpServerWrapper();

src/server/http_server.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ namespace fastmcpp::server
99
{
1010

1111
HttpServerWrapper::HttpServerWrapper(std::shared_ptr<Server> core, std::string host, int port,
12-
std::string auth_token,
12+
std::string auth_token, std::string cors_origin,
1313
std::unordered_map<std::string, std::string> response_headers)
1414
: core_(std::move(core)), host_(std::move(host)), requested_port_(port),
1515
auth_token_(std::move(auth_token)), response_headers_(std::move(response_headers))
1616
{
17+
if (!cors_origin.empty() &&
18+
response_headers_.find("Access-Control-Allow-Origin") == response_headers_.end())
19+
response_headers_["Access-Control-Allow-Origin"] = std::move(cors_origin);
1720
}
1821

1922
HttpServerWrapper::~HttpServerWrapper()

src/server/sse_server.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "fastmcpp/util/json.hpp"
55

66
#include <algorithm>
7-
#include <cassert>
87
#include <cctype>
98
#include <chrono>
109
#include <ctime>
@@ -80,26 +79,25 @@ std::optional<TaskNotificationInfo> extract_task_notification_info(const fastmcp
8079

8180
SseServerWrapper::SseServerWrapper(McpHandler handler, std::string host, int port,
8281
std::string sse_path, std::string message_path,
83-
std::string auth_token,
82+
std::string auth_token, std::string cors_origin,
8483
std::unordered_map<std::string, std::string> response_headers)
8584
: handler_(std::move(handler)), host_(std::move(host)), requested_port_(port),
8685
sse_path_(std::move(sse_path)), message_path_(std::move(message_path)),
8786
auth_token_(std::move(auth_token)), response_headers_(std::move(response_headers))
8887
{
89-
assert(port >= 0 && "'port' is expected to be non-negative.");
88+
if (!cors_origin.empty() &&
89+
response_headers_.find("Access-Control-Allow-Origin") == response_headers_.end())
90+
response_headers_["Access-Control-Allow-Origin"] = std::move(cors_origin);
9091

9192
for (const auto& [name, value] : response_headers_)
9293
{
9394
std::string lower_name = name;
9495
std::transform(lower_name.begin(), lower_name.end(), lower_name.begin(),
9596
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
9697

97-
assert(lower_name != "content-type" &&
98-
"'response_headers' must not override SSE Content-Type.");
99-
assert(lower_name != "connection" &&
100-
"'response_headers' must not override SSE Connection.");
101-
assert(lower_name != "cache-control" &&
102-
"'response_headers' must not override SSE Cache-Control.");
98+
if (lower_name == "content-type" || lower_name == "connection" ||
99+
lower_name == "cache-control")
100+
throw std::invalid_argument("response_headers must not override '" + name + "'");
103101
}
104102
}
105103

src/server/streamable_http_server.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "fastmcpp/util/json.hpp"
55

66
#include <algorithm>
7-
#include <cassert>
87
#include <cctype>
98
#include <chrono>
109
#include <httplib.h>
@@ -19,21 +18,24 @@ namespace fastmcpp::server
1918
StreamableHttpServerWrapper::StreamableHttpServerWrapper(McpHandler handler, std::string host,
2019
int port, std::string mcp_path,
2120
std::string auth_token,
21+
std::string cors_origin,
2222
std::unordered_map<std::string, std::string> response_headers)
2323
: handler_(std::move(handler)), host_(std::move(host)), requested_port_(port),
2424
mcp_path_(std::move(mcp_path)), auth_token_(std::move(auth_token)),
2525
response_headers_(std::move(response_headers))
2626
{
27-
assert(port >= 0 && "'port' is expected to be non-negative.");
27+
if (!cors_origin.empty() &&
28+
response_headers_.find("Access-Control-Allow-Origin") == response_headers_.end())
29+
response_headers_["Access-Control-Allow-Origin"] = std::move(cors_origin);
2830

2931
for (const auto& [name, value] : response_headers_)
3032
{
3133
std::string lower_name = name;
3234
std::transform(lower_name.begin(), lower_name.end(), lower_name.begin(),
3335
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
3436

35-
assert(lower_name != "content-type" &&
36-
"'response_headers' must not override streamable HTTP Content-Type.");
37+
if (lower_name == "content-type")
38+
throw std::invalid_argument("response_headers must not override '" + name + "'");
3739
}
3840
}
3941

tests/server/auth_cors_security.cpp

Lines changed: 128 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,14 @@ int main()
151151
std::cout << " [PASS] HTTP server does not set CORS by default\n";
152152
}
153153

154-
// Test 5: HTTP server should set CORS header when explicitly configured
154+
// Test 5: HTTP server should set CORS header when using cors_origin convenience param
155155
{
156-
std::cout << "Test: HTTP server sets CORS header when configured...\n";
156+
std::cout << "Test: HTTP server sets CORS header via cors_origin param...\n";
157157

158158
auto srv = std::make_shared<Server>();
159159
srv->route("test", [](const Json&) { return Json{{"result", "ok"}}; });
160160

161-
HttpServerWrapper http_server(srv, "127.0.0.1", 18603, "",
162-
{{"Access-Control-Allow-Origin", "https://example.com"}});
161+
HttpServerWrapper http_server(srv, "127.0.0.1", 18603, "", "https://example.com");
163162
if (!http_server.start())
164163
{
165164
std::cerr << "Failed to start HTTP server\n";
@@ -231,8 +230,7 @@ int main()
231230
auto srv = std::make_shared<Server>();
232231
srv->route("test", [](const Json&) { return Json{{"result", "ok"}}; });
233232

234-
HttpServerWrapper http_server(srv, "127.0.0.1", 18605, "",
235-
{{"Access-Control-Allow-Origin", "https://example.com"}});
233+
HttpServerWrapper http_server(srv, "127.0.0.1", 18605, "", "https://example.com");
236234
if (!http_server.start())
237235
{
238236
std::cerr << "Failed to start HTTP server with CORS config\n";
@@ -284,7 +282,7 @@ int main()
284282
{ return Json{{"jsonrpc", "2.0"}, {"id", req["id"]}, {"result", {}}}; };
285283

286284
SseServerWrapper sse_server(handler, "127.0.0.1", 18605, "/sse", "/messages", "",
287-
{{"Access-Control-Allow-Origin", "https://example.com"}});
285+
"https://example.com");
288286
if (!sse_server.start())
289287
{
290288
std::cerr << "Failed to start SSE server with CORS config\n";
@@ -328,6 +326,129 @@ int main()
328326
std::cout << " [PASS] SSE message endpoint handles CORS preflight\n";
329327
}
330328

329+
// Test 9: HTTP server CORS via response_headers map (new API)
330+
{
331+
std::cout << "Test: HTTP server sets CORS header via response_headers map...\n";
332+
333+
auto srv = std::make_shared<Server>();
334+
srv->route("test", [](const Json&) { return Json{{"result", "ok"}}; });
335+
336+
HttpServerWrapper http_server(srv, "127.0.0.1", 18606, "", "",
337+
{{"Access-Control-Allow-Origin", "*"},
338+
{"X-Custom-Header", "custom-value"}});
339+
if (!http_server.start())
340+
{
341+
std::cerr << "Failed to start HTTP server\n";
342+
return 1;
343+
}
344+
345+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
346+
347+
httplib::Client client("127.0.0.1", 18606);
348+
auto res = client.Post("/test", R"({"x":1})", "application/json");
349+
350+
if (!res || res->status != 200)
351+
{
352+
std::cerr << " [FAIL] Request failed\n";
353+
http_server.stop();
354+
return 1;
355+
}
356+
357+
auto cors_it = res->headers.find("Access-Control-Allow-Origin");
358+
if (cors_it == res->headers.end() || cors_it->second != "*")
359+
{
360+
std::cerr << " [FAIL] CORS header missing or incorrect\n";
361+
http_server.stop();
362+
return 1;
363+
}
364+
365+
auto custom_it = res->headers.find("X-Custom-Header");
366+
if (custom_it == res->headers.end() || custom_it->second != "custom-value")
367+
{
368+
std::cerr << " [FAIL] Custom header missing or incorrect\n";
369+
http_server.stop();
370+
return 1;
371+
}
372+
373+
http_server.stop();
374+
std::cout << " [PASS] HTTP server sets headers via response_headers map\n";
375+
}
376+
377+
// Test 10: SSE server rejects protected headers (Content-Type)
378+
{
379+
std::cout << "Test: SSE server rejects protected Content-Type header...\n";
380+
381+
auto handler = [](const Json& req) -> Json
382+
{ return Json{{"jsonrpc", "2.0"}, {"id", req["id"]}, {"result", {}}}; };
383+
384+
bool threw = false;
385+
try
386+
{
387+
SseServerWrapper sse_server(handler, "127.0.0.1", 18607, "/sse", "/messages", "", "",
388+
{{"Content-Type", "text/plain"}});
389+
}
390+
catch (const std::invalid_argument& e)
391+
{
392+
threw = true;
393+
std::string msg = e.what();
394+
if (msg.find("Content-Type") == std::string::npos)
395+
{
396+
std::cerr << " [FAIL] Exception message should mention Content-Type: " << msg
397+
<< "\n";
398+
return 1;
399+
}
400+
}
401+
402+
if (!threw)
403+
{
404+
std::cerr << " [FAIL] Should have thrown std::invalid_argument\n";
405+
return 1;
406+
}
407+
408+
std::cout << " [PASS] SSE server rejects protected Content-Type header\n";
409+
}
410+
411+
// Test 11: cors_origin is overridden when response_headers also sets it
412+
{
413+
std::cout << "Test: response_headers takes priority over cors_origin...\n";
414+
415+
auto srv = std::make_shared<Server>();
416+
srv->route("test", [](const Json&) { return Json{{"result", "ok"}}; });
417+
418+
// cors_origin="*" but response_headers overrides with specific origin
419+
HttpServerWrapper http_server(srv, "127.0.0.1", 18608, "", "*",
420+
{{"Access-Control-Allow-Origin", "https://specific.com"}});
421+
if (!http_server.start())
422+
{
423+
std::cerr << "Failed to start HTTP server\n";
424+
return 1;
425+
}
426+
427+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
428+
429+
httplib::Client client("127.0.0.1", 18608);
430+
auto res = client.Post("/test", R"({"x":1})", "application/json");
431+
432+
if (!res || res->status != 200)
433+
{
434+
std::cerr << " [FAIL] Request failed\n";
435+
http_server.stop();
436+
return 1;
437+
}
438+
439+
auto cors_it = res->headers.find("Access-Control-Allow-Origin");
440+
if (cors_it == res->headers.end() || cors_it->second != "https://specific.com")
441+
{
442+
std::cerr << " [FAIL] response_headers should take priority, got: "
443+
<< (cors_it != res->headers.end() ? cors_it->second : "missing") << "\n";
444+
http_server.stop();
445+
return 1;
446+
}
447+
448+
http_server.stop();
449+
std::cout << " [PASS] response_headers takes priority over cors_origin\n";
450+
}
451+
331452
std::cout << "\n[OK] All HTTP/SSE auth and CORS security tests passed!\n";
332453
return 0;
333454
}

0 commit comments

Comments
 (0)