Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions internal/governance/engine_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package governance

import (
"os"
"path/filepath"
"testing"
)

// writeConfig writes a temporary agentguard.yaml and returns its path.
func writeConfig(t *testing.T, content string) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "agentguard.yaml")
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
t.Fatalf("writeConfig: %v", err)
}
return path
}

const enforceConfig = `
mode: enforce
policies:
- name: no-destructive-rm
description: Block rm commands
match:
command: rm
action: deny
message: rm is not allowed in enforce mode
- name: no-git-push
description: Block git push
match:
command: git
args_contain: ["push"]
action: deny
message: git push is not allowed
- name: monitor-writes
description: Log all writes
match:
command: write_file
action: monitor
message: write observed
`

const monitorConfig = `
mode: monitor
policies:
- name: no-destructive-rm
match:
command: rm
action: deny
message: rm is not allowed
`

const timeoutConfig = `
mode: enforce
policies:
- name: long-running-budget
match:
command: "*"
action: monitor
message: budget policy
timeout_seconds: 600
`

// TestEvaluate_EnforceDeny verifies that deny policies block execution in enforce mode.
func TestEvaluate_EnforceDeny(t *testing.T) {
path := writeConfig(t, enforceConfig)
eng, err := NewEngine(path)
if err != nil {
t.Fatalf("NewEngine: %v", err)
}

d := eng.Evaluate("run_shell", map[string]string{"command": "rm -rf /tmp/work"})
if d.Allowed {
t.Error("rm should be denied in enforce mode")
}
if d.PolicyName != "no-destructive-rm" {
t.Errorf("PolicyName = %q, want %q", d.PolicyName, "no-destructive-rm")
}
if d.Mode != "enforce" {
t.Errorf("Mode = %q, want %q", d.Mode, "enforce")
}
}

// TestEvaluate_EnforceDeny_ArgsContain verifies args_contain matching.
func TestEvaluate_EnforceDeny_ArgsContain(t *testing.T) {
path := writeConfig(t, enforceConfig)
eng, _ := NewEngine(path)

d := eng.Evaluate("run_shell", map[string]string{"command": "git push origin main"})
if d.Allowed {
t.Error("git push should be denied in enforce mode")
}
if d.PolicyName != "no-git-push" {
t.Errorf("PolicyName = %q, want %q", d.PolicyName, "no-git-push")
}

// git pull should not match the git push policy
d2 := eng.Evaluate("run_shell", map[string]string{"command": "git pull origin main"})
if !d2.Allowed {
t.Error("git pull should be allowed (only push is denied)")
}
}

// TestEvaluate_MonitorAllow verifies that deny policies only log in monitor mode.
func TestEvaluate_MonitorAllow(t *testing.T) {
path := writeConfig(t, monitorConfig)
eng, _ := NewEngine(path)

d := eng.Evaluate("run_shell", map[string]string{"command": "rm -rf /tmp"})
if !d.Allowed {
t.Error("rm should be allowed in monitor mode (deny = log only)")
}
if d.PolicyName != "no-destructive-rm" {
t.Errorf("PolicyName = %q, want %q", d.PolicyName, "no-destructive-rm")
}
if d.Mode != "monitor" {
t.Errorf("Mode = %q, want %q", d.Mode, "monitor")
}
}

// TestEvaluate_MonitorAction verifies monitor-action policies always allow.
func TestEvaluate_MonitorAction(t *testing.T) {
path := writeConfig(t, enforceConfig)
eng, _ := NewEngine(path)

// monitor-writes policy matches write_file and should always allow
d := eng.Evaluate("write_file", map[string]string{"command": "write_file", "path": "foo.go"})
if !d.Allowed {
t.Error("monitor policy should always allow")
}
}

// TestEvaluate_DefaultAllow verifies that unmatched commands are allowed.
func TestEvaluate_DefaultAllow(t *testing.T) {
path := writeConfig(t, enforceConfig)
eng, _ := NewEngine(path)

d := eng.Evaluate("run_shell", map[string]string{"command": "go test ./..."})
if !d.Allowed {
t.Errorf("go test should be default-allowed, got reason: %q", d.Reason)
}
if d.PolicyName != "default-allow" {
t.Errorf("PolicyName = %q, want %q", d.PolicyName, "default-allow")
}
}

// TestGetTimeout_PolicyTimeout verifies policy-level timeout is respected.
func TestGetTimeout_PolicyTimeout(t *testing.T) {
path := writeConfig(t, timeoutConfig)
eng, _ := NewEngine(path)

got := eng.GetTimeout()
if got != 600 {
t.Errorf("GetTimeout() = %d, want 600", got)
}
}

// TestGetTimeout_Default verifies the 300s default when no policy sets a timeout.
func TestGetTimeout_Default(t *testing.T) {
path := writeConfig(t, enforceConfig)
eng, _ := NewEngine(path)

got := eng.GetTimeout()
if got != 300 {
t.Errorf("GetTimeout() = %d, want 300 (default)", got)
}
}

// TestNewEngine_DefaultMonitorMode verifies that missing mode defaults to monitor.
func TestNewEngine_DefaultMonitorMode(t *testing.T) {
cfg := `
policies:
- name: test
match:
command: rm
action: deny
message: denied
`
path := writeConfig(t, cfg)
eng, err := NewEngine(path)
if err != nil {
t.Fatalf("NewEngine: %v", err)
}
if eng.Mode != "monitor" {
t.Errorf("Mode = %q, want %q (default)", eng.Mode, "monitor")
}
}

// TestNewEngine_MissingFile verifies error on missing config.
func TestNewEngine_MissingFile(t *testing.T) {
_, err := NewEngine("/no/such/file.yaml")
if err == nil {
t.Error("expected error for missing config file, got nil")
}
}

// TestNewEngine_InvalidYAML verifies error on malformed config.
func TestNewEngine_InvalidYAML(t *testing.T) {
path := writeConfig(t, "mode: [\ninvalid yaml")
_, err := NewEngine(path)
if err == nil {
t.Error("expected error for invalid YAML, got nil")
}
}
13 changes: 1 addition & 12 deletions internal/intent/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,18 +347,7 @@ func flattenParams(m map[string]any) map[string]string {
case string:
result[k] = val
case float64:
if val == float64(int(val)) {
result[k] = strings.TrimRight(strings.TrimRight(
strings.Replace(
strings.Replace(
fmt.Sprintf("%f", val), ".", "", 1),
"0", "", -1),
"0"), "")
// Simpler: just use Sprintf
result[k] = fmt.Sprintf("%g", val)
} else {
result[k] = fmt.Sprintf("%g", val)
}
result[k] = fmt.Sprintf("%g", val)
case bool:
result[k] = fmt.Sprintf("%t", val)
default:
Expand Down
162 changes: 162 additions & 0 deletions internal/intent/parser_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package intent

import (
"testing"
)

// TestParse_JSONBlock verifies extraction from ```json ... ``` code blocks.
func TestParse_JSONBlock(t *testing.T) {
content := "I'll write the file now.\n```json\n{\"tool\": \"write_file\", \"params\": {\"path\": \"main.go\", \"content\": \"package main\"}}\n```"
a := Parse(content)
if a == nil {
t.Fatal("expected action, got nil")
}
if a.Tool != "write_file" {
t.Errorf("Tool = %q, want %q", a.Tool, "write_file")
}
if a.Params["path"] != "main.go" {
t.Errorf("params[path] = %q, want %q", a.Params["path"], "main.go")
}
if a.Source != SourceJSONBlock {
t.Errorf("Source = %q, want %q", a.Source, SourceJSONBlock)
}
}

// TestParse_XMLTag verifies extraction from <tool>...</tool> XML tags.
func TestParse_XMLTag(t *testing.T) {
content := `Running the command now.<tool>{"tool": "run_shell", "params": {"command": "go test ./..."}}</tool>`
a := Parse(content)
if a == nil {
t.Fatal("expected action, got nil")
}
if a.Tool != "run_shell" {
t.Errorf("Tool = %q, want %q", a.Tool, "run_shell")
}
if a.Params["command"] != "go test ./..." {
t.Errorf("params[command] = %q, want %q", a.Params["command"], "go test ./...")
}
if a.Source != SourceXMLTag {
t.Errorf("Source = %q, want %q", a.Source, SourceXMLTag)
}
}

// TestParse_FunctionCall verifies extraction from OpenAI function_call format.
func TestParse_FunctionCall(t *testing.T) {
content := `{"name": "read_file", "arguments": "{\"path\": \"/etc/hosts\"}"}`
a := Parse(content)
if a == nil {
t.Fatal("expected action, got nil")
}
if a.Tool != "read_file" {
t.Errorf("Tool = %q, want %q", a.Tool, "read_file")
}
if a.Params["path"] != "/etc/hosts" {
t.Errorf("params[path] = %q, want %q", a.Params["path"], "/etc/hosts")
}
if a.Source != SourceFunctionCall {
t.Errorf("Source = %q, want %q", a.Source, SourceFunctionCall)
}
}

// TestParse_BareJSON verifies extraction from inline JSON objects.
func TestParse_BareJSON(t *testing.T) {
content := `Let me list the files: {"tool": "list_files", "directory": "/tmp"}`
a := Parse(content)
if a == nil {
t.Fatal("expected action, got nil")
}
if a.Tool != "list_files" {
t.Errorf("Tool = %q, want %q", a.Tool, "list_files")
}
if a.Source != SourceBareJSON {
t.Errorf("Source = %q, want %q", a.Source, SourceBareJSON)
}
}

// TestParse_NoAction verifies that plain prose returns nil.
func TestParse_NoAction(t *testing.T) {
cases := []string{
"I've finished the task. The code looks good.",
"Based on the analysis, the bug is in line 42.",
"",
"Here is the answer: 42.",
}
for _, c := range cases {
if a := Parse(c); a != nil {
t.Errorf("Parse(%q) = %+v, want nil", c, a)
}
}
}

// TestParse_ToolAliases verifies that model-emitted aliases map to canonical names.
func TestParse_ToolAliases(t *testing.T) {
cases := []struct {
raw string
wantTool string
}{
{`{"tool": "Bash", "params": {"command": "ls"}}`, "run_shell"},
{`{"tool": "Read", "params": {"path": "main.go"}}`, "read_file"},
{`{"tool": "Write", "params": {"path": "out.go", "content": "x"}}`, "write_file"},
{`{"tool": "Glob", "params": {"directory": "."}}`, "list_files"},
{`{"tool": "Grep", "params": {"directory": ".", "pattern": "foo"}}`, "search_files"},
}
for _, tc := range cases {
content := "```json\n" + tc.raw + "\n```"
a := Parse(content)
if a == nil {
t.Errorf("Parse(%q): got nil, want tool %q", tc.raw, tc.wantTool)
continue
}
if a.Tool != tc.wantTool {
t.Errorf("Parse(%q): Tool = %q, want %q", tc.raw, a.Tool, tc.wantTool)
}
}
}

// TestParse_ParamAliases verifies that param aliases are normalized.
func TestParse_ParamAliases(t *testing.T) {
content := "```json\n{\"tool\": \"write_file\", \"params\": {\"file_path\": \"main.go\", \"text\": \"hello\"}}\n```"
a := Parse(content)
if a == nil {
t.Fatal("expected action, got nil")
}
if a.Params["path"] != "main.go" {
t.Errorf("file_path should normalize to path, got %q", a.Params["path"])
}
if a.Params["content"] != "hello" {
t.Errorf("text should normalize to content, got %q", a.Params["content"])
}
}

// TestParse_UnknownTool verifies that unknown tool names return nil.
func TestParse_UnknownTool(t *testing.T) {
content := "```json\n{\"tool\": \"do_something_weird\", \"params\": {}}\n```"
a := Parse(content)
if a != nil {
t.Errorf("Parse with unknown tool: got %+v, want nil", a)
}
}

// TestFlattenParams verifies numeric and bool conversions.
func TestFlattenParams(t *testing.T) {
input := map[string]any{
"name": "foo",
"count": float64(42),
"ratio": float64(3.14),
"enabled": true,
}
got := flattenParams(input)

if got["name"] != "foo" {
t.Errorf("name = %q, want %q", got["name"], "foo")
}
if got["count"] != "42" {
t.Errorf("count = %q, want %q", got["count"], "42")
}
if got["ratio"] != "3.14" {
t.Errorf("ratio = %q, want %q", got["ratio"], "3.14")
}
if got["enabled"] != "true" {
t.Errorf("enabled = %q, want %q", got["enabled"], "true")
}
}
Loading
Loading