diff --git a/docs/examples/hooks.yaml b/docs/examples/hooks.yaml index 7a89625c..ad90911d 100644 --- a/docs/examples/hooks.yaml +++ b/docs/examples/hooks.yaml @@ -17,8 +17,9 @@ hooks: kind: builtin mode: sync handler: warn_on_tool_call + match: + tool_name: ["bash"] params: - tool_names: ["bash"] message: "执行 bash 前请确认命令不会破坏工作区。" - id: require-readme-before-final diff --git a/docs/examples/user-hooks-config.yaml b/docs/examples/user-hooks-config.yaml index 055b3562..0228cc5f 100644 --- a/docs/examples/user-hooks-config.yaml +++ b/docs/examples/user-hooks-config.yaml @@ -25,8 +25,9 @@ runtime: kind: builtin mode: sync handler: warn_on_tool_call + match: + tool_name: ["bash"] params: - tool_names: ["bash"] message: "执行 bash 前请确认命令不会破坏工作区。" - id: user-http-observe diff --git a/docs/runtime-hooks-design.md b/docs/runtime-hooks-design.md index 68c59cbf..f4444a0f 100644 --- a/docs/runtime-hooks-design.md +++ b/docs/runtime-hooks-design.md @@ -30,6 +30,10 @@ P2 仅支持: `before_tool_call`、`after_tool_result`、`before_completion_decision`、`accept_gate`、`after_tool_failure`、 `session_start`、`session_end`、`user_prompt_submit`、`post_compact`、`subagent_stop` - handler:`require_file_exists`、`warn_on_tool_call`、`add_context_note` +- `match`:统一 matcher DSL(字段间 AND、同字段多值 OR),支持: + - `tool_name`:精确匹配(`string` 或 `[]string`) + - `tool_name_regex`:正则匹配(`string` 或 `[]string`,单条最长 256) + - `arguments_contains`:参数预览包含匹配(`[]string`) - `kind=http + mode=observe`:允许发送 HTTP 观测回调(不支持 block) - `http observe` 默认不携带 metadata(`include_metadata=false`);即使显式开启也会剥离 `result_content_preview`、`execution_error` - `http observe` 回调端点仅允许 loopback 地址(`localhost` / `127.0.0.1` / `::1`),避免误配为公网外发 @@ -73,6 +77,7 @@ user/repo hook 接收的 `HookContext` 经过白名单裁剪,仅保留最小 - `run_id` / `session_id` - `point` / `tool_call_id` / `tool_name` +- `tool_arguments_preview`(脱敏+截断后的参数预览) - `is_error` / `error_class` - `result_content_preview` / `result_metadata_present` - `execution_error` @@ -109,6 +114,22 @@ runtime 内置 `HookPointCapability` 作为唯一真源,定义每个点位是 - `CanBlock=false` 的点位,hook 返回 `block` 会自动降级为观测结果,不中断主链。 - `CanUpdateInput` 在 `user_prompt_submit` 点位已开放:command hook 可通过 stdout JSON 的 `update_input` 字段改写用户输入。 - `UserAllowed=false` 的点位拒绝 user/repo 挂载(配置 fail-fast)。 +- matcher 字段会按点位能力矩阵做 fail-fast:不支持的维度会在配置加载阶段直接报错。 + +### matcher 点位维度矩阵(#684) + +| point | tool_name | tool_name_regex | arguments_contains | +|---|---|---|---| +| `before_tool_call` | ✅ | ✅ | ✅ | +| `after_tool_result` | ✅ | ✅ | ❌ | +| `after_tool_failure` | ✅ | ✅ | ✅ | +| `before_permission_decision` | ✅ | ✅ | ❌ | +| 其他点位 | ❌ | ❌ | ❌ | + +说明: + +- `arguments_contains` 基于 `tool_arguments_preview` 字段匹配,不读取 `tool_arguments` 原文。 +- `warn_on_tool_call` 当前要求显式配置 `match`;旧参数 `params.tool_name/tool_names` 不再承担匹配语义。 ### trust gate diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 65d14c6f..1ebdca32 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -126,8 +126,9 @@ runtime: priority: 100 timeout_sec: 2 failure_policy: warn_only - params: + match: tool_name: bash + params: message: "bash is called" ` writeLoaderConfig(t, loader, raw) diff --git a/internal/config/runtime_hooks.go b/internal/config/runtime_hooks.go index fa237f52..f5a87c98 100644 --- a/internal/config/runtime_hooks.go +++ b/internal/config/runtime_hooks.go @@ -64,6 +64,7 @@ type RuntimeHookItemConfig struct { Kind string `yaml:"kind,omitempty"` Mode string `yaml:"mode,omitempty"` Handler string `yaml:"handler,omitempty"` + Match map[string]any `yaml:"match,omitempty"` Priority int `yaml:"priority,omitempty"` TimeoutSec int `yaml:"timeout_sec,omitempty"` FailurePolicy string `yaml:"failure_policy,omitempty"` @@ -189,6 +190,12 @@ func (c RuntimeHookItemConfig) Clone() RuntimeHookItemConfig { if c.Enabled != nil { cloned.Enabled = boolPtr(*c.Enabled) } + if len(c.Match) > 0 { + cloned.Match = make(map[string]any, len(c.Match)) + for key, value := range c.Match { + cloned.Match[key] = cloneRuntimeHookParamValue(value) + } + } if len(c.Params) > 0 { cloned.Params = make(map[string]any, len(c.Params)) for key, value := range c.Params { @@ -279,9 +286,14 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { default: return fmt.Errorf("handler %q is not supported", c.Handler) } - if handler == runtimeHookHandlerWarnOnToolCall && !hasWarnOnToolCallTargets(c.Params) { - return fmt.Errorf("handler %q requires params.tool_name or params.tool_names", c.Handler) - } + if handler == runtimeHookHandlerWarnOnToolCall && !hooks.HasHookMatcherConfig(c.Match) { + return fmt.Errorf("handler %q requires match", c.Handler) + } + if hooks.HasHookMatcherConfig(c.Match) { + if err := hooks.ValidateHookMatcher(point, c.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } case runtimeHookKindCommand: if normalizedMode != runtimeHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", c.Mode) @@ -289,6 +301,11 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { if err := hooks.ValidateCommandParams(c.Params); err != nil { return err } + if hooks.HasHookMatcherConfig(c.Match) { + if err := hooks.ValidateHookMatcher(point, c.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } case runtimeHookKindHTTP: if normalizedMode != runtimeHookModeObserve { return fmt.Errorf("mode %q is not supported for kind http (only observe)", c.Mode) @@ -296,6 +313,11 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { if err := validateRuntimeHTTPObserveItem(c, policy); err != nil { return err } + if hooks.HasHookMatcherConfig(c.Match) { + if err := hooks.ValidateHookMatcher(point, c.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } } return nil } @@ -398,35 +420,6 @@ func cloneRuntimeHookParamValue(value any) any { } } -func hasWarnOnToolCallTargets(params map[string]any) bool { - if len(params) == 0 { - return false - } - toolNameRaw, hasToolName := params["tool_name"] - if hasToolName && strings.TrimSpace(fmt.Sprintf("%v", toolNameRaw)) != "" { - return true - } - toolNamesRaw, hasToolNames := params["tool_names"] - if !hasToolNames || toolNamesRaw == nil { - return false - } - switch typed := toolNamesRaw.(type) { - case []string: - for _, item := range typed { - if strings.TrimSpace(item) != "" { - return true - } - } - case []any: - for _, item := range typed { - if strings.TrimSpace(fmt.Sprintf("%v", item)) != "" { - return true - } - } - } - return false -} - // readRuntimeHookParamString 以兼容方式读取 runtime hook 参数中的字符串值。 func readRuntimeHookParamString(params map[string]any, key string) string { if len(params) == 0 { @@ -443,4 +436,3 @@ func readRuntimeHookParamString(params map[string]any, key string) string { return fmt.Sprintf("%v", typed) } } - diff --git a/internal/config/runtime_hooks_test.go b/internal/config/runtime_hooks_test.go index 8a988ae5..e3614502 100644 --- a/internal/config/runtime_hooks_test.go +++ b/internal/config/runtime_hooks_test.go @@ -390,13 +390,15 @@ func TestRuntimeHooksConfigItemDefaultsAndClone(t *testing.T) { ID: "warn-bash", Point: string(hooks.HookPointBeforeToolCall), Handler: runtimeHookHandlerWarnOnToolCall, - Params: map[string]any{ + Match: map[string]any{ "tool_name": "bash", - "tags": []any{"warn", "tool"}, + }, + Params: map[string]any{ + "tags": []any{"warn", "tool"}, }, }, }, - } +} cfg.ApplyDefaults(defaultRuntimeHooksConfig()) item := cfg.Items[0] @@ -489,6 +491,65 @@ func TestRuntimeHooksConfigValidateWarnOnToolCallRequiresTarget(t *testing.T) { } } +func TestRuntimeHooksConfigValidateWarnOnToolCallAllowsMatchWithoutLegacyTargets(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "warn-with-match", + Point: string(hooks.HookPointBeforeToolCall), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindBuiltIn, + Mode: runtimeHookModeSync, + Handler: runtimeHookHandlerWarnOnToolCall, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Match: map[string]any{ + "tool_name": "bash", + }, + }, + }, + } + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +} + +func TestRuntimeHooksConfigValidateRejectsUnsupportedMatcherDimensionForPoint(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "session-start-match", + Point: string(hooks.HookPointSessionStart), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindBuiltIn, + Mode: runtimeHookModeSync, + Handler: runtimeHookHandlerAddContextNote, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"note": "observe"}, + Match: map[string]any{ + "tool_name": "bash", + }, + }, + }, + } + if err := cfg.Validate(); err == nil { + t.Fatal("expected unsupported matcher dimension to fail validation") + } +} + func TestRuntimeHooksConfigEdgeBranches(t *testing.T) { t.Parallel() @@ -605,20 +666,19 @@ func TestRuntimeHooksConfigEdgeBranches(t *testing.T) { t.Fatal("expected deep clone for nested map in slice") } - if hasWarnOnToolCallTargets(nil) { - t.Fatal("nil params should be false") - } - if !hasWarnOnToolCallTargets(map[string]any{"tool_name": "bash"}) { - t.Fatal("tool_name should pass") - } - if !hasWarnOnToolCallTargets(map[string]any{"tool_names": []string{"", "bash"}}) { - t.Fatal("tool_names []string should pass") - } - if !hasWarnOnToolCallTargets(map[string]any{"tool_names": []any{"", "bash"}}) { - t.Fatal("tool_names []any should pass") + + matchCfg := RuntimeHookItemConfig{ + Match: map[string]any{ + "tool_name_regex": []any{`^bash$`}, + }, } - if hasWarnOnToolCallTargets(map[string]any{"tool_names": "bash"}) { - t.Fatal("tool_names scalar should fail") + clonedCfg := matchCfg.Clone() + clonedRegexes := clonedCfg.Match["tool_name_regex"].([]any) + clonedRegexes[0] = "^filesystem$" + clonedCfg.Match["tool_name_regex"] = clonedRegexes + originalRegexes := matchCfg.Match["tool_name_regex"].([]any) + if originalRegexes[0] == "^filesystem$" { + t.Fatal("expected match field to be deep-cloned") } }) } diff --git a/internal/runtime/hooks/executor.go b/internal/runtime/hooks/executor.go index 8f32faef..2586483a 100644 --- a/internal/runtime/hooks/executor.go +++ b/internal/runtime/hooks/executor.go @@ -79,6 +79,9 @@ func (e *Executor) Run(ctx context.Context, point HookPoint, input HookContext) if spec.Scope == HookScopeUser || spec.Scope == HookScopeRepo { hookInput = sanitizeUserHookContext(hookInput) } + if spec.Matcher != nil && !spec.Matcher.Match(hookInput) { + continue + } if spec.Mode == HookModeAsync || spec.Mode == HookModeAsyncRewake { e.runAsync(ctx, spec, hookInput) continue @@ -340,6 +343,7 @@ func sanitizeUserHookContext(input HookContext) HookContext { "point": {}, "tool_call_id": {}, "tool_name": {}, + "tool_arguments_preview": {}, "is_error": {}, "error_class": {}, "result_content_preview": {}, diff --git a/internal/runtime/hooks/executor_test.go b/internal/runtime/hooks/executor_test.go index 6a1372e4..8953c977 100644 --- a/internal/runtime/hooks/executor_test.go +++ b/internal/runtime/hooks/executor_test.go @@ -955,10 +955,11 @@ func TestExecutorSanitizeUserHookContext(t *testing.T) { RunID: "run-1", SessionID: "session-1", Metadata: map[string]any{ - "tool_name": "bash", - "tool_arguments": "--secret-token=abc", - "capability_token": "should-not-leak", - "workdir": "/tmp/work", + "tool_name": "bash", + "tool_arguments": "--secret-token=abc", + "tool_arguments_preview": "token=***", + "capability_token": "should-not-leak", + "workdir": "/tmp/work", }, }) @@ -971,6 +972,9 @@ func TestExecutorSanitizeUserHookContext(t *testing.T) { if _, exists := captured.Metadata["tool_arguments"]; exists { t.Fatal("tool_arguments should be stripped for user hook context") } + if got := captured.Metadata["tool_arguments_preview"]; got != "token=***" { + t.Fatalf("tool_arguments_preview = %v, want token=***", got) + } if _, exists := captured.Metadata["capability_token"]; exists { t.Fatal("capability_token should be stripped for user hook context") } @@ -999,10 +1003,11 @@ func TestExecutorSanitizeRepoHookContext(t *testing.T) { RunID: "run-1", SessionID: "session-1", Metadata: map[string]any{ - "tool_name": "bash", - "tool_arguments": "--secret-token=abc", - "capability_token": "should-not-leak", - "workdir": "/tmp/work", + "tool_name": "bash", + "tool_arguments": "--secret-token=abc", + "tool_arguments_preview": "token=***", + "capability_token": "should-not-leak", + "workdir": "/tmp/work", }, }) @@ -1012,7 +1017,39 @@ func TestExecutorSanitizeRepoHookContext(t *testing.T) { if _, exists := captured.Metadata["tool_arguments"]; exists { t.Fatal("tool_arguments should be stripped for repo hook context") } + if got := captured.Metadata["tool_arguments_preview"]; got != "token=***" { + t.Fatalf("tool_arguments_preview = %v, want token=***", got) + } if _, exists := captured.Metadata["capability_token"]; exists { t.Fatal("capability_token should be stripped for repo hook context") } } + +func TestExecutorSkipsHookWhenMatcherMissed(t *testing.T) { + t.Parallel() + + registry := NewRegistry() + executor := NewExecutor(registry, nil, 100*time.Millisecond) + if err := registry.Register(HookSpec{ + ID: "matcher-hook", + Point: HookPointBeforeToolCall, + Scope: HookScopeUser, + Matcher: &HookMatcher{ToolNames: []string{"bash"}}, + Handler: func(context.Context, HookContext) HookResult { + return HookResult{Status: HookResultPass, Message: "should-not-run"} + }, + }); err != nil { + t.Fatalf("Register() error = %v", err) + } + + output := executor.Run(context.Background(), HookPointBeforeToolCall, HookContext{ + Metadata: map[string]any{"tool_name": "filesystem"}, + }) + if output.Blocked { + t.Fatalf("Blocked = true, want false") + } + if len(output.Results) != 0 { + t.Fatalf("len(Results) = %d, want 0 when matcher missed", len(output.Results)) + } +} + diff --git a/internal/runtime/hooks/matcher.go b/internal/runtime/hooks/matcher.go new file mode 100644 index 00000000..16c13322 --- /dev/null +++ b/internal/runtime/hooks/matcher.go @@ -0,0 +1,281 @@ +package hooks + +import ( + "fmt" + "regexp" + "strings" +) + +const ( + // MaxHookMatcherRegexLength 限制 tool_name_regex 单条表达式长度,避免超长输入拖慢匹配。 + MaxHookMatcherRegexLength = 256 +) + +const ( + hookMatcherFieldToolName = "tool_name" + hookMatcherFieldToolNameRegex = "tool_name_regex" + hookMatcherFieldArgumentsContains = "arguments_contains" + hookMatcherMetadataToolName = "tool_name" + hookMatcherMetadataArguments = "tool_arguments_preview" +) + +// HookMatcher 描述编译后的 hook 匹配器。 +type HookMatcher struct { + ToolNames []string + ToolNameRegex []*regexp.Regexp + ArgumentsContains []string +} + +// HasHookMatcherConfig 判断 matcher 配置是否包含至少一个非空维度。 +func HasHookMatcherConfig(raw map[string]any) bool { + if len(raw) == 0 { + return false + } + names := readHookMatcherStringValues(raw, hookMatcherFieldToolName) + if len(names) > 0 { + return true + } + regexes := readHookMatcherStringValues(raw, hookMatcherFieldToolNameRegex) + if len(regexes) > 0 { + return true + } + contains := readHookMatcherStringValues(raw, hookMatcherFieldArgumentsContains) + return len(contains) > 0 +} + +// ValidateHookMatcher 校验 matcher 配置在指定点位上是否合法。 +func ValidateHookMatcher(point HookPoint, raw map[string]any) error { + _, err := CompileHookMatcher(point, raw) + return err +} + +// CompileHookMatcher 将 matcher 原始配置编译为可执行结构,并在点位能力上做 fail-fast 校验。 +func CompileHookMatcher(point HookPoint, raw map[string]any) (*HookMatcher, error) { + if len(raw) == 0 { + return nil, nil + } + if err := validateHookMatcherFields(raw); err != nil { + return nil, err + } + if !HasHookMatcherConfig(raw) { + return nil, fmt.Errorf("match contains no recognized matcher fields (expected: tool_name, tool_name_regex, arguments_contains)") + } + capability, ok := HookPointCapabilities(point) + if !ok { + return nil, fmt.Errorf("point %q is not supported", point) + } + + namesRaw := readHookMatcherStringValues(raw, hookMatcherFieldToolName) + regexRaw := readHookMatcherStringValues(raw, hookMatcherFieldToolNameRegex) + containsRaw := readHookMatcherStringValues(raw, hookMatcherFieldArgumentsContains) + + if len(namesRaw) > 0 && !capability.Matcher.ToolName { + return nil, fmt.Errorf("point %q does not support matcher field %q", point, hookMatcherFieldToolName) + } + if len(regexRaw) > 0 && !capability.Matcher.ToolNameRegex { + return nil, fmt.Errorf("point %q does not support matcher field %q", point, hookMatcherFieldToolNameRegex) + } + if len(containsRaw) > 0 && !capability.Matcher.ArgumentsContains { + return nil, fmt.Errorf("point %q does not support matcher field %q", point, hookMatcherFieldArgumentsContains) + } + + matcher := &HookMatcher{ + ToolNames: normalizeHookMatcherValues(namesRaw), + ArgumentsContains: normalizeHookMatcherValues(containsRaw), + } + for _, expression := range regexRaw { + trimmed := strings.TrimSpace(expression) + if trimmed == "" { + continue + } + if len(trimmed) > MaxHookMatcherRegexLength { + return nil, fmt.Errorf( + "matcher field %q expression length exceeds %d", + hookMatcherFieldToolNameRegex, + MaxHookMatcherRegexLength, + ) + } + compiled, err := regexp.Compile(trimmed) + if err != nil { + return nil, fmt.Errorf("matcher field %q has invalid regex %q: %w", hookMatcherFieldToolNameRegex, trimmed, err) + } + matcher.ToolNameRegex = append(matcher.ToolNameRegex, compiled) + } + if matcher.IsEmpty() { + return nil, fmt.Errorf("match must include at least one non-empty matcher field") + } + return matcher, nil +} + +// validateHookMatcherFields 校验 matcher 配置中不存在未支持字段,避免拼写错误被静默忽略。 +func validateHookMatcherFields(raw map[string]any) error { + if len(raw) == 0 { + return nil + } + for key := range raw { + normalized := strings.ToLower(strings.TrimSpace(key)) + switch normalized { + case hookMatcherFieldToolName, hookMatcherFieldToolNameRegex, hookMatcherFieldArgumentsContains: + continue + default: + return fmt.Errorf( + "match contains unknown field %q (allowed: tool_name, tool_name_regex, arguments_contains)", + key, + ) + } + } + return nil +} + +// IsEmpty 判断 matcher 是否包含可执行维度。 +func (m *HookMatcher) IsEmpty() bool { + if m == nil { + return true + } + return len(m.ToolNames) == 0 && len(m.ToolNameRegex) == 0 && len(m.ArgumentsContains) == 0 +} + +// Match 根据 HookContext 执行 matcher 判定;字段间为 AND,同字段多值为 OR。 +func (m *HookMatcher) Match(input HookContext) bool { + if m == nil || m.IsEmpty() { + return true + } + toolName := strings.TrimSpace(readHookMatcherMetadataString(input.Metadata, hookMatcherMetadataToolName)) + if len(m.ToolNames) > 0 { + if toolName == "" || !containsEqualFold(m.ToolNames, toolName) { + return false + } + } + if len(m.ToolNameRegex) > 0 { + if toolName == "" { + return false + } + matched := false + for _, compiled := range m.ToolNameRegex { + if compiled.MatchString(toolName) { + matched = true + break + } + } + if !matched { + return false + } + } + if len(m.ArgumentsContains) > 0 { + argumentsPreview := strings.ToLower(strings.TrimSpace(readHookMatcherMetadataString( + input.Metadata, + hookMatcherMetadataArguments, + ))) + if argumentsPreview == "" { + return false + } + matched := false + for _, fragment := range m.ArgumentsContains { + if strings.Contains(argumentsPreview, fragment) { + matched = true + break + } + } + if !matched { + return false + } + } + return true +} + +// readHookMatcherStringValues 读取 matcher 字段中的字符串集合,兼容 string / []string / []any。 +func readHookMatcherStringValues(raw map[string]any, key string) []string { + if len(raw) == 0 { + return nil + } + value, ok := raw[key] + if !ok || value == nil { + return nil + } + switch typed := value.(type) { + case string: + if strings.TrimSpace(typed) == "" { + return nil + } + return []string{typed} + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if strings.TrimSpace(item) == "" { + continue + } + out = append(out, item) + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if item == nil { + continue + } + text := strings.TrimSpace(fmt.Sprintf("%v", item)) + if text == "" { + continue + } + out = append(out, text) + } + return out + default: + text := strings.TrimSpace(fmt.Sprintf("%v", typed)) + if text == "" { + return nil + } + return []string{text} + } +} + +// normalizeHookMatcherValues 将 matcher 词条规范为小写并剔除空值。 +func normalizeHookMatcherValues(values []string) []string { + if len(values) == 0 { + return nil + } + normalized := make([]string, 0, len(values)) + for _, value := range values { + text := strings.ToLower(strings.TrimSpace(value)) + if text == "" { + continue + } + normalized = append(normalized, text) + } + return normalized +} + +// containsEqualFold 判断字符串列表是否包含目标值(忽略大小写)。 +func containsEqualFold(values []string, target string) bool { + normalizedTarget := strings.ToLower(strings.TrimSpace(target)) + if normalizedTarget == "" { + return false + } + for _, value := range values { + if strings.EqualFold(strings.TrimSpace(value), normalizedTarget) { + return true + } + } + return false +} + +// readHookMatcherMetadataString 从 metadata 中读取字符串,兼容大小写键和非字符串值。 +func readHookMatcherMetadataString(metadata map[string]any, key string) string { + if len(metadata) == 0 { + return "" + } + normalizedKey := strings.ToLower(strings.TrimSpace(key)) + if normalizedKey == "" { + return "" + } + if value, ok := metadata[normalizedKey]; ok && value != nil { + return strings.TrimSpace(fmt.Sprintf("%v", value)) + } + for currentKey, value := range metadata { + if !strings.EqualFold(strings.TrimSpace(currentKey), normalizedKey) || value == nil { + continue + } + return strings.TrimSpace(fmt.Sprintf("%v", value)) + } + return "" +} diff --git a/internal/runtime/hooks/matcher_test.go b/internal/runtime/hooks/matcher_test.go new file mode 100644 index 00000000..cf85dfd3 --- /dev/null +++ b/internal/runtime/hooks/matcher_test.go @@ -0,0 +1,344 @@ +package hooks + +import ( + "regexp" + "testing" +) + +func TestHasHookMatcherConfig(t *testing.T) { + t.Parallel() + + if HasHookMatcherConfig(nil) { + t.Fatal("nil matcher config should be false") + } + if HasHookMatcherConfig(map[string]any{}) { + t.Fatal("empty matcher config should be false") + } + if !HasHookMatcherConfig(map[string]any{"tool_name": "bash"}) { + t.Fatal("tool_name matcher should be true") + } + if !HasHookMatcherConfig(map[string]any{"tool_name_regex": []any{"^bash$"}}) { + t.Fatal("tool_name_regex matcher should be true") + } + if !HasHookMatcherConfig(map[string]any{"arguments_contains": []string{"rm -rf"}}) { + t.Fatal("arguments_contains matcher should be true") + } + if HasHookMatcherConfig(map[string]any{"tool_name": " "}) { + t.Fatal("whitespace-only tool_name should be false") + } +} + +func TestCompileHookMatcherAndMatch(t *testing.T) { + t.Parallel() + + matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": []any{"bash", "filesystem"}, + "tool_name_regex": []string{`^(bash|shell)$`}, + "arguments_contains": []string{"rm -rf"}, + }) + if err != nil { + t.Fatalf("CompileHookMatcher() error = %v", err) + } + if matcher == nil { + t.Fatal("expected matcher to be compiled") + } + + if !matcher.Match(HookContext{ + Metadata: map[string]any{ + "tool_name": "bash", + "tool_arguments_preview": "sudo rm -rf /tmp/test", + }, + }) { + t.Fatal("expected matcher to pass for matching metadata") + } + if matcher.Match(HookContext{ + Metadata: map[string]any{ + "tool_name": "bash", + "tool_arguments_preview": "echo hello", + }, + }) { + t.Fatal("expected matcher to fail when arguments_contains not matched") + } + if matcher.Match(HookContext{ + Metadata: map[string]any{ + "tool_name": "filesystem", + "tool_arguments_preview": "rm -rf /tmp", + }, + }) { + t.Fatal("expected matcher to fail when tool_name_regex not matched") + } +} + +func TestCompileHookMatcherValidation(t *testing.T) { + t.Parallel() + + if _, err := CompileHookMatcher(HookPointSessionStart, map[string]any{ + "tool_name": "bash", + }); err == nil { + t.Fatal("expected session_start tool_name matcher to be rejected") + } + + if _, err := CompileHookMatcher(HookPointAfterToolResult, map[string]any{ + "arguments_contains": []string{"rm -rf"}, + }); err == nil { + t.Fatal("expected after_tool_result arguments_contains to be rejected") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name_regex": "(", + }); err == nil { + t.Fatal("expected invalid regex to fail") + } + + longRegex := make([]byte, MaxHookMatcherRegexLength+1) + for i := range longRegex { + longRegex[i] = 'a' + } + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name_regex": string(longRegex), + }); err == nil { + t.Fatal("expected overlong regex to fail") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_names": "bash", + }); err == nil { + t.Fatal("expected unknown matcher field to be rejected") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "unknown": "value", + }); err == nil { + t.Fatal("expected completely unknown matcher field to be rejected") + } + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + "tool_names": []any{"filesystem"}, + }); err == nil { + t.Fatal("expected mixed matcher fields with typo to be rejected") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, nil); err != nil { + t.Fatal("nil raw should succeed with nil matcher") + } + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{}); err != nil { + t.Fatal("empty raw should succeed with nil matcher") + } +} + +func TestValidateHookMatcher(t *testing.T) { + t.Parallel() + + if err := ValidateHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + }); err != nil { + t.Fatalf("ValidateHookMatcher() error = %v", err) + } + if err := ValidateHookMatcher(HookPointSessionStart, map[string]any{ + "tool_name": "bash", + }); err == nil { + t.Fatal("expected session_start matcher to fail validation") + } +} + +func TestIsEmpty(t *testing.T) { + t.Parallel() + + var nilMatcher *HookMatcher + if !nilMatcher.IsEmpty() { + t.Fatal("nil matcher should be empty") + } + if !(&HookMatcher{}).IsEmpty() { + t.Fatal("zero-value matcher should be empty") + } + if (&HookMatcher{ToolNames: []string{"bash"}}).IsEmpty() { + t.Fatal("matcher with tool_name should not be empty") + } +} + +func TestMatchNilAndEmpty(t *testing.T) { + t.Parallel() + + var nilMatcher *HookMatcher + if !nilMatcher.Match(HookContext{}) { + t.Fatal("nil matcher should match everything") + } + empty := &HookMatcher{} + if !empty.Match(HookContext{}) { + t.Fatal("empty matcher should match everything") + } +} + +func TestMatchSingleDimension(t *testing.T) { + t.Parallel() + + t.Run("tool_name only", func(t *testing.T) { + t.Parallel() + m := &HookMatcher{ToolNames: []string{"bash", "filesystem"}} + if !m.Match(HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { + t.Fatal("expected match for bash") + } + if m.Match(HookContext{Metadata: map[string]any{"tool_name": "python"}}) { + t.Fatal("expected no match for python") + } + if m.Match(HookContext{Metadata: map[string]any{}}) { + t.Fatal("expected no match when tool_name metadata missing") + } + }) + + t.Run("tool_name_regex only", func(t *testing.T) { + t.Parallel() + compiled := regexp.MustCompile(`^(bash|shell)$`) + m := &HookMatcher{ToolNameRegex: []*regexp.Regexp{compiled}} + if !m.Match(HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { + t.Fatal("expected regex match for bash") + } + if m.Match(HookContext{Metadata: map[string]any{"tool_name": "python"}}) { + t.Fatal("expected regex no match for python") + } + if m.Match(HookContext{Metadata: map[string]any{}}) { + t.Fatal("expected no match when tool_name missing for regex") + } + }) + + t.Run("arguments_contains only", func(t *testing.T) { + t.Parallel() + m := &HookMatcher{ArgumentsContains: []string{"rm -rf", "sudo"}} + if !m.Match(HookContext{Metadata: map[string]any{"tool_arguments_preview": "sudo rm -rf /tmp"}}) { + t.Fatal("expected arguments_contains match") + } + if m.Match(HookContext{Metadata: map[string]any{"tool_arguments_preview": "echo hello"}}) { + t.Fatal("expected arguments_contains no match") + } + if m.Match(HookContext{Metadata: map[string]any{}}) { + t.Fatal("expected no match when arguments_preview missing") + } + }) +} + +func TestReadHookMatcherStringValues(t *testing.T) { + t.Parallel() + + if got := readHookMatcherStringValues(nil, "x"); len(got) != 0 { + t.Fatal("nil raw should return nil") + } + if got := readHookMatcherStringValues(map[string]any{}, "x"); len(got) != 0 { + t.Fatal("empty raw should return nil") + } + if got := readHookMatcherStringValues(map[string]any{"x": nil}, "x"); len(got) != 0 { + t.Fatal("nil value should return nil") + } + if got := readHookMatcherStringValues(map[string]any{"x": " "}, "x"); len(got) != 0 { + t.Fatal("whitespace-only string should return nil") + } + if got := readHookMatcherStringValues(map[string]any{"x": 42}, "x"); len(got) != 1 || got[0] != "42" { + t.Fatalf("int value should be converted to string, got %v", got) + } + if got := readHookMatcherStringValues(map[string]any{"x": []any{" a ", nil, 123}}, "x"); len(got) != 2 || got[0] != "a" || got[1] != "123" { + t.Fatalf("[]any with mixed values, got %v", got) + } + if got := readHookMatcherStringValues(map[string]any{"x": "hello"}, "y"); len(got) != 0 { + t.Fatal("missing key should return nil") + } +} + +func TestNormalizeHookMatcherValues(t *testing.T) { + t.Parallel() + + if got := normalizeHookMatcherValues(nil); len(got) != 0 { + t.Fatal("nil values should return nil") + } + if got := normalizeHookMatcherValues([]string{}); len(got) != 0 { + t.Fatal("empty values should return nil") + } + if got := normalizeHookMatcherValues([]string{" ", "\t"}); len(got) != 0 { + t.Fatal("whitespace-only values should return empty") + } + if got := normalizeHookMatcherValues([]string{" BASH ", "", " Filesystem "}); len(got) != 2 || got[0] != "bash" || got[1] != "filesystem" { + t.Fatalf("mixed values should be normalized, got %v", got) + } +} + +func TestContainsEqualFold(t *testing.T) { + t.Parallel() + + if containsEqualFold(nil, "bash") { + t.Fatal("nil values should not match") + } + if containsEqualFold([]string{"bash"}, "") { + t.Fatal("empty target should not match") + } + if containsEqualFold([]string{"bash"}, " ") { + t.Fatal("whitespace-only target should not match") + } + if !containsEqualFold([]string{"BASH", "FILESYSTEM"}, " bash ") { + t.Fatal("case-insensitive match should work") + } + if containsEqualFold([]string{"bash"}, "python") { + t.Fatal("non-matching should return false") + } +} + +func TestReadHookMatcherMetadataString(t *testing.T) { + t.Parallel() + + if got := readHookMatcherMetadataString(nil, "x"); got != "" { + t.Fatal("nil metadata should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{}, "x"); got != "" { + t.Fatal("empty metadata should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"x": nil}, "x"); got != "" { + t.Fatal("nil value should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"x": 123}, "x"); got != "123" { + t.Fatalf("non-string value should be converted, got %q", got) + } + if got := readHookMatcherMetadataString(map[string]any{"x": "hello"}, " "); got != "" { + t.Fatal("empty key should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"x": "hello"}, ""); got != "" { + t.Fatal("empty string key should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"TOOL_NAME": "bash"}, "tool_name"); got != "bash" { + t.Fatalf("case-insensitive key lookup failed, got %q", got) + } + if got := readHookMatcherMetadataString(map[string]any{"y": "hello"}, "x"); got != "" { + t.Fatal("missing key should return empty") + } +} + +func TestCompileHookMatcherRegexWhitespaceSkipped(t *testing.T) { + t.Parallel() + + matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + "tool_name_regex": []string{" ", "\t"}, + }) + if err != nil { + t.Fatalf("CompileHookMatcher() error = %v", err) + } + if matcher == nil { + t.Fatal("expected matcher compiled even when regex values are whitespace-only") + } + if len(matcher.ToolNameRegex) != 0 { + t.Fatalf("expected empty tool_name_regex slice, got %d entries", len(matcher.ToolNameRegex)) + } +} + +func TestCompileHookMatcherRegexOnly(t *testing.T) { + t.Parallel() + + matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name_regex": `^bash`, + }) + if err != nil { + t.Fatalf("CompileHookMatcher() error = %v", err) + } + if matcher == nil { + t.Fatal("expected matcher compiled for regex-only config") + } + if !matcher.Match(HookContext{Metadata: map[string]any{"tool_name": "bash-script"}}) { + t.Fatal("expected regex to match") + } +} diff --git a/internal/runtime/hooks/types.go b/internal/runtime/hooks/types.go index 2f928f5c..7cd37869 100644 --- a/internal/runtime/hooks/types.go +++ b/internal/runtime/hooks/types.go @@ -45,22 +45,77 @@ type HookPointCapability struct { CanAnnotate bool CanUpdateInput bool UserAllowed bool + Matcher HookMatcherCapability +} + +// HookMatcherCapability 描述点位可用的 matcher 维度。 +type HookMatcherCapability struct { + ToolName bool + ToolNameRegex bool + ArgumentsContains bool } var hookPointCapabilities = map[HookPoint]HookPointCapability{ - HookPointBeforeToolCall: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointAfterToolResult: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointBeforeCompletionDecision: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointAcceptGate: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointBeforePermissionDecision: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false}, - HookPointAfterToolFailure: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointSessionStart: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointSessionEnd: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointUserPromptSubmit: {CanBlock: true, CanAnnotate: true, CanUpdateInput: true, UserAllowed: true}, - HookPointPreCompact: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false}, - HookPointPostCompact: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointSubAgentStart: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false}, - HookPointSubAgentStop: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, + HookPointBeforeToolCall: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: true, + }, + }, + HookPointAfterToolResult: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: false, + }, + }, + HookPointBeforeCompletionDecision: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointAcceptGate: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointBeforePermissionDecision: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: false, + }, + }, + HookPointAfterToolFailure: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: true, + }, + }, + HookPointSessionStart: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointSessionEnd: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointUserPromptSubmit: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: true, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointPreCompact: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false, + Matcher: HookMatcherCapability{}, + }, + HookPointPostCompact: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointSubAgentStart: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false, + Matcher: HookMatcherCapability{}, + }, + HookPointSubAgentStop: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, } // HookScope 描述 hook 的权限/上下文裁剪等级。 @@ -142,6 +197,7 @@ type HookSpec struct { Timeout time.Duration FailurePolicy FailurePolicy Handler HookHandler + Matcher *HookMatcher } // normalizeAndValidate 将 HookSpec 归一化并校验当前阶段可用字段。 diff --git a/internal/runtime/hooks/types_test.go b/internal/runtime/hooks/types_test.go index 9a81cb07..3664133c 100644 --- a/internal/runtime/hooks/types_test.go +++ b/internal/runtime/hooks/types_test.go @@ -217,6 +217,9 @@ func TestHookPointCapabilities(t *testing.T) { if !capability.CanBlock { t.Fatal("before_permission_decision should allow block") } + if !capability.Matcher.ToolName || !capability.Matcher.ToolNameRegex { + t.Fatal("before_permission_decision should support tool_name/tool_name_regex matcher") + } capability, ok = HookPointCapabilities(HookPointAfterToolFailure) if !ok { @@ -225,6 +228,9 @@ func TestHookPointCapabilities(t *testing.T) { if capability.CanBlock { t.Fatal("after_tool_failure should be observe-only") } + if !capability.Matcher.ArgumentsContains { + t.Fatal("after_tool_failure should support arguments_contains matcher") + } capability, ok = HookPointCapabilities(HookPointBeforeCompletionDecision) if !ok { @@ -241,6 +247,9 @@ func TestHookPointCapabilities(t *testing.T) { if !capability.CanBlock { t.Fatal("accept_gate should allow block") } + if capability.Matcher.ToolName || capability.Matcher.ToolNameRegex || capability.Matcher.ArgumentsContains { + t.Fatal("accept_gate should not expose matcher dimensions") + } if _, exists := HookPointCapabilities(HookPoint("unknown")); exists { t.Fatal("unknown hook point should not have capability") diff --git a/internal/runtime/repo_hooks.go b/internal/runtime/repo_hooks.go index b22bb343..292fdfef 100644 --- a/internal/runtime/repo_hooks.go +++ b/internal/runtime/repo_hooks.go @@ -359,31 +359,25 @@ func validateRepoHookItem(item config.RuntimeHookItemConfig) error { default: return fmt.Errorf("handler %q is not supported", item.Handler) } - if handler == "warn_on_tool_call" && !runtimeHasWarnOnToolCallTargets(item.Params) { - return fmt.Errorf("handler %q requires params.tool_name or params.tool_names", item.Handler) - } + if handler == "warn_on_tool_call" && !runtimehooks.HasHookMatcherConfig(item.Match) { + return fmt.Errorf("handler %q requires match", item.Handler) + } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(point, item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } case repoHookKindCommand: if err := runtimehooks.ValidateCommandParams(item.Params); err != nil { return err } - } - return nil -} - -// runtimeHasWarnOnToolCallTargets 判断 warn_on_tool_call 是否配置了至少一个目标工具。 -func runtimeHasWarnOnToolCallTargets(params map[string]any) bool { - if len(params) == 0 { - return false - } - if name := strings.TrimSpace(readHookParamString(params, "tool_name")); name != "" { - return true - } - for _, value := range readHookParamStringSlice(params, "tool_names") { - if strings.TrimSpace(value) != "" { - return true + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(point, item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } } } - return false + return nil } // evaluateWorkspaceTrust 根据 trust store 判断 workspace 是否可信并附带容错诊断。 diff --git a/internal/runtime/repo_hooks_test.go b/internal/runtime/repo_hooks_test.go index 64688256..a4f74c9f 100644 --- a/internal/runtime/repo_hooks_test.go +++ b/internal/runtime/repo_hooks_test.go @@ -596,25 +596,47 @@ func TestValidateRepoHookItemRejectsExternalKindsWithP6LiteMessage(t *testing.T) } } -func TestRuntimeHasWarnOnToolCallTargetsBranches(t *testing.T) { - cases := []struct { - name string - params map[string]any - want bool - }{ - {name: "nil", params: nil, want: false}, - {name: "tool_name", params: map[string]any{"tool_name": "bash"}, want: true}, - {name: "tool_name blank", params: map[string]any{"tool_name": " "}, want: false}, - {name: "tool_names", params: map[string]any{"tool_names": []any{"bash"}}, want: true}, - {name: "tool_names blank", params: map[string]any{"tool_names": []any{" "}}, want: false}, + +func TestValidateRepoHookItemAllowsWarnOnToolCallWithMatchOnly(t *testing.T) { + t.Parallel() + + item := config.RuntimeHookItemConfig{ + ID: "repo-warn-match", + Point: "before_tool_call", + Scope: "repo", + Kind: "builtin", + Mode: "sync", + Handler: "warn_on_tool_call", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Match: map[string]any{ + "tool_name": "bash", + }, } - for _, tc := range cases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - if got := runtimeHasWarnOnToolCallTargets(tc.params); got != tc.want { - t.Fatalf("runtimeHasWarnOnToolCallTargets() = %v, want %v", got, tc.want) - } - }) + if err := validateRepoHookItem(item); err != nil { + t.Fatalf("validateRepoHookItem() error = %v", err) + } +} + +func TestValidateRepoHookItemRejectsUnsupportedMatcherDimension(t *testing.T) { + t.Parallel() + + item := config.RuntimeHookItemConfig{ + ID: "repo-session-match", + Point: "session_start", + Scope: "repo", + Kind: "builtin", + Mode: "sync", + Handler: "add_context_note", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Params: map[string]any{"note": "repo"}, + Match: map[string]any{ + "tool_name": "bash", + }, + } + if err := validateRepoHookItem(item); err == nil { + t.Fatal("expected unsupported matcher dimension to fail") } } diff --git a/internal/runtime/toolexec.go b/internal/runtime/toolexec.go index 49a82158..8b57a247 100644 --- a/internal/runtime/toolexec.go +++ b/internal/runtime/toolexec.go @@ -6,9 +6,11 @@ import ( "errors" "fmt" "path/filepath" + "regexp" "sort" "strings" "sync" + "unicode" "neo-code/internal/checkpoint" providertypes "neo-code/internal/provider/types" @@ -22,6 +24,12 @@ type indexedToolCall struct { call providertypes.ToolCall } +const hookToolArgumentsPreviewMaxChars = 512 + +var hookToolArgumentsSensitivePattern = regexp.MustCompile( + `(?i)(token|password|secret|api[_-]?key|access[_-]?key|auth)\s*[:=]\s*("[^"]*"|'[^']*'|[^\s]+)`, +) + // executeAssistantToolCalls 并发执行 assistant 返回的全部工具调用并返回结构化执行摘要。 func (s *Service) executeAssistantToolCalls( ctx context.Context, @@ -160,10 +168,11 @@ func (s *Service) executeOneToolCallWithoutPersistence( beforeToolHookOutput := s.runHookPoint(ctx, state, runtimehooks.HookPointBeforeToolCall, runtimehooks.HookContext{ Metadata: map[string]any{ - "tool_call_id": strings.TrimSpace(call.ID), - "tool_name": strings.TrimSpace(call.Name), - "tool_arguments": strings.TrimSpace(call.Arguments), - "workdir": strings.TrimSpace(snapshot.Workdir), + "tool_call_id": strings.TrimSpace(call.ID), + "tool_name": strings.TrimSpace(call.Name), + "tool_arguments": strings.TrimSpace(call.Arguments), + "tool_arguments_preview": buildToolArgumentsPreview(call.Arguments), + "workdir": strings.TrimSpace(snapshot.Workdir), }, }) if beforeToolHookOutput.Blocked { @@ -752,6 +761,122 @@ func summarizeHookResultContent(content string) string { return trimmed[:256] } +// buildToolArgumentsPreview 生成 matcher 可用的参数预览,并对敏感键值执行脱敏。 +func buildToolArgumentsPreview(arguments string) string { + trimmed := strings.TrimSpace(arguments) + if trimmed == "" { + return "" + } + masked := sanitizeHookToolArguments(trimmed) + return truncateHookTextByChars(masked, hookToolArgumentsPreviewMaxChars) +} + +// sanitizeHookToolArguments 优先按 JSON 结构递归脱敏,非 JSON 输入回退为轻量正则脱敏。 +func sanitizeHookToolArguments(arguments string) string { + if masked, ok := sanitizeHookToolArgumentsJSON(arguments); ok { + return masked + } + return hookToolArgumentsSensitivePattern.ReplaceAllString(arguments, `$1=***`) +} + +// sanitizeHookToolArgumentsJSON 尝试解析 JSON 并按敏感键递归替换值。 +func sanitizeHookToolArgumentsJSON(arguments string) (string, bool) { + var decoded any + if err := json.Unmarshal([]byte(arguments), &decoded); err != nil { + return "", false + } + sanitized := maskHookToolArgumentValue(decoded) + encoded, err := json.Marshal(sanitized) + if err != nil { + return "", false + } + return string(encoded), true +} + +// maskHookToolArgumentValue 递归处理 JSON 节点,对敏感键对应的值统一替换为 "***"。 +func maskHookToolArgumentValue(value any) any { + switch typed := value.(type) { + case map[string]any: + masked := make(map[string]any, len(typed)) + for key, item := range typed { + if isSensitiveHookToolArgumentKey(key) { + masked[key] = "***" + continue + } + masked[key] = maskHookToolArgumentValue(item) + } + return masked + case []any: + masked := make([]any, len(typed)) + for index, item := range typed { + masked[index] = maskHookToolArgumentValue(item) + } + return masked + default: + return value + } +} + +// isSensitiveHookToolArgumentKey 判断参数键名是否属于敏感信息字段。 +func isSensitiveHookToolArgumentKey(key string) bool { + tokens := tokenizeHookToolArgumentKey(key) + if len(tokens) == 0 { + return false + } + for index, token := range tokens { + switch token { + case "password", "passwd", "secret", "token", "auth", "authorization": + return true + case "apikey", "accesskey", "authtoken", "accesstoken": + return true + case "api", "access": + if index+1 < len(tokens) && tokens[index+1] == "key" { + return true + } + case "key": + if index > 0 && (tokens[index-1] == "api" || tokens[index-1] == "access") { + return true + } + } + } + return false +} + +// tokenizeHookToolArgumentKey 将参数键拆分为小写词元,兼容 snake/kebab/camelCase。 +func tokenizeHookToolArgumentKey(key string) []string { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + return nil + } + var builder strings.Builder + var prev rune + for _, current := range trimmed { + switch { + case unicode.IsLetter(current) || unicode.IsDigit(current): + if unicode.IsUpper(current) && unicode.IsLower(prev) { + builder.WriteByte(' ') + } + builder.WriteRune(unicode.ToLower(current)) + default: + builder.WriteByte(' ') + } + prev = current + } + return strings.Fields(builder.String()) +} + +// truncateHookTextByChars 按字符长度截断文本,避免 metadata 放大。 +func truncateHookTextByChars(text string, maxChars int) string { + if maxChars <= 0 { + return "" + } + runes := []rune(text) + if len(runes) <= maxChars { + return text + } + return string(runes[:maxChars]) +} + // extractTodoIDsFromPayload 提取 todo 事件快照中的条目 ID,用于冲突事实去重统计。 func extractTodoIDsFromPayload(items []TodoViewItem) []string { if len(items) == 0 { @@ -828,11 +953,12 @@ func (s *Service) emitAfterToolFailureHook( workdir string, ) { afterToolFailureMetadata := map[string]any{ - "tool_call_id": strings.TrimSpace(call.ID), - "tool_name": strings.TrimSpace(call.Name), - "is_error": result.IsError, - "error_class": strings.TrimSpace(result.ErrorClass), - "workdir": strings.TrimSpace(workdir), + "tool_call_id": strings.TrimSpace(call.ID), + "tool_name": strings.TrimSpace(call.Name), + "tool_arguments_preview": buildToolArgumentsPreview(call.Arguments), + "is_error": result.IsError, + "error_class": strings.TrimSpace(result.ErrorClass), + "workdir": strings.TrimSpace(workdir), } if execErr != nil { afterToolFailureMetadata["execution_error"] = strings.TrimSpace(execErr.Error()) diff --git a/internal/runtime/toolexec_preview_test.go b/internal/runtime/toolexec_preview_test.go new file mode 100644 index 00000000..1b272356 --- /dev/null +++ b/internal/runtime/toolexec_preview_test.go @@ -0,0 +1,81 @@ +package runtime + +import ( + "strings" + "testing" +) + +func TestBuildToolArgumentsPreviewMaskJSONSensitiveFields(t *testing.T) { + t.Parallel() + + raw := `{"api_key":"sk-123","password":"p@ss","nested":{"secret":"abc"},"safe":"ok"}` + preview := buildToolArgumentsPreview(raw) + if strings.Contains(preview, "sk-123") { + t.Fatalf("preview leaked api_key: %q", preview) + } + if strings.Contains(preview, "p@ss") { + t.Fatalf("preview leaked password: %q", preview) + } + if strings.Contains(preview, `"secret":"abc"`) { + t.Fatalf("preview leaked nested secret: %q", preview) + } + if !strings.Contains(preview, `"api_key":"***"`) { + t.Fatalf("preview should mask api_key: %q", preview) + } + if !strings.Contains(preview, `"password":"***"`) { + t.Fatalf("preview should mask password: %q", preview) + } + if !strings.Contains(preview, `"secret":"***"`) { + t.Fatalf("preview should mask nested secret: %q", preview) + } + if !strings.Contains(preview, `"safe":"ok"`) { + t.Fatalf("preview should keep non-sensitive keys: %q", preview) + } +} + +func TestBuildToolArgumentsPreviewMaskNonJSONFallback(t *testing.T) { + t.Parallel() + + preview := buildToolArgumentsPreview(`token=abc password:xyz arg=ok`) + if strings.Contains(preview, "abc") || strings.Contains(preview, "xyz") { + t.Fatalf("preview leaked fallback credentials: %q", preview) + } + if !strings.Contains(preview, "token=***") { + t.Fatalf("preview should mask token in fallback mode: %q", preview) + } + if !strings.Contains(preview, "password=***") { + t.Fatalf("preview should mask password in fallback mode: %q", preview) + } +} + +func TestBuildToolArgumentsPreviewTruncate(t *testing.T) { + t.Parallel() + + raw := strings.Repeat("a", hookToolArgumentsPreviewMaxChars+20) + preview := buildToolArgumentsPreview(raw) + if len([]rune(preview)) != hookToolArgumentsPreviewMaxChars { + t.Fatalf("preview length=%d, want %d", len([]rune(preview)), hookToolArgumentsPreviewMaxChars) + } +} + +func TestIsSensitiveHookToolArgumentKey(t *testing.T) { + t.Parallel() + + cases := []struct { + key string + want bool + }{ + {key: "api_key", want: true}, + {key: "accessKey", want: true}, + {key: "authorization", want: true}, + {key: "auth_token", want: true}, + {key: "password", want: true}, + {key: "author", want: false}, + {key: "tool_name", want: false}, + } + for _, tc := range cases { + if got := isSensitiveHookToolArgumentKey(tc.key); got != tc.want { + t.Fatalf("isSensitiveHookToolArgumentKey(%q)=%v, want %v", tc.key, got, tc.want) + } + } +} diff --git a/internal/runtime/user_hooks.go b/internal/runtime/user_hooks.go index ec7a4f15..8040b82a 100644 --- a/internal/runtime/user_hooks.go +++ b/internal/runtime/user_hooks.go @@ -12,7 +12,6 @@ import ( "os" "path/filepath" "runtime" - "slices" "strings" "time" @@ -197,48 +196,54 @@ func buildConfiguredHookSpec( if err := validateConfiguredHookItemForP6Lite(item, scope); err != nil { return runtimehooks.HookSpec{}, err } + point := runtimehooks.HookPoint(strings.TrimSpace(item.Point)) + matcher, err := buildConfiguredHookMatcher(item, point) + if err != nil { + return runtimehooks.HookSpec{}, err + } kind := strings.ToLower(strings.TrimSpace(item.Kind)) specKind := runtimehooks.HookKindFunction specMode := runtimehooks.HookModeSync var ( - handler runtimehooks.HookHandler - err error + handler runtimehooks.HookHandler + buildErr error ) switch kind { case configuredHookKindBuiltin: - handler, err = buildUserBuiltinHookHandler(strings.TrimSpace(item.Handler), item.Params, defaultWorkdir) + handler, buildErr = buildUserBuiltinHookHandler(strings.TrimSpace(item.Handler), item.Params, defaultWorkdir) specKind = runtimehooks.HookKindFunction specMode = runtimehooks.HookModeSync case configuredHookKindCommand: - handler, err = buildUserCommandHookHandler( + handler, buildErr = buildUserCommandHookHandler( strings.TrimSpace(item.ID), - runtimehooks.HookPoint(strings.TrimSpace(item.Point)), + point, item.Params, defaultWorkdir, ) specKind = runtimehooks.HookKindCommand specMode = runtimehooks.HookModeSync case configuredHookKindHTTP: - handler, err = buildUserHTTPObserveHookHandler(item) + handler, buildErr = buildUserHTTPObserveHookHandler(item) specKind = runtimehooks.HookKindHTTP specMode = runtimehooks.HookModeObserve default: return runtimehooks.HookSpec{}, fmt.Errorf("kind %q is not supported", item.Kind) } - if err != nil { - return runtimehooks.HookSpec{}, err + if buildErr != nil { + return runtimehooks.HookSpec{}, buildErr } return runtimehooks.HookSpec{ - ID: strings.TrimSpace(item.ID), - Point: runtimehooks.HookPoint(strings.TrimSpace(item.Point)), - Scope: scope, - Source: source, - Kind: specKind, - Mode: specMode, - Priority: item.Priority, - Timeout: time.Duration(item.TimeoutSec) * time.Second, - FailurePolicy: mapRuntimeHookFailurePolicy(item.FailurePolicy), - Handler: handler, + ID: strings.TrimSpace(item.ID), + Point: point, + Scope: scope, + Source: source, + Kind: specKind, + Mode: specMode, + Priority: item.Priority, + Timeout: time.Duration(item.TimeoutSec) * time.Second, + FailurePolicy: mapRuntimeHookFailurePolicy(item.FailurePolicy), + Handler: handler, + Matcher: matcher, }, nil } @@ -257,6 +262,15 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported", item.Mode) } + handler := strings.ToLower(strings.TrimSpace(item.Handler)) + if handler == "warn_on_tool_call" && !runtimehooks.HasHookMatcherConfig(item.Match) { + return fmt.Errorf("handler %q requires match", item.Handler) + } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } case configuredHookKindCommand: if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", item.Mode) @@ -264,6 +278,11 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if _, _, err := runtimehooks.ParseCommandParams(item.Params); err != nil { return err } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } case configuredHookKindHTTP: if mode != configuredHookModeObserve { return fmt.Errorf("mode %q is not supported for kind http (only observe)", item.Mode) @@ -272,6 +291,11 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if policy == "fail_closed" { return fmt.Errorf("failure_policy %q is not supported for kind http observe", item.FailurePolicy) } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } default: if isExternalHookKind(kind) { return fmt.Errorf( @@ -284,6 +308,19 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop return nil } + +// buildConfiguredHookMatcher 编译 hook matcher。 +func buildConfiguredHookMatcher(item config.RuntimeHookItemConfig, point runtimehooks.HookPoint) (*runtimehooks.HookMatcher, error) { + if !runtimehooks.HasHookMatcherConfig(item.Match) { + return nil, nil + } + matcher, err := runtimehooks.CompileHookMatcher(point, item.Match) + if err != nil { + return nil, fmt.Errorf("match: %w", err) + } + return matcher, nil +} + func isExternalHookKind(kind string) bool { switch strings.ToLower(strings.TrimSpace(kind)) { case "command", "http", "prompt", "agent": @@ -339,30 +376,16 @@ func buildUserBuiltinHookHandler( } return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} }, nil - case "warn_on_tool_call": - targetTool := strings.ToLower(strings.TrimSpace(readHookParamString(params, "tool_name"))) - targetTools := normalizeHookParamStringSlice(readHookParamStringSlice(params, "tool_names")) - if targetTool == "" && len(targetTools) == 0 { - return nil, fmt.Errorf("handler warn_on_tool_call requires params.tool_name or params.tool_names") - } - defaultMessage := "tool call matched warn_on_tool_call" - if customMessage := strings.TrimSpace(readHookParamString(params, "message")); customMessage != "" { - defaultMessage = customMessage - } - return func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { - _ = ctx - toolName := strings.ToLower(strings.TrimSpace(readHookContextMetadataString(input, "tool_name"))) - if toolName == "" { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} - } - if targetTool != "" && toolName == targetTool { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: defaultMessage} + case "warn_on_tool_call": + defaultMessage := "tool call matched warn_on_tool_call" + if customMessage := strings.TrimSpace(readHookParamString(params, "message")); customMessage != "" { + defaultMessage = customMessage } - if len(targetTools) > 0 && slices.Contains(targetTools, toolName) { + return func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { + _ = ctx + _ = input return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: defaultMessage} - } - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} - }, nil + }, nil case "add_context_note": note := strings.TrimSpace(readHookParamString(params, "note")) if note == "" { diff --git a/internal/runtime/user_hooks_test.go b/internal/runtime/user_hooks_test.go index f328687e..588361a9 100644 --- a/internal/runtime/user_hooks_test.go +++ b/internal/runtime/user_hooks_test.go @@ -33,8 +33,10 @@ func TestBuildUserHookSpecMapsFailurePolicyAndScope(t *testing.T) { Priority: 99, TimeoutSec: 7, FailurePolicy: "warn_only", - Params: map[string]any{ + Match: map[string]any{ "tool_name": "bash", + }, + Params: map[string]any{ "message": "tool call warning", }, } @@ -351,7 +353,6 @@ func TestWarnOnToolCallAndAddContextNoteHandlers(t *testing.T) { t.Parallel() warnHandler, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{ - "tool_name": "bash", "message": "bash was called", }, t.TempDir()) if err != nil { @@ -359,7 +360,6 @@ func TestWarnOnToolCallAndAddContextNoteHandlers(t *testing.T) { } warnResult := warnHandler(context.Background(), runtimehooks.HookContext{ Metadata: map[string]any{ - "tool_name": "bash", }, }) if warnResult.Status != runtimehooks.HookResultPass { @@ -369,14 +369,14 @@ func TestWarnOnToolCallAndAddContextNoteHandlers(t *testing.T) { t.Fatalf("warn message = %q, want %q", warnResult.Message, "bash was called") } - ignoreResult := warnHandler(context.Background(), runtimehooks.HookContext{ - Metadata: map[string]any{ - "tool_name": "filesystem", - }, - }) - if strings.TrimSpace(ignoreResult.Message) != "" { - t.Fatalf("expected unmatched tool to have empty message, got %q", ignoreResult.Message) - } + anyToolResult := warnHandler(context.Background(), runtimehooks.HookContext{ + Metadata: map[string]any{ + "tool_name": "filesystem", + }, + }) + if anyToolResult.Message != "bash was called" { + t.Fatalf("warn message = %q, want %q", anyToolResult.Message, "bash was called") + } noteHandler, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{ "note": "manual check required", @@ -407,10 +407,10 @@ func TestConfigureRuntimeHooksFromConfig(t *testing.T) { Scope: "user", Kind: "builtin", Mode: "sync", - Handler: "warn_on_tool_call", - Params: map[string]any{ - "tool_name": "bash", - }, + Handler: "warn_on_tool_call", + Match: map[string]any{ + "tool_name": "bash", + }, }, } cfg.Runtime.Hooks.ApplyDefaults(config.StaticDefaults().Runtime.Hooks) @@ -455,9 +455,11 @@ func TestConfigureRuntimeHooksFromConfigKeepsBaseExecutorAndComposes(t *testing. Scope: "user", Kind: "builtin", Mode: "sync", + Match: map[string]any{ + "tool_name": "bash", + }, Handler: "warn_on_tool_call", Params: map[string]any{ - "tool_name": "bash", "message": "warn", }, }, @@ -818,11 +820,11 @@ func TestConfigureRuntimeHooksWithoutItemsKeepsBehaviorUnchanged(t *testing.T) { if service.hookExecutor == nil { t.Fatal("expected runtime hooks chain to remain available for repo discovery") } - out := service.hookExecutor.Run( - context.Background(), - runtimehooks.HookPointBeforeToolCall, - runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "bash", "workdir": cfg.Workdir}}, - ) + out := service.hookExecutor.Run( + context.Background(), + runtimehooks.HookPointBeforeToolCall, + runtimehooks.HookContext{Metadata: map[string]any{"workdir": cfg.Workdir}}, + ) if out.Blocked || len(out.Results) != 0 { t.Fatalf("unexpected hook output without user/repo config: %+v", out) } @@ -834,8 +836,12 @@ func TestBuildUserBuiltinHookHandlerEdgeCases(t *testing.T) { if _, err := buildUserBuiltinHookHandler("require_file_exists", map[string]any{}, t.TempDir()); err == nil { t.Fatal("expected missing path error") } - if _, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, t.TempDir()); err == nil { - t.Fatal("expected missing target error") + handlerWithoutTarget, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, t.TempDir()) + if err != nil { + t.Fatalf("build warn_on_tool_call without target error: %v", err) + } + if got := handlerWithoutTarget(context.Background(), runtimehooks.HookContext{}); got.Message == "" { + t.Fatalf("expected default warning message when no target is configured, got %+v", got) } if _, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{}, t.TempDir()); err == nil { t.Fatal("expected missing note/message error") @@ -851,7 +857,7 @@ func TestBuildUserBuiltinHookHandlerEdgeCases(t *testing.T) { t.Fatalf("expected match message, got %q", pass.Message) } noTool := handler(context.Background(), runtimehooks.HookContext{}) - if noTool.Status != runtimehooks.HookResultPass || noTool.Message != "" { + if noTool.Status != runtimehooks.HookResultPass || noTool.Message != "hit" { t.Fatalf("unexpected no-tool result: %+v", noTool) } @@ -1349,6 +1355,7 @@ func TestConfigureRuntimeHooksInjectsAsyncResultSinkIntoBaseExecutor(t *testing. t.Fatal("expected async rewake notification to be enqueued via configured async sink") } + type countingHookExecutor struct { calls atomic.Int32 output runtimehooks.RunOutput @@ -1518,8 +1525,8 @@ func TestUserHookHandlersAndPathChecks(t *testing.T) { if _, err := buildUserBuiltinHookHandler("require_file_exists", map[string]any{}, workdir); err == nil { t.Fatal("expected missing path error") } - if _, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, workdir); err == nil { - t.Fatal("expected missing tool target error") + if _, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, workdir); err != nil { + t.Fatalf("warn_on_tool_call without target should be allowed for matcher-based filtering: %v", err) } if _, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{}, workdir); err == nil { t.Fatal("expected missing note/message error") @@ -1538,10 +1545,10 @@ func TestUserHookHandlersAndPathChecks(t *testing.T) { if result.Message == "" { t.Fatalf("expected default warn message for matched tool") } - result = warnHandler(context.Background(), runtimehooks.HookContext{}) - if result.Message != "" { - t.Fatalf("expected empty message when no tool_name metadata, got %q", result.Message) - } + result = warnHandler(context.Background(), runtimehooks.HookContext{}) + if result.Message == "" { + t.Fatalf("expected default warn message, got empty") + } noteHandler, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{"message": "note-via-message"}, workdir) if err != nil {