From 64d1fe1fec770f3060a78a354539230d01da0985 Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Fri, 29 May 2026 08:44:57 +0800 Subject: [PATCH 1/3] @ feat(hooks): add unified Hook Matcher DSL with tool_name/tool_name_regex/arguments_contains MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a match field on hook items that supports three matcher dimensions (AND within fields, OR across values). Matcher filtering is applied at the executor scheduling layer — non-matching hooks are skipped without affecting existing block/failure semantics. Includes backward-compatible bridging of the legacy warn_on_tool_call params, plus a migration notification when old and new parameters coexist. - New matcher types and compiler in internal/runtime/hooks/matcher.go - Executor integration to skip non-matching hooks before running handlers - tool_arguments_preview (sanitized + truncated) added to before_tool_call metadata - Config validation for match syntax and hook-point capability constraints - Backward-compatible warn_on_tool_call bridging with deprecation notification - Updated design docs and example configs Co-Authored-By: Claude Opus 4.7 @ --- docs/examples/hooks.yaml | 3 +- docs/examples/user-hooks-config.yaml | 3 +- docs/runtime-hooks-design.md | 22 ++ internal/config/runtime_hooks.go | 75 ++++++- internal/config/runtime_hooks_test.go | 73 +++++++ internal/runtime/hooks/executor.go | 43 ++++ internal/runtime/hooks/executor_test.go | 90 ++++++++- internal/runtime/hooks/matcher.go | 255 ++++++++++++++++++++++++ internal/runtime/hooks/matcher_test.go | 96 +++++++++ internal/runtime/hooks/types.go | 85 ++++++-- internal/runtime/hooks/types_test.go | 9 + internal/runtime/repo_hooks.go | 15 +- internal/runtime/repo_hooks_test.go | 43 ++++ internal/runtime/toolexec.go | 49 ++++- internal/runtime/user_hooks.go | 170 ++++++++++++++-- internal/runtime/user_hooks_test.go | 88 +++++++- 16 files changed, 1057 insertions(+), 62 deletions(-) create mode 100644 internal/runtime/hooks/matcher.go create mode 100644 internal/runtime/hooks/matcher_test.go diff --git a/docs/examples/hooks.yaml b/docs/examples/hooks.yaml index 7a89625cf..ad90911d8 100644 --- a/docs/examples/hooks.yaml +++ b/docs/examples/hooks.yaml @@ -17,8 +17,9 @@ hooks: kind: builtin mode: sync handler: warn_on_tool_call + match: + tool_name: ["bash"] params: - tool_names: ["bash"] message: "执行 bash 前请确认命令不会破坏工作区。" - id: require-readme-before-final diff --git a/docs/examples/user-hooks-config.yaml b/docs/examples/user-hooks-config.yaml index 055b35626..0228cc5f9 100644 --- a/docs/examples/user-hooks-config.yaml +++ b/docs/examples/user-hooks-config.yaml @@ -25,8 +25,9 @@ runtime: kind: builtin mode: sync handler: warn_on_tool_call + match: + tool_name: ["bash"] params: - tool_names: ["bash"] message: "执行 bash 前请确认命令不会破坏工作区。" - id: user-http-observe diff --git a/docs/runtime-hooks-design.md b/docs/runtime-hooks-design.md index 68c59cbf4..097a52a45 100644 --- a/docs/runtime-hooks-design.md +++ b/docs/runtime-hooks-design.md @@ -30,6 +30,10 @@ P2 仅支持: `before_tool_call`、`after_tool_result`、`before_completion_decision`、`accept_gate`、`after_tool_failure`、 `session_start`、`session_end`、`user_prompt_submit`、`post_compact`、`subagent_stop` - handler:`require_file_exists`、`warn_on_tool_call`、`add_context_note` +- `match`:统一 matcher DSL(字段间 AND、同字段多值 OR),支持: + - `tool_name`:精确匹配(`string` 或 `[]string`) + - `tool_name_regex`:正则匹配(`string` 或 `[]string`,单条最长 256) + - `arguments_contains`:参数预览包含匹配(`[]string`) - `kind=http + mode=observe`:允许发送 HTTP 观测回调(不支持 block) - `http observe` 默认不携带 metadata(`include_metadata=false`);即使显式开启也会剥离 `result_content_preview`、`execution_error` - `http observe` 回调端点仅允许 loopback 地址(`localhost` / `127.0.0.1` / `::1`),避免误配为公网外发 @@ -73,6 +77,7 @@ user/repo hook 接收的 `HookContext` 经过白名单裁剪,仅保留最小 - `run_id` / `session_id` - `point` / `tool_call_id` / `tool_name` +- `tool_arguments_preview`(脱敏+截断后的参数预览) - `is_error` / `error_class` - `result_content_preview` / `result_metadata_present` - `execution_error` @@ -109,6 +114,23 @@ runtime 内置 `HookPointCapability` 作为唯一真源,定义每个点位是 - `CanBlock=false` 的点位,hook 返回 `block` 会自动降级为观测结果,不中断主链。 - `CanUpdateInput` 在 `user_prompt_submit` 点位已开放:command hook 可通过 stdout JSON 的 `update_input` 字段改写用户输入。 - `UserAllowed=false` 的点位拒绝 user/repo 挂载(配置 fail-fast)。 +- matcher 字段会按点位能力矩阵做 fail-fast:不支持的维度会在配置加载阶段直接报错。 + +### matcher 点位维度矩阵(#684) + +| point | tool_name | tool_name_regex | arguments_contains | +|---|---|---|---| +| `before_tool_call` | ✅ | ✅ | ✅ | +| `after_tool_result` | ✅ | ✅ | ❌ | +| `after_tool_failure` | ✅ | ✅ | ✅ | +| `before_permission_decision` | ✅ | ✅ | ❌ | +| 其他点位 | ❌ | ❌ | ❌ | + +说明: + +- `arguments_contains` 基于 `tool_arguments_preview` 字段匹配,不读取 `tool_arguments` 原文。 +- `warn_on_tool_call` 的旧参数 `params.tool_name/tool_names` 仍兼容;未配置 `match` 时会自动桥接为 matcher。 +- 若 `match` 与旧参数共存,以 `match` 为准,并发出 `hook_notification` 迁移提示事件。 ### trust gate diff --git a/internal/config/runtime_hooks.go b/internal/config/runtime_hooks.go index fa237f527..86b4194c6 100644 --- a/internal/config/runtime_hooks.go +++ b/internal/config/runtime_hooks.go @@ -64,6 +64,7 @@ type RuntimeHookItemConfig struct { Kind string `yaml:"kind,omitempty"` Mode string `yaml:"mode,omitempty"` Handler string `yaml:"handler,omitempty"` + Match map[string]any `yaml:"match,omitempty"` Priority int `yaml:"priority,omitempty"` TimeoutSec int `yaml:"timeout_sec,omitempty"` FailurePolicy string `yaml:"failure_policy,omitempty"` @@ -189,6 +190,12 @@ func (c RuntimeHookItemConfig) Clone() RuntimeHookItemConfig { if c.Enabled != nil { cloned.Enabled = boolPtr(*c.Enabled) } + if len(c.Match) > 0 { + cloned.Match = make(map[string]any, len(c.Match)) + for key, value := range c.Match { + cloned.Match[key] = cloneRuntimeHookParamValue(value) + } + } if len(c.Params) > 0 { cloned.Params = make(map[string]any, len(c.Params)) for key, value := range c.Params { @@ -279,8 +286,15 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { default: return fmt.Errorf("handler %q is not supported", c.Handler) } - if handler == runtimeHookHandlerWarnOnToolCall && !hasWarnOnToolCallTargets(c.Params) { - return fmt.Errorf("handler %q requires params.tool_name or params.tool_names", c.Handler) + hasExplicitMatcher := hooks.HasHookMatcherConfig(c.Match) + if handler == runtimeHookHandlerWarnOnToolCall && !hasExplicitMatcher && !hasWarnOnToolCallTargets(c.Params) { + return fmt.Errorf("handler %q requires match or params.tool_name/tool_names", c.Handler) + } + matcherRaw := resolveRuntimeHookMatcherConfigForValidation(c, handler) + if matcherRaw != nil { + if err := hooks.ValidateHookMatcher(point, matcherRaw); err != nil { + return fmt.Errorf("match: %w", err) + } } case runtimeHookKindCommand: if normalizedMode != runtimeHookModeSync { @@ -289,6 +303,11 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { if err := hooks.ValidateCommandParams(c.Params); err != nil { return err } + if hooks.HasHookMatcherConfig(c.Match) { + if err := hooks.ValidateHookMatcher(point, c.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } case runtimeHookKindHTTP: if normalizedMode != runtimeHookModeObserve { return fmt.Errorf("mode %q is not supported for kind http (only observe)", c.Mode) @@ -296,6 +315,11 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { if err := validateRuntimeHTTPObserveItem(c, policy); err != nil { return err } + if hooks.HasHookMatcherConfig(c.Match) { + if err := hooks.ValidateHookMatcher(point, c.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } } return nil } @@ -427,6 +451,52 @@ func hasWarnOnToolCallTargets(params map[string]any) bool { return false } +// resolveRuntimeHookMatcherConfigForValidation 返回配置校验阶段的 matcher 配置。 +// 对 warn_on_tool_call 保持旧参数兼容:当未配置 match 时自动桥接 tool_name/tool_names。 +func resolveRuntimeHookMatcherConfigForValidation(item RuntimeHookItemConfig, handler string) map[string]any { + if hooks.HasHookMatcherConfig(item.Match) { + return item.Match + } + if strings.EqualFold(strings.TrimSpace(handler), runtimeHookHandlerWarnOnToolCall) && hasWarnOnToolCallTargets(item.Params) { + return runtimeHookLegacyWarnMatcherConfig(item.Params) + } + return nil +} + +// runtimeHookLegacyWarnMatcherConfig 将 warn_on_tool_call 旧参数桥接为 matcher 配置。 +func runtimeHookLegacyWarnMatcherConfig(params map[string]any) map[string]any { + if len(params) == 0 { + return nil + } + var toolNames []string + if name := strings.TrimSpace(readRuntimeHookParamString(params, "tool_name")); name != "" { + toolNames = append(toolNames, name) + } + if raw, ok := params["tool_names"]; ok && raw != nil { + switch typed := raw.(type) { + case []string: + toolNames = append(toolNames, typed...) + case []any: + for _, value := range typed { + toolNames = append(toolNames, strings.TrimSpace(fmt.Sprintf("%v", value))) + } + } + } + filtered := make([]string, 0, len(toolNames)) + for _, value := range toolNames { + if strings.TrimSpace(value) == "" { + continue + } + filtered = append(filtered, value) + } + if len(filtered) == 0 { + return nil + } + return map[string]any{ + "tool_name": filtered, + } +} + // readRuntimeHookParamString 以兼容方式读取 runtime hook 参数中的字符串值。 func readRuntimeHookParamString(params map[string]any, key string) string { if len(params) == 0 { @@ -443,4 +513,3 @@ func readRuntimeHookParamString(params map[string]any, key string) string { return fmt.Sprintf("%v", typed) } } - diff --git a/internal/config/runtime_hooks_test.go b/internal/config/runtime_hooks_test.go index 8a988ae5f..c537caf13 100644 --- a/internal/config/runtime_hooks_test.go +++ b/internal/config/runtime_hooks_test.go @@ -489,6 +489,65 @@ func TestRuntimeHooksConfigValidateWarnOnToolCallRequiresTarget(t *testing.T) { } } +func TestRuntimeHooksConfigValidateWarnOnToolCallAllowsMatchWithoutLegacyTargets(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "warn-with-match", + Point: string(hooks.HookPointBeforeToolCall), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindBuiltIn, + Mode: runtimeHookModeSync, + Handler: runtimeHookHandlerWarnOnToolCall, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Match: map[string]any{ + "tool_name": "bash", + }, + }, + }, + } + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +} + +func TestRuntimeHooksConfigValidateRejectsUnsupportedMatcherDimensionForPoint(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "session-start-match", + Point: string(hooks.HookPointSessionStart), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindBuiltIn, + Mode: runtimeHookModeSync, + Handler: runtimeHookHandlerAddContextNote, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"note": "observe"}, + Match: map[string]any{ + "tool_name": "bash", + }, + }, + }, + } + if err := cfg.Validate(); err == nil { + t.Fatal("expected unsupported matcher dimension to fail validation") + } +} + func TestRuntimeHooksConfigEdgeBranches(t *testing.T) { t.Parallel() @@ -620,6 +679,20 @@ func TestRuntimeHooksConfigEdgeBranches(t *testing.T) { if hasWarnOnToolCallTargets(map[string]any{"tool_names": "bash"}) { t.Fatal("tool_names scalar should fail") } + + matchCfg := RuntimeHookItemConfig{ + Match: map[string]any{ + "tool_name_regex": []any{`^bash$`}, + }, + } + clonedCfg := matchCfg.Clone() + clonedRegexes := clonedCfg.Match["tool_name_regex"].([]any) + clonedRegexes[0] = "^filesystem$" + clonedCfg.Match["tool_name_regex"] = clonedRegexes + originalRegexes := matchCfg.Match["tool_name_regex"].([]any) + if originalRegexes[0] == "^filesystem$" { + t.Fatal("expected match field to be deep-cloned") + } }) } diff --git a/internal/runtime/hooks/executor.go b/internal/runtime/hooks/executor.go index 8f32faef9..be9a1f74b 100644 --- a/internal/runtime/hooks/executor.go +++ b/internal/runtime/hooks/executor.go @@ -23,6 +23,7 @@ type Executor struct { defaultTimeout time.Duration maxInFlight int32 inFlight atomic.Int32 + migrationWarns sync.Map now func() time.Time asyncSink AsyncResultSink } @@ -79,6 +80,10 @@ func (e *Executor) Run(ctx context.Context, point HookPoint, input HookContext) if spec.Scope == HookScopeUser || spec.Scope == HookScopeRepo { hookInput = sanitizeUserHookContext(hookInput) } + e.emitMatcherMigrationWarning(ctx, spec) + if spec.Matcher != nil && !spec.Matcher.Match(hookInput) { + continue + } if spec.Mode == HookModeAsync || spec.Mode == HookModeAsyncRewake { e.runAsync(ctx, spec, hookInput) continue @@ -103,6 +108,43 @@ func (e *Executor) Run(ctx context.Context, point HookPoint, input HookContext) return output } +// emitMatcherMigrationWarning 在 detect 到旧 warn_on_tool_call 参数与 match 共存时发出一次迁移提示事件。 +func (e *Executor) emitMatcherMigrationWarning(ctx context.Context, spec HookSpec) { + if e == nil { + return + } + message := strings.TrimSpace(spec.MatcherMigrationWarning) + if message == "" { + return + } + dedupeKey := strings.ToLower(strings.TrimSpace( + fmt.Sprintf("%s|%s|%s|%s", spec.ID, spec.Point, spec.Scope, spec.Source), + )) + if dedupeKey == "" { + dedupeKey = strings.ToLower(strings.TrimSpace(spec.ID)) + } + if dedupeKey == "" { + dedupeKey = "matcher_migration_warning" + } + if _, loaded := e.migrationWarns.LoadOrStore(dedupeKey, struct{}{}); loaded { + return + } + e.emitBestEffort(ctx, HookEvent{ + Type: HookEventNotification, + HookID: spec.ID, + Point: spec.Point, + Scope: spec.Scope, + Source: spec.Source, + Kind: spec.Kind, + Mode: spec.Mode, + Status: HookResultPass, + Message: message, + RewakeReason: "matcher_migration", + RewakeSummary: message, + DedupeKey: dedupeKey, + }) +} + // normalizeHookResultByCapability 根据 HookPoint 能力矩阵约束单条结果。 func normalizeHookResultByCapability(point HookPoint, result HookResult) HookResult { capability, ok := HookPointCapabilities(point) @@ -340,6 +382,7 @@ func sanitizeUserHookContext(input HookContext) HookContext { "point": {}, "tool_call_id": {}, "tool_name": {}, + "tool_arguments_preview": {}, "is_error": {}, "error_class": {}, "result_content_preview": {}, diff --git a/internal/runtime/hooks/executor_test.go b/internal/runtime/hooks/executor_test.go index 6a1372e4a..b3e2e25cb 100644 --- a/internal/runtime/hooks/executor_test.go +++ b/internal/runtime/hooks/executor_test.go @@ -955,10 +955,11 @@ func TestExecutorSanitizeUserHookContext(t *testing.T) { RunID: "run-1", SessionID: "session-1", Metadata: map[string]any{ - "tool_name": "bash", - "tool_arguments": "--secret-token=abc", - "capability_token": "should-not-leak", - "workdir": "/tmp/work", + "tool_name": "bash", + "tool_arguments": "--secret-token=abc", + "tool_arguments_preview": "token=***", + "capability_token": "should-not-leak", + "workdir": "/tmp/work", }, }) @@ -971,6 +972,9 @@ func TestExecutorSanitizeUserHookContext(t *testing.T) { if _, exists := captured.Metadata["tool_arguments"]; exists { t.Fatal("tool_arguments should be stripped for user hook context") } + if got := captured.Metadata["tool_arguments_preview"]; got != "token=***" { + t.Fatalf("tool_arguments_preview = %v, want token=***", got) + } if _, exists := captured.Metadata["capability_token"]; exists { t.Fatal("capability_token should be stripped for user hook context") } @@ -999,10 +1003,11 @@ func TestExecutorSanitizeRepoHookContext(t *testing.T) { RunID: "run-1", SessionID: "session-1", Metadata: map[string]any{ - "tool_name": "bash", - "tool_arguments": "--secret-token=abc", - "capability_token": "should-not-leak", - "workdir": "/tmp/work", + "tool_name": "bash", + "tool_arguments": "--secret-token=abc", + "tool_arguments_preview": "token=***", + "capability_token": "should-not-leak", + "workdir": "/tmp/work", }, }) @@ -1012,7 +1017,76 @@ func TestExecutorSanitizeRepoHookContext(t *testing.T) { if _, exists := captured.Metadata["tool_arguments"]; exists { t.Fatal("tool_arguments should be stripped for repo hook context") } + if got := captured.Metadata["tool_arguments_preview"]; got != "token=***" { + t.Fatalf("tool_arguments_preview = %v, want token=***", got) + } if _, exists := captured.Metadata["capability_token"]; exists { t.Fatal("capability_token should be stripped for repo hook context") } } + +func TestExecutorSkipsHookWhenMatcherMissed(t *testing.T) { + t.Parallel() + + registry := NewRegistry() + executor := NewExecutor(registry, nil, 100*time.Millisecond) + if err := registry.Register(HookSpec{ + ID: "matcher-hook", + Point: HookPointBeforeToolCall, + Scope: HookScopeUser, + Matcher: &HookMatcher{ToolNames: []string{"bash"}}, + Handler: func(context.Context, HookContext) HookResult { + return HookResult{Status: HookResultPass, Message: "should-not-run"} + }, + }); err != nil { + t.Fatalf("Register() error = %v", err) + } + + output := executor.Run(context.Background(), HookPointBeforeToolCall, HookContext{ + Metadata: map[string]any{"tool_name": "filesystem"}, + }) + if output.Blocked { + t.Fatalf("Blocked = true, want false") + } + if len(output.Results) != 0 { + t.Fatalf("len(Results) = %d, want 0 when matcher missed", len(output.Results)) + } +} + +func TestExecutorEmitsMatcherMigrationWarningOnce(t *testing.T) { + t.Parallel() + + registry := NewRegistry() + emitter := &recordingEmitter{} + executor := NewExecutor(registry, emitter, 100*time.Millisecond) + if err := registry.Register(HookSpec{ + ID: "matcher-warning-hook", + Point: HookPointBeforeToolCall, + Scope: HookScopeUser, + Source: HookSourceUser, + MatcherMigrationWarning: "matcher migration warning", + Handler: func(context.Context, HookContext) HookResult { + return HookResult{Status: HookResultPass} + }, + }); err != nil { + t.Fatalf("Register() error = %v", err) + } + + _ = executor.Run(context.Background(), HookPointBeforeToolCall, HookContext{}) + _ = executor.Run(context.Background(), HookPointBeforeToolCall, HookContext{}) + + events := emitter.snapshot() + warningCount := 0 + for _, event := range events { + if event.Type != HookEventNotification { + continue + } + warningCount++ + if event.RewakeReason != "matcher_migration" { + t.Fatalf("notification reason = %q, want matcher_migration", event.RewakeReason) + } + } + if warningCount != 1 { + t.Fatalf("matcher migration warning count = %d, want 1", warningCount) + } +} diff --git a/internal/runtime/hooks/matcher.go b/internal/runtime/hooks/matcher.go new file mode 100644 index 000000000..485418b49 --- /dev/null +++ b/internal/runtime/hooks/matcher.go @@ -0,0 +1,255 @@ +package hooks + +import ( + "fmt" + "regexp" + "strings" +) + +const ( + // MaxHookMatcherRegexLength 限制 tool_name_regex 单条表达式长度,避免超长输入拖慢匹配。 + MaxHookMatcherRegexLength = 256 +) + +const ( + hookMatcherFieldToolName = "tool_name" + hookMatcherFieldToolNameRegex = "tool_name_regex" + hookMatcherFieldArgumentsContains = "arguments_contains" + hookMatcherMetadataToolName = "tool_name" + hookMatcherMetadataArguments = "tool_arguments_preview" +) + +// HookMatcher 描述编译后的 hook 匹配器。 +type HookMatcher struct { + ToolNames []string + ToolNameRegex []*regexp.Regexp + ArgumentsContains []string +} + +// HasHookMatcherConfig 判断 matcher 配置是否包含至少一个非空维度。 +func HasHookMatcherConfig(raw map[string]any) bool { + if len(raw) == 0 { + return false + } + names := readHookMatcherStringValues(raw, hookMatcherFieldToolName) + if len(names) > 0 { + return true + } + regexes := readHookMatcherStringValues(raw, hookMatcherFieldToolNameRegex) + if len(regexes) > 0 { + return true + } + contains := readHookMatcherStringValues(raw, hookMatcherFieldArgumentsContains) + return len(contains) > 0 +} + +// ValidateHookMatcher 校验 matcher 配置在指定点位上是否合法。 +func ValidateHookMatcher(point HookPoint, raw map[string]any) error { + _, err := CompileHookMatcher(point, raw) + return err +} + +// CompileHookMatcher 将 matcher 原始配置编译为可执行结构,并在点位能力上做 fail-fast 校验。 +func CompileHookMatcher(point HookPoint, raw map[string]any) (*HookMatcher, error) { + if !HasHookMatcherConfig(raw) { + return nil, nil + } + capability, ok := HookPointCapabilities(point) + if !ok { + return nil, fmt.Errorf("point %q is not supported", point) + } + + namesRaw := readHookMatcherStringValues(raw, hookMatcherFieldToolName) + regexRaw := readHookMatcherStringValues(raw, hookMatcherFieldToolNameRegex) + containsRaw := readHookMatcherStringValues(raw, hookMatcherFieldArgumentsContains) + + if len(namesRaw) > 0 && !capability.Matcher.ToolName { + return nil, fmt.Errorf("point %q does not support matcher field %q", point, hookMatcherFieldToolName) + } + if len(regexRaw) > 0 && !capability.Matcher.ToolNameRegex { + return nil, fmt.Errorf("point %q does not support matcher field %q", point, hookMatcherFieldToolNameRegex) + } + if len(containsRaw) > 0 && !capability.Matcher.ArgumentsContains { + return nil, fmt.Errorf("point %q does not support matcher field %q", point, hookMatcherFieldArgumentsContains) + } + + matcher := &HookMatcher{ + ToolNames: normalizeHookMatcherValues(namesRaw), + ArgumentsContains: normalizeHookMatcherValues(containsRaw), + } + for _, expression := range regexRaw { + trimmed := strings.TrimSpace(expression) + if trimmed == "" { + continue + } + if len(trimmed) > MaxHookMatcherRegexLength { + return nil, fmt.Errorf( + "matcher field %q expression length exceeds %d", + hookMatcherFieldToolNameRegex, + MaxHookMatcherRegexLength, + ) + } + compiled, err := regexp.Compile(trimmed) + if err != nil { + return nil, fmt.Errorf("matcher field %q has invalid regex %q: %w", hookMatcherFieldToolNameRegex, trimmed, err) + } + matcher.ToolNameRegex = append(matcher.ToolNameRegex, compiled) + } + if matcher.IsEmpty() { + return nil, fmt.Errorf("match must include at least one non-empty matcher field") + } + return matcher, nil +} + +// IsEmpty 判断 matcher 是否包含可执行维度。 +func (m *HookMatcher) IsEmpty() bool { + if m == nil { + return true + } + return len(m.ToolNames) == 0 && len(m.ToolNameRegex) == 0 && len(m.ArgumentsContains) == 0 +} + +// Match 根据 HookContext 执行 matcher 判定;字段间为 AND,同字段多值为 OR。 +func (m *HookMatcher) Match(input HookContext) bool { + if m == nil || m.IsEmpty() { + return true + } + toolName := strings.TrimSpace(readHookMatcherMetadataString(input.Metadata, hookMatcherMetadataToolName)) + if len(m.ToolNames) > 0 { + if toolName == "" || !containsEqualFold(m.ToolNames, toolName) { + return false + } + } + if len(m.ToolNameRegex) > 0 { + if toolName == "" { + return false + } + matched := false + for _, compiled := range m.ToolNameRegex { + if compiled.MatchString(toolName) { + matched = true + break + } + } + if !matched { + return false + } + } + if len(m.ArgumentsContains) > 0 { + argumentsPreview := strings.ToLower(strings.TrimSpace(readHookMatcherMetadataString( + input.Metadata, + hookMatcherMetadataArguments, + ))) + if argumentsPreview == "" { + return false + } + matched := false + for _, fragment := range m.ArgumentsContains { + if strings.Contains(argumentsPreview, fragment) { + matched = true + break + } + } + if !matched { + return false + } + } + return true +} + +// readHookMatcherStringValues 读取 matcher 字段中的字符串集合,兼容 string / []string / []any。 +func readHookMatcherStringValues(raw map[string]any, key string) []string { + if len(raw) == 0 { + return nil + } + value, ok := raw[key] + if !ok || value == nil { + return nil + } + switch typed := value.(type) { + case string: + if strings.TrimSpace(typed) == "" { + return nil + } + return []string{typed} + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if strings.TrimSpace(item) == "" { + continue + } + out = append(out, item) + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if item == nil { + continue + } + text := strings.TrimSpace(fmt.Sprintf("%v", item)) + if text == "" { + continue + } + out = append(out, text) + } + return out + default: + text := strings.TrimSpace(fmt.Sprintf("%v", typed)) + if text == "" { + return nil + } + return []string{text} + } +} + +// normalizeHookMatcherValues 将 matcher 词条规范为小写并剔除空值。 +func normalizeHookMatcherValues(values []string) []string { + if len(values) == 0 { + return nil + } + normalized := make([]string, 0, len(values)) + for _, value := range values { + text := strings.ToLower(strings.TrimSpace(value)) + if text == "" { + continue + } + normalized = append(normalized, text) + } + return normalized +} + +// containsEqualFold 判断字符串列表是否包含目标值(忽略大小写)。 +func containsEqualFold(values []string, target string) bool { + normalizedTarget := strings.ToLower(strings.TrimSpace(target)) + if normalizedTarget == "" { + return false + } + for _, value := range values { + if strings.EqualFold(strings.TrimSpace(value), normalizedTarget) { + return true + } + } + return false +} + +// readHookMatcherMetadataString 从 metadata 中读取字符串,兼容大小写键和非字符串值。 +func readHookMatcherMetadataString(metadata map[string]any, key string) string { + if len(metadata) == 0 { + return "" + } + normalizedKey := strings.ToLower(strings.TrimSpace(key)) + if normalizedKey == "" { + return "" + } + if value, ok := metadata[normalizedKey]; ok && value != nil { + return strings.TrimSpace(fmt.Sprintf("%v", value)) + } + for currentKey, value := range metadata { + if !strings.EqualFold(strings.TrimSpace(currentKey), normalizedKey) || value == nil { + continue + } + return strings.TrimSpace(fmt.Sprintf("%v", value)) + } + return "" +} diff --git a/internal/runtime/hooks/matcher_test.go b/internal/runtime/hooks/matcher_test.go new file mode 100644 index 000000000..895aa28af --- /dev/null +++ b/internal/runtime/hooks/matcher_test.go @@ -0,0 +1,96 @@ +package hooks + +import "testing" + +func TestHasHookMatcherConfig(t *testing.T) { + t.Parallel() + + if HasHookMatcherConfig(nil) { + t.Fatal("nil matcher config should be false") + } + if HasHookMatcherConfig(map[string]any{}) { + t.Fatal("empty matcher config should be false") + } + if !HasHookMatcherConfig(map[string]any{"tool_name": "bash"}) { + t.Fatal("tool_name matcher should be true") + } + if !HasHookMatcherConfig(map[string]any{"tool_name_regex": []any{"^bash$"}}) { + t.Fatal("tool_name_regex matcher should be true") + } + if !HasHookMatcherConfig(map[string]any{"arguments_contains": []string{"rm -rf"}}) { + t.Fatal("arguments_contains matcher should be true") + } +} + +func TestCompileHookMatcherAndMatch(t *testing.T) { + t.Parallel() + + matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": []any{"bash", "filesystem"}, + "tool_name_regex": []string{`^(bash|shell)$`}, + "arguments_contains": []string{"rm -rf"}, + }) + if err != nil { + t.Fatalf("CompileHookMatcher() error = %v", err) + } + if matcher == nil { + t.Fatal("expected matcher to be compiled") + } + + if !matcher.Match(HookContext{ + Metadata: map[string]any{ + "tool_name": "bash", + "tool_arguments_preview": "sudo rm -rf /tmp/test", + }, + }) { + t.Fatal("expected matcher to pass for matching metadata") + } + if matcher.Match(HookContext{ + Metadata: map[string]any{ + "tool_name": "bash", + "tool_arguments_preview": "echo hello", + }, + }) { + t.Fatal("expected matcher to fail when arguments_contains not matched") + } + if matcher.Match(HookContext{ + Metadata: map[string]any{ + "tool_name": "filesystem", + "tool_arguments_preview": "rm -rf /tmp", + }, + }) { + t.Fatal("expected matcher to fail when tool_name_regex not matched") + } +} + +func TestCompileHookMatcherValidation(t *testing.T) { + t.Parallel() + + if _, err := CompileHookMatcher(HookPointSessionStart, map[string]any{ + "tool_name": "bash", + }); err == nil { + t.Fatal("expected session_start tool_name matcher to be rejected") + } + + if _, err := CompileHookMatcher(HookPointAfterToolResult, map[string]any{ + "arguments_contains": []string{"rm -rf"}, + }); err == nil { + t.Fatal("expected after_tool_result arguments_contains to be rejected") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name_regex": "(", + }); err == nil { + t.Fatal("expected invalid regex to fail") + } + + longRegex := make([]byte, MaxHookMatcherRegexLength+1) + for i := range longRegex { + longRegex[i] = 'a' + } + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name_regex": string(longRegex), + }); err == nil { + t.Fatal("expected overlong regex to fail") + } +} diff --git a/internal/runtime/hooks/types.go b/internal/runtime/hooks/types.go index 2f928f5c5..e329555b7 100644 --- a/internal/runtime/hooks/types.go +++ b/internal/runtime/hooks/types.go @@ -45,22 +45,77 @@ type HookPointCapability struct { CanAnnotate bool CanUpdateInput bool UserAllowed bool + Matcher HookMatcherCapability +} + +// HookMatcherCapability 描述点位可用的 matcher 维度。 +type HookMatcherCapability struct { + ToolName bool + ToolNameRegex bool + ArgumentsContains bool } var hookPointCapabilities = map[HookPoint]HookPointCapability{ - HookPointBeforeToolCall: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointAfterToolResult: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointBeforeCompletionDecision: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointAcceptGate: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointBeforePermissionDecision: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false}, - HookPointAfterToolFailure: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointSessionStart: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointSessionEnd: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointUserPromptSubmit: {CanBlock: true, CanAnnotate: true, CanUpdateInput: true, UserAllowed: true}, - HookPointPreCompact: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false}, - HookPointPostCompact: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointSubAgentStart: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false}, - HookPointSubAgentStop: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, + HookPointBeforeToolCall: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: true, + }, + }, + HookPointAfterToolResult: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: false, + }, + }, + HookPointBeforeCompletionDecision: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointAcceptGate: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointBeforePermissionDecision: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: false, + }, + }, + HookPointAfterToolFailure: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: true, + }, + }, + HookPointSessionStart: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointSessionEnd: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointUserPromptSubmit: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: true, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointPreCompact: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false, + Matcher: HookMatcherCapability{}, + }, + HookPointPostCompact: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointSubAgentStart: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false, + Matcher: HookMatcherCapability{}, + }, + HookPointSubAgentStop: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, } // HookScope 描述 hook 的权限/上下文裁剪等级。 @@ -142,6 +197,10 @@ type HookSpec struct { Timeout time.Duration FailurePolicy FailurePolicy Handler HookHandler + Matcher *HookMatcher + + // MatcherMigrationWarning 用于在运行时提示 warn_on_tool_call 旧参数与 match 共存时的迁移风险。 + MatcherMigrationWarning string } // normalizeAndValidate 将 HookSpec 归一化并校验当前阶段可用字段。 diff --git a/internal/runtime/hooks/types_test.go b/internal/runtime/hooks/types_test.go index 9a81cb07e..3664133cb 100644 --- a/internal/runtime/hooks/types_test.go +++ b/internal/runtime/hooks/types_test.go @@ -217,6 +217,9 @@ func TestHookPointCapabilities(t *testing.T) { if !capability.CanBlock { t.Fatal("before_permission_decision should allow block") } + if !capability.Matcher.ToolName || !capability.Matcher.ToolNameRegex { + t.Fatal("before_permission_decision should support tool_name/tool_name_regex matcher") + } capability, ok = HookPointCapabilities(HookPointAfterToolFailure) if !ok { @@ -225,6 +228,9 @@ func TestHookPointCapabilities(t *testing.T) { if capability.CanBlock { t.Fatal("after_tool_failure should be observe-only") } + if !capability.Matcher.ArgumentsContains { + t.Fatal("after_tool_failure should support arguments_contains matcher") + } capability, ok = HookPointCapabilities(HookPointBeforeCompletionDecision) if !ok { @@ -241,6 +247,9 @@ func TestHookPointCapabilities(t *testing.T) { if !capability.CanBlock { t.Fatal("accept_gate should allow block") } + if capability.Matcher.ToolName || capability.Matcher.ToolNameRegex || capability.Matcher.ArgumentsContains { + t.Fatal("accept_gate should not expose matcher dimensions") + } if _, exists := HookPointCapabilities(HookPoint("unknown")); exists { t.Fatal("unknown hook point should not have capability") diff --git a/internal/runtime/repo_hooks.go b/internal/runtime/repo_hooks.go index b22bb343d..7562f1795 100644 --- a/internal/runtime/repo_hooks.go +++ b/internal/runtime/repo_hooks.go @@ -359,13 +359,24 @@ func validateRepoHookItem(item config.RuntimeHookItemConfig) error { default: return fmt.Errorf("handler %q is not supported", item.Handler) } - if handler == "warn_on_tool_call" && !runtimeHasWarnOnToolCallTargets(item.Params) { - return fmt.Errorf("handler %q requires params.tool_name or params.tool_names", item.Handler) + hasExplicitMatcher := runtimehooks.HasHookMatcherConfig(item.Match) + if handler == "warn_on_tool_call" && !hasExplicitMatcher && !runtimeHasWarnOnToolCallTargets(item.Params) { + return fmt.Errorf("handler %q requires match or params.tool_name/tool_names", item.Handler) + } + if matcherRaw := resolveConfiguredHookMatcherRaw(item); matcherRaw != nil { + if err := runtimehooks.ValidateHookMatcher(point, matcherRaw); err != nil { + return fmt.Errorf("match: %w", err) + } } case repoHookKindCommand: if err := runtimehooks.ValidateCommandParams(item.Params); err != nil { return err } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(point, item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } } return nil } diff --git a/internal/runtime/repo_hooks_test.go b/internal/runtime/repo_hooks_test.go index 64688256c..90e17b6c1 100644 --- a/internal/runtime/repo_hooks_test.go +++ b/internal/runtime/repo_hooks_test.go @@ -618,6 +618,49 @@ func TestRuntimeHasWarnOnToolCallTargetsBranches(t *testing.T) { } } +func TestValidateRepoHookItemAllowsWarnOnToolCallWithMatchOnly(t *testing.T) { + t.Parallel() + + item := config.RuntimeHookItemConfig{ + ID: "repo-warn-match", + Point: "before_tool_call", + Scope: "repo", + Kind: "builtin", + Mode: "sync", + Handler: "warn_on_tool_call", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Match: map[string]any{ + "tool_name": "bash", + }, + } + if err := validateRepoHookItem(item); err != nil { + t.Fatalf("validateRepoHookItem() error = %v", err) + } +} + +func TestValidateRepoHookItemRejectsUnsupportedMatcherDimension(t *testing.T) { + t.Parallel() + + item := config.RuntimeHookItemConfig{ + ID: "repo-session-match", + Point: "session_start", + Scope: "repo", + Kind: "builtin", + Mode: "sync", + Handler: "add_context_note", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Params: map[string]any{"note": "repo"}, + Match: map[string]any{ + "tool_name": "bash", + }, + } + if err := validateRepoHookItem(item); err == nil { + t.Fatal("expected unsupported matcher dimension to fail") + } +} + func TestResolveRepoHooksPathBranches(t *testing.T) { workspace := t.TempDir() hooksPath := filepath.Join(workspace, ".neocode", "hooks.yaml") diff --git a/internal/runtime/toolexec.go b/internal/runtime/toolexec.go index 49a821585..1698b6e03 100644 --- a/internal/runtime/toolexec.go +++ b/internal/runtime/toolexec.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "path/filepath" + "regexp" "sort" "strings" "sync" @@ -22,6 +23,12 @@ type indexedToolCall struct { call providertypes.ToolCall } +const hookToolArgumentsPreviewMaxChars = 512 + +var hookToolArgumentsSensitivePattern = regexp.MustCompile( + `(?i)(token|password|secret|api[_-]?key|access[_-]?key|auth)\s*[:=]\s*("[^"]*"|'[^']*'|[^\s]+)`, +) + // executeAssistantToolCalls 并发执行 assistant 返回的全部工具调用并返回结构化执行摘要。 func (s *Service) executeAssistantToolCalls( ctx context.Context, @@ -160,10 +167,11 @@ func (s *Service) executeOneToolCallWithoutPersistence( beforeToolHookOutput := s.runHookPoint(ctx, state, runtimehooks.HookPointBeforeToolCall, runtimehooks.HookContext{ Metadata: map[string]any{ - "tool_call_id": strings.TrimSpace(call.ID), - "tool_name": strings.TrimSpace(call.Name), - "tool_arguments": strings.TrimSpace(call.Arguments), - "workdir": strings.TrimSpace(snapshot.Workdir), + "tool_call_id": strings.TrimSpace(call.ID), + "tool_name": strings.TrimSpace(call.Name), + "tool_arguments": strings.TrimSpace(call.Arguments), + "tool_arguments_preview": buildToolArgumentsPreview(call.Arguments), + "workdir": strings.TrimSpace(snapshot.Workdir), }, }) if beforeToolHookOutput.Blocked { @@ -752,6 +760,28 @@ func summarizeHookResultContent(content string) string { return trimmed[:256] } +// buildToolArgumentsPreview 生成 matcher 可用的参数预览,并对敏感键值执行脱敏。 +func buildToolArgumentsPreview(arguments string) string { + trimmed := strings.TrimSpace(arguments) + if trimmed == "" { + return "" + } + masked := hookToolArgumentsSensitivePattern.ReplaceAllString(trimmed, `$1=***`) + return truncateHookTextByChars(masked, hookToolArgumentsPreviewMaxChars) +} + +// truncateHookTextByChars 按字符长度截断文本,避免 metadata 放大。 +func truncateHookTextByChars(text string, maxChars int) string { + if maxChars <= 0 { + return "" + } + runes := []rune(text) + if len(runes) <= maxChars { + return text + } + return string(runes[:maxChars]) +} + // extractTodoIDsFromPayload 提取 todo 事件快照中的条目 ID,用于冲突事实去重统计。 func extractTodoIDsFromPayload(items []TodoViewItem) []string { if len(items) == 0 { @@ -828,11 +858,12 @@ func (s *Service) emitAfterToolFailureHook( workdir string, ) { afterToolFailureMetadata := map[string]any{ - "tool_call_id": strings.TrimSpace(call.ID), - "tool_name": strings.TrimSpace(call.Name), - "is_error": result.IsError, - "error_class": strings.TrimSpace(result.ErrorClass), - "workdir": strings.TrimSpace(workdir), + "tool_call_id": strings.TrimSpace(call.ID), + "tool_name": strings.TrimSpace(call.Name), + "tool_arguments_preview": buildToolArgumentsPreview(call.Arguments), + "is_error": result.IsError, + "error_class": strings.TrimSpace(result.ErrorClass), + "workdir": strings.TrimSpace(workdir), } if execErr != nil { afterToolFailureMetadata["execution_error"] = strings.TrimSpace(execErr.Error()) diff --git a/internal/runtime/user_hooks.go b/internal/runtime/user_hooks.go index ec7a4f159..49accc51f 100644 --- a/internal/runtime/user_hooks.go +++ b/internal/runtime/user_hooks.go @@ -197,48 +197,55 @@ func buildConfiguredHookSpec( if err := validateConfiguredHookItemForP6Lite(item, scope); err != nil { return runtimehooks.HookSpec{}, err } + point := runtimehooks.HookPoint(strings.TrimSpace(item.Point)) + matcher, matcherWarning, sanitizedParams, err := buildConfiguredHookMatcher(item, point) + if err != nil { + return runtimehooks.HookSpec{}, err + } kind := strings.ToLower(strings.TrimSpace(item.Kind)) specKind := runtimehooks.HookKindFunction specMode := runtimehooks.HookModeSync var ( - handler runtimehooks.HookHandler - err error + handler runtimehooks.HookHandler + buildErr error ) switch kind { case configuredHookKindBuiltin: - handler, err = buildUserBuiltinHookHandler(strings.TrimSpace(item.Handler), item.Params, defaultWorkdir) + handler, buildErr = buildUserBuiltinHookHandler(strings.TrimSpace(item.Handler), sanitizedParams, defaultWorkdir) specKind = runtimehooks.HookKindFunction specMode = runtimehooks.HookModeSync case configuredHookKindCommand: - handler, err = buildUserCommandHookHandler( + handler, buildErr = buildUserCommandHookHandler( strings.TrimSpace(item.ID), - runtimehooks.HookPoint(strings.TrimSpace(item.Point)), + point, item.Params, defaultWorkdir, ) specKind = runtimehooks.HookKindCommand specMode = runtimehooks.HookModeSync case configuredHookKindHTTP: - handler, err = buildUserHTTPObserveHookHandler(item) + handler, buildErr = buildUserHTTPObserveHookHandler(item) specKind = runtimehooks.HookKindHTTP specMode = runtimehooks.HookModeObserve default: return runtimehooks.HookSpec{}, fmt.Errorf("kind %q is not supported", item.Kind) } - if err != nil { - return runtimehooks.HookSpec{}, err + if buildErr != nil { + return runtimehooks.HookSpec{}, buildErr } return runtimehooks.HookSpec{ - ID: strings.TrimSpace(item.ID), - Point: runtimehooks.HookPoint(strings.TrimSpace(item.Point)), - Scope: scope, - Source: source, - Kind: specKind, - Mode: specMode, - Priority: item.Priority, - Timeout: time.Duration(item.TimeoutSec) * time.Second, - FailurePolicy: mapRuntimeHookFailurePolicy(item.FailurePolicy), - Handler: handler, + ID: strings.TrimSpace(item.ID), + Point: point, + Scope: scope, + Source: source, + Kind: specKind, + Mode: specMode, + Priority: item.Priority, + Timeout: time.Duration(item.TimeoutSec) * time.Second, + FailurePolicy: mapRuntimeHookFailurePolicy(item.FailurePolicy), + Handler: handler, + Matcher: matcher, + MatcherMigrationWarning: matcherWarning, }, nil } @@ -257,6 +264,17 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported", item.Mode) } + handler := strings.ToLower(strings.TrimSpace(item.Handler)) + hasExplicitMatcher := runtimehooks.HasHookMatcherConfig(item.Match) + if handler == "warn_on_tool_call" && !hasExplicitMatcher && !runtimeHasWarnOnToolCallTargets(item.Params) { + return fmt.Errorf("handler %q requires match or params.tool_name/tool_names", item.Handler) + } + matcherRaw := resolveConfiguredHookMatcherRaw(item) + if matcherRaw != nil { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), matcherRaw); err != nil { + return fmt.Errorf("match: %w", err) + } + } case configuredHookKindCommand: if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", item.Mode) @@ -264,6 +282,11 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if _, _, err := runtimehooks.ParseCommandParams(item.Params); err != nil { return err } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } case configuredHookKindHTTP: if mode != configuredHookModeObserve { return fmt.Errorf("mode %q is not supported for kind http (only observe)", item.Mode) @@ -272,6 +295,11 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if policy == "fail_closed" { return fmt.Errorf("failure_policy %q is not supported for kind http observe", item.FailurePolicy) } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } default: if isExternalHookKind(kind) { return fmt.Errorf( @@ -284,6 +312,106 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop return nil } +// resolveConfiguredHookMatcherRaw 返回运行时装配阶段应使用的 matcher 原始配置。 +func resolveConfiguredHookMatcherRaw(item config.RuntimeHookItemConfig) map[string]any { + if runtimehooks.HasHookMatcherConfig(item.Match) { + return item.Match + } + if strings.EqualFold(strings.TrimSpace(item.Handler), "warn_on_tool_call") && + runtimeHasWarnOnToolCallTargets(item.Params) { + return buildLegacyWarnMatcherFromParams(item.Params) + } + return nil +} + +// buildConfiguredHookMatcher 编译 hook matcher 并生成迁移告警,同时返回供 handler 使用的参数副本。 +func buildConfiguredHookMatcher( + item config.RuntimeHookItemConfig, + point runtimehooks.HookPoint, +) (*runtimehooks.HookMatcher, string, map[string]any, error) { + sanitizedParams := cloneHookParams(item.Params) + matcherRaw := resolveConfiguredHookMatcherRaw(item) + if matcherRaw == nil { + return nil, "", sanitizedParams, nil + } + matcher, err := runtimehooks.CompileHookMatcher(point, matcherRaw) + if err != nil { + return nil, "", nil, fmt.Errorf("match: %w", err) + } + if matcher == nil { + return nil, "", sanitizedParams, nil + } + explicitMatcher := runtimehooks.HasHookMatcherConfig(item.Match) + legacyWarnTargets := strings.EqualFold(strings.TrimSpace(item.Handler), "warn_on_tool_call") && + runtimeHasWarnOnToolCallTargets(item.Params) + warning := "" + if explicitMatcher && legacyWarnTargets { + warning = "hook matcher migration: match is configured; params.tool_name/tool_names on warn_on_tool_call are ignored" + delete(sanitizedParams, "tool_name") + delete(sanitizedParams, "tool_names") + } + return matcher, warning, sanitizedParams, nil +} + +// cloneHookParams 深拷贝 params,避免装配阶段修改影响原始配置对象。 +func cloneHookParams(params map[string]any) map[string]any { + if len(params) == 0 { + return nil + } + cloned := make(map[string]any, len(params)) + for key, value := range params { + cloned[key] = cloneHookParamValue(value) + } + return cloned +} + +// cloneHookParamValue 深拷贝 matcher/params 结构,避免 map/slice 底层共享。 +func cloneHookParamValue(value any) any { + switch typed := value.(type) { + case map[string]any: + cloned := make(map[string]any, len(typed)) + for key, item := range typed { + cloned[key] = cloneHookParamValue(item) + } + return cloned + case []any: + cloned := make([]any, len(typed)) + for index, item := range typed { + cloned[index] = cloneHookParamValue(item) + } + return cloned + case []string: + cloned := make([]string, len(typed)) + copy(cloned, typed) + return cloned + default: + return value + } +} + +// buildLegacyWarnMatcherFromParams 将 warn_on_tool_call 旧参数桥接为 matcher 配置。 +func buildLegacyWarnMatcherFromParams(params map[string]any) map[string]any { + if len(params) == 0 { + return nil + } + var toolNames []string + if name := strings.TrimSpace(readHookParamString(params, "tool_name")); name != "" { + toolNames = append(toolNames, name) + } + for _, value := range readHookParamStringSlice(params, "tool_names") { + if strings.TrimSpace(value) == "" { + continue + } + toolNames = append(toolNames, value) + } + if len(toolNames) == 0 { + return nil + } + return map[string]any{ + "tool_name": toolNames, + } +} + func isExternalHookKind(kind string) bool { switch strings.ToLower(strings.TrimSpace(kind)) { case "command", "http", "prompt", "agent": @@ -342,9 +470,6 @@ func buildUserBuiltinHookHandler( case "warn_on_tool_call": targetTool := strings.ToLower(strings.TrimSpace(readHookParamString(params, "tool_name"))) targetTools := normalizeHookParamStringSlice(readHookParamStringSlice(params, "tool_names")) - if targetTool == "" && len(targetTools) == 0 { - return nil, fmt.Errorf("handler warn_on_tool_call requires params.tool_name or params.tool_names") - } defaultMessage := "tool call matched warn_on_tool_call" if customMessage := strings.TrimSpace(readHookParamString(params, "message")); customMessage != "" { defaultMessage = customMessage @@ -352,6 +477,9 @@ func buildUserBuiltinHookHandler( return func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { _ = ctx toolName := strings.ToLower(strings.TrimSpace(readHookContextMetadataString(input, "tool_name"))) + if targetTool == "" && len(targetTools) == 0 { + return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: defaultMessage} + } if toolName == "" { return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} } diff --git a/internal/runtime/user_hooks_test.go b/internal/runtime/user_hooks_test.go index f328687ec..5decb4ea6 100644 --- a/internal/runtime/user_hooks_test.go +++ b/internal/runtime/user_hooks_test.go @@ -834,8 +834,12 @@ func TestBuildUserBuiltinHookHandlerEdgeCases(t *testing.T) { if _, err := buildUserBuiltinHookHandler("require_file_exists", map[string]any{}, t.TempDir()); err == nil { t.Fatal("expected missing path error") } - if _, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, t.TempDir()); err == nil { - t.Fatal("expected missing target error") + handlerWithoutTarget, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, t.TempDir()) + if err != nil { + t.Fatalf("build warn_on_tool_call without target error: %v", err) + } + if got := handlerWithoutTarget(context.Background(), runtimehooks.HookContext{}); got.Message == "" { + t.Fatalf("expected default warning message when no target is configured, got %+v", got) } if _, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{}, t.TempDir()); err == nil { t.Fatal("expected missing note/message error") @@ -1349,6 +1353,82 @@ func TestConfigureRuntimeHooksInjectsAsyncResultSinkIntoBaseExecutor(t *testing. t.Fatal("expected async rewake notification to be enqueued via configured async sink") } +func TestBuildUserHookSpecBridgesWarnOnToolCallLegacyParamsToMatcher(t *testing.T) { + t.Parallel() + + item := config.RuntimeHookItemConfig{ + ID: "warn-legacy", + Point: "before_tool_call", + Scope: "user", + Kind: "builtin", + Mode: "sync", + Handler: "warn_on_tool_call", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Params: map[string]any{ + "tool_names": []any{"bash"}, + "message": "legacy warning", + }, + } + + spec, err := buildUserHookSpec(item, t.TempDir()) + if err != nil { + t.Fatalf("buildUserHookSpec() error = %v", err) + } + if spec.Matcher == nil { + t.Fatal("expected legacy warn_on_tool_call params to be bridged into matcher") + } + if spec.Matcher.Match(runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "filesystem"}}) { + t.Fatal("matcher should reject unmatched tool") + } + if !spec.Matcher.Match(runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { + t.Fatal("matcher should accept configured legacy tool") + } +} + +func TestBuildUserHookSpecWarnOnToolCallMatchTakesPrecedenceAndEmitsMigrationWarning(t *testing.T) { + t.Parallel() + + item := config.RuntimeHookItemConfig{ + ID: "warn-conflict", + Point: "before_tool_call", + Scope: "user", + Kind: "builtin", + Mode: "sync", + Handler: "warn_on_tool_call", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Match: map[string]any{ + "tool_name": "filesystem", + }, + Params: map[string]any{ + "tool_name": "bash", + "message": "explicit matcher wins", + }, + } + + spec, err := buildUserHookSpec(item, t.TempDir()) + if err != nil { + t.Fatalf("buildUserHookSpec() error = %v", err) + } + if spec.Matcher == nil { + t.Fatal("expected matcher to be compiled") + } + if spec.MatcherMigrationWarning == "" { + t.Fatal("expected migration warning when match and legacy warn params coexist") + } + if !spec.Matcher.Match(runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "filesystem"}}) { + t.Fatal("matcher should follow explicit match config") + } + if spec.Matcher.Match(runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { + t.Fatal("legacy params should be ignored when explicit match exists") + } + result := spec.Handler(context.Background(), runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "filesystem"}}) + if result.Message != "explicit matcher wins" { + t.Fatalf("handler message = %q, want explicit matcher wins", result.Message) + } +} + type countingHookExecutor struct { calls atomic.Int32 output runtimehooks.RunOutput @@ -1518,8 +1598,8 @@ func TestUserHookHandlersAndPathChecks(t *testing.T) { if _, err := buildUserBuiltinHookHandler("require_file_exists", map[string]any{}, workdir); err == nil { t.Fatal("expected missing path error") } - if _, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, workdir); err == nil { - t.Fatal("expected missing tool target error") + if _, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, workdir); err != nil { + t.Fatalf("warn_on_tool_call without target should be allowed for matcher-based filtering: %v", err) } if _, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{}, workdir); err == nil { t.Fatal("expected missing note/message error") From aea182ad076f385b0e56759b0e974608f25219e4 Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Fri, 29 May 2026 10:00:35 +0800 Subject: [PATCH 2/3] @ refactor(hooks): remove warn_on_tool_call backward compat, reject unknown match fields - Remove legacy warn_on_tool_call bridging: params.tool_name/tool_names no longer route through matcher; handler only reads "message" param. - Reject unknown match keys at CompileHookMatcher level instead of silently ignoring, preventing misconfigured hooks from running unfiltered. - Strip MatcherMigrationWarning from HookSpec, emitMatcherMigrationWarning, and all migration-notification infrastructure. - Remove ~140 lines of compat helpers across config and runtime packages. - Raise hooks package coverage to 95.9% with new matcher edge-case tests. Co-Authored-By: Claude Opus 4.7 @ --- internal/config/loader_test.go | 3 +- internal/config/runtime_hooks.go | 91 +-------- internal/config/runtime_hooks_test.go | 23 +-- internal/runtime/hooks/executor.go | 39 ---- internal/runtime/hooks/executor_test.go | 37 ---- internal/runtime/hooks/matcher.go | 5 +- internal/runtime/hooks/matcher_test.go | 245 +++++++++++++++++++++++- internal/runtime/hooks/types.go | 5 +- internal/runtime/repo_hooks.go | 31 +-- internal/runtime/repo_hooks_test.go | 21 -- internal/runtime/user_hooks.go | 155 +++------------ internal/runtime/user_hooks_test.go | 129 +++---------- 12 files changed, 323 insertions(+), 461 deletions(-) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 65d14c6f9..1ebdca32d 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -126,8 +126,9 @@ runtime: priority: 100 timeout_sec: 2 failure_policy: warn_only - params: + match: tool_name: bash + params: message: "bash is called" ` writeLoaderConfig(t, loader, raw) diff --git a/internal/config/runtime_hooks.go b/internal/config/runtime_hooks.go index 86b4194c6..f5a87c98b 100644 --- a/internal/config/runtime_hooks.go +++ b/internal/config/runtime_hooks.go @@ -286,16 +286,14 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { default: return fmt.Errorf("handler %q is not supported", c.Handler) } - hasExplicitMatcher := hooks.HasHookMatcherConfig(c.Match) - if handler == runtimeHookHandlerWarnOnToolCall && !hasExplicitMatcher && !hasWarnOnToolCallTargets(c.Params) { - return fmt.Errorf("handler %q requires match or params.tool_name/tool_names", c.Handler) - } - matcherRaw := resolveRuntimeHookMatcherConfigForValidation(c, handler) - if matcherRaw != nil { - if err := hooks.ValidateHookMatcher(point, matcherRaw); err != nil { - return fmt.Errorf("match: %w", err) + if handler == runtimeHookHandlerWarnOnToolCall && !hooks.HasHookMatcherConfig(c.Match) { + return fmt.Errorf("handler %q requires match", c.Handler) + } + if hooks.HasHookMatcherConfig(c.Match) { + if err := hooks.ValidateHookMatcher(point, c.Match); err != nil { + return fmt.Errorf("match: %w", err) + } } - } case runtimeHookKindCommand: if normalizedMode != runtimeHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", c.Mode) @@ -422,81 +420,6 @@ func cloneRuntimeHookParamValue(value any) any { } } -func hasWarnOnToolCallTargets(params map[string]any) bool { - if len(params) == 0 { - return false - } - toolNameRaw, hasToolName := params["tool_name"] - if hasToolName && strings.TrimSpace(fmt.Sprintf("%v", toolNameRaw)) != "" { - return true - } - toolNamesRaw, hasToolNames := params["tool_names"] - if !hasToolNames || toolNamesRaw == nil { - return false - } - switch typed := toolNamesRaw.(type) { - case []string: - for _, item := range typed { - if strings.TrimSpace(item) != "" { - return true - } - } - case []any: - for _, item := range typed { - if strings.TrimSpace(fmt.Sprintf("%v", item)) != "" { - return true - } - } - } - return false -} - -// resolveRuntimeHookMatcherConfigForValidation 返回配置校验阶段的 matcher 配置。 -// 对 warn_on_tool_call 保持旧参数兼容:当未配置 match 时自动桥接 tool_name/tool_names。 -func resolveRuntimeHookMatcherConfigForValidation(item RuntimeHookItemConfig, handler string) map[string]any { - if hooks.HasHookMatcherConfig(item.Match) { - return item.Match - } - if strings.EqualFold(strings.TrimSpace(handler), runtimeHookHandlerWarnOnToolCall) && hasWarnOnToolCallTargets(item.Params) { - return runtimeHookLegacyWarnMatcherConfig(item.Params) - } - return nil -} - -// runtimeHookLegacyWarnMatcherConfig 将 warn_on_tool_call 旧参数桥接为 matcher 配置。 -func runtimeHookLegacyWarnMatcherConfig(params map[string]any) map[string]any { - if len(params) == 0 { - return nil - } - var toolNames []string - if name := strings.TrimSpace(readRuntimeHookParamString(params, "tool_name")); name != "" { - toolNames = append(toolNames, name) - } - if raw, ok := params["tool_names"]; ok && raw != nil { - switch typed := raw.(type) { - case []string: - toolNames = append(toolNames, typed...) - case []any: - for _, value := range typed { - toolNames = append(toolNames, strings.TrimSpace(fmt.Sprintf("%v", value))) - } - } - } - filtered := make([]string, 0, len(toolNames)) - for _, value := range toolNames { - if strings.TrimSpace(value) == "" { - continue - } - filtered = append(filtered, value) - } - if len(filtered) == 0 { - return nil - } - return map[string]any{ - "tool_name": filtered, - } -} - // readRuntimeHookParamString 以兼容方式读取 runtime hook 参数中的字符串值。 func readRuntimeHookParamString(params map[string]any, key string) string { if len(params) == 0 { diff --git a/internal/config/runtime_hooks_test.go b/internal/config/runtime_hooks_test.go index c537caf13..e36145026 100644 --- a/internal/config/runtime_hooks_test.go +++ b/internal/config/runtime_hooks_test.go @@ -390,13 +390,15 @@ func TestRuntimeHooksConfigItemDefaultsAndClone(t *testing.T) { ID: "warn-bash", Point: string(hooks.HookPointBeforeToolCall), Handler: runtimeHookHandlerWarnOnToolCall, - Params: map[string]any{ + Match: map[string]any{ "tool_name": "bash", - "tags": []any{"warn", "tool"}, + }, + Params: map[string]any{ + "tags": []any{"warn", "tool"}, }, }, }, - } +} cfg.ApplyDefaults(defaultRuntimeHooksConfig()) item := cfg.Items[0] @@ -664,21 +666,6 @@ func TestRuntimeHooksConfigEdgeBranches(t *testing.T) { t.Fatal("expected deep clone for nested map in slice") } - if hasWarnOnToolCallTargets(nil) { - t.Fatal("nil params should be false") - } - if !hasWarnOnToolCallTargets(map[string]any{"tool_name": "bash"}) { - t.Fatal("tool_name should pass") - } - if !hasWarnOnToolCallTargets(map[string]any{"tool_names": []string{"", "bash"}}) { - t.Fatal("tool_names []string should pass") - } - if !hasWarnOnToolCallTargets(map[string]any{"tool_names": []any{"", "bash"}}) { - t.Fatal("tool_names []any should pass") - } - if hasWarnOnToolCallTargets(map[string]any{"tool_names": "bash"}) { - t.Fatal("tool_names scalar should fail") - } matchCfg := RuntimeHookItemConfig{ Match: map[string]any{ diff --git a/internal/runtime/hooks/executor.go b/internal/runtime/hooks/executor.go index be9a1f74b..2586483a3 100644 --- a/internal/runtime/hooks/executor.go +++ b/internal/runtime/hooks/executor.go @@ -23,7 +23,6 @@ type Executor struct { defaultTimeout time.Duration maxInFlight int32 inFlight atomic.Int32 - migrationWarns sync.Map now func() time.Time asyncSink AsyncResultSink } @@ -80,7 +79,6 @@ func (e *Executor) Run(ctx context.Context, point HookPoint, input HookContext) if spec.Scope == HookScopeUser || spec.Scope == HookScopeRepo { hookInput = sanitizeUserHookContext(hookInput) } - e.emitMatcherMigrationWarning(ctx, spec) if spec.Matcher != nil && !spec.Matcher.Match(hookInput) { continue } @@ -108,43 +106,6 @@ func (e *Executor) Run(ctx context.Context, point HookPoint, input HookContext) return output } -// emitMatcherMigrationWarning 在 detect 到旧 warn_on_tool_call 参数与 match 共存时发出一次迁移提示事件。 -func (e *Executor) emitMatcherMigrationWarning(ctx context.Context, spec HookSpec) { - if e == nil { - return - } - message := strings.TrimSpace(spec.MatcherMigrationWarning) - if message == "" { - return - } - dedupeKey := strings.ToLower(strings.TrimSpace( - fmt.Sprintf("%s|%s|%s|%s", spec.ID, spec.Point, spec.Scope, spec.Source), - )) - if dedupeKey == "" { - dedupeKey = strings.ToLower(strings.TrimSpace(spec.ID)) - } - if dedupeKey == "" { - dedupeKey = "matcher_migration_warning" - } - if _, loaded := e.migrationWarns.LoadOrStore(dedupeKey, struct{}{}); loaded { - return - } - e.emitBestEffort(ctx, HookEvent{ - Type: HookEventNotification, - HookID: spec.ID, - Point: spec.Point, - Scope: spec.Scope, - Source: spec.Source, - Kind: spec.Kind, - Mode: spec.Mode, - Status: HookResultPass, - Message: message, - RewakeReason: "matcher_migration", - RewakeSummary: message, - DedupeKey: dedupeKey, - }) -} - // normalizeHookResultByCapability 根据 HookPoint 能力矩阵约束单条结果。 func normalizeHookResultByCapability(point HookPoint, result HookResult) HookResult { capability, ok := HookPointCapabilities(point) diff --git a/internal/runtime/hooks/executor_test.go b/internal/runtime/hooks/executor_test.go index b3e2e25cb..8953c9775 100644 --- a/internal/runtime/hooks/executor_test.go +++ b/internal/runtime/hooks/executor_test.go @@ -1053,40 +1053,3 @@ func TestExecutorSkipsHookWhenMatcherMissed(t *testing.T) { } } -func TestExecutorEmitsMatcherMigrationWarningOnce(t *testing.T) { - t.Parallel() - - registry := NewRegistry() - emitter := &recordingEmitter{} - executor := NewExecutor(registry, emitter, 100*time.Millisecond) - if err := registry.Register(HookSpec{ - ID: "matcher-warning-hook", - Point: HookPointBeforeToolCall, - Scope: HookScopeUser, - Source: HookSourceUser, - MatcherMigrationWarning: "matcher migration warning", - Handler: func(context.Context, HookContext) HookResult { - return HookResult{Status: HookResultPass} - }, - }); err != nil { - t.Fatalf("Register() error = %v", err) - } - - _ = executor.Run(context.Background(), HookPointBeforeToolCall, HookContext{}) - _ = executor.Run(context.Background(), HookPointBeforeToolCall, HookContext{}) - - events := emitter.snapshot() - warningCount := 0 - for _, event := range events { - if event.Type != HookEventNotification { - continue - } - warningCount++ - if event.RewakeReason != "matcher_migration" { - t.Fatalf("notification reason = %q, want matcher_migration", event.RewakeReason) - } - } - if warningCount != 1 { - t.Fatalf("matcher migration warning count = %d, want 1", warningCount) - } -} diff --git a/internal/runtime/hooks/matcher.go b/internal/runtime/hooks/matcher.go index 485418b49..493be0126 100644 --- a/internal/runtime/hooks/matcher.go +++ b/internal/runtime/hooks/matcher.go @@ -51,9 +51,12 @@ func ValidateHookMatcher(point HookPoint, raw map[string]any) error { // CompileHookMatcher 将 matcher 原始配置编译为可执行结构,并在点位能力上做 fail-fast 校验。 func CompileHookMatcher(point HookPoint, raw map[string]any) (*HookMatcher, error) { - if !HasHookMatcherConfig(raw) { + if len(raw) == 0 { return nil, nil } + if !HasHookMatcherConfig(raw) { + return nil, fmt.Errorf("match contains no recognized matcher fields (expected: tool_name, tool_name_regex, arguments_contains)") + } capability, ok := HookPointCapabilities(point) if !ok { return nil, fmt.Errorf("point %q is not supported", point) diff --git a/internal/runtime/hooks/matcher_test.go b/internal/runtime/hooks/matcher_test.go index 895aa28af..99f088a53 100644 --- a/internal/runtime/hooks/matcher_test.go +++ b/internal/runtime/hooks/matcher_test.go @@ -1,6 +1,9 @@ package hooks -import "testing" +import ( + "regexp" + "testing" +) func TestHasHookMatcherConfig(t *testing.T) { t.Parallel() @@ -20,6 +23,9 @@ func TestHasHookMatcherConfig(t *testing.T) { if !HasHookMatcherConfig(map[string]any{"arguments_contains": []string{"rm -rf"}}) { t.Fatal("arguments_contains matcher should be true") } + if HasHookMatcherConfig(map[string]any{"tool_name": " "}) { + t.Fatal("whitespace-only tool_name should be false") + } } func TestCompileHookMatcherAndMatch(t *testing.T) { @@ -93,4 +99,241 @@ func TestCompileHookMatcherValidation(t *testing.T) { }); err == nil { t.Fatal("expected overlong regex to fail") } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_names": "bash", + }); err == nil { + t.Fatal("expected unknown matcher field to be rejected") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "unknown": "value", + }); err == nil { + t.Fatal("expected completely unknown matcher field to be rejected") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, nil); err != nil { + t.Fatal("nil raw should succeed with nil matcher") + } + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{}); err != nil { + t.Fatal("empty raw should succeed with nil matcher") + } +} + +func TestValidateHookMatcher(t *testing.T) { + t.Parallel() + + if err := ValidateHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + }); err != nil { + t.Fatalf("ValidateHookMatcher() error = %v", err) + } + if err := ValidateHookMatcher(HookPointSessionStart, map[string]any{ + "tool_name": "bash", + }); err == nil { + t.Fatal("expected session_start matcher to fail validation") + } +} + +func TestIsEmpty(t *testing.T) { + t.Parallel() + + var nilMatcher *HookMatcher + if !nilMatcher.IsEmpty() { + t.Fatal("nil matcher should be empty") + } + if !(&HookMatcher{}).IsEmpty() { + t.Fatal("zero-value matcher should be empty") + } + if (&HookMatcher{ToolNames: []string{"bash"}}).IsEmpty() { + t.Fatal("matcher with tool_name should not be empty") + } +} + +func TestMatchNilAndEmpty(t *testing.T) { + t.Parallel() + + var nilMatcher *HookMatcher + if !nilMatcher.Match(HookContext{}) { + t.Fatal("nil matcher should match everything") + } + empty := &HookMatcher{} + if !empty.Match(HookContext{}) { + t.Fatal("empty matcher should match everything") + } +} + +func TestMatchSingleDimension(t *testing.T) { + t.Parallel() + + t.Run("tool_name only", func(t *testing.T) { + t.Parallel() + m := &HookMatcher{ToolNames: []string{"bash", "filesystem"}} + if !m.Match(HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { + t.Fatal("expected match for bash") + } + if m.Match(HookContext{Metadata: map[string]any{"tool_name": "python"}}) { + t.Fatal("expected no match for python") + } + if m.Match(HookContext{Metadata: map[string]any{}}) { + t.Fatal("expected no match when tool_name metadata missing") + } + }) + + t.Run("tool_name_regex only", func(t *testing.T) { + t.Parallel() + compiled := regexp.MustCompile(`^(bash|shell)$`) + m := &HookMatcher{ToolNameRegex: []*regexp.Regexp{compiled}} + if !m.Match(HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { + t.Fatal("expected regex match for bash") + } + if m.Match(HookContext{Metadata: map[string]any{"tool_name": "python"}}) { + t.Fatal("expected regex no match for python") + } + if m.Match(HookContext{Metadata: map[string]any{}}) { + t.Fatal("expected no match when tool_name missing for regex") + } + }) + + t.Run("arguments_contains only", func(t *testing.T) { + t.Parallel() + m := &HookMatcher{ArgumentsContains: []string{"rm -rf", "sudo"}} + if !m.Match(HookContext{Metadata: map[string]any{"tool_arguments_preview": "sudo rm -rf /tmp"}}) { + t.Fatal("expected arguments_contains match") + } + if m.Match(HookContext{Metadata: map[string]any{"tool_arguments_preview": "echo hello"}}) { + t.Fatal("expected arguments_contains no match") + } + if m.Match(HookContext{Metadata: map[string]any{}}) { + t.Fatal("expected no match when arguments_preview missing") + } + }) +} + +func TestReadHookMatcherStringValues(t *testing.T) { + t.Parallel() + + if got := readHookMatcherStringValues(nil, "x"); len(got) != 0 { + t.Fatal("nil raw should return nil") + } + if got := readHookMatcherStringValues(map[string]any{}, "x"); len(got) != 0 { + t.Fatal("empty raw should return nil") + } + if got := readHookMatcherStringValues(map[string]any{"x": nil}, "x"); len(got) != 0 { + t.Fatal("nil value should return nil") + } + if got := readHookMatcherStringValues(map[string]any{"x": " "}, "x"); len(got) != 0 { + t.Fatal("whitespace-only string should return nil") + } + if got := readHookMatcherStringValues(map[string]any{"x": 42}, "x"); len(got) != 1 || got[0] != "42" { + t.Fatalf("int value should be converted to string, got %v", got) + } + if got := readHookMatcherStringValues(map[string]any{"x": []any{" a ", nil, 123}}, "x"); len(got) != 2 || got[0] != "a" || got[1] != "123" { + t.Fatalf("[]any with mixed values, got %v", got) + } + if got := readHookMatcherStringValues(map[string]any{"x": "hello"}, "y"); len(got) != 0 { + t.Fatal("missing key should return nil") + } +} + +func TestNormalizeHookMatcherValues(t *testing.T) { + t.Parallel() + + if got := normalizeHookMatcherValues(nil); len(got) != 0 { + t.Fatal("nil values should return nil") + } + if got := normalizeHookMatcherValues([]string{}); len(got) != 0 { + t.Fatal("empty values should return nil") + } + if got := normalizeHookMatcherValues([]string{" ", "\t"}); len(got) != 0 { + t.Fatal("whitespace-only values should return empty") + } + if got := normalizeHookMatcherValues([]string{" BASH ", "", " Filesystem "}); len(got) != 2 || got[0] != "bash" || got[1] != "filesystem" { + t.Fatalf("mixed values should be normalized, got %v", got) + } +} + +func TestContainsEqualFold(t *testing.T) { + t.Parallel() + + if containsEqualFold(nil, "bash") { + t.Fatal("nil values should not match") + } + if containsEqualFold([]string{"bash"}, "") { + t.Fatal("empty target should not match") + } + if containsEqualFold([]string{"bash"}, " ") { + t.Fatal("whitespace-only target should not match") + } + if !containsEqualFold([]string{"BASH", "FILESYSTEM"}, " bash ") { + t.Fatal("case-insensitive match should work") + } + if containsEqualFold([]string{"bash"}, "python") { + t.Fatal("non-matching should return false") + } +} + +func TestReadHookMatcherMetadataString(t *testing.T) { + t.Parallel() + + if got := readHookMatcherMetadataString(nil, "x"); got != "" { + t.Fatal("nil metadata should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{}, "x"); got != "" { + t.Fatal("empty metadata should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"x": nil}, "x"); got != "" { + t.Fatal("nil value should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"x": 123}, "x"); got != "123" { + t.Fatalf("non-string value should be converted, got %q", got) + } + if got := readHookMatcherMetadataString(map[string]any{"x": "hello"}, " "); got != "" { + t.Fatal("empty key should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"x": "hello"}, ""); got != "" { + t.Fatal("empty string key should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"TOOL_NAME": "bash"}, "tool_name"); got != "bash" { + t.Fatalf("case-insensitive key lookup failed, got %q", got) + } + if got := readHookMatcherMetadataString(map[string]any{"y": "hello"}, "x"); got != "" { + t.Fatal("missing key should return empty") + } +} + +func TestCompileHookMatcherRegexWhitespaceSkipped(t *testing.T) { + t.Parallel() + + matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + "tool_name_regex": []string{" ", "\t"}, + }) + if err != nil { + t.Fatalf("CompileHookMatcher() error = %v", err) + } + if matcher == nil { + t.Fatal("expected matcher compiled even when regex values are whitespace-only") + } + if len(matcher.ToolNameRegex) != 0 { + t.Fatalf("expected empty tool_name_regex slice, got %d entries", len(matcher.ToolNameRegex)) + } +} + + +func TestCompileHookMatcherRegexOnly(t *testing.T) { + t.Parallel() + + matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name_regex": `^bash`, + }) + if err != nil { + t.Fatalf("CompileHookMatcher() error = %v", err) + } + if matcher == nil { + t.Fatal("expected matcher compiled for regex-only config") + } + if !matcher.Match(HookContext{Metadata: map[string]any{"tool_name": "bash-script"}}) { + t.Fatal("expected regex to match") + } } diff --git a/internal/runtime/hooks/types.go b/internal/runtime/hooks/types.go index e329555b7..7cd378691 100644 --- a/internal/runtime/hooks/types.go +++ b/internal/runtime/hooks/types.go @@ -197,10 +197,7 @@ type HookSpec struct { Timeout time.Duration FailurePolicy FailurePolicy Handler HookHandler - Matcher *HookMatcher - - // MatcherMigrationWarning 用于在运行时提示 warn_on_tool_call 旧参数与 match 共存时的迁移风险。 - MatcherMigrationWarning string + Matcher *HookMatcher } // normalizeAndValidate 将 HookSpec 归一化并校验当前阶段可用字段。 diff --git a/internal/runtime/repo_hooks.go b/internal/runtime/repo_hooks.go index 7562f1795..292fdfeff 100644 --- a/internal/runtime/repo_hooks.go +++ b/internal/runtime/repo_hooks.go @@ -359,15 +359,14 @@ func validateRepoHookItem(item config.RuntimeHookItemConfig) error { default: return fmt.Errorf("handler %q is not supported", item.Handler) } - hasExplicitMatcher := runtimehooks.HasHookMatcherConfig(item.Match) - if handler == "warn_on_tool_call" && !hasExplicitMatcher && !runtimeHasWarnOnToolCallTargets(item.Params) { - return fmt.Errorf("handler %q requires match or params.tool_name/tool_names", item.Handler) - } - if matcherRaw := resolveConfiguredHookMatcherRaw(item); matcherRaw != nil { - if err := runtimehooks.ValidateHookMatcher(point, matcherRaw); err != nil { - return fmt.Errorf("match: %w", err) + if handler == "warn_on_tool_call" && !runtimehooks.HasHookMatcherConfig(item.Match) { + return fmt.Errorf("handler %q requires match", item.Handler) + } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(point, item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } } - } case repoHookKindCommand: if err := runtimehooks.ValidateCommandParams(item.Params); err != nil { return err @@ -381,22 +380,6 @@ func validateRepoHookItem(item config.RuntimeHookItemConfig) error { return nil } -// runtimeHasWarnOnToolCallTargets 判断 warn_on_tool_call 是否配置了至少一个目标工具。 -func runtimeHasWarnOnToolCallTargets(params map[string]any) bool { - if len(params) == 0 { - return false - } - if name := strings.TrimSpace(readHookParamString(params, "tool_name")); name != "" { - return true - } - for _, value := range readHookParamStringSlice(params, "tool_names") { - if strings.TrimSpace(value) != "" { - return true - } - } - return false -} - // evaluateWorkspaceTrust 根据 trust store 判断 workspace 是否可信并附带容错诊断。 func evaluateWorkspaceTrust(workspace string) trustDecision { storePath := resolveTrustedWorkspacesPath() diff --git a/internal/runtime/repo_hooks_test.go b/internal/runtime/repo_hooks_test.go index 90e17b6c1..a4f74c9fc 100644 --- a/internal/runtime/repo_hooks_test.go +++ b/internal/runtime/repo_hooks_test.go @@ -596,27 +596,6 @@ func TestValidateRepoHookItemRejectsExternalKindsWithP6LiteMessage(t *testing.T) } } -func TestRuntimeHasWarnOnToolCallTargetsBranches(t *testing.T) { - cases := []struct { - name string - params map[string]any - want bool - }{ - {name: "nil", params: nil, want: false}, - {name: "tool_name", params: map[string]any{"tool_name": "bash"}, want: true}, - {name: "tool_name blank", params: map[string]any{"tool_name": " "}, want: false}, - {name: "tool_names", params: map[string]any{"tool_names": []any{"bash"}}, want: true}, - {name: "tool_names blank", params: map[string]any{"tool_names": []any{" "}}, want: false}, - } - for _, tc := range cases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - if got := runtimeHasWarnOnToolCallTargets(tc.params); got != tc.want { - t.Fatalf("runtimeHasWarnOnToolCallTargets() = %v, want %v", got, tc.want) - } - }) - } -} func TestValidateRepoHookItemAllowsWarnOnToolCallWithMatchOnly(t *testing.T) { t.Parallel() diff --git a/internal/runtime/user_hooks.go b/internal/runtime/user_hooks.go index 49accc51f..8040b82a9 100644 --- a/internal/runtime/user_hooks.go +++ b/internal/runtime/user_hooks.go @@ -12,7 +12,6 @@ import ( "os" "path/filepath" "runtime" - "slices" "strings" "time" @@ -198,7 +197,7 @@ func buildConfiguredHookSpec( return runtimehooks.HookSpec{}, err } point := runtimehooks.HookPoint(strings.TrimSpace(item.Point)) - matcher, matcherWarning, sanitizedParams, err := buildConfiguredHookMatcher(item, point) + matcher, err := buildConfiguredHookMatcher(item, point) if err != nil { return runtimehooks.HookSpec{}, err } @@ -211,7 +210,7 @@ func buildConfiguredHookSpec( ) switch kind { case configuredHookKindBuiltin: - handler, buildErr = buildUserBuiltinHookHandler(strings.TrimSpace(item.Handler), sanitizedParams, defaultWorkdir) + handler, buildErr = buildUserBuiltinHookHandler(strings.TrimSpace(item.Handler), item.Params, defaultWorkdir) specKind = runtimehooks.HookKindFunction specMode = runtimehooks.HookModeSync case configuredHookKindCommand: @@ -245,7 +244,6 @@ func buildConfiguredHookSpec( FailurePolicy: mapRuntimeHookFailurePolicy(item.FailurePolicy), Handler: handler, Matcher: matcher, - MatcherMigrationWarning: matcherWarning, }, nil } @@ -264,17 +262,15 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported", item.Mode) } - handler := strings.ToLower(strings.TrimSpace(item.Handler)) - hasExplicitMatcher := runtimehooks.HasHookMatcherConfig(item.Match) - if handler == "warn_on_tool_call" && !hasExplicitMatcher && !runtimeHasWarnOnToolCallTargets(item.Params) { - return fmt.Errorf("handler %q requires match or params.tool_name/tool_names", item.Handler) - } - matcherRaw := resolveConfiguredHookMatcherRaw(item) - if matcherRaw != nil { - if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), matcherRaw); err != nil { - return fmt.Errorf("match: %w", err) + handler := strings.ToLower(strings.TrimSpace(item.Handler)) + if handler == "warn_on_tool_call" && !runtimehooks.HasHookMatcherConfig(item.Match) { + return fmt.Errorf("handler %q requires match", item.Handler) + } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } } - } case configuredHookKindCommand: if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", item.Mode) @@ -312,104 +308,17 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop return nil } -// resolveConfiguredHookMatcherRaw 返回运行时装配阶段应使用的 matcher 原始配置。 -func resolveConfiguredHookMatcherRaw(item config.RuntimeHookItemConfig) map[string]any { - if runtimehooks.HasHookMatcherConfig(item.Match) { - return item.Match - } - if strings.EqualFold(strings.TrimSpace(item.Handler), "warn_on_tool_call") && - runtimeHasWarnOnToolCallTargets(item.Params) { - return buildLegacyWarnMatcherFromParams(item.Params) - } - return nil -} -// buildConfiguredHookMatcher 编译 hook matcher 并生成迁移告警,同时返回供 handler 使用的参数副本。 -func buildConfiguredHookMatcher( - item config.RuntimeHookItemConfig, - point runtimehooks.HookPoint, -) (*runtimehooks.HookMatcher, string, map[string]any, error) { - sanitizedParams := cloneHookParams(item.Params) - matcherRaw := resolveConfiguredHookMatcherRaw(item) - if matcherRaw == nil { - return nil, "", sanitizedParams, nil +// buildConfiguredHookMatcher 编译 hook matcher。 +func buildConfiguredHookMatcher(item config.RuntimeHookItemConfig, point runtimehooks.HookPoint) (*runtimehooks.HookMatcher, error) { + if !runtimehooks.HasHookMatcherConfig(item.Match) { + return nil, nil } - matcher, err := runtimehooks.CompileHookMatcher(point, matcherRaw) + matcher, err := runtimehooks.CompileHookMatcher(point, item.Match) if err != nil { - return nil, "", nil, fmt.Errorf("match: %w", err) - } - if matcher == nil { - return nil, "", sanitizedParams, nil - } - explicitMatcher := runtimehooks.HasHookMatcherConfig(item.Match) - legacyWarnTargets := strings.EqualFold(strings.TrimSpace(item.Handler), "warn_on_tool_call") && - runtimeHasWarnOnToolCallTargets(item.Params) - warning := "" - if explicitMatcher && legacyWarnTargets { - warning = "hook matcher migration: match is configured; params.tool_name/tool_names on warn_on_tool_call are ignored" - delete(sanitizedParams, "tool_name") - delete(sanitizedParams, "tool_names") - } - return matcher, warning, sanitizedParams, nil -} - -// cloneHookParams 深拷贝 params,避免装配阶段修改影响原始配置对象。 -func cloneHookParams(params map[string]any) map[string]any { - if len(params) == 0 { - return nil - } - cloned := make(map[string]any, len(params)) - for key, value := range params { - cloned[key] = cloneHookParamValue(value) - } - return cloned -} - -// cloneHookParamValue 深拷贝 matcher/params 结构,避免 map/slice 底层共享。 -func cloneHookParamValue(value any) any { - switch typed := value.(type) { - case map[string]any: - cloned := make(map[string]any, len(typed)) - for key, item := range typed { - cloned[key] = cloneHookParamValue(item) - } - return cloned - case []any: - cloned := make([]any, len(typed)) - for index, item := range typed { - cloned[index] = cloneHookParamValue(item) - } - return cloned - case []string: - cloned := make([]string, len(typed)) - copy(cloned, typed) - return cloned - default: - return value - } -} - -// buildLegacyWarnMatcherFromParams 将 warn_on_tool_call 旧参数桥接为 matcher 配置。 -func buildLegacyWarnMatcherFromParams(params map[string]any) map[string]any { - if len(params) == 0 { - return nil - } - var toolNames []string - if name := strings.TrimSpace(readHookParamString(params, "tool_name")); name != "" { - toolNames = append(toolNames, name) - } - for _, value := range readHookParamStringSlice(params, "tool_names") { - if strings.TrimSpace(value) == "" { - continue - } - toolNames = append(toolNames, value) - } - if len(toolNames) == 0 { - return nil - } - return map[string]any{ - "tool_name": toolNames, + return nil, fmt.Errorf("match: %w", err) } + return matcher, nil } func isExternalHookKind(kind string) bool { @@ -467,30 +376,16 @@ func buildUserBuiltinHookHandler( } return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} }, nil - case "warn_on_tool_call": - targetTool := strings.ToLower(strings.TrimSpace(readHookParamString(params, "tool_name"))) - targetTools := normalizeHookParamStringSlice(readHookParamStringSlice(params, "tool_names")) - defaultMessage := "tool call matched warn_on_tool_call" - if customMessage := strings.TrimSpace(readHookParamString(params, "message")); customMessage != "" { - defaultMessage = customMessage - } - return func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { - _ = ctx - toolName := strings.ToLower(strings.TrimSpace(readHookContextMetadataString(input, "tool_name"))) - if targetTool == "" && len(targetTools) == 0 { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: defaultMessage} + case "warn_on_tool_call": + defaultMessage := "tool call matched warn_on_tool_call" + if customMessage := strings.TrimSpace(readHookParamString(params, "message")); customMessage != "" { + defaultMessage = customMessage } - if toolName == "" { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} - } - if targetTool != "" && toolName == targetTool { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: defaultMessage} - } - if len(targetTools) > 0 && slices.Contains(targetTools, toolName) { + return func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { + _ = ctx + _ = input return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: defaultMessage} - } - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} - }, nil + }, nil case "add_context_note": note := strings.TrimSpace(readHookParamString(params, "note")) if note == "" { diff --git a/internal/runtime/user_hooks_test.go b/internal/runtime/user_hooks_test.go index 5decb4ea6..588361a9c 100644 --- a/internal/runtime/user_hooks_test.go +++ b/internal/runtime/user_hooks_test.go @@ -33,8 +33,10 @@ func TestBuildUserHookSpecMapsFailurePolicyAndScope(t *testing.T) { Priority: 99, TimeoutSec: 7, FailurePolicy: "warn_only", - Params: map[string]any{ + Match: map[string]any{ "tool_name": "bash", + }, + Params: map[string]any{ "message": "tool call warning", }, } @@ -351,7 +353,6 @@ func TestWarnOnToolCallAndAddContextNoteHandlers(t *testing.T) { t.Parallel() warnHandler, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{ - "tool_name": "bash", "message": "bash was called", }, t.TempDir()) if err != nil { @@ -359,7 +360,6 @@ func TestWarnOnToolCallAndAddContextNoteHandlers(t *testing.T) { } warnResult := warnHandler(context.Background(), runtimehooks.HookContext{ Metadata: map[string]any{ - "tool_name": "bash", }, }) if warnResult.Status != runtimehooks.HookResultPass { @@ -369,14 +369,14 @@ func TestWarnOnToolCallAndAddContextNoteHandlers(t *testing.T) { t.Fatalf("warn message = %q, want %q", warnResult.Message, "bash was called") } - ignoreResult := warnHandler(context.Background(), runtimehooks.HookContext{ - Metadata: map[string]any{ - "tool_name": "filesystem", - }, - }) - if strings.TrimSpace(ignoreResult.Message) != "" { - t.Fatalf("expected unmatched tool to have empty message, got %q", ignoreResult.Message) - } + anyToolResult := warnHandler(context.Background(), runtimehooks.HookContext{ + Metadata: map[string]any{ + "tool_name": "filesystem", + }, + }) + if anyToolResult.Message != "bash was called" { + t.Fatalf("warn message = %q, want %q", anyToolResult.Message, "bash was called") + } noteHandler, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{ "note": "manual check required", @@ -407,10 +407,10 @@ func TestConfigureRuntimeHooksFromConfig(t *testing.T) { Scope: "user", Kind: "builtin", Mode: "sync", - Handler: "warn_on_tool_call", - Params: map[string]any{ - "tool_name": "bash", - }, + Handler: "warn_on_tool_call", + Match: map[string]any{ + "tool_name": "bash", + }, }, } cfg.Runtime.Hooks.ApplyDefaults(config.StaticDefaults().Runtime.Hooks) @@ -455,9 +455,11 @@ func TestConfigureRuntimeHooksFromConfigKeepsBaseExecutorAndComposes(t *testing. Scope: "user", Kind: "builtin", Mode: "sync", + Match: map[string]any{ + "tool_name": "bash", + }, Handler: "warn_on_tool_call", Params: map[string]any{ - "tool_name": "bash", "message": "warn", }, }, @@ -818,11 +820,11 @@ func TestConfigureRuntimeHooksWithoutItemsKeepsBehaviorUnchanged(t *testing.T) { if service.hookExecutor == nil { t.Fatal("expected runtime hooks chain to remain available for repo discovery") } - out := service.hookExecutor.Run( - context.Background(), - runtimehooks.HookPointBeforeToolCall, - runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "bash", "workdir": cfg.Workdir}}, - ) + out := service.hookExecutor.Run( + context.Background(), + runtimehooks.HookPointBeforeToolCall, + runtimehooks.HookContext{Metadata: map[string]any{"workdir": cfg.Workdir}}, + ) if out.Blocked || len(out.Results) != 0 { t.Fatalf("unexpected hook output without user/repo config: %+v", out) } @@ -855,7 +857,7 @@ func TestBuildUserBuiltinHookHandlerEdgeCases(t *testing.T) { t.Fatalf("expected match message, got %q", pass.Message) } noTool := handler(context.Background(), runtimehooks.HookContext{}) - if noTool.Status != runtimehooks.HookResultPass || noTool.Message != "" { + if noTool.Status != runtimehooks.HookResultPass || noTool.Message != "hit" { t.Fatalf("unexpected no-tool result: %+v", noTool) } @@ -1353,81 +1355,6 @@ func TestConfigureRuntimeHooksInjectsAsyncResultSinkIntoBaseExecutor(t *testing. t.Fatal("expected async rewake notification to be enqueued via configured async sink") } -func TestBuildUserHookSpecBridgesWarnOnToolCallLegacyParamsToMatcher(t *testing.T) { - t.Parallel() - - item := config.RuntimeHookItemConfig{ - ID: "warn-legacy", - Point: "before_tool_call", - Scope: "user", - Kind: "builtin", - Mode: "sync", - Handler: "warn_on_tool_call", - TimeoutSec: 2, - FailurePolicy: "warn_only", - Params: map[string]any{ - "tool_names": []any{"bash"}, - "message": "legacy warning", - }, - } - - spec, err := buildUserHookSpec(item, t.TempDir()) - if err != nil { - t.Fatalf("buildUserHookSpec() error = %v", err) - } - if spec.Matcher == nil { - t.Fatal("expected legacy warn_on_tool_call params to be bridged into matcher") - } - if spec.Matcher.Match(runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "filesystem"}}) { - t.Fatal("matcher should reject unmatched tool") - } - if !spec.Matcher.Match(runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { - t.Fatal("matcher should accept configured legacy tool") - } -} - -func TestBuildUserHookSpecWarnOnToolCallMatchTakesPrecedenceAndEmitsMigrationWarning(t *testing.T) { - t.Parallel() - - item := config.RuntimeHookItemConfig{ - ID: "warn-conflict", - Point: "before_tool_call", - Scope: "user", - Kind: "builtin", - Mode: "sync", - Handler: "warn_on_tool_call", - TimeoutSec: 2, - FailurePolicy: "warn_only", - Match: map[string]any{ - "tool_name": "filesystem", - }, - Params: map[string]any{ - "tool_name": "bash", - "message": "explicit matcher wins", - }, - } - - spec, err := buildUserHookSpec(item, t.TempDir()) - if err != nil { - t.Fatalf("buildUserHookSpec() error = %v", err) - } - if spec.Matcher == nil { - t.Fatal("expected matcher to be compiled") - } - if spec.MatcherMigrationWarning == "" { - t.Fatal("expected migration warning when match and legacy warn params coexist") - } - if !spec.Matcher.Match(runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "filesystem"}}) { - t.Fatal("matcher should follow explicit match config") - } - if spec.Matcher.Match(runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { - t.Fatal("legacy params should be ignored when explicit match exists") - } - result := spec.Handler(context.Background(), runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "filesystem"}}) - if result.Message != "explicit matcher wins" { - t.Fatalf("handler message = %q, want explicit matcher wins", result.Message) - } -} type countingHookExecutor struct { calls atomic.Int32 @@ -1618,10 +1545,10 @@ func TestUserHookHandlersAndPathChecks(t *testing.T) { if result.Message == "" { t.Fatalf("expected default warn message for matched tool") } - result = warnHandler(context.Background(), runtimehooks.HookContext{}) - if result.Message != "" { - t.Fatalf("expected empty message when no tool_name metadata, got %q", result.Message) - } + result = warnHandler(context.Background(), runtimehooks.HookContext{}) + if result.Message == "" { + t.Fatalf("expected default warn message, got empty") + } noteHandler, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{"message": "note-via-message"}, workdir) if err != nil { From 90df2abbed6622ba7aabbfa8b2717efcaf5572b4 Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:11:59 +0800 Subject: [PATCH 3/3] fix hook matcher validation and preview sanitization --- docs/runtime-hooks-design.md | 3 +- internal/runtime/hooks/matcher.go | 23 ++++++ internal/runtime/hooks/matcher_test.go | 11 ++- internal/runtime/toolexec.go | 97 ++++++++++++++++++++++- internal/runtime/toolexec_preview_test.go | 81 +++++++++++++++++++ 5 files changed, 209 insertions(+), 6 deletions(-) create mode 100644 internal/runtime/toolexec_preview_test.go diff --git a/docs/runtime-hooks-design.md b/docs/runtime-hooks-design.md index 097a52a45..f4444a0f2 100644 --- a/docs/runtime-hooks-design.md +++ b/docs/runtime-hooks-design.md @@ -129,8 +129,7 @@ runtime 内置 `HookPointCapability` 作为唯一真源,定义每个点位是 说明: - `arguments_contains` 基于 `tool_arguments_preview` 字段匹配,不读取 `tool_arguments` 原文。 -- `warn_on_tool_call` 的旧参数 `params.tool_name/tool_names` 仍兼容;未配置 `match` 时会自动桥接为 matcher。 -- 若 `match` 与旧参数共存,以 `match` 为准,并发出 `hook_notification` 迁移提示事件。 +- `warn_on_tool_call` 当前要求显式配置 `match`;旧参数 `params.tool_name/tool_names` 不再承担匹配语义。 ### trust gate diff --git a/internal/runtime/hooks/matcher.go b/internal/runtime/hooks/matcher.go index 493be0126..16c133228 100644 --- a/internal/runtime/hooks/matcher.go +++ b/internal/runtime/hooks/matcher.go @@ -54,6 +54,9 @@ func CompileHookMatcher(point HookPoint, raw map[string]any) (*HookMatcher, erro if len(raw) == 0 { return nil, nil } + if err := validateHookMatcherFields(raw); err != nil { + return nil, err + } if !HasHookMatcherConfig(raw) { return nil, fmt.Errorf("match contains no recognized matcher fields (expected: tool_name, tool_name_regex, arguments_contains)") } @@ -104,6 +107,26 @@ func CompileHookMatcher(point HookPoint, raw map[string]any) (*HookMatcher, erro return matcher, nil } +// validateHookMatcherFields 校验 matcher 配置中不存在未支持字段,避免拼写错误被静默忽略。 +func validateHookMatcherFields(raw map[string]any) error { + if len(raw) == 0 { + return nil + } + for key := range raw { + normalized := strings.ToLower(strings.TrimSpace(key)) + switch normalized { + case hookMatcherFieldToolName, hookMatcherFieldToolNameRegex, hookMatcherFieldArgumentsContains: + continue + default: + return fmt.Errorf( + "match contains unknown field %q (allowed: tool_name, tool_name_regex, arguments_contains)", + key, + ) + } + } + return nil +} + // IsEmpty 判断 matcher 是否包含可执行维度。 func (m *HookMatcher) IsEmpty() bool { if m == nil { diff --git a/internal/runtime/hooks/matcher_test.go b/internal/runtime/hooks/matcher_test.go index 99f088a53..cf85dfd3f 100644 --- a/internal/runtime/hooks/matcher_test.go +++ b/internal/runtime/hooks/matcher_test.go @@ -111,6 +111,12 @@ func TestCompileHookMatcherValidation(t *testing.T) { }); err == nil { t.Fatal("expected completely unknown matcher field to be rejected") } + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + "tool_names": []any{"filesystem"}, + }); err == nil { + t.Fatal("expected mixed matcher fields with typo to be rejected") + } if _, err := CompileHookMatcher(HookPointBeforeToolCall, nil); err != nil { t.Fatal("nil raw should succeed with nil matcher") @@ -306,8 +312,8 @@ func TestCompileHookMatcherRegexWhitespaceSkipped(t *testing.T) { t.Parallel() matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ - "tool_name": "bash", - "tool_name_regex": []string{" ", "\t"}, + "tool_name": "bash", + "tool_name_regex": []string{" ", "\t"}, }) if err != nil { t.Fatalf("CompileHookMatcher() error = %v", err) @@ -320,7 +326,6 @@ func TestCompileHookMatcherRegexWhitespaceSkipped(t *testing.T) { } } - func TestCompileHookMatcherRegexOnly(t *testing.T) { t.Parallel() diff --git a/internal/runtime/toolexec.go b/internal/runtime/toolexec.go index 1698b6e03..8b57a2477 100644 --- a/internal/runtime/toolexec.go +++ b/internal/runtime/toolexec.go @@ -10,6 +10,7 @@ import ( "sort" "strings" "sync" + "unicode" "neo-code/internal/checkpoint" providertypes "neo-code/internal/provider/types" @@ -766,10 +767,104 @@ func buildToolArgumentsPreview(arguments string) string { if trimmed == "" { return "" } - masked := hookToolArgumentsSensitivePattern.ReplaceAllString(trimmed, `$1=***`) + masked := sanitizeHookToolArguments(trimmed) return truncateHookTextByChars(masked, hookToolArgumentsPreviewMaxChars) } +// sanitizeHookToolArguments 优先按 JSON 结构递归脱敏,非 JSON 输入回退为轻量正则脱敏。 +func sanitizeHookToolArguments(arguments string) string { + if masked, ok := sanitizeHookToolArgumentsJSON(arguments); ok { + return masked + } + return hookToolArgumentsSensitivePattern.ReplaceAllString(arguments, `$1=***`) +} + +// sanitizeHookToolArgumentsJSON 尝试解析 JSON 并按敏感键递归替换值。 +func sanitizeHookToolArgumentsJSON(arguments string) (string, bool) { + var decoded any + if err := json.Unmarshal([]byte(arguments), &decoded); err != nil { + return "", false + } + sanitized := maskHookToolArgumentValue(decoded) + encoded, err := json.Marshal(sanitized) + if err != nil { + return "", false + } + return string(encoded), true +} + +// maskHookToolArgumentValue 递归处理 JSON 节点,对敏感键对应的值统一替换为 "***"。 +func maskHookToolArgumentValue(value any) any { + switch typed := value.(type) { + case map[string]any: + masked := make(map[string]any, len(typed)) + for key, item := range typed { + if isSensitiveHookToolArgumentKey(key) { + masked[key] = "***" + continue + } + masked[key] = maskHookToolArgumentValue(item) + } + return masked + case []any: + masked := make([]any, len(typed)) + for index, item := range typed { + masked[index] = maskHookToolArgumentValue(item) + } + return masked + default: + return value + } +} + +// isSensitiveHookToolArgumentKey 判断参数键名是否属于敏感信息字段。 +func isSensitiveHookToolArgumentKey(key string) bool { + tokens := tokenizeHookToolArgumentKey(key) + if len(tokens) == 0 { + return false + } + for index, token := range tokens { + switch token { + case "password", "passwd", "secret", "token", "auth", "authorization": + return true + case "apikey", "accesskey", "authtoken", "accesstoken": + return true + case "api", "access": + if index+1 < len(tokens) && tokens[index+1] == "key" { + return true + } + case "key": + if index > 0 && (tokens[index-1] == "api" || tokens[index-1] == "access") { + return true + } + } + } + return false +} + +// tokenizeHookToolArgumentKey 将参数键拆分为小写词元,兼容 snake/kebab/camelCase。 +func tokenizeHookToolArgumentKey(key string) []string { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + return nil + } + var builder strings.Builder + var prev rune + for _, current := range trimmed { + switch { + case unicode.IsLetter(current) || unicode.IsDigit(current): + if unicode.IsUpper(current) && unicode.IsLower(prev) { + builder.WriteByte(' ') + } + builder.WriteRune(unicode.ToLower(current)) + default: + builder.WriteByte(' ') + } + prev = current + } + return strings.Fields(builder.String()) +} + // truncateHookTextByChars 按字符长度截断文本,避免 metadata 放大。 func truncateHookTextByChars(text string, maxChars int) string { if maxChars <= 0 { diff --git a/internal/runtime/toolexec_preview_test.go b/internal/runtime/toolexec_preview_test.go new file mode 100644 index 000000000..1b2723565 --- /dev/null +++ b/internal/runtime/toolexec_preview_test.go @@ -0,0 +1,81 @@ +package runtime + +import ( + "strings" + "testing" +) + +func TestBuildToolArgumentsPreviewMaskJSONSensitiveFields(t *testing.T) { + t.Parallel() + + raw := `{"api_key":"sk-123","password":"p@ss","nested":{"secret":"abc"},"safe":"ok"}` + preview := buildToolArgumentsPreview(raw) + if strings.Contains(preview, "sk-123") { + t.Fatalf("preview leaked api_key: %q", preview) + } + if strings.Contains(preview, "p@ss") { + t.Fatalf("preview leaked password: %q", preview) + } + if strings.Contains(preview, `"secret":"abc"`) { + t.Fatalf("preview leaked nested secret: %q", preview) + } + if !strings.Contains(preview, `"api_key":"***"`) { + t.Fatalf("preview should mask api_key: %q", preview) + } + if !strings.Contains(preview, `"password":"***"`) { + t.Fatalf("preview should mask password: %q", preview) + } + if !strings.Contains(preview, `"secret":"***"`) { + t.Fatalf("preview should mask nested secret: %q", preview) + } + if !strings.Contains(preview, `"safe":"ok"`) { + t.Fatalf("preview should keep non-sensitive keys: %q", preview) + } +} + +func TestBuildToolArgumentsPreviewMaskNonJSONFallback(t *testing.T) { + t.Parallel() + + preview := buildToolArgumentsPreview(`token=abc password:xyz arg=ok`) + if strings.Contains(preview, "abc") || strings.Contains(preview, "xyz") { + t.Fatalf("preview leaked fallback credentials: %q", preview) + } + if !strings.Contains(preview, "token=***") { + t.Fatalf("preview should mask token in fallback mode: %q", preview) + } + if !strings.Contains(preview, "password=***") { + t.Fatalf("preview should mask password in fallback mode: %q", preview) + } +} + +func TestBuildToolArgumentsPreviewTruncate(t *testing.T) { + t.Parallel() + + raw := strings.Repeat("a", hookToolArgumentsPreviewMaxChars+20) + preview := buildToolArgumentsPreview(raw) + if len([]rune(preview)) != hookToolArgumentsPreviewMaxChars { + t.Fatalf("preview length=%d, want %d", len([]rune(preview)), hookToolArgumentsPreviewMaxChars) + } +} + +func TestIsSensitiveHookToolArgumentKey(t *testing.T) { + t.Parallel() + + cases := []struct { + key string + want bool + }{ + {key: "api_key", want: true}, + {key: "accessKey", want: true}, + {key: "authorization", want: true}, + {key: "auth_token", want: true}, + {key: "password", want: true}, + {key: "author", want: false}, + {key: "tool_name", want: false}, + } + for _, tc := range cases { + if got := isSensitiveHookToolArgumentKey(tc.key); got != tc.want { + t.Fatalf("isSensitiveHookToolArgumentKey(%q)=%v, want %v", tc.key, got, tc.want) + } + } +}