diff --git a/docs/gateway-rpc-api.md b/docs/gateway-rpc-api.md index 6425bd9ed..c3013e963 100644 --- a/docs/gateway-rpc-api.md +++ b/docs/gateway-rpc-api.md @@ -155,7 +155,8 @@ type BindStreamParams struct { ```go type RunInputMedia struct { - URI string `json:"uri"` + URI string `json:"uri,omitempty"` + AssetID string `json:"asset_id,omitempty"` MimeType string `json:"mime_type"` FileName string `json:"file_name,omitempty"` } @@ -175,6 +176,12 @@ type RunParams struct { } ``` +- 多模态图片约束: + - `type=image` 时 `media.mime_type` 必填。 + - `media.uri` 与 `media.asset_id` 必须二选一,不能同时为空或同时提供。 + - `media.uri` 仅用于后端可读取的本地路径;Web 浏览器上传图片应先通过 `POST /api/session-assets` 保存,再在 `gateway.run` 中使用 `media.asset_id` 引用。 + - `asset_id` 必须属于当前 `session_id`,不存在或跨 session 引用会在 runtime 输入准备阶段失败。 + - Response Schema: - Success(受理即返回): @@ -223,6 +230,49 @@ type RunParams struct { --- +## HTTP API: session assets + +浏览器图片上传不应把本地伪路径传给 Runtime。Web 客户端需要在发送前先创建或确认 `session_id`,再通过受鉴权保护的 HTTP API 保存图片,最后在 `gateway.run.input_parts[].media.asset_id` 中引用。 + +### POST /api/session-assets + +- Auth Required: Yes(`Authorization: Bearer `) +- Headers: + - `X-NeoCode-Workspace-Hash`: 当前工作区哈希。多工作区 Web 客户端必须发送;单工作区或旧客户端可省略并回落到默认工作区。 +- Content-Type: `multipart/form-data` +- Fields: + - `session_id`: 目标会话 ID,必填。 + - `file`: 图片文件,必填。 +- Server-side validation: + - 仅接受 `image/png`、`image/jpeg`、`image/webp`。 + - MIME 以服务端文件头检测结果为准,不信任浏览器声明。 + - 空文件返回 `400`。 + - 超过 `MaxSessionAssetBytes` 返回 `413`。 + - 非图片或不支持类型返回 `415`。 + - 未认证返回 `401`,Origin/CORS 或 ACL 拒绝返回 `403`。 + - 工作区不存在返回 `404 workspace not found`;目标 session 不在该工作区返回 `404 session not found`。 +- Response: + +```json +{ + "session_id": "sess-1", + "asset_id": "asset-1", + "mime_type": "image/png", + "size": 1024 +} +``` + +### GET /api/session-assets/{session_id}/{asset_id} + +- Auth Required: Yes(`Authorization: Bearer `) +- Headers: + - `X-NeoCode-Workspace-Hash`: 当前工作区哈希。多工作区 Web 客户端必须发送;省略时回落到默认工作区。 +- 返回图片二进制,`Content-Type` 为保存时确认的 MIME。 +- 用于历史消息缩略图按需读取。 +- 工作区不存在返回 `404 workspace not found`;不存在或不可见的 asset 返回 `404 asset not found`。 + +--- + ## Method: gateway.compact - Stability: Stable diff --git a/docs/reference/gateway-error-catalog.md b/docs/reference/gateway-error-catalog.md index 1c3de61ea..c9a6712f1 100644 --- a/docs/reference/gateway-error-catalog.md +++ b/docs/reference/gateway-error-catalog.md @@ -10,7 +10,7 @@ | --- | --- | --- | --- | --- | --- | | `invalid_frame` | `200` | `-32700` / `-32600` / `-32602` | 请求帧结构或编码不合法。包括 JSON 解析失败、请求体包含多余 JSON 值、`id/jsonrpc` 非法、`params` 严格解码失败。 | 非法 JSON;`id` 为 `null`;`params` 含未知字段。 | 不要直接重试,先修复请求构造器。 | | `invalid_action` | `200` | `-32602` | 动作参数值非法,但方法本身存在。 | `params.channel` 不在 `all/ipc/ws/sse`;`params.decision` 非 `allow_once/allow_session/reject`。 | 视为调用方输入错误,修正参数后再发。 | -| `invalid_multimodal_payload` | `200` | `-32602` | `gateway.run` 的 `input_parts` 结构或字段不满足契约。 | `image` 分片缺少 `media.uri` 或 `media.mime_type`;`text` 分片文本为空。 | 校验输入分片后重试,不做盲重试。 | +| `invalid_multimodal_payload` | `200` | `-32602` | `gateway.run` 的 `input_parts` 结构或字段不满足契约。 | `image` 分片缺少 `media.mime_type`,或 `media.uri` / `media.asset_id` 未满足二选一;`text` 分片文本为空。 | 校验输入分片后重试,不做盲重试。 | | `missing_required_field` | `200` | `-32600` / `-32602` | 缺失必填字段。请求层字段缺失多映射为 `-32600`,方法参数层字段缺失多映射为 `-32602`。 | 缺失 `id`;缺失 `params`;`cancel` 缺失 `run_id`。 | 调整参数补齐必填项再重试。 | | `unsupported_action` | `200` | `-32601` | 方法未注册或不被网关识别。 | 调用不存在的方法名。 | 客户端按能力探测降级,或升级服务端版本。 | | `internal_error` | `200` | `-32603` | 网关内部异常或未分类下游异常。 | 结果编码失败;runtime port 不可用;未知运行时错误。 | 采用指数退避重试;持续失败时告警。 | diff --git a/docs/reference/gateway-rpc-api.md b/docs/reference/gateway-rpc-api.md index 0a9c8be45..82dad5784 100644 --- a/docs/reference/gateway-rpc-api.md +++ b/docs/reference/gateway-rpc-api.md @@ -306,6 +306,13 @@ type RunParams struct { Mode string `json:"mode,omitempty"` // Agent 工作模式:build|plan,可选,默认沿用 session 当前 mode } +type RunInputMedia struct { + URI string `json:"uri,omitempty"` + AssetID string `json:"asset_id,omitempty"` + MimeType string `json:"mime_type"` + FileName string `json:"file_name,omitempty"` +} + type RunInputPart struct { Type string `json:"type"` // text|image Text string `json:"text,omitempty"` // text MUST @@ -318,7 +325,7 @@ type RunInputPart struct { 1. `input_text` 与 `input_parts` 至少一项非空。 2. `input_parts` 中: 1. `type=text` 时 `text` `MUST` 非空。 -2. `type=image` 时 `media.uri` 与 `media.mime_type` `MUST` 非空。 +2. `type=image` 时 `media.mime_type` `MUST` 非空,`media.uri` 与 `media.asset_id` `MUST` 二选一且不能同时提供。Web 上传图片应先调用 `POST /api/session-assets`,再在 `gateway.run` 中用 `asset_id` 引用。 3. 未知字段会因严格解码触发 `invalid_frame`。 4. `run_id` 归一化顺序为:显式 `run_id` > `request_id` > 网关生成 `run_`。 5. `mode` 可选值为 `"build"` 或 `"plan"`,为空时默认沿用 session 当前 mode(新会话默认为 `"build"`)。切换 mode 后,后端会更新 session 并影响后续运行的工具可用性和 prompt 策略。 @@ -397,6 +404,37 @@ sequenceDiagram G-->>C: ack(cancel) ``` +### HTTP session asset API + +浏览器图片上传使用 HTTP API,不通过 JSON-RPC 传输文件内容。客户端发送图片前需要先拥有有效 `session_id`(新会话可先调用 `gateway.createSession`)。 + +`POST /api/session-assets` + +- Auth Required: `Yes`,使用 `Authorization: Bearer `。 +- Headers: `X-NeoCode-Workspace-Hash` 携带当前工作区哈希;多工作区 Web 客户端必须发送,省略时回落到默认工作区。 +- Content-Type: `multipart/form-data`。 +- 字段:`session_id`(必填)、`file`(必填)。 +- 仅接受 PNG/JPEG/WebP;服务端按文件头检测 MIME,不信任浏览器声明。 +- 空文件返回 `400`,超出 `MaxSessionAssetBytes` 返回 `413`,不支持 MIME 返回 `415`,未认证返回 `401`,Origin/CORS 或 ACL 拒绝返回 `403`。 +- 工作区不存在返回 `404 workspace not found`;目标 session 不在该工作区返回 `404 session not found`。 +- 成功返回: + +```json +{ + "session_id": "session-1", + "asset_id": "asset-1", + "mime_type": "image/png", + "size": 1024 +} +``` + +`GET /api/session-assets/{session_id}/{asset_id}` + +- Auth Required: `Yes`。 +- Headers: `X-NeoCode-Workspace-Hash` 携带当前工作区哈希;多工作区 Web 客户端必须发送。 +- 返回图片二进制,用于历史消息缩略图。 +- 工作区不存在返回 `404 workspace not found`;不存在、跨 session 或不可见的 asset 返回 `404 asset not found`。 + Observation: 1. `gateway_requests_total{method="gateway.run",status="ok|error"}`。 diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index c3f8b61f1..11404bdc7 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -697,6 +697,66 @@ func (b *gatewayRuntimePortBridge) CreateSession(ctx context.Context, input gate return strings.TrimSpace(session.ID), nil } +// SaveSessionAsset 将浏览器上传的附件保存到当前工作区的 session asset store。 +func (b *gatewayRuntimePortBridge) SaveSessionAsset( + ctx context.Context, + input gateway.SaveSessionAssetInput, +) (gateway.SessionAssetMeta, error) { + if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { + return gateway.SessionAssetMeta{}, err + } + sessionID := strings.TrimSpace(input.SessionID) + if sessionID == "" { + return gateway.SessionAssetMeta{}, gateway.ErrRuntimeResourceNotFound + } + assetStore, ok := b.sessionStore.(agentsession.AssetStore) + if !ok || assetStore == nil { + return gateway.SessionAssetMeta{}, fmt.Errorf("gateway runtime bridge: session asset store is unavailable") + } + meta, err := assetStore.SaveAsset(ctx, sessionID, input.Reader, strings.TrimSpace(input.MimeType)) + if err != nil { + return gateway.SessionAssetMeta{}, err + } + return gateway.SessionAssetMeta{ + SessionID: sessionID, + AssetID: strings.TrimSpace(meta.ID), + MimeType: strings.TrimSpace(meta.MimeType), + Size: meta.Size, + }, nil +} + +// OpenSessionAsset 打开当前工作区的会话附件,供 Gateway HTTP 读取端点流式返回。 +func (b *gatewayRuntimePortBridge) OpenSessionAsset( + ctx context.Context, + input gateway.OpenSessionAssetInput, +) (gateway.OpenSessionAssetResult, error) { + if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { + return gateway.OpenSessionAssetResult{}, err + } + sessionID := strings.TrimSpace(input.SessionID) + assetID := strings.TrimSpace(input.AssetID) + if sessionID == "" || assetID == "" { + return gateway.OpenSessionAssetResult{}, gateway.ErrRuntimeResourceNotFound + } + assetStore, ok := b.sessionStore.(agentsession.AssetStore) + if !ok || assetStore == nil { + return gateway.OpenSessionAssetResult{}, fmt.Errorf("gateway runtime bridge: session asset store is unavailable") + } + reader, meta, err := assetStore.Open(ctx, sessionID, assetID) + if err != nil { + return gateway.OpenSessionAssetResult{}, err + } + return gateway.OpenSessionAssetResult{ + Reader: reader, + Meta: gateway.SessionAssetMeta{ + SessionID: sessionID, + AssetID: strings.TrimSpace(meta.ID), + MimeType: strings.TrimSpace(meta.MimeType), + Size: meta.Size, + }, + }, nil +} + // DeleteSession 删除/归档指定会话。 func (b *gatewayRuntimePortBridge) DeleteSession(ctx context.Context, input gateway.DeleteSessionInput) (bool, error) { if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { @@ -1684,11 +1744,13 @@ func convertGatewayRunInput(input gateway.RunInput) agentruntime.PrepareInput { continue } path := strings.TrimSpace(part.Media.URI) - if path == "" { + assetID := strings.TrimSpace(part.Media.AssetID) + if path == "" && assetID == "" { continue } images = append(images, agentruntime.UserImageInput{ Path: path, + AssetID: assetID, MimeType: strings.TrimSpace(part.Media.MimeType), }) } @@ -1867,6 +1929,7 @@ func convertSessionMessages(messages []providertypes.Message) []gateway.SessionM convertedMessage := gateway.SessionMessage{ Role: strings.TrimSpace(message.Role), Content: renderSessionMessageContent(message.Parts), + Parts: convertProviderContentParts(message.Parts), ToolCallID: strings.TrimSpace(message.ToolCallID), IsError: message.IsError, } @@ -1885,6 +1948,52 @@ func convertSessionMessages(messages []providertypes.Message) []gateway.SessionM return converted } +// convertProviderContentParts 将 provider 通用内容分片转换为 Gateway 会话快照分片。 +func convertProviderContentParts(parts []providertypes.ContentPart) []gateway.InputPart { + if len(parts) == 0 { + return nil + } + converted := make([]gateway.InputPart, 0, len(parts)) + for _, part := range parts { + switch part.Kind { + case providertypes.ContentPartText: + if text := strings.TrimSpace(part.Text); text != "" { + converted = append(converted, gateway.InputPart{ + Type: gateway.InputPartTypeText, + Text: text, + }) + } + case providertypes.ContentPartImage: + if part.Image == nil { + continue + } + switch part.Image.SourceType { + case providertypes.ImageSourceSessionAsset: + if part.Image.Asset == nil || strings.TrimSpace(part.Image.Asset.ID) == "" { + continue + } + converted = append(converted, gateway.InputPart{ + Type: gateway.InputPartTypeImage, + Media: &gateway.Media{ + AssetID: strings.TrimSpace(part.Image.Asset.ID), + MimeType: strings.TrimSpace(part.Image.Asset.MimeType), + }, + }) + case providertypes.ImageSourceRemote: + if url := strings.TrimSpace(part.Image.URL); url != "" { + converted = append(converted, gateway.InputPart{ + Type: gateway.InputPartTypeImage, + Media: &gateway.Media{ + URI: url, + }, + }) + } + } + } + } + return converted +} + // convertRuntimePlanTodoItem 将 session 计划中的 legacy todo 项映射为 gateway 展示结构。 func convertRuntimePlanTodoItem(item agentsession.TodoItem) gateway.PlanTodoItem { required := false diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index 064fc332c..dcbd9f4bc 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -1,10 +1,12 @@ package cli import ( + "bytes" "context" "encoding/json" "errors" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -1550,6 +1552,7 @@ func TestConvertGatewayRunInputAndSessionHelpers(t *testing.T) { {Type: gateway.InputPartTypeImage, Media: nil}, {Type: gateway.InputPartTypeImage, Media: &gateway.Media{URI: " "}}, {Type: gateway.InputPartTypeImage, Media: &gateway.Media{URI: " /tmp/a.png ", MimeType: " image/png "}}, + {Type: gateway.InputPartTypeImage, Media: &gateway.Media{AssetID: " asset-1 ", MimeType: " image/webp "}}, }, Workdir: " /tmp/work ", }) @@ -1559,8 +1562,14 @@ func TestConvertGatewayRunInputAndSessionHelpers(t *testing.T) { if converted.Text != "base\ntext" { t.Fatalf("text = %q, want %q", converted.Text, "base\ntext") } - if len(converted.Images) != 1 || converted.Images[0].Path != "/tmp/a.png" { - t.Fatalf("images = %#v, want one valid image", converted.Images) + if len(converted.Images) != 2 { + t.Fatalf("images = %#v, want two valid images", converted.Images) + } + if converted.Images[0].Path != "/tmp/a.png" || converted.Images[0].MimeType != "image/png" { + t.Fatalf("local image = %#v, want normalized path/mime", converted.Images[0]) + } + if converted.Images[1].AssetID != "asset-1" || converted.Images[1].MimeType != "image/webp" { + t.Fatalf("asset image = %#v, want normalized asset_id/mime", converted.Images[1]) } if got := renderSessionMessageContent(nil); got != "" { @@ -1580,6 +1589,109 @@ func TestConvertGatewayRunInputAndSessionHelpers(t *testing.T) { } } +func TestGatewayRuntimePortBridgeSessionAssets(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := agentsession.NewSQLiteStore(t.TempDir(), workdir) + session := agentsession.NewWithWorkdir("asset session", workdir) + if _, err := store.CreateSession(context.Background(), agentsession.CreateSessionInput{ + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Head: session.HeadSnapshot(), + }); err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + bridge, err := newGatewayRuntimePortBridge( + context.Background(), + &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, + store, + ) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + defer bridge.Close() + + payload := []byte("image payload") + meta, err := bridge.SaveSessionAsset(context.Background(), gateway.SaveSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: " " + session.ID + " ", + Reader: bytes.NewReader(payload), + MimeType: " image/png ", + }) + if err != nil { + t.Fatalf("SaveSessionAsset() error = %v", err) + } + if meta.SessionID != session.ID || meta.AssetID == "" || meta.MimeType != "image/png" || meta.Size != int64(len(payload)) { + t.Fatalf("unexpected saved meta: %+v", meta) + } + + opened, err := bridge.OpenSessionAsset(context.Background(), gateway.OpenSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: session.ID, + AssetID: " " + meta.AssetID + " ", + }) + if err != nil { + t.Fatalf("OpenSessionAsset() error = %v", err) + } + defer opened.Reader.Close() + got, err := io.ReadAll(opened.Reader) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if string(got) != string(payload) || opened.Meta.AssetID != meta.AssetID || opened.Meta.MimeType != "image/png" { + t.Fatalf("unexpected opened asset meta=%+v payload=%q", opened.Meta, string(got)) + } +} + +func TestGatewayRuntimePortBridgeSessionAssetErrors(t *testing.T) { + t.Parallel() + + bridge, err := newGatewayRuntimePortBridge( + context.Background(), + &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, + testSessionStore, + ) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + defer bridge.Close() + + if _, err := bridge.SaveSessionAsset(context.Background(), gateway.SaveSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: " ", + Reader: strings.NewReader("x"), + MimeType: "image/png", + }); err == nil { + t.Fatal("expected empty session id save error") + } + if _, err := bridge.OpenSessionAsset(context.Background(), gateway.OpenSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + AssetID: " ", + }); err == nil { + t.Fatal("expected empty asset id open error") + } + if _, err := bridge.SaveSessionAsset(context.Background(), gateway.SaveSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + Reader: strings.NewReader("x"), + MimeType: "image/png", + }); err == nil || !strings.Contains(err.Error(), "asset store is unavailable") { + t.Fatalf("expected unavailable asset store save error, got %v", err) + } + if _, err := bridge.OpenSessionAsset(context.Background(), gateway.OpenSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + AssetID: "asset-1", + }); err == nil || !strings.Contains(err.Error(), "asset store is unavailable") { + t.Fatalf("expected unavailable asset store open error, got %v", err) + } +} + func TestConvertRuntimeSessionToGatewaySessionIncludesCurrentPlan(t *testing.T) { required := true session := agentsession.New("plan session") diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 08ca8bbdc..2f5ebb98c 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -1100,6 +1100,14 @@ func (stubRuntimePort) CreateSession(context.Context, gateway.CreateSessionInput return "", nil } +func (stubRuntimePort) SaveSessionAsset(context.Context, gateway.SaveSessionAssetInput) (gateway.SessionAssetMeta, error) { + return gateway.SessionAssetMeta{}, nil +} + +func (stubRuntimePort) OpenSessionAsset(context.Context, gateway.OpenSessionAssetInput) (gateway.OpenSessionAssetResult, error) { + return gateway.OpenSessionAssetResult{}, nil +} + func (stubRuntimePort) ListSessionTodos(context.Context, gateway.ListSessionTodosInput) (gateway.TodoSnapshot, error) { return gateway.TodoSnapshot{}, nil } diff --git a/internal/cli/web_command.go b/internal/cli/web_command.go index 3722da13d..94e914c78 100644 --- a/internal/cli/web_command.go +++ b/internal/cli/web_command.go @@ -25,6 +25,8 @@ var ( webCommandStartGatewayServer = startGatewayServer webCommandBuildFrontend = buildFrontend webCommandLookPath = exec.LookPath + openBrowserFn = openBrowser + userHomeDirFn = os.UserHomeDir webCommandEmbeddedAssets = func() (fs.FS, bool) { if !webassets.IsAvailable() { return nil, false @@ -327,7 +329,7 @@ func waitForGatewayAndOpenBrowser(ctx context.Context, address string, logger *l browserURL += "/?token=" + token } logger.Printf("gateway is ready, opening browser: %s", baseURL) - if openErr := openBrowser(browserURL); openErr != nil { + if openErr := openBrowserFn(browserURL); openErr != nil { logger.Printf("failed to open browser: %v (open %s manually)", openErr, browserURL) } return @@ -340,7 +342,7 @@ func waitForGatewayAndOpenBrowser(ctx context.Context, address string, logger *l // readGatewayToken 从 ~/.neocode/auth.json 读取认证 token。 func readGatewayToken() string { - homeDir, err := os.UserHomeDir() + homeDir, err := userHomeDirFn() if err != nil { return "" } diff --git a/internal/cli/web_command_test.go b/internal/cli/web_command_test.go index dd48daab4..3f228094b 100644 --- a/internal/cli/web_command_test.go +++ b/internal/cli/web_command_test.go @@ -423,6 +423,13 @@ func TestBuildFrontendAndReadGatewayToken(t *testing.T) { if err := os.WriteFile(filepath.Join(authDir, "auth.json"), authData, 0o644); err != nil { t.Fatalf("write auth.json: %v", err) } + originalUserHomeDir := userHomeDirFn + userHomeDirFn = func() (string, error) { + return homeDir, nil + } + t.Cleanup(func() { + userHomeDirFn = originalUserHomeDir + }) originalHome := os.Getenv("HOME") if err := os.Setenv("HOME", homeDir); err != nil { t.Fatalf("set HOME: %v", err) @@ -449,6 +456,13 @@ func TestWaitForGatewayAndOpenBrowserAndResolveListenAddress(t *testing.T) { if err := os.WriteFile(filepath.Join(authDir, "auth.json"), authData, 0o644); err != nil { t.Fatalf("write auth.json: %v", err) } + originalUserHomeDir := userHomeDirFn + userHomeDirFn = func() (string, error) { + return homeDir, nil + } + t.Cleanup(func() { + userHomeDirFn = originalUserHomeDir + }) originalHome := os.Getenv("HOME") if err := os.Setenv("HOME", homeDir); err != nil { t.Fatalf("set HOME: %v", err) @@ -474,6 +488,13 @@ func TestWaitForGatewayAndOpenBrowserAndResolveListenAddress(t *testing.T) { t.Cleanup(func() { _ = os.Setenv("PATH", originalPath) }) + originalOpenBrowser := openBrowserFn + openBrowserFn = func(url string) error { + return os.WriteFile(openLog, []byte(url), 0o644) + } + t.Cleanup(func() { + openBrowserFn = originalOpenBrowser + }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/healthz" { @@ -543,6 +564,13 @@ func TestResolveWebStaticDirCurrentWorkdirAndReadGatewayTokenInvalid(t *testing. if err := os.WriteFile(filepath.Join(authDir, "auth.json"), []byte("{invalid"), 0o644); err != nil { t.Fatalf("write invalid auth.json: %v", err) } + originalUserHomeDir := userHomeDirFn + userHomeDirFn = func() (string, error) { + return homeDir, nil + } + t.Cleanup(func() { + userHomeDirFn = originalUserHomeDir + }) originalHome := os.Getenv("HOME") if err := os.Setenv("HOME", homeDir); err != nil { t.Fatalf("set HOME: %v", err) diff --git a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go index 70377cc7e..7e54f27a5 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go @@ -189,6 +189,20 @@ func (s *urlschemeIntegrationRuntimeStub) CreateSession( return strings.TrimSpace("session-review-integration"), nil } +func (s *urlschemeIntegrationRuntimeStub) SaveSessionAsset( + context.Context, + gateway.SaveSessionAssetInput, +) (gateway.SessionAssetMeta, error) { + return gateway.SessionAssetMeta{}, nil +} + +func (s *urlschemeIntegrationRuntimeStub) OpenSessionAsset( + context.Context, + gateway.OpenSessionAssetInput, +) (gateway.OpenSessionAssetResult, error) { + return gateway.OpenSessionAssetResult{}, nil +} + func (s *urlschemeIntegrationRuntimeStub) ListSessionTodos( context.Context, gateway.ListSessionTodosInput, diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index 69a788d34..945201d43 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -311,6 +311,14 @@ func (s *bootstrapRuntimeStub) CreateSession(ctx context.Context, input CreateSe return strings.TrimSpace(input.SessionID), nil } +func (s *bootstrapRuntimeStub) SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{SessionID: input.SessionID, AssetID: "asset_test", MimeType: input.MimeType}, nil +} + +func (s *bootstrapRuntimeStub) OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} + func (s *bootstrapRuntimeStub) ListCheckpoints(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) { if s != nil && s.listCheckpointsFn != nil { return s.listCheckpointsFn(ctx, input) @@ -5335,6 +5343,12 @@ func (runtimeOnlyStub) GetRuntimeSnapshot(ctx context.Context, input GetRuntimeS func (runtimeOnlyStub) CreateSession(ctx context.Context, input CreateSessionInput) (string, error) { return "", nil } +func (runtimeOnlyStub) SaveSessionAsset(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} +func (runtimeOnlyStub) OpenSessionAsset(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} func (runtimeOnlyStub) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { return false, nil } diff --git a/internal/gateway/contracts.go b/internal/gateway/contracts.go index 18d62a61f..43eed5168 100644 --- a/internal/gateway/contracts.go +++ b/internal/gateway/contracts.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "io" "time" "neo-code/internal/config" @@ -227,6 +228,48 @@ type CreateSessionInput struct { SessionID string } +// SessionAssetMeta 描述 Gateway 可见的会话附件元数据。 +type SessionAssetMeta struct { + // SessionID 是附件所属会话标识。 + SessionID string `json:"session_id"` + // AssetID 是附件标识。 + AssetID string `json:"asset_id"` + // MimeType 是服务端确认后的 MIME 类型。 + MimeType string `json:"mime_type"` + // Size 是附件原始字节数。 + Size int64 `json:"size"` +} + +// SaveSessionAssetInput 表示保存浏览器上传附件的下游输入。 +type SaveSessionAssetInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是附件所属会话标识。 + SessionID string + // Reader 是附件二进制内容。 + Reader io.Reader + // MimeType 是服务端探测确认后的 MIME 类型。 + MimeType string +} + +// OpenSessionAssetInput 表示读取会话附件的下游输入。 +type OpenSessionAssetInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是附件所属会话标识。 + SessionID string + // AssetID 是附件标识。 + AssetID string +} + +// OpenSessionAssetResult 表示打开会话附件后的读取结果。 +type OpenSessionAssetResult struct { + // Reader 是附件内容流,调用方负责关闭。 + Reader io.ReadCloser + // Meta 是附件元数据。 + Meta SessionAssetMeta +} + // DeleteSessionInput 表示 gateway.deleteSession 动作的下游输入。 type DeleteSessionInput struct { // SubjectID 是请求方身份主体标识。 @@ -694,6 +737,8 @@ type SessionMessage struct { Role string `json:"role"` // Content 是消息内容。 Content string `json:"content"` + // Parts 是消息的结构化多模态分片,供支持图片的客户端渲染。 + Parts []InputPart `json:"parts,omitempty"` // ToolCalls 是 assistant 发起的工具调用元数据。 ToolCalls []ToolCall `json:"tool_calls,omitempty"` // ToolCallID 是工具消息关联的调用标识。 @@ -920,6 +965,10 @@ type RuntimePort interface { GetRuntimeSnapshot(ctx context.Context, input GetRuntimeSnapshotInput) (RuntimeSnapshot, error) // CreateSession 创建并返回可用会话标识。 CreateSession(ctx context.Context, input CreateSessionInput) (string, error) + // SaveSessionAsset 保存会话附件并返回元数据。 + SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) + // OpenSessionAsset 打开会话附件供 HTTP 读取接口返回。 + OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) // DeleteSession 删除/归档指定会话。 DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) // RenameSession 重命名指定会话。 diff --git a/internal/gateway/contracts_test.go b/internal/gateway/contracts_test.go index de1ef52cb..d67e4b57a 100644 --- a/internal/gateway/contracts_test.go +++ b/internal/gateway/contracts_test.go @@ -147,6 +147,14 @@ func (s *runtimePortCompileStub) CreateSession(_ context.Context, _ CreateSessio return "", nil } +func (s *runtimePortCompileStub) SaveSessionAsset(_ context.Context, _ SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} + +func (s *runtimePortCompileStub) OpenSessionAsset(_ context.Context, _ OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} + func (s *runtimePortCompileStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { return nil, nil } diff --git a/internal/gateway/multi_workspace_runtime.go b/internal/gateway/multi_workspace_runtime.go index 33c3bf52b..a30b94cf1 100644 --- a/internal/gateway/multi_workspace_runtime.go +++ b/internal/gateway/multi_workspace_runtime.go @@ -402,6 +402,22 @@ func (m *MultiWorkspaceRuntime) CreateSession(ctx context.Context, input CreateS return port.CreateSession(ctx, input) } +func (m *MultiWorkspaceRuntime) SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + port, err := m.getPort(ctx) + if err != nil { + return SessionAssetMeta{}, err + } + return port.SaveSessionAsset(ctx, input) +} + +func (m *MultiWorkspaceRuntime) OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + port, err := m.getPort(ctx) + if err != nil { + return OpenSessionAssetResult{}, err + } + return port.OpenSessionAsset(ctx, input) +} + func (m *MultiWorkspaceRuntime) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { port, err := m.getPort(ctx) if err != nil { diff --git a/internal/gateway/multi_workspace_runtime_test.go b/internal/gateway/multi_workspace_runtime_test.go index f4919c7ef..4dfa084b4 100644 --- a/internal/gateway/multi_workspace_runtime_test.go +++ b/internal/gateway/multi_workspace_runtime_test.go @@ -28,6 +28,8 @@ type recordingPort struct { approvePlanCalls atomic.Int32 resolveUserCalls atomic.Int32 cancelCalls atomic.Int32 + saveAssetCalls atomic.Int32 + openAssetCalls atomic.Int32 closed atomic.Int32 closeOnce sync.Once @@ -135,6 +137,16 @@ func (p *recordingPort) CreateSession(_ context.Context, _ CreateSessionInput) ( return p.id, nil } +func (p *recordingPort) SaveSessionAsset(_ context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + p.saveAssetCalls.Add(1) + return SessionAssetMeta{SessionID: input.SessionID, AssetID: p.id, MimeType: input.MimeType}, nil +} + +func (p *recordingPort) OpenSessionAsset(_ context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + p.openAssetCalls.Add(1) + return OpenSessionAssetResult{Meta: SessionAssetMeta{SessionID: input.SessionID, AssetID: input.AssetID}}, nil +} + func (p *recordingPort) DeleteSession(_ context.Context, _ DeleteSessionInput) (bool, error) { return true, nil } @@ -783,6 +795,12 @@ func TestMultiWorkspaceRuntime_RoutingMatrix(t *testing.T) { if _, err := mw.ExecuteSystemTool(alphaCtx, ExecuteSystemToolInput{}); err != nil { t.Fatalf("ExecuteSystemTool alpha: %v", err) } + if _, err := mw.SaveSessionAsset(betaCtx, SaveSessionAssetInput{SessionID: "s-1", MimeType: "image/png"}); err != nil { + t.Fatalf("SaveSessionAsset beta: %v", err) + } + if _, err := mw.OpenSessionAsset(alphaCtx, OpenSessionAssetInput{SessionID: "s-1", AssetID: "asset-1"}); err != nil { + t.Fatalf("OpenSessionAsset alpha: %v", err) + } alphaPort := builder.portFor(alpha.Path) betaPort := builder.portFor(beta.Path) @@ -801,6 +819,12 @@ func TestMultiWorkspaceRuntime_RoutingMatrix(t *testing.T) { if got := alphaPort.executeSysCalls.Load(); got != 1 { t.Fatalf("alpha ExecuteSystemTool calls = %d, want 1", got) } + if got := betaPort.saveAssetCalls.Load(); got != 1 { + t.Fatalf("beta SaveSessionAsset calls = %d, want 1", got) + } + if got := alphaPort.openAssetCalls.Load(); got != 1 { + t.Fatalf("alpha OpenSessionAsset calls = %d, want 1", got) + } } func TestMultiWorkspaceRuntime_ListWorkspacesMatchesIndex(t *testing.T) { diff --git a/internal/gateway/network_server.go b/internal/gateway/network_server.go index cc78a2bf8..109047a23 100644 --- a/internal/gateway/network_server.go +++ b/internal/gateway/network_server.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "os" + "path" "strconv" "strings" "sync" @@ -21,6 +22,7 @@ import ( "golang.org/x/net/websocket" "neo-code/internal/gateway/protocol" + agentsession "neo-code/internal/session" ) const ( @@ -40,6 +42,8 @@ const ( DefaultNetworkMaxStreamConnections = 128 // DefaultWSUnauthenticatedTimeout 定义 WS 未认证连接的最大等待时间。 DefaultWSUnauthenticatedTimeout = 3 * time.Second + // SessionAssetWorkspaceHeader 定义 Web 上传/读取会话附件时携带当前工作区的 HTTP Header。 + SessionAssetWorkspaceHeader = "X-NeoCode-Workspace-Hash" ) var ( @@ -367,6 +371,12 @@ func (s *NetworkServer) buildHandler(runtimePort RuntimePort) http.Handler { mux.HandleFunc("/rpc", func(writer http.ResponseWriter, request *http.Request) { s.handleRPCRequest(writer, request, runtimePort) }) + mux.HandleFunc("/api/session-assets", func(writer http.ResponseWriter, request *http.Request) { + s.handleSessionAssetUpload(writer, request, runtimePort) + }) + mux.HandleFunc("/api/session-assets/", func(writer http.ResponseWriter, request *http.Request) { + s.handleSessionAssetRead(writer, request, runtimePort) + }) mux.Handle("/ws", websocket.Server{ Handshake: func(_ *websocket.Config, request *http.Request) error { return s.validateWebSocketOrigin(request) @@ -387,6 +397,241 @@ func (s *NetworkServer) buildHandler(runtimePort RuntimePort) http.Handler { return mux } +// handleSessionAssetUpload 接收浏览器上传图片,并保存为当前会话的 session asset。 +func (s *NetworkServer) handleSessionAssetUpload(writer http.ResponseWriter, request *http.Request, runtimePort RuntimePort) { + if request.Method != http.MethodPost { + http.Error(writer, "method not allowed", http.StatusMethodNotAllowed) + return + } + subjectID, ok := s.authenticatedHTTPSubjectID(request) + if !ok { + http.Error(writer, "unauthorized", http.StatusUnauthorized) + return + } + if !s.isHTTPControlPlaneMethodAllowed(sessionAssetUploadMethod) { + s.writeHTTPAccessDenied(writer, sessionAssetUploadMethod) + return + } + if runtimePort == nil { + writeJSONResponse(writer, http.StatusServiceUnavailable, map[string]string{"error": "runtime unavailable"}) + return + } + + limit := agentsession.MaxSessionAssetBytes + request.Body = http.MaxBytesReader(writer, request.Body, limit+(1<<20)) + if err := request.ParseMultipartForm(limit + 4096); err != nil { + if strings.Contains(strings.ToLower(err.Error()), "too large") { + writeJSONResponse(writer, http.StatusRequestEntityTooLarge, map[string]string{"error": "asset is too large"}) + return + } + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "invalid multipart form"}) + return + } + + sessionID := strings.TrimSpace(request.FormValue("session_id")) + if sessionID == "" { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "session_id is required"}) + return + } + + file, _, err := request.FormFile("file") + if err != nil { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "file is required"}) + return + } + defer func() { + _ = file.Close() + }() + + payload, err := io.ReadAll(io.LimitReader(file, limit+1)) + if err != nil { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "read uploaded file failed"}) + return + } + if len(payload) == 0 { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "file is empty"}) + return + } + if int64(len(payload)) > limit { + writeJSONResponse(writer, http.StatusRequestEntityTooLarge, map[string]string{"error": "asset is too large"}) + return + } + + mimeType := detectAllowedUploadImageMime(payload) + if mimeType == "" { + writeJSONResponse(writer, http.StatusUnsupportedMediaType, map[string]string{"error": "unsupported image type"}) + return + } + + meta, err := runtimePort.SaveSessionAsset(sessionAssetRequestContext(request), SaveSessionAssetInput{ + SubjectID: subjectID, + SessionID: sessionID, + Reader: bytes.NewReader(payload), + MimeType: mimeType, + }) + if err != nil { + writeSessionAssetUploadHTTPError(writer, err) + return + } + writeJSONResponse(writer, http.StatusOK, meta) +} + +// handleSessionAssetRead 读取会话图片附件,供 Web 历史消息缩略图展示。 +func (s *NetworkServer) handleSessionAssetRead(writer http.ResponseWriter, request *http.Request, runtimePort RuntimePort) { + if request.Method != http.MethodGet { + http.Error(writer, "method not allowed", http.StatusMethodNotAllowed) + return + } + subjectID, ok := s.authenticatedHTTPSubjectID(request) + if !ok { + http.Error(writer, "unauthorized", http.StatusUnauthorized) + return + } + if !s.isHTTPControlPlaneMethodAllowed(sessionAssetReadMethod) { + s.writeHTTPAccessDenied(writer, sessionAssetReadMethod) + return + } + if runtimePort == nil { + writeJSONResponse(writer, http.StatusServiceUnavailable, map[string]string{"error": "runtime unavailable"}) + return + } + + sessionID, assetID, ok := parseSessionAssetPath(request.URL.Path) + if !ok { + http.NotFound(writer, request) + return + } + result, err := runtimePort.OpenSessionAsset(sessionAssetRequestContext(request), OpenSessionAssetInput{ + SubjectID: subjectID, + SessionID: sessionID, + AssetID: assetID, + }) + if err != nil { + writeSessionAssetReadHTTPError(writer, err) + return + } + defer func() { + _ = result.Reader.Close() + }() + + writer.Header().Set("Content-Type", result.Meta.MimeType) + if result.Meta.Size > 0 { + writer.Header().Set("Content-Length", strconv.FormatInt(result.Meta.Size, 10)) + } + writer.Header().Set("Cache-Control", "private, max-age=300") + _, _ = io.Copy(writer, result.Reader) +} + +// sessionAssetRequestContext 将 HTTP Header 中的工作区哈希注入请求上下文,供多工作区 Runtime 路由。 +func sessionAssetRequestContext(request *http.Request) context.Context { + if request == nil { + return context.Background() + } + workspaceHash := strings.TrimSpace(request.Header.Get(SessionAssetWorkspaceHeader)) + if workspaceHash == "" { + return request.Context() + } + state := NewConnectionWorkspaceState() + state.SetWorkspaceHash(workspaceHash) + return WithConnectionWorkspaceState(request.Context(), state) +} + +// authenticatedHTTPSubjectID 校验 HTTP Bearer Token 并返回主体标识。 +func (s *NetworkServer) authenticatedHTTPSubjectID(request *http.Request) (string, bool) { + if s.authenticator == nil { + return "", false + } + token := extractBearerToken(request.Header.Get("Authorization")) + subjectID, ok := s.authenticator.ResolveSubjectID(token) + if !ok || strings.TrimSpace(subjectID) == "" { + return "", false + } + return strings.TrimSpace(subjectID), true +} + +// isHTTPControlPlaneMethodAllowed 按 HTTP 来源复用控制面 ACL,覆盖非 JSON-RPC 的 HTTP 端点。 +func (s *NetworkServer) isHTTPControlPlaneMethodAllowed(method string) bool { + if s == nil || s.acl == nil { + return true + } + return s.acl.IsAllowed(RequestSourceHTTP, method) +} + +// writeHTTPAccessDenied 记录 HTTP 端点 ACL 拒绝并返回统一的 403 JSON 响应。 +func (s *NetworkServer) writeHTTPAccessDenied(writer http.ResponseWriter, method string) { + if s != nil && s.metrics != nil { + s.metrics.IncACLDenied(string(RequestSourceHTTP), method) + } + writeJSONResponse(writer, http.StatusForbidden, map[string]string{"error": "access denied"}) +} + +// detectAllowedUploadImageMime 用文件头确认上传图片类型,只允许 PNG/JPEG/WebP。 +func detectAllowedUploadImageMime(payload []byte) string { + if len(payload) == 0 { + return "" + } + probe := payload + if len(probe) > 512 { + probe = probe[:512] + } + mimeType := strings.ToLower(strings.TrimSpace(http.DetectContentType(probe))) + switch mimeType { + case "image/png", "image/jpeg", "image/webp": + return mimeType + default: + return "" + } +} + +// parseSessionAssetPath 从 /api/session-assets/{session_id}/{asset_id} 提取路径参数。 +func parseSessionAssetPath(rawPath string) (string, string, bool) { + cleanPath := path.Clean("/" + strings.TrimSpace(rawPath)) + const prefix = "/api/session-assets/" + if !strings.HasPrefix(cleanPath, prefix) { + return "", "", false + } + parts := strings.Split(strings.TrimPrefix(cleanPath, prefix), "/") + if len(parts) != 2 { + return "", "", false + } + sessionID := strings.TrimSpace(parts[0]) + assetID := strings.TrimSpace(parts[1]) + return sessionID, assetID, sessionID != "" && assetID != "" +} + +// writeSessionAssetUploadHTTPError 将上传阶段的下游错误映射为明确 HTTP 状态。 +func writeSessionAssetUploadHTTPError(writer http.ResponseWriter, err error) { + writeSessionAssetHTTPError(writer, err, "session not found") +} + +// writeSessionAssetReadHTTPError 将读取阶段的下游错误映射为明确 HTTP 状态。 +func writeSessionAssetReadHTTPError(writer http.ResponseWriter, err error) { + writeSessionAssetHTTPError(writer, err, "asset not found") +} + +// writeSessionAssetHTTPError 将下游附件错误映射为明确 HTTP 状态。 +func writeSessionAssetHTTPError(writer http.ResponseWriter, err error, notFoundMessage string) { + if err == nil { + writeJSONResponse(writer, http.StatusInternalServerError, map[string]string{"error": "unknown asset error"}) + return + } + message := strings.ToLower(err.Error()) + switch { + case strings.Contains(message, "workspace") && strings.Contains(message, "not found"): + writeJSONResponse(writer, http.StatusNotFound, map[string]string{"error": "workspace not found"}) + case errors.Is(err, os.ErrNotExist) || errors.Is(err, ErrRuntimeResourceNotFound): + writeJSONResponse(writer, http.StatusNotFound, map[string]string{"error": notFoundMessage}) + case strings.Contains(message, "asset size exceeds"): + writeJSONResponse(writer, http.StatusRequestEntityTooLarge, map[string]string{"error": err.Error()}) + case strings.Contains(message, "unsupported") || strings.Contains(message, "not an image"): + writeJSONResponse(writer, http.StatusUnsupportedMediaType, map[string]string{"error": err.Error()}) + case strings.Contains(message, "access denied"): + writeJSONResponse(writer, http.StatusForbidden, map[string]string{"error": "access denied"}) + default: + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": err.Error()}) + } +} + // withCORS 为网络入口注入 CORS 头,仅对白名单 Origin 回显允许值。 // WebSocket 升级请求不受 CORS 约束,直接放行交予 WS 握手阶段的 Origin 校验。 func (s *NetworkServer) withCORS(next http.Handler) http.Handler { @@ -406,7 +651,7 @@ func (s *NetworkServer) withCORS(next http.Handler) http.Handler { } writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, "+SessionAssetWorkspaceHeader) if request.Method == http.MethodOptions { writer.WriteHeader(http.StatusNoContent) return diff --git a/internal/gateway/network_server_test.go b/internal/gateway/network_server_test.go index 96301525c..5958b9658 100644 --- a/internal/gateway/network_server_test.go +++ b/internal/gateway/network_server_test.go @@ -2,10 +2,13 @@ package gateway import ( "bufio" + "bytes" "context" "encoding/json" + "fmt" "io" "log" + "mime/multipart" "net/http" "net/http/httptest" "strings" @@ -15,6 +18,7 @@ import ( "golang.org/x/net/websocket" "neo-code/internal/gateway/protocol" + agentsession "neo-code/internal/session" ) func TestResolveNetworkListenAddress(t *testing.T) { @@ -400,6 +404,293 @@ func TestNetworkServerRPCErrorBranches(t *testing.T) { }) } +func TestNetworkServerSessionAssetUploadAndRead(t *testing.T) { + payload := gatewayMinimalPNGBytes() + var capturedUpload SaveSessionAssetInput + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(_ context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + capturedUpload = input + got, err := io.ReadAll(input.Reader) + if err != nil { + t.Fatalf("read uploaded asset: %v", err) + } + if !bytes.Equal(got, payload) { + t.Fatalf("uploaded payload mismatch") + } + return SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: "asset-1", + MimeType: input.MimeType, + Size: int64(len(got)), + }, nil + }, + openAssetFn: func(_ context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + if input.SubjectID != "local_admin" || input.SessionID != "session-1" || input.AssetID != "asset-1" { + t.Fatalf("open input = %+v, want subject/session/asset", input) + } + return OpenSessionAssetResult{ + Reader: io.NopCloser(bytes.NewReader(payload)), + Meta: SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: input.AssetID, + MimeType: "image/png", + Size: int64(len(payload)), + }, + }, nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + uploadRequest := newSessionAssetUploadRequest(t, "session-1", "a.png", payload) + uploadRequest.Header.Set("Authorization", "Bearer gateway-token") + uploadRecorder := httptest.NewRecorder() + handler.ServeHTTP(uploadRecorder, uploadRequest) + if uploadRecorder.Code != http.StatusOK { + t.Fatalf("upload status = %d body=%s", uploadRecorder.Code, uploadRecorder.Body.String()) + } + var uploadResponse SessionAssetMeta + if err := json.Unmarshal(uploadRecorder.Body.Bytes(), &uploadResponse); err != nil { + t.Fatalf("decode upload response: %v", err) + } + if uploadResponse.AssetID != "asset-1" || uploadResponse.MimeType != "image/png" || uploadResponse.Size != int64(len(payload)) { + t.Fatalf("upload response = %+v", uploadResponse) + } + if capturedUpload.SubjectID != "local_admin" || capturedUpload.SessionID != "session-1" || capturedUpload.MimeType != "image/png" { + t.Fatalf("captured upload = %+v", capturedUpload) + } + + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusOK { + t.Fatalf("read status = %d body=%s", readRecorder.Code, readRecorder.Body.String()) + } + if got := readRecorder.Header().Get("Content-Type"); got != "image/png" { + t.Fatalf("read content-type = %q, want image/png", got) + } + if !bytes.Equal(readRecorder.Body.Bytes(), payload) { + t.Fatalf("read payload mismatch") + } +} + +func TestNetworkServerSessionAssetsRespectHTTPACL(t *testing.T) { + deniedACL := &ControlPlaneACL{ + mode: ACLModeStrict, + allow: map[RequestSource]map[string]struct{}{RequestSourceHTTP: {}}, + enabled: true, + } + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) { + t.Fatal("SaveSessionAsset should not be called when ACL denies upload") + return SessionAssetMeta{}, nil + }, + openAssetFn: func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + t.Fatal("OpenSessionAsset should not be called when ACL denies read") + return OpenSessionAssetResult{}, nil + }, + } + server := &NetworkServer{ + authenticator: staticTokenAuthenticator{token: "gateway-token"}, + acl: deniedACL, + metrics: NewGatewayMetrics(), + } + handler := server.buildHandler(runtimePort) + + uploadRequest := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + uploadRequest.Header.Set("Authorization", "Bearer gateway-token") + uploadRecorder := httptest.NewRecorder() + handler.ServeHTTP(uploadRecorder, uploadRequest) + if uploadRecorder.Code != http.StatusForbidden { + t.Fatalf("upload status = %d body=%s, want %d", uploadRecorder.Code, uploadRecorder.Body.String(), http.StatusForbidden) + } + + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusForbidden { + t.Fatalf("read status = %d body=%s, want %d", readRecorder.Code, readRecorder.Body.String(), http.StatusForbidden) + } +} + +func TestNetworkServerSessionAssetWorkspaceHeader(t *testing.T) { + payload := gatewayMinimalPNGBytes() + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + if got := WorkspaceHashFromContext(ctx); got != "workspace-b" { + t.Fatalf("upload workspace hash = %q, want workspace-b", got) + } + return SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: "asset-1", + MimeType: input.MimeType, + Size: int64(len(payload)), + }, nil + }, + openAssetFn: func(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + if got := WorkspaceHashFromContext(ctx); got != "workspace-b" { + t.Fatalf("read workspace hash = %q, want workspace-b", got) + } + return OpenSessionAssetResult{ + Reader: io.NopCloser(bytes.NewReader(payload)), + Meta: SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: input.AssetID, + MimeType: "image/png", + Size: int64(len(payload)), + }, + }, nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + uploadRequest := newSessionAssetUploadRequest(t, "session-1", "a.png", payload) + uploadRequest.Header.Set("Authorization", "Bearer gateway-token") + uploadRequest.Header.Set(SessionAssetWorkspaceHeader, "workspace-b") + uploadRecorder := httptest.NewRecorder() + handler.ServeHTTP(uploadRecorder, uploadRequest) + if uploadRecorder.Code != http.StatusOK { + t.Fatalf("upload status = %d body=%s", uploadRecorder.Code, uploadRecorder.Body.String()) + } + + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRequest.Header.Set(SessionAssetWorkspaceHeader, "workspace-b") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusOK { + t.Fatalf("read status = %d body=%s", readRecorder.Code, readRecorder.Body.String()) + } +} + +func TestNetworkServerSessionAssetWorkspaceHeaderEmptyFallback(t *testing.T) { + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + if got := WorkspaceHashFromContext(ctx); got != "" { + t.Fatalf("workspace hash = %q, want empty fallback", got) + } + return SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: "asset-1", + MimeType: input.MimeType, + Size: int64(len(gatewayMinimalPNGBytes())), + }, nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d body=%s", recorder.Code, recorder.Body.String()) + } +} + +func TestNetworkServerSessionAssetUploadErrors(t *testing.T) { + runtimePort := &runtimePortEventStub{} + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.withCORS(server.buildHandler(runtimePort)) + + t.Run("unauthorized", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnauthorized) + } + }) + + t.Run("forbidden origin", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + request.Header.Set("Authorization", "Bearer gateway-token") + request.Header.Set("Origin", "http://evil.example") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusForbidden) + } + }) + + t.Run("non image", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "bad.txt", []byte("not an image")) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusUnsupportedMediaType { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnsupportedMediaType) + } + }) + + t.Run("empty file", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "empty.png", nil) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusBadRequest) + } + }) + + t.Run("oversized file", func(t *testing.T) { + request := newSessionAssetUploadRequest( + t, + "session-1", + "huge.png", + bytes.Repeat([]byte{0}, int(agentsession.MaxSessionAssetBytes)+1), + ) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusRequestEntityTooLarge) + } + }) + + t.Run("workspace not found", func(t *testing.T) { + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, fmt.Errorf("%w: workspace missing not found", ErrRuntimeResourceNotFound) + }, + } + handler := server.withCORS(server.buildHandler(runtimePort)) + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + request.Header.Set("Authorization", "Bearer gateway-token") + request.Header.Set(SessionAssetWorkspaceHeader, "missing") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusNotFound) + } + if !strings.Contains(recorder.Body.String(), "workspace not found") { + t.Fatalf("body = %s, want workspace not found", recorder.Body.String()) + } + }) +} + +func TestNetworkServerSessionAssetReadNotFound(t *testing.T) { + runtimePort := &runtimePortEventStub{ + openAssetFn: func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, ErrRuntimeResourceNotFound + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + request := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/missing", nil) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusNotFound) + } +} + func TestNetworkServerWebSocketAndSSEPing(t *testing.T) { server := newTestNetworkServer(t, NetworkServerOptions{}) testContext, cancel := context.WithCancel(context.Background()) @@ -1322,6 +1613,45 @@ type noFlushResponseWriter struct { body strings.Builder } +func newSessionAssetUploadRequest(t *testing.T, sessionID, fileName string, payload []byte) *http.Request { + t.Helper() + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + if sessionID != "" { + if err := writer.WriteField("session_id", sessionID); err != nil { + t.Fatalf("write session_id field: %v", err) + } + } + part, err := writer.CreateFormFile("file", fileName) + if err != nil { + t.Fatalf("create file part: %v", err) + } + if _, err := part.Write(payload); err != nil { + t.Fatalf("write file part: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close multipart writer: %v", err) + } + request := httptest.NewRequest(http.MethodPost, "/api/session-assets", &body) + request.Header.Set("Content-Type", writer.FormDataContentType()) + return request +} + +func gatewayMinimalPNGBytes() []byte { + return []byte{ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, + 0x89, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x44, 0x41, + 0x54, 0x78, 0x9c, 0x63, 0xf8, 0xcf, 0xc0, 0x00, + 0x00, 0x03, 0x01, 0x01, 0x00, 0xc9, 0xfe, 0x92, + 0xef, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, + 0x44, 0xae, 0x42, 0x60, 0x82, + } +} + type staticTokenAuthenticator struct { token string } diff --git a/internal/gateway/protocol/jsonrpc.go b/internal/gateway/protocol/jsonrpc.go index ac41d62aa..71ab1d6ed 100644 --- a/internal/gateway/protocol/jsonrpc.go +++ b/internal/gateway/protocol/jsonrpc.go @@ -250,6 +250,7 @@ const ( // RunInputMedia 用于承载 gateway.run 中图片分片的媒体元数据。 type RunInputMedia struct { URI string `json:"uri"` + AssetID string `json:"asset_id,omitempty"` MimeType string `json:"mime_type"` FileName string `json:"file_name,omitempty"` } @@ -1402,6 +1403,7 @@ func decodeRunParams(raw json.RawMessage) (RunParams, *JSONRPCError) { p.InputParts[i].Text = strings.TrimSpace(p.InputParts[i].Text) if m := p.InputParts[i].Media; m != nil { m.URI = strings.TrimSpace(m.URI) + m.AssetID = strings.TrimSpace(m.AssetID) m.MimeType = strings.TrimSpace(m.MimeType) m.FileName = strings.TrimSpace(m.FileName) } diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go index 3a7f7da48..c64211e6c 100644 --- a/internal/gateway/protocol/jsonrpc_test.go +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -393,7 +393,8 @@ func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { "workdir":" /tmp/work ", "input_parts":[ {"type":" TEXT ","text":" world "}, - {"type":" image ","media":{"uri":" /tmp/a.png ","mime_type":" image/png ","file_name":" a.png "}} + {"type":" image ","media":{"uri":" /tmp/a.png ","mime_type":" image/png ","file_name":" a.png "}}, + {"type":" image ","media":{"asset_id":" asset-1 ","mime_type":" image/webp "}} ] }`), } @@ -414,8 +415,8 @@ func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { if runParams.InputText != "hello" { t.Fatalf("run input_text = %q, want %q", runParams.InputText, "hello") } - if len(runParams.InputParts) != 2 { - t.Fatalf("run input_parts len = %d, want 2", len(runParams.InputParts)) + if len(runParams.InputParts) != 3 { + t.Fatalf("run input_parts len = %d, want 3", len(runParams.InputParts)) } if runParams.InputParts[0].Type != "text" || runParams.InputParts[0].Text != "world" { t.Fatalf("run text part = %#v, want normalized text part", runParams.InputParts[0]) @@ -426,6 +427,12 @@ func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { if runParams.InputParts[1].Media.MimeType != "image/png" || runParams.InputParts[1].Media.FileName != "a.png" { t.Fatalf("run image media = %#v, want trimmed mime/file_name", runParams.InputParts[1].Media) } + if runParams.InputParts[2].Type != "image" || + runParams.InputParts[2].Media == nil || + runParams.InputParts[2].Media.AssetID != "asset-1" || + runParams.InputParts[2].Media.MimeType != "image/webp" { + t.Fatalf("run image asset media = %#v, want trimmed asset_id/mime", runParams.InputParts[2]) + } compactNormalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ JSONRPC: JSONRPCVersion, diff --git a/internal/gateway/rpc_dispatch.go b/internal/gateway/rpc_dispatch.go index 3b5f38ae0..33639ff66 100644 --- a/internal/gateway/rpc_dispatch.go +++ b/internal/gateway/rpc_dispatch.go @@ -344,6 +344,7 @@ func convertProtocolRunInputParts(parts []protocol.RunInputPart) []InputPart { if part.Media != nil { convertedPart.Media = &Media{ URI: strings.TrimSpace(part.Media.URI), + AssetID: strings.TrimSpace(part.Media.AssetID), MimeType: strings.TrimSpace(part.Media.MimeType), FileName: strings.TrimSpace(part.Media.FileName), } diff --git a/internal/gateway/rpc_dispatch_test.go b/internal/gateway/rpc_dispatch_test.go index 1994e3e51..737851253 100644 --- a/internal/gateway/rpc_dispatch_test.go +++ b/internal/gateway/rpc_dispatch_test.go @@ -234,6 +234,14 @@ func (s *rpcRunCaptureRuntimeStub) CreateSession(ctx context.Context, input Crea return s.createSessionID, nil } +func (s *rpcRunCaptureRuntimeStub) SaveSessionAsset(_ context.Context, _ SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} + +func (s *rpcRunCaptureRuntimeStub) OpenSessionAsset(_ context.Context, _ OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} + func (s *rpcRunCaptureRuntimeStub) ListSessionTodos(_ context.Context, _ ListSessionTodosInput) (TodoSnapshot, error) { return TodoSnapshot{}, nil } @@ -1130,6 +1138,12 @@ func (s *runtimePortOnlyStub) GetRuntimeSnapshot(_ context.Context, _ GetRuntime func (s *runtimePortOnlyStub) CreateSession(_ context.Context, _ CreateSessionInput) (string, error) { return "", nil } +func (s *runtimePortOnlyStub) SaveSessionAsset(_ context.Context, _ SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} +func (s *runtimePortOnlyStub) OpenSessionAsset(_ context.Context, _ OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} func (s *runtimePortOnlyStub) DeleteSession(_ context.Context, _ DeleteSessionInput) (bool, error) { return false, nil } @@ -1208,7 +1222,8 @@ func TestDispatchRPCRequestRunHydratesInputPartsAndFallbackRunID(t *testing.T) { "session_id":"session-run-1", "input_parts":[ {"type":"text","text":"hello world"}, - {"type":"image","media":{"uri":"C:/tmp/pic.png","mime_type":"image/png"}} + {"type":"image","media":{"uri":"C:/tmp/pic.png","mime_type":"image/png"}}, + {"type":"image","media":{"asset_id":"asset-1","mime_type":"image/webp"}} ] }`), }, runtimeStub) @@ -1229,8 +1244,8 @@ func TestDispatchRPCRequestRunHydratesInputPartsAndFallbackRunID(t *testing.T) { if captured.RunID != "req-run-hydrate" { t.Fatalf("runtime run run_id = %q, want %q", captured.RunID, "req-run-hydrate") } - if len(captured.InputParts) != 2 { - t.Fatalf("runtime run input_parts len = %d, want %d", len(captured.InputParts), 2) + if len(captured.InputParts) != 3 { + t.Fatalf("runtime run input_parts len = %d, want %d", len(captured.InputParts), 3) } if captured.InputParts[0].Type != InputPartTypeText { t.Fatalf("runtime text part type = %q, want %q", captured.InputParts[0].Type, InputPartTypeText) @@ -1241,6 +1256,11 @@ func TestDispatchRPCRequestRunHydratesInputPartsAndFallbackRunID(t *testing.T) { if captured.InputParts[1].Media == nil || captured.InputParts[1].Media.URI != "C:/tmp/pic.png" { t.Fatalf("runtime image media = %#v, want uri %q", captured.InputParts[1].Media, "C:/tmp/pic.png") } + if captured.InputParts[2].Media == nil || + captured.InputParts[2].Media.AssetID != "asset-1" || + captured.InputParts[2].Media.MimeType != "image/webp" { + t.Fatalf("runtime image asset media = %#v, want asset_id", captured.InputParts[2].Media) + } } func TestDispatchRPCRequest_DenyCrossSubjectLoadSession(t *testing.T) { diff --git a/internal/gateway/security.go b/internal/gateway/security.go index 66d741de8..9c1d70eb4 100644 --- a/internal/gateway/security.go +++ b/internal/gateway/security.go @@ -4,7 +4,11 @@ import ( "strings" ) -const pingMethod = "gateway.ping" +const ( + pingMethod = "gateway.ping" + sessionAssetUploadMethod = "gateway.sessionAssetUpload" + sessionAssetReadMethod = "gateway.sessionAssetRead" +) // RequestSource 表示控制面请求来源,用于 ACL 与日志分类。 type RequestSource string @@ -98,6 +102,8 @@ func fullControlPlaneMethods() map[string]struct{} { "gateway.renameWorkspace", "gateway.deleteWorkspace", "wake.openUrl", + sessionAssetUploadMethod, + sessionAssetReadMethod, } return normalizedMethodSet(methods...) } diff --git a/internal/gateway/security_test.go b/internal/gateway/security_test.go index 55e955399..3c9bb98c2 100644 --- a/internal/gateway/security_test.go +++ b/internal/gateway/security_test.go @@ -45,6 +45,9 @@ func TestStrictACLAllowlist(t *testing.T) { {source: RequestSourceSSE, method: "gateway.approvePlan", want: false}, {source: RequestSourceHTTP, method: "gateway.userQuestionAnswer", want: true}, {source: RequestSourceHTTP, method: "gateway.user_question_answer", want: true}, + {source: RequestSourceHTTP, method: sessionAssetUploadMethod, want: true}, + {source: RequestSourceHTTP, method: sessionAssetReadMethod, want: true}, + {source: RequestSourceSSE, method: sessionAssetReadMethod, want: false}, {source: RequestSourceUnknown, method: "gateway.ping", want: false}, {source: RequestSourceUnknown, method: "gateway.approvePlan", want: false}, } diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index a61d8eca9..261a37a70 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -367,7 +367,9 @@ func TestServerHandleConnectionAuthenticateFlow(t *testing.T) { } type runtimePortEventStub struct { - events <-chan RuntimeEvent + events <-chan RuntimeEvent + saveAssetFn func(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) + openAssetFn func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) } func (s *runtimePortEventStub) Run(_ context.Context, _ RunInput) error { @@ -467,6 +469,20 @@ func (s *runtimePortEventStub) CreateSession(_ context.Context, _ CreateSessionI return "", nil } +func (s *runtimePortEventStub) SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + if s.saveAssetFn != nil { + return s.saveAssetFn(ctx, input) + } + return SessionAssetMeta{}, nil +} + +func (s *runtimePortEventStub) OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + if s.openAssetFn != nil { + return s.openAssetFn(ctx, input) + } + return OpenSessionAssetResult{}, nil +} + func (s *runtimePortEventStub) ListSessionTodos(_ context.Context, _ ListSessionTodosInput) (TodoSnapshot, error) { return TodoSnapshot{}, nil } diff --git a/internal/gateway/static_files.go b/internal/gateway/static_files.go index 936f7c78f..12aeb4c17 100644 --- a/internal/gateway/static_files.go +++ b/internal/gateway/static_files.go @@ -16,6 +16,7 @@ var knownAPIPrefixes = map[string]bool{ "/healthz": true, "/version": true, "/rpc": true, + "/api": true, "/ws": true, "/sse": true, "/metrics": true, diff --git a/internal/gateway/types.go b/internal/gateway/types.go index 16d207394..e2a98a938 100644 --- a/internal/gateway/types.go +++ b/internal/gateway/types.go @@ -136,6 +136,8 @@ const ( type Media struct { // URI 是媒体资源地址。 URI string `json:"uri"` + // AssetID 是已保存的 session asset 标识。 + AssetID string `json:"asset_id,omitempty"` // MimeType 是媒体 MIME 类型。 MimeType string `json:"mime_type"` // FileName 是媒体文件名。 diff --git a/internal/gateway/validate.go b/internal/gateway/validate.go index 985ee96e3..4684a323c 100644 --- a/internal/gateway/validate.go +++ b/internal/gateway/validate.go @@ -585,8 +585,10 @@ func validateInputPart(part InputPart, index int) *FrameError { if part.Media == nil { return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires media") } - if strings.TrimSpace(part.Media.URI) == "" { - return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires media.uri") + hasURI := strings.TrimSpace(part.Media.URI) != "" + hasAssetID := strings.TrimSpace(part.Media.AssetID) != "" + if hasURI == hasAssetID { + return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires exactly one of media.uri or media.asset_id") } if strings.TrimSpace(part.Media.MimeType) == "" { return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires media.mime_type") diff --git a/internal/gateway/validate_test.go b/internal/gateway/validate_test.go index a2db95066..958e46b07 100644 --- a/internal/gateway/validate_test.go +++ b/internal/gateway/validate_test.go @@ -829,6 +829,23 @@ func TestValidateFrame_MultimodalPayloadRules(t *testing.T) { }, wantNil: true, }, + { + name: "valid image asset part", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRun, + InputParts: []InputPart{ + { + Type: InputPartTypeImage, + Media: &Media{ + AssetID: "asset-1", + MimeType: "image/png", + }, + }, + }, + }, + wantNil: true, + }, { name: "text part with empty text", frame: MessageFrame{ @@ -852,7 +869,7 @@ func TestValidateFrame_MultimodalPayloadRules(t *testing.T) { wantCode: ErrorCodeInvalidMultimodalPayload.String(), }, { - name: "image part missing media.uri", + name: "image part missing media.uri and media.asset_id", frame: MessageFrame{ Type: FrameTypeRequest, Action: FrameActionRun, @@ -865,6 +882,24 @@ func TestValidateFrame_MultimodalPayloadRules(t *testing.T) { }, wantCode: ErrorCodeInvalidMultimodalPayload.String(), }, + { + name: "image part has both media.uri and media.asset_id", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRun, + InputParts: []InputPart{ + { + Type: InputPartTypeImage, + Media: &Media{ + URI: "file:///a.png", + AssetID: "asset-1", + MimeType: "image/png", + }, + }, + }, + }, + wantCode: ErrorCodeInvalidMultimodalPayload.String(), + }, { name: "image part missing media.mime_type", frame: MessageFrame{ diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index df749b868..8d8092f8e 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -41,6 +41,17 @@ func (p *Provider) EstimateInputTokens( ctx context.Context, req providertypes.GenerateRequest, ) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } params, err := BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/deepseek/provider.go b/internal/provider/deepseek/provider.go index 55956c7ef..1ffbab9dc 100644 --- a/internal/provider/deepseek/provider.go +++ b/internal/provider/deepseek/provider.go @@ -40,6 +40,17 @@ func (p *Provider) EstimateInputTokens( ctx context.Context, req providertypes.GenerateRequest, ) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/deepseek/provider_more_test.go b/internal/provider/deepseek/provider_more_test.go index d517cb340..7f259ef44 100644 --- a/internal/provider/deepseek/provider_more_test.go +++ b/internal/provider/deepseek/provider_more_test.go @@ -89,6 +89,22 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { t.Fatalf("expected positive token estimate, got %+v", estimate) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } + events := make(chan providertypes.StreamEvent, 8) if err := p.Generate(context.Background(), req, events); err != nil { t.Fatalf("Generate() error = %v", err) diff --git a/internal/provider/estimate.go b/internal/provider/estimate.go index 2f0499e8e..e94b59063 100644 --- a/internal/provider/estimate.go +++ b/internal/provider/estimate.go @@ -4,7 +4,9 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "math" + "strings" providertypes "neo-code/internal/provider/types" ) @@ -15,6 +17,8 @@ const ( EstimateGateAdvisory = "advisory" EstimateGateGateable = "gateable" localEstimateSlack = 1.15 + // DefaultImageInputTokenEstimate 是无法读取图片尺寸时单张图片的保守预算估算值。 + DefaultImageInputTokenEstimate = 2048 ) // EstimateSerializedPayloadTokens 基于最终协议载荷的序列化结果估算输入 token 数。 @@ -34,6 +38,109 @@ func EstimateTextTokens(text string) int { return int(math.Ceil(float64(len([]byte(text))) / 4.0 * localEstimateSlack)) } +// RequestContainsImagePart 判断请求中是否包含图片分片,供 provider 选择多模态投影估算路径。 +func RequestContainsImagePart(req providertypes.GenerateRequest) bool { + for _, message := range req.Messages { + for _, part := range message.Parts { + if part.Kind == providertypes.ContentPartImage { + return true + } + } + } + return false +} + +// ResolveRequestModel 按请求模型优先、配置默认模型兜底的规则解析实际模型名。 +func ResolveRequestModel(req providertypes.GenerateRequest, defaultModel string) string { + model := strings.TrimSpace(req.Model) + if model == "" { + model = strings.TrimSpace(defaultModel) + } + return model +} + +// EstimateProjectedInputTokens 只估算语义输入,不把图片的 base64 传输体计入 prompt token。 +func EstimateProjectedInputTokens(req providertypes.GenerateRequest, model string) (int, error) { + if strings.TrimSpace(model) == "" { + return 0, errors.New("model is empty") + } + + var textBuilder strings.Builder + textBuilder.WriteString(model) + textBuilder.WriteByte('\n') + textBuilder.WriteString(req.SystemPrompt) + textBuilder.WriteByte('\n') + + imageCount := 0 + for _, message := range req.Messages { + if err := providertypes.ValidateParts(message.Parts); err != nil { + return 0, err + } + textBuilder.WriteString(message.Role) + textBuilder.WriteByte('\n') + textBuilder.WriteString(message.ToolCallID) + textBuilder.WriteByte('\n') + for _, part := range message.Parts { + switch part.Kind { + case providertypes.ContentPartText: + textBuilder.WriteString(part.Text) + textBuilder.WriteByte('\n') + case providertypes.ContentPartImage: + imageCount++ + if part.Image != nil { + textBuilder.WriteString(string(part.Image.SourceType)) + textBuilder.WriteByte('\n') + textBuilder.WriteString(part.Image.URL) + textBuilder.WriteByte('\n') + if part.Image.Asset != nil { + textBuilder.WriteString(part.Image.Asset.ID) + textBuilder.WriteByte('\n') + textBuilder.WriteString(part.Image.Asset.MimeType) + textBuilder.WriteByte('\n') + } + } + } + } + for _, call := range message.ToolCalls { + textBuilder.WriteString(call.ID) + textBuilder.WriteByte('\n') + textBuilder.WriteString(call.Name) + textBuilder.WriteByte('\n') + textBuilder.WriteString(call.Arguments) + textBuilder.WriteByte('\n') + } + for key, value := range message.ToolMetadata { + textBuilder.WriteString(key) + textBuilder.WriteByte('=') + textBuilder.WriteString(value) + textBuilder.WriteByte('\n') + } + } + + for _, spec := range req.Tools { + textBuilder.WriteString(spec.Name) + textBuilder.WriteByte('\n') + textBuilder.WriteString(spec.Description) + textBuilder.WriteByte('\n') + normalized := NormalizeToolSchemaObject(spec.Schema) + encoded, err := json.Marshal(normalized) + if err != nil { + return 0, err + } + textBuilder.Write(encoded) + textBuilder.WriteByte('\n') + } + if req.ThinkingConfig != nil { + textBuilder.WriteString(req.ThinkingConfig.Effort) + textBuilder.WriteByte('\n') + if req.ThinkingConfig.Enabled { + textBuilder.WriteString("thinking_enabled") + } + } + + return EstimateTextTokens(textBuilder.String()) + imageCount*DefaultImageInputTokenEstimate, nil +} + // BuildGenerateRequestSignature 生成 GenerateRequest 的稳定签名,用于估算与发送阶段的请求复用匹配。 func BuildGenerateRequestSignature(req providertypes.GenerateRequest) string { encoded, err := json.Marshal(req) diff --git a/internal/provider/estimate_test.go b/internal/provider/estimate_test.go index 5fecb1351..7e830b66d 100644 --- a/internal/provider/estimate_test.go +++ b/internal/provider/estimate_test.go @@ -1,6 +1,7 @@ package provider import ( + "strings" "testing" providertypes "neo-code/internal/provider/types" @@ -40,6 +41,130 @@ func TestEstimateTextTokens(t *testing.T) { } } +func TestResolveRequestModel(t *testing.T) { + t.Parallel() + + req := providertypes.GenerateRequest{Model: " request-model "} + if got := ResolveRequestModel(req, "default-model"); got != "request-model" { + t.Fatalf("ResolveRequestModel() = %q, want request model", got) + } + + req.Model = " " + if got := ResolveRequestModel(req, " default-model "); got != "default-model" { + t.Fatalf("ResolveRequestModel() fallback = %q, want default model", got) + } +} + +func TestRequestContainsImagePart(t *testing.T) { + t.Parallel() + + textOnly := providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }}, + } + if RequestContainsImagePart(textOnly) { + t.Fatal("expected text-only request to report no images") + } + withImage := textOnly + withImage.Messages[0].Parts = append(withImage.Messages[0].Parts, providertypes.NewSessionAssetImagePart("asset-1", "image/png")) + if !RequestContainsImagePart(withImage) { + t.Fatal("expected image request to report images") + } +} + +func TestEstimateProjectedInputTokensDoesNotCountBase64Transport(t *testing.T) { + t.Parallel() + + tokens, err := EstimateProjectedInputTokens(providertypes.GenerateRequest{ + SystemPrompt: "You are concise.", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("describe this"), + providertypes.NewSessionAssetImagePart("asset-1", "image/png"), + }, + }}, + Tools: []providertypes.ToolSpec{{ + Name: "filesystem_read_file", + Description: "Read a file", + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + }, + }}, + }, "gpt-4.1") + if err != nil { + t.Fatalf("EstimateProjectedInputTokens() error = %v", err) + } + if tokens <= DefaultImageInputTokenEstimate { + t.Fatalf("expected text and tool schema to add tokens, got %d", tokens) + } + if tokens > 10_000 { + t.Fatalf("projected estimate counted transport-sized payload, got %d", tokens) + } + + oneMiBDataURLTokens := EstimateTextTokens(strings.Repeat("x", int(EstimateDataURLTransportBytes(1024*1024, "image/png")))) + if tokens >= oneMiBDataURLTokens { + t.Fatalf("projected estimate = %d, want below data URL transport estimate %d", tokens, oneMiBDataURLTokens) + } +} + +func TestEstimateProjectedInputTokensValidatesPartsAndModel(t *testing.T) { + t.Parallel() + + if _, err := EstimateProjectedInputTokens(providertypes.GenerateRequest{}, " "); err == nil { + t.Fatal("expected empty model error") + } + _, err := EstimateProjectedInputTokens(providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{{Kind: "invalid"}}, + }}, + }, "gpt") + if err == nil { + t.Fatal("expected invalid parts error") + } + + _, err = EstimateProjectedInputTokens(providertypes.GenerateRequest{ + Model: "gpt", + Tools: []providertypes.ToolSpec{{Name: "bad", Schema: map[string]any{"unsupported": func() {}}}}, + }, "gpt") + if err == nil { + t.Fatal("expected invalid tool schema error") + } +} + +func TestEstimateProjectedInputTokensCoversMetadataAndImageSources(t *testing.T) { + t.Parallel() + + tokens, err := EstimateProjectedInputTokens(providertypes.GenerateRequest{ + SystemPrompt: "system", + Messages: []providertypes.Message{{ + Role: providertypes.RoleTool, + ToolCallID: "tool-call-1", + Parts: []providertypes.ContentPart{ + providertypes.NewRemoteImagePart("https://example.com/a.png"), + providertypes.NewSessionAssetImagePart("asset-1", "image/png"), + }, + ToolCalls: []providertypes.ToolCall{{ID: "call-1", Name: "bash", Arguments: `{"cmd":"pwd"}`}}, + ToolMetadata: map[string]string{ + "exit_code": "0", + }, + }}, + ThinkingConfig: &providertypes.ThinkingConfig{Enabled: true, Effort: "medium"}, + }, "gpt-4.1") + if err != nil { + t.Fatalf("EstimateProjectedInputTokens() error = %v", err) + } + if tokens <= 2*DefaultImageInputTokenEstimate { + t.Fatalf("expected metadata text to add tokens, got %d", tokens) + } +} + func TestBuildGenerateRequestSignature(t *testing.T) { t.Parallel() @@ -68,4 +193,10 @@ func TestBuildGenerateRequestSignature(t *testing.T) { if sigA == sigC { t.Fatalf("different requests should have different signatures: %q == %q", sigA, sigC) } + + bad := reqA + bad.Tools = []providertypes.ToolSpec{{Name: "bad", Schema: map[string]any{"unsupported": func() {}}}} + if got := BuildGenerateRequestSignature(bad); got != "" { + t.Fatalf("BuildGenerateRequestSignature(bad) = %q, want empty signature", got) + } } diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index 2739e2cf6..e63785d7e 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -44,6 +44,17 @@ func (p *Provider) EstimateInputTokens( ctx context.Context, req providertypes.GenerateRequest, ) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } model, contents, genConfig, err := BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/mimo/provider.go b/internal/provider/mimo/provider.go index a5eac3580..688f2ed47 100644 --- a/internal/provider/mimo/provider.go +++ b/internal/provider/mimo/provider.go @@ -37,6 +37,17 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { } func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/mimo/provider_more_test.go b/internal/provider/mimo/provider_more_test.go index 5caa66f95..5fb02c8ff 100644 --- a/internal/provider/mimo/provider_more_test.go +++ b/internal/provider/mimo/provider_more_test.go @@ -80,6 +80,21 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { t.Fatalf("EstimateInputTokens() error = %v", err) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } events := make(chan providertypes.StreamEvent, 8) if err := p.Generate(context.Background(), req, events); err != nil { t.Fatalf("Generate() error = %v", err) diff --git a/internal/provider/minimax/provider.go b/internal/provider/minimax/provider.go index 4f5b9d5d0..86cc6b1f8 100644 --- a/internal/provider/minimax/provider.go +++ b/internal/provider/minimax/provider.go @@ -40,6 +40,17 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { } func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/minimax/provider_more_test.go b/internal/provider/minimax/provider_more_test.go index 92e0ddc22..b375fb372 100644 --- a/internal/provider/minimax/provider_more_test.go +++ b/internal/provider/minimax/provider_more_test.go @@ -79,6 +79,21 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { t.Fatalf("EstimateInputTokens() error = %v", err) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } p, err = New(provider.RuntimeConfig{ BaseURL: server.URL + "/chat/completions", APIKeyEnv: "TEST_KEY", diff --git a/internal/provider/openaicompat/glm/provider.go b/internal/provider/openaicompat/glm/provider.go index e4daae228..d5a245fc3 100644 --- a/internal/provider/openaicompat/glm/provider.go +++ b/internal/provider/openaicompat/glm/provider.go @@ -37,6 +37,17 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { } func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/openaicompat/glm/provider_more_test.go b/internal/provider/openaicompat/glm/provider_more_test.go index ad8cc4e67..eebe44d62 100644 --- a/internal/provider/openaicompat/glm/provider_more_test.go +++ b/internal/provider/openaicompat/glm/provider_more_test.go @@ -80,6 +80,21 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { t.Fatalf("EstimateInputTokens() error = %v", err) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } events := make(chan providertypes.StreamEvent, 8) if err := p.Generate(context.Background(), req, events); err != nil { t.Fatalf("Generate() error = %v", err) diff --git a/internal/provider/openaicompat/openaicompat_test.go b/internal/provider/openaicompat/openaicompat_test.go index db290abf6..a7c4690a4 100644 --- a/internal/provider/openaicompat/openaicompat_test.go +++ b/internal/provider/openaicompat/openaicompat_test.go @@ -276,6 +276,46 @@ func TestEstimateInputTokensReturnsAdvisoryLocalEstimate(t *testing.T) { } } +func TestEstimateInputTokensWithImageUsesProjectedEstimate(t *testing.T) { + t.Parallel() + + p, err := New(resolvedConfig("", "gpt-4.1")) + if err != nil { + t.Fatalf("New() error = %v", err) + } + reader := &singleUseSessionAssetReader{ + assets: map[string]sessionAsset{ + "asset-1": {data: []byte(strings.Repeat("x", 1024*1024)), mime: "image/png"}, + }, + } + estimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Model: "gpt-4.1", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("describe"), + providertypes.NewSessionAssetImagePart("asset-1", "image/png"), + }, + }}, + SessionAssetReader: reader, + }) + if err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if reader.openCount != 0 { + t.Fatalf("expected estimate not to open session asset, got %d opens", reader.openCount) + } + if estimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected text+image estimate, got %+v", estimate) + } + if estimate.EstimatedInputTokens > 10_000 { + t.Fatalf("estimate counted base64 transport payload, got %+v", estimate) + } + if estimate.GatePolicy != provider.EstimateGateAdvisory { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateAdvisory) + } +} + func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { t.Setenv(config.OpenAIDefaultAPIKeyEnv, "test-key") diff --git a/internal/provider/openaicompat/provider.go b/internal/provider/openaicompat/provider.go index e43f2e18c..15d79a552 100644 --- a/internal/provider/openaicompat/provider.go +++ b/internal/provider/openaicompat/provider.go @@ -61,6 +61,17 @@ func (p *Provider) EstimateInputTokens( } var tokens int + if provider.RequestContainsImagePart(req) { + tokens, err = provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } switch mode { case executionModeCompletions: payload, buildErr := chatcompletions.BuildRequest(ctx, p.cfg, req) diff --git a/internal/provider/openaicompat/qwen/provider.go b/internal/provider/openaicompat/qwen/provider.go index 896c4efa0..500ca25ae 100644 --- a/internal/provider/openaicompat/qwen/provider.go +++ b/internal/provider/openaicompat/qwen/provider.go @@ -37,6 +37,17 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { } func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/openaicompat/qwen/provider_more_test.go b/internal/provider/openaicompat/qwen/provider_more_test.go index 0cf27f966..dce1580c1 100644 --- a/internal/provider/openaicompat/qwen/provider_more_test.go +++ b/internal/provider/openaicompat/qwen/provider_more_test.go @@ -80,6 +80,21 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { t.Fatalf("EstimateInputTokens() error = %v", err) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } events := make(chan providertypes.StreamEvent, 8) if err := p.Generate(context.Background(), req, events); err != nil { t.Fatalf("Generate() error = %v", err) diff --git a/internal/runtime/input_prepare.go b/internal/runtime/input_prepare.go index 0752a370e..8b9e70827 100644 --- a/internal/runtime/input_prepare.go +++ b/internal/runtime/input_prepare.go @@ -148,6 +148,7 @@ func (p sessionInputPreparer) Prepare( for _, image := range input.Images { sessionImages = append(sessionImages, agentsession.PrepareImageInput{ Path: strings.TrimSpace(image.Path), + AssetID: strings.TrimSpace(image.AssetID), MimeType: strings.TrimSpace(image.MimeType), }) } diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 2761784e8..326fd415f 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -75,6 +75,7 @@ type UserInput struct { // UserImageInput 表示用户输入中附带的单个图片引用(路径 + MIME)。 type UserImageInput struct { Path string + AssetID string MimeType string } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 5475ef24d..87de3bae8 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -6374,6 +6374,90 @@ func TestServiceRunAllowsAfterProactiveCompactWhenEstimateAdvisory(t *testing.T) } } +func TestServiceRunAllowsImageRequestWithinProjectedBudget(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.PromptBudget = 5000 + cfg.Context.Budget.FallbackPromptBudget = 5000 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, "gpt-4.1")) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + }, + responses: []scriptedResponse{ + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("图片已收到")}, + }, + FinishReason: "stop", + }, + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + service.compactRunner = &stubCompactRunner{} + + if err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-image-allow", + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("describe"), + providertypes.NewSessionAssetImagePart("asset-1", "image/png"), + }, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + + if scripted.callCount != 1 { + t.Fatalf("expected provider Generate to be called once, got %d", scripted.callCount) + } + if compactRunner := service.compactRunner.(*stubCompactRunner); len(compactRunner.calls) != 0 { + t.Fatalf("expected no proactive compact for projected image estimate, got %d calls", len(compactRunner.calls)) + } + + events := collectRuntimeEvents(service.Events()) + var budgetPayload *BudgetCheckedPayload + for _, event := range events { + if event.Type != EventBudgetChecked { + continue + } + payload, ok := event.Payload.(BudgetCheckedPayload) + if !ok { + t.Fatalf("expected BudgetCheckedPayload, got %T", event.Payload) + } + budgetPayload = &payload + break + } + if budgetPayload == nil { + t.Fatalf("expected budget_checked event, got %+v", events) + } + if budgetPayload.Action != string(controlplane.TurnBudgetActionAllow) || + budgetPayload.Reason != controlplane.BudgetDecisionReasonWithinBudget { + t.Fatalf("unexpected budget decision: %+v", budgetPayload) + } + if budgetPayload.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate || + budgetPayload.EstimatedInputTokens >= budgetPayload.PromptBudget { + t.Fatalf("unexpected projected image estimate: %+v", budgetPayload) + } +} + func TestServiceRunStopsAfterNoOpProactiveCompactWhenEstimateGateable(t *testing.T) { t.Parallel() diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index 76b8dc6df..f5b9d9bab 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -20,6 +20,7 @@ const defaultSessionTitle = "New Session" // PrepareImageInput 表示一次用户输入中附带的本地图片引用。 type PrepareImageInput struct { Path string + AssetID string MimeType string } @@ -128,6 +129,32 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar savedAssets := make([]AssetMeta, 0, len(input.Images)) for index, image := range input.Images { path := strings.TrimSpace(image.Path) + assetID := strings.TrimSpace(image.AssetID) + if assetID != "" { + if path != "" { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, &AssetSaveError{ + SessionID: session.ID, + Index: index, + Path: path, + Err: fmt.Errorf("image input cannot contain both path and asset id"), + } + } + meta, err := p.referenceImageAsset(ctx, session.ID, assetID, image.MimeType) + if err != nil { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, &AssetSaveError{ + SessionID: session.ID, + Index: index, + Path: assetID, + Err: err, + } + } + parts = append(parts, providertypes.NewSessionAssetImagePart(meta.ID, meta.MimeType)) + continue + } if path == "" { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) p.cleanupSavedAssets(ctx, session.ID, savedAssets) @@ -220,6 +247,38 @@ func (p *InputPreparer) saveImageAsset( return meta, nil } +// referenceImageAsset 校验已保存附件属于当前会话,并返回可进入 provider 的图片元数据。 +func (p *InputPreparer) referenceImageAsset( + ctx context.Context, + sessionID string, + assetID string, + mimeType string, +) (AssetMeta, error) { + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + if p.assetStore == nil { + return AssetMeta{}, fmt.Errorf("session: asset store is not configured") + } + normalizedAssetID := strings.TrimSpace(assetID) + if normalizedAssetID == "" { + return AssetMeta{}, fmt.Errorf("image asset id is empty") + } + + meta, err := p.assetStore.Stat(ctx, sessionID, normalizedAssetID) + if err != nil { + return AssetMeta{}, fmt.Errorf("stat image asset: %w", err) + } + if !strings.HasPrefix(strings.ToLower(strings.TrimSpace(meta.MimeType)), "image/") { + return AssetMeta{}, fmt.Errorf("asset %q is not an image", normalizedAssetID) + } + declaredMime := normalizeMimeType(mimeType) + if declaredMime != "" && declaredMime != meta.MimeType { + return AssetMeta{}, fmt.Errorf("declared mime type %q mismatches saved asset %q", declaredMime, meta.MimeType) + } + return meta, nil +} + // resolveImageMimeType 解析图片 MIME 类型,仅允许 image/*,并要求声明值与文件头探测一致。 func resolveImageMimeType(ctx context.Context, path string, declared string, file *os.File) (string, error) { if err := ctx.Err(); err != nil { diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index d45527799..356449cc7 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -1,6 +1,7 @@ package session import ( + "bytes" "context" "errors" "io" @@ -94,6 +95,46 @@ func TestInputPreparerPrepareTextAndImage(t *testing.T) { } } +func TestInputPreparerPrepareSavedAssetReference(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := newInputPreparerTestStore(t, workdir) + session := NewWithWorkdir("existing", workdir) + if err := createSessionForPreparerTest(context.Background(), store, session); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + meta, err := store.SaveAsset(context.Background(), session.ID, bytes.NewReader(minimalPNGBytes()), "image/png") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + result, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: session.ID, + Text: "describe it", + Images: []PrepareImageInput{{AssetID: meta.ID, MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 0 { + t.Fatalf("expected no newly saved assets, got %+v", result.SavedAssets) + } + if len(result.Parts) != 2 { + t.Fatalf("expected text and image parts, got %+v", result.Parts) + } + imagePart := result.Parts[1] + if imagePart.Kind != providertypes.ContentPartImage || + imagePart.Image == nil || + imagePart.Image.Asset == nil || + imagePart.Image.Asset.ID != meta.ID || + imagePart.Image.Asset.MimeType != "image/png" { + t.Fatalf("unexpected image part: %+v", imagePart) + } +} + func TestInputPreparerPrepareImageInfersMimeWhenMissing(t *testing.T) { t.Parallel() @@ -185,6 +226,51 @@ func TestInputPreparerPrepareErrors(t *testing.T) { } }) + t.Run("missing image reference is rejected", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "bad asset", + Images: []PrepareImageInput{{AssetID: " ", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected missing image reference error") + } + if !strings.Contains(err.Error(), "image path is empty") { + t.Fatalf("expected image reference error, got %v", err) + } + }) + + t.Run("missing referenced asset is rejected", func(t *testing.T) { + localStore := newInputPreparerTestStore(t, workdir) + existing := NewWithWorkdir("asset-missing", workdir) + if err := createSessionForPreparerTest(context.Background(), localStore, existing); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + preparer := NewInputPreparer(localStore, localStore) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: existing.ID, + Text: "bad asset", + Images: []PrepareImageInput{{AssetID: "asset-missing", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected missing referenced asset error") + } + }) + + t.Run("asset id and path cannot both be set", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "bad asset", + Images: []PrepareImageInput{{Path: "a.png", AssetID: "asset-1", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected asset id and path conflict error") + } + }) + t.Run("asset save error is structured", func(t *testing.T) { preparer := NewInputPreparer(store, store) _, err := preparer.Prepare(context.Background(), PrepareInput{ @@ -384,6 +470,92 @@ func TestInputPreparerPrepareImagePathAndMimeValidation(t *testing.T) { t.Fatalf("expected mismatch error, got %v", err) } }) + + t.Run("declared mime params are normalized", func(t *testing.T) { + imagePath := filepath.Join(workdir, "declared-params.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "declared params", + Images: []PrepareImageInput{{Path: imagePath, MimeType: " IMAGE/PNG; charset=binary "}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 1 || result.SavedAssets[0].MimeType != "image/png" { + t.Fatalf("unexpected saved assets: %+v", result.SavedAssets) + } + }) + + t.Run("declared non image mime is rejected", func(t *testing.T) { + imagePath := filepath.Join(workdir, "declared-text.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "declared text", + Images: []PrepareImageInput{{Path: imagePath, MimeType: "text/plain"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected non-image mime error") + } + if !strings.Contains(err.Error(), "is not an image") { + t.Fatalf("expected non-image mime error, got %v", err) + } + }) + + t.Run("extension mismatch is rejected when mime omitted", func(t *testing.T) { + imagePath := filepath.Join(workdir, "wrong.jpg") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "extension mismatch", + Images: []PrepareImageInput{{Path: imagePath}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected extension mismatch error") + } + if !strings.Contains(err.Error(), "file extension mime") { + t.Fatalf("expected extension mismatch error, got %v", err) + } + }) +} + +func TestInputPreparerPrepareSavedAssetReferenceValidation(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := newInputPreparerTestStore(t, workdir) + session := NewWithWorkdir("existing", workdir) + if err := createSessionForPreparerTest(context.Background(), store, session); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + meta, err := store.SaveAsset(context.Background(), session.ID, bytes.NewReader(minimalPNGBytes()), "image/png") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + _, err = preparer.Prepare(context.Background(), PrepareInput{ + SessionID: session.ID, + Text: "bad declared mime", + Images: []PrepareImageInput{{AssetID: meta.ID, MimeType: "image/jpeg"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected referenced asset mime mismatch") + } + if !strings.Contains(err.Error(), "mismatches saved asset") { + t.Fatalf("expected saved asset mismatch error, got %v", err) + } } func TestAssetSaveErrorMethods(t *testing.T) { diff --git a/web/src/api/gateway.test.ts b/web/src/api/gateway.test.ts index bb4cf0eec..f022a7405 100644 --- a/web/src/api/gateway.test.ts +++ b/web/src/api/gateway.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, vi, beforeEach } from 'vitest' +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' import { GatewayAPI } from './gateway' import { Method } from './protocol' @@ -13,6 +13,10 @@ describe('GatewayAPI', () => { api = new GatewayAPI(ws) }) + afterEach(() => { + vi.unstubAllGlobals() + }) + it('maps authenticate and run methods', async () => { await api.authenticate('tok') await api.run({ input_text: 'hello' }) @@ -21,6 +25,14 @@ describe('GatewayAPI', () => { expect(call).toHaveBeenNthCalledWith(2, Method.Run, { input_text: 'hello' }) }) + it('maps createSession method', async () => { + await api.createSession() + await api.createSession('s1') + + expect(call).toHaveBeenNthCalledWith(1, Method.CreateSession, {}) + expect(call).toHaveBeenNthCalledWith(2, Method.CreateSession, { session_id: 's1' }) + }) + it('maps optional session_id in listModels', async () => { await api.listModels() await api.listModels('s1') @@ -60,5 +72,68 @@ describe('GatewayAPI', () => { expect(call).toHaveBeenNthCalledWith(2, Method.ApprovePlan, { session_id: 's1', plan_id: 'p1', revision: 2 }) expect(call).toHaveBeenNthCalledWith(3, Method.UserQuestionAnswer, { request_id: 'q1', status: 'answered', message: 'ok' }) }) + + it('uploads session assets with bearer auth, workspace header, and multipart body', async () => { + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ session_id: 's1', asset_id: 'asset-1', mime_type: 'image/png', size: 3 }), + }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, 'http://localhost:1455/', ' token-1 ') + + const file = new File(['abc'], 'a.png', { type: 'image/png' }) + const result = await api.uploadSessionAsset('s1', file, 'workspace-b') + + expect(result.asset_id).toBe('asset-1') + expect(fetchMock).toHaveBeenCalledWith('http://localhost:1455/api/session-assets', expect.objectContaining({ + method: 'POST', + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-b' }, + })) + const init = fetchMock.mock.calls[0][1] as RequestInit + expect(init.body).toBeInstanceOf(FormData) + expect((init.body as FormData).get('session_id')).toBe('s1') + expect((init.body as FormData).get('file')).toBe(file) + }) + + it('fetches session asset blobs with bearer auth and workspace header', async () => { + const blob = new Blob(['img'], { type: 'image/png' }) + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + blob: () => Promise.resolve(blob), + }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, '/gateway', 'token-1') + + await expect(api.fetchSessionAsset('s 1', 'asset/1', 'workspace-b')).resolves.toBe(blob) + expect(fetchMock).toHaveBeenCalledWith('/gateway/api/session-assets/s%201/asset%2F1', { + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-b' }, + }) + }) + + it('uses switched workspace as session asset HTTP fallback', async () => { + call.mockResolvedValueOnce({ type: 'ack', payload: { workspace_hash: 'workspace-c' } }) + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + blob: () => Promise.resolve(new Blob(['img'])), + }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, '', 'token-1') + + await api.switchWorkspace('workspace-c') + await api.fetchSessionAsset('s1', 'asset-1') + + expect(fetchMock).toHaveBeenCalledWith('/api/session-assets/s1/asset-1', { + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-c' }, + }) + }) + + it('surfaces session asset HTTP errors', async () => { + vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ + ok: false, + status: 415, + json: () => Promise.resolve({ error: 'unsupported image type' }), + })) + await expect(api.uploadSessionAsset('s1', new File(['x'], 'x.txt'))).rejects.toThrow('unsupported image type') + }) }) diff --git a/web/src/api/gateway.ts b/web/src/api/gateway.ts index 18357270b..45ab4f313 100644 --- a/web/src/api/gateway.ts +++ b/web/src/api/gateway.ts @@ -5,6 +5,8 @@ import { type AuthenticateParams, type BindStreamParams, type RunParams, + type CreateSessionParams, + type CreateSessionResult, type CancelParams, type LoadSessionParams, type ListSessionTodosParams, @@ -73,14 +75,20 @@ import { type RenameWorkspaceResult, type DeleteWorkspaceParams, type DeleteWorkspaceResult, + type SessionAssetUploadResult, } from './protocol' /** Gateway 业务 API 客户端,基于 WebSocket 全双工通道 */ export class GatewayAPI { private ws: WSClient + private baseURL: string + private token: string + private currentWorkspaceHash = '' - constructor(ws: WSClient) { + constructor(ws: WSClient, baseURL = '', token = '') { this.ws = ws + this.baseURL = baseURL.replace(/\/+$/, '') + this.token = token.trim() } /** 认证,返回 ack 结果 */ @@ -93,11 +101,45 @@ export class GatewayAPI { return this.ws.call(Method.BindStream, params) } + /** 显式创建一个会话,供发送图片前建立 asset 归属 */ + async createSession(sessionId?: string) { + const params: CreateSessionParams = sessionId ? { session_id: sessionId } : {} + return this.ws.call(Method.CreateSession, params) + } + /** 发起一次 run,返回 ack 含 session_id 和 run_id */ async run(params: RunParams) { return this.ws.call(Method.Run, params) } + /** 上传会话图片附件,返回可在 input_parts 中引用的 asset_id */ + async uploadSessionAsset(sessionId: string, file: File, workspaceHash = '') { + const form = new FormData() + form.append('session_id', sessionId) + form.append('file', file) + const res = await fetch(`${this.baseURL}/api/session-assets`, { + method: 'POST', + headers: this.httpHeaders(workspaceHash), + body: form, + }) + if (!res.ok) { + throw new Error(await readHTTPError(res, 'Upload failed')) + } + return res.json() as Promise + } + + /** 读取会话图片附件 Blob,用于历史消息缩略图 */ + async fetchSessionAsset(sessionId: string, assetId: string, workspaceHash = '') { + const res = await fetch( + `${this.baseURL}/api/session-assets/${encodeURIComponent(sessionId)}/${encodeURIComponent(assetId)}`, + { headers: this.httpHeaders(workspaceHash) }, + ) + if (!res.ok) { + throw new Error(await readHTTPError(res, 'Asset fetch failed')) + } + return res.blob() + } + /** 取消运行,返回取消结果 */ async cancel(params: CancelParams) { return this.ws.call(Method.Cancel, params) @@ -290,7 +332,9 @@ export class GatewayAPI { /** 切换工作区 */ async switchWorkspace(workspaceHash: string) { - return this.ws.call(Method.SwitchWorkspace, { workspace_hash: workspaceHash } satisfies SwitchWorkspaceParams) + const result = await this.ws.call(Method.SwitchWorkspace, { workspace_hash: workspaceHash } satisfies SwitchWorkspaceParams) + this.currentWorkspaceHash = workspaceHash.trim() + return result } /** 重命名工作区 */ @@ -302,4 +346,21 @@ export class GatewayAPI { async deleteWorkspace(workspaceHash: string, removeData?: boolean) { return this.ws.call(Method.DeleteWorkspace, { workspace_hash: workspaceHash, remove_data: removeData } satisfies DeleteWorkspaceParams) } + + getCurrentWorkspaceHash() { + return this.currentWorkspaceHash + } + + private httpHeaders(workspaceHash = '') { + const headers: Record = {} + if (this.token) headers.Authorization = `Bearer ${this.token}` + const resolvedWorkspaceHash = workspaceHash.trim() || this.currentWorkspaceHash + if (resolvedWorkspaceHash) headers['X-NeoCode-Workspace-Hash'] = resolvedWorkspaceHash + return Object.keys(headers).length > 0 ? headers : undefined + } +} + +async function readHTTPError(res: Response, fallback: string) { + const data = await res.json().catch(() => null) as { error?: string } | null + return data?.error || `${fallback} (HTTP ${res.status})` } diff --git a/web/src/api/protocol.ts b/web/src/api/protocol.ts index 15bad0b11..52c4e37eb 100644 --- a/web/src/api/protocol.ts +++ b/web/src/api/protocol.ts @@ -11,6 +11,7 @@ export const Method = { Ping: "gateway.ping", BindStream: "gateway.bindStream", Run: "gateway.run", + CreateSession: "gateway.createSession", Cancel: "gateway.cancel", Compact: "gateway.compact", ListSessions: "gateway.listSessions", @@ -234,9 +235,15 @@ export interface RunParams { export interface RunInputPart { type: string; text?: string; - media?: { uri: string; mime_type: string; file_name?: string }; + media?: { uri?: string; asset_id?: string; mime_type: string; file_name?: string }; } +export interface CreateSessionParams { + session_id?: string; +} + +export type CreateSessionResult = RPCResult<{ session_id: string }>; + /** gateway.cancel 参数 */ export interface CancelParams { session_id?: string; @@ -307,11 +314,19 @@ export interface SessionSummary { export interface SessionMessage { role: string; content: string; + parts?: RunInputPart[]; tool_calls?: ToolCall[]; tool_call_id?: string; is_error?: boolean; } +export interface SessionAssetUploadResult { + session_id: string; + asset_id: string; + mime_type: string; + size: number; +} + /** 工具调用 */ export interface ToolCall { id: string; diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index 53829db08..484fa90e8 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -6,10 +6,13 @@ import { useComposerStore } from '@/stores/useComposerStore' import { useSessionStore } from '@/stores/useSessionStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' import { useGatewayStore } from '@/stores/useGatewayStore' +import { useWorkspaceStore } from '@/stores/useWorkspaceStore' const mockGatewayAPI = { listAvailableSkills: vi.fn(), listModels: vi.fn(), + createSession: vi.fn(), + uploadSessionAsset: vi.fn(), run: vi.fn(), bindStream: vi.fn(), cancel: vi.fn(), @@ -68,9 +71,23 @@ describe('ChatInput', () => { selected_model_id: '', }, }) - - useComposerStore.setState({ composerText: '' }) + mockGatewayAPI.createSession.mockResolvedValue({ payload: { session_id: 'session-created' } }) + mockGatewayAPI.uploadSessionAsset.mockResolvedValue({ + session_id: 'session-created', + asset_id: 'asset-1', + mime_type: 'image/png', + size: 3, + }) + mockGatewayAPI.run.mockResolvedValue({ session_id: 'session-created', run_id: 'run-1' }) + mockGatewayAPI.bindStream.mockResolvedValue({}) + if (typeof URL.createObjectURL !== 'function') { + Object.defineProperty(URL, 'createObjectURL', { configurable: true, value: vi.fn() }) + } + vi.spyOn(URL, 'createObjectURL').mockReturnValue('blob:preview-1') + + useComposerStore.setState({ composerText: '', attachments: [] }) useSessionStore.setState({ currentSessionId: '' } as never) + useWorkspaceStore.setState({ currentWorkspaceHash: 'workspace-b' } as never) useGatewayStore.setState({ currentRunId: '' } as never) useRuntimeInsightStore.getState().reset() useChatStore.setState({ @@ -157,12 +174,72 @@ describe('ChatInput', () => { }) }) - it('does not render the unimplemented attachment and mention buttons', () => { + it('renders the image attachment picker but keeps mention button absent', () => { render() - expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: /添加图片/ })).toBeInTheDocument() expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() }) + + it('uploads selected image and sends image-only input parts after creating a session', async () => { + render() + + const file = new File(['img'], 'a.png', { type: 'image/png' }) + const input = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(input, { target: { files: [file] } }) + + await waitFor(() => { + expect(screen.getByAltText('a.png')).toBeInTheDocument() + }) + + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.createSession).toHaveBeenCalled() + expect(mockGatewayAPI.uploadSessionAsset).toHaveBeenCalledWith('session-created', file, 'workspace-b') + expect(mockGatewayAPI.run).toHaveBeenCalledWith({ + session_id: 'session-created', + input_parts: [ + { type: 'image', media: { asset_id: 'asset-1', mime_type: 'image/png', file_name: 'a.png' } }, + ], + mode: 'build', + }) + }) + + expect(useChatStore.getState().messages[0]).toMatchObject({ + role: 'user', + attachments: [{ assetId: 'asset-1', previewUrl: 'blob:preview-1', workspaceHash: 'workspace-b' }], + }) + }) + + it('treats slash text as a normal message when an image is attached', async () => { + useSessionStore.setState({ currentSessionId: 'session-1' } as never) + mockGatewayAPI.uploadSessionAsset.mockResolvedValueOnce({ + session_id: 'session-1', + asset_id: 'asset-2', + mime_type: 'image/png', + size: 3, + }) + render() + + const file = new File(['img'], 'slash.png', { type: 'image/png' }) + const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(fileInput, { target: { files: [file] } }) + fireEvent.change(screen.getByRole('textbox'), { target: { value: '/memo' } }) + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).not.toHaveBeenCalled() + expect(mockGatewayAPI.uploadSessionAsset).toHaveBeenCalledWith('session-1', file, 'workspace-b') + expect(mockGatewayAPI.run).toHaveBeenCalledWith(expect.objectContaining({ + session_id: 'session-1', + input_parts: [ + { type: 'text', text: '/memo' }, + { type: 'image', media: { asset_id: 'asset-2', mime_type: 'image/png', file_name: 'slash.png' } }, + ], + })) + }) + }) it('blocks normal sends while compaction is running', async () => { useChatStore.getState().startCompacting('manual', 'Compacting context...') render() diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index 6291a702b..abf63a858 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -3,8 +3,14 @@ import { useChatStore, createUserMessage } from '@/stores/useChatStore' import { useGatewayStore } from '@/stores/useGatewayStore' import { useSessionStore, isValidSessionId } from '@/stores/useSessionStore' import { useUIStore } from '@/stores/useUIStore' -import { useComposerStore } from '@/stores/useComposerStore' +import { + acceptedImageMimeTypes, + maxComposerAttachmentBytes, + useComposerStore, + type ComposerAttachment, +} from '@/stores/useComposerStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' +import { useWorkspaceStore } from '@/stores/useWorkspaceStore' import { formatTokenCount } from '@/utils/format' import { useGatewayAPI } from '@/context/RuntimeProvider' import { @@ -19,7 +25,7 @@ import { import SlashCommandMenu from './SlashCommandMenu' import SkillPicker from './SkillPicker' import ModelSelector from './ModelSelector' -import { Send, Square } from 'lucide-react' +import { ImagePlus, Loader2, Send, Square, X } from 'lucide-react' const slashMenuAnchorStyle: React.CSSProperties = { position: 'absolute', @@ -123,14 +129,22 @@ function resolveBudgetRingState( export default function ChatInput() { const gatewayAPI = useGatewayAPI() const text = useComposerStore((state) => state.composerText) + const attachments = useComposerStore((state) => state.attachments) const setText = useComposerStore((state) => state.setComposerText) + const addAttachmentFiles = useComposerStore((state) => state.addAttachmentFiles) + const removeAttachment = useComposerStore((state) => state.removeAttachment) + const clearAttachments = useComposerStore((state) => state.clearAttachments) + const setAttachmentStatus = useComposerStore((state) => state.setAttachmentStatus) const [rows, setRows] = useState(1) + const [dragActive, setDragActive] = useState(false) const textareaRef = useRef(null) + const fileInputRef = useRef(null) const runCancelledRef = useRef(false) const composingRef = useRef(false) const isGenerating = useChatStore((state) => state.isGenerating) const isCompacting = useChatStore((state) => state.isCompacting) const addMessage = useChatStore((state) => state.addMessage) + const removeMessage = useChatStore((state) => state.removeMessage) const addSystemMessage = useChatStore((state) => state.addSystemMessage) const setGenerating = useChatStore((state) => state.setGenerating) const sessionId = useSessionStore((state) => state.currentSessionId) @@ -138,6 +152,7 @@ export default function ChatInput() { const setAgentMode = useChatStore((state) => state.setAgentMode) const permissionMode = useChatStore((state) => state.permissionMode) const setPermissionMode = useChatStore((state) => state.setPermissionMode) + const currentWorkspaceHash = useWorkspaceStore((state) => state.currentWorkspaceHash) const [showSlashMenu, setShowSlashMenu] = useState(false) const [selectedIndex, setSelectedIndex] = useState(0) @@ -302,7 +317,9 @@ export default function ChatInput() { async function handleSubmit() { const input = text.trim() - if (!input) return + const pendingAttachments = attachments + if (!input && pendingAttachments.length === 0) return + let submittedMessageId = '' if (isCompacting) { useUIStore.getState().showToast('Context compaction is still running', 'info') @@ -314,30 +331,65 @@ export default function ChatInput() { return } - if (isSlashCommand(input)) { + if (pendingAttachments.length === 0 && isSlashCommand(input)) { setText('') setShowSlashMenu(false) const handled = await executeSlashCommand(input) if (handled) return } - setText('') - const userMsg = createUserMessage(input) - addMessage(userMsg) - useRuntimeInsightStore.getState().setTodoSnapshot({ - items: [], - summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, - }) - setGenerating(true) - runCancelledRef.current = false - try { if (!gatewayAPI) return - const isNewSession = !isValidSessionId(sessionId) + let targetSessionId = sessionId + if (!isValidSessionId(targetSessionId)) { + const created = await gatewayAPI.createSession() + targetSessionId = created.payload?.session_id || '' + if (!isValidSessionId(targetSessionId)) throw new Error('Create session failed') + useSessionStore.getState().setCurrentSessionId(targetSessionId) + await gatewayAPI.bindStream({ session_id: targetSessionId, channel: 'all' }).catch(() => {}) + } + + const workspaceHash = currentWorkspaceHash.trim() + const uploaded = [] + for (const attachment of pendingAttachments) { + setAttachmentStatus(attachment.id, 'uploading') + try { + const meta = await gatewayAPI.uploadSessionAsset(targetSessionId, attachment.file, workspaceHash) + setAttachmentStatus(attachment.id, 'uploaded') + uploaded.push({ attachment, meta }) + } catch (err) { + const message = err instanceof Error ? err.message : 'Upload failed' + setAttachmentStatus(attachment.id, 'error', message) + throw err + } + } + + const inputParts = buildRunInputParts(input, uploaded) + const userMsg = createUserMessage(input, uploaded.map(({ attachment, meta }) => ({ + id: attachment.id, + sessionId: targetSessionId, + workspaceHash, + assetId: meta.asset_id, + mimeType: meta.mime_type, + name: attachment.file.name, + size: meta.size, + previewUrl: attachment.previewUrl, + }))) + + setText('') + clearAttachments(false) + addMessage(userMsg) + submittedMessageId = userMsg.id + useRuntimeInsightStore.getState().setTodoSnapshot({ + items: [], + summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, + }) + setGenerating(true) + runCancelledRef.current = false + const ack = await gatewayAPI.run({ - session_id: isNewSession ? undefined : sessionId, - new_session: isNewSession ? true : undefined, - input_text: input, + session_id: targetSessionId, + input_parts: inputParts, mode: agentMode, }) if (!runCancelledRef.current) { @@ -351,10 +403,12 @@ export default function ChatInput() { } } catch (err) { if (!runCancelledRef.current) { + if (submittedMessageId) { + removeMessage(submittedMessageId) + } setGenerating(false) - useChatStore.getState().removeMessage(userMsg.id) console.error('Run failed:', err) - useUIStore.getState().showToast('Failed to send message', 'error') + useUIStore.getState().showToast(err instanceof Error ? err.message : 'Failed to send message', 'error') } } } @@ -421,6 +475,33 @@ export default function ChatInput() { void executeSlashCommand(cmd.usage) } + function handleFilesSelected(files: FileList | File[]) { + const accepted: File[] = [] + for (const file of Array.from(files)) { + if (!acceptedImageMimeTypes.includes(file.type as any)) { + useUIStore.getState().showToast('Only PNG, JPEG, and WebP images are supported', 'error') + continue + } + if (file.size <= 0) { + useUIStore.getState().showToast('Cannot upload an empty file', 'error') + continue + } + if (file.size > maxComposerAttachmentBytes) { + useUIStore.getState().showToast('Image exceeds the 20 MiB limit', 'error') + continue + } + accepted.push(file) + } + if (accepted.length > 0) addAttachmentFiles(accepted) + } + + function handleDrop(e: React.DragEvent) { + e.preventDefault() + setDragActive(false) + if (controlsLocked) return + handleFilesSelected(e.dataTransfer.files) + } + async function handleCancel() { runCancelledRef.current = true const runId = useGatewayStore.getState().currentRunId @@ -439,7 +520,7 @@ export default function ChatInput() { } } - const isEmpty = !text.trim() + const isEmpty = !text.trim() && attachments.length === 0 const controlsLocked = isGenerating || isCompacting return ( @@ -460,7 +541,16 @@ export default function ChatInput() { /> )} -
+
{ + e.preventDefault() + if (!controlsLocked) setDragActive(true) + }} + onDragLeave={() => setDragActive(false)} + onDrop={handleDrop} + > +