diff --git a/README.md b/README.md index 4c3a03a..41cebc8 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ api: plugins: : # Can have as many of these as you like + protocol_version: 2 # Optional: Defaults to 1 for backwards compatibility source: labels: type: plugin-check @@ -51,6 +52,8 @@ plugins: policies: - - + policy_data: # Optional: Mapping for supported policies. Can be any data structure + : config: : : diff --git a/cmd/agent.go b/cmd/agent.go index 31c54b7..22098b4 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -43,6 +43,7 @@ import ( "golang.org/x/sync/singleflight" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" ) type apiAuthConfig struct { @@ -60,12 +61,13 @@ type agentPolicy string type agentPluginConfig map[string]string type agentPlugin struct { - ProtocolVersion int32 `mapstructure:"protocol_version"` - Schedule *string `mapstructure:"schedule,omitempty"` - Source string `mapstructure:"source"` - Policies []agentPolicy `mapstructure:"policies"` - Config agentPluginConfig `mapstructure:"config"` - Labels map[string]string `mapstructure:"labels"` + ProtocolVersion int32 `mapstructure:"protocol_version"` + Schedule *string `mapstructure:"schedule,omitempty"` + Source string `mapstructure:"source"` + Policies []agentPolicy `mapstructure:"policies"` + Config agentPluginConfig `mapstructure:"config"` + Labels map[string]string `mapstructure:"labels"` + PolicyData map[string]interface{} `mapstructure:"policy_data,omitempty"` protocolSet bool } @@ -394,6 +396,19 @@ func initRunner(name string, protocolVersion int32, runnerInstance runner.Runner return err } +func configureRunner(name string, runnerInstance runner.RunnerV2, config agentPluginConfig, policyData map[string]interface{}) error { + policyDataStruct, err := mapToStruct(policyData) + if err != nil { + return fmt.Errorf("invalid policy_data for plugin %s: %w", name, err) + } + + _, err = runnerInstance.Configure(&proto.ConfigureRequest{ + Config: config, + PolicyData: policyDataStruct, + }) + return err +} + func loadConfig(cmd *cobra.Command, v *viper.Viper) (*agentConfig, error) { err := v.ReadInConfig() if err != nil { @@ -946,6 +961,13 @@ func copyStringMap(input map[string]string) map[string]string { return output } +func mapToStruct(m map[string]interface{}) (*structpb.Struct, error) { + if m == nil { + return nil, nil + } + return structpb.NewStruct(m) +} + func pluginEvidenceLabels(config *agentConfig, pluginName string, pluginConfig *agentPlugin) map[string]string { return pluginEvidenceLabelsWithHash(config, pluginName, pluginConfig, agentConfigurationHash(config)) } @@ -1351,10 +1373,7 @@ func (ar *AgentRunner) runAllPlugins(ctx context.Context) error { if err := func() error { defer cleanupRunner() - _, err = runnerInstance.Configure(&proto.ConfigureRequest{ - Config: pluginConfig.Config, - }) - if err != nil { + if err := configureRunner(pluginName, runnerInstance, pluginConfig.Config, pluginConfig.PolicyData); err != nil { // What do we do here ? //endTimer := time.Now() //_, err = client.Results.Create(&sdk.Result{ @@ -1500,10 +1519,7 @@ func (ar *AgentRunner) runPlugin(ctx context.Context, name string, plugin *agent } defer cleanupRunner() - _, err = runnerInstance.Configure(&proto.ConfigureRequest{ - Config: plugin.Config, - }) - if err != nil { + if err := configureRunner(name, runnerInstance, plugin.Config, plugin.PolicyData); err != nil { return err } diff --git a/cmd/agent_test.go b/cmd/agent_test.go index 1520ed3..a436547 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -27,7 +27,10 @@ import ( ) type initTestRunner struct { - initErr error + configureCalls int + configureErr error + configureRequest *proto.ConfigureRequest + initErr error } type emptyError struct{} @@ -37,7 +40,9 @@ func (e emptyError) Error() string { } func (r *initTestRunner) Configure(request *proto.ConfigureRequest) (*proto.ConfigureResponse, error) { - return &proto.ConfigureResponse{}, nil + r.configureCalls++ + r.configureRequest = request + return &proto.ConfigureResponse{}, r.configureErr } func (r *initTestRunner) Eval(request *proto.EvalRequest, a runner.ApiHelper) (*proto.EvalResponse, error) { @@ -982,6 +987,53 @@ func TestInitRunner(t *testing.T) { }) } +func TestConfigureRunner(t *testing.T) { + t.Run("passes config and policy data to runner", func(t *testing.T) { + testRunner := &initTestRunner{} + + err := configureRunner( + "test-plugin", + testRunner, + agentPluginConfig{"endpoint": "localhost"}, + map[string]interface{}{"allowed_versions": map[string]interface{}{"wget": "1.20.3"}}, + ) + if err != nil { + t.Fatalf("configureRunner() error = %v, expected nil", err) + } + + if testRunner.configureCalls != 1 { + t.Fatalf("Configure called %d times, expected 1", testRunner.configureCalls) + } + if got := testRunner.configureRequest.Config["endpoint"]; got != "localhost" { + t.Fatalf("Configure config endpoint = %q, expected %q", got, "localhost") + } + allowedVersions := testRunner.configureRequest.PolicyData.Fields["allowed_versions"].GetStructValue() + if got := allowedVersions.Fields["wget"].GetStringValue(); got != "1.20.3" { + t.Fatalf("Configure policy_data allowed_versions.wget = %q, expected %q", got, "1.20.3") + } + }) + + t.Run("rejects unsupported policy data before configuring runner", func(t *testing.T) { + testRunner := &initTestRunner{} + + err := configureRunner( + "test-plugin", + testRunner, + nil, + map[string]interface{}{"unsupported": make(chan int)}, + ) + if err == nil { + t.Fatal("configureRunner() error = nil, expected invalid policy_data error") + } + if !strings.Contains(err.Error(), "invalid policy_data for plugin test-plugin") { + t.Fatalf("configureRunner() error = %q, expected plugin policy_data context", err.Error()) + } + if testRunner.configureCalls != 0 { + t.Fatalf("Configure called %d times, expected 0", testRunner.configureCalls) + } + }) +} + func TestAgentRunnerBuildsAuthenticatedSDKClient(t *testing.T) { var ( tokenRequests int diff --git a/docs/configuration.md b/docs/configuration.md index 2cca8cd..e762ef9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -29,6 +29,9 @@ plugins: : : ... + policy_data: # Optional dynamic data for policies + : + ... agent_evidence: enabled: true @@ -53,6 +56,12 @@ The `policies` field is a list of paths to the policy files that the plugin will The `config` field is a map of configuration values that the plugin will use to connect to the data source. The values will be passed to the plugin when it is run. +The `policy_data` field is an optional map of dynamic data that will be passed to the plugin's policy manager. This data +can be of any shape and is made available to OPA/Rego policies during evaluation. This allows you to provide runtime +configuration to policies without modifying the policy files themselves. + +Usage: `satisfied if input.value == data.allowed_value` + You can specify as many plugins as you wish, as long as each identifier is unique. You can even reuse the same plugin multiple times with different configurations. diff --git a/policy-manager/policy-manager.go b/policy-manager/policy-manager.go index cc2b043..382f677 100644 --- a/policy-manager/policy-manager.go +++ b/policy-manager/policy-manager.go @@ -14,6 +14,8 @@ import ( "github.com/go-viper/mapstructure/v2" "github.com/hashicorp/go-hclog" "github.com/open-policy-agent/opa/v1/rego" + "github.com/open-policy-agent/opa/v1/storage" + "github.com/open-policy-agent/opa/v1/storage/inmem" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -30,29 +32,108 @@ type EvalOutput struct { type PolicyManager struct { logger hclog.Logger loaderOptions []func(r *rego.Rego) + policyData map[string]interface{} } -func New(ctx context.Context, logger hclog.Logger, policyPath string) *PolicyManager { +func New(ctx context.Context, logger hclog.Logger, policyPath string, policyData map[string]interface{}) *PolicyManager { return &PolicyManager{ - logger: logger, - loaderOptions: []func(r *rego.Rego){ - rego.LoadBundle(policyPath), - }, + logger: logger, + policyData: policyData, + loaderOptions: []func(r *rego.Rego){rego.LoadBundle(policyPath)}, + } +} + +func (pm *PolicyManager) prepareForEval(ctx context.Context, regoArgs ...func(r *rego.Rego)) (rego.PreparedEvalQuery, error) { + store := inmem.New() + txn, err := store.NewTransaction(ctx, storage.TransactionParams{Write: true}) + if err != nil { + return rego.PreparedEvalQuery{}, err + } + + committed := false + defer func() { + if !committed { + store.Abort(ctx, txn) + } + }() + + args := make([]func(r *rego.Rego), 0, len(regoArgs)+len(pm.loaderOptions)+2) + args = append(args, + rego.Store(store), + rego.Transaction(txn), + ) + args = append(args, regoArgs...) + args = append(args, pm.loaderOptions...) + + query, err := rego.New(args...).PrepareForEval(ctx) + if err != nil { + return rego.PreparedEvalQuery{}, err } + + if err := writePolicyData(ctx, store, txn, pm.policyData); err != nil { + return rego.PreparedEvalQuery{}, err + } + + if err := store.Commit(ctx, txn); err != nil { + return rego.PreparedEvalQuery{}, err + } + committed = true + + // PreparedEvalQuery.Eval opens a fresh read transaction unless an + // EvalTransaction is provided, so committing this write transaction makes the + // loaded bundle and injected policy data visible to later evaluations. + return query, nil +} + +func writePolicyData(ctx context.Context, store storage.Store, txn storage.Transaction, data map[string]interface{}) error { + for key, value := range data { + if err := writePolicyDataValue(ctx, store, txn, storage.Path{key}, value); err != nil { + return err + } + } + return nil +} + +func writePolicyDataValue(ctx context.Context, store storage.Store, txn storage.Transaction, path storage.Path, value interface{}) error { + valueMap, valueIsMap := value.(map[string]interface{}) + if valueIsMap { + existing, err := store.Read(ctx, txn, path) + if err == nil { + if _, ok := existing.(map[string]interface{}); ok { + for key, nestedValue := range valueMap { + nestedPath := append(append(storage.Path{}, path...), key) + if err := writePolicyDataValue(ctx, store, txn, nestedPath, nestedValue); err != nil { + return err + } + } + return nil + } + } else if !storage.IsNotFound(err) { + return err + } + } + + op := storage.AddOp + if _, err := store.Read(ctx, txn, path); err == nil { + op = storage.ReplaceOp + } else if !storage.IsNotFound(err) { + return err + } + + if err := store.Write(ctx, txn, op, path, value); err != nil { + return fmt.Errorf("write policy data at %q: %w", path.String(), err) + } + return nil } func (pm *PolicyManager) Execute(ctx context.Context, input interface{}) ([]Result, error) { var output []Result pm.logger.Trace("Executing policy", "input", input) - regoArgs := []func(r *rego.Rego){ + query, err := pm.prepareForEval(ctx, rego.Query("data.compliance_framework"), rego.Package("compliance_framework"), - } - regoArgs = append(regoArgs, pm.loaderOptions...) - r := rego.New(regoArgs...) - - query, err := r.PrepareForEval(ctx) + ) if err != nil { return nil, err } @@ -71,14 +152,14 @@ func (pm *PolicyManager) Execute(ctx context.Context, input interface{}) ([]Resu }, } - regoArgs := []func(r *rego.Rego){ + subQuery, err := pm.prepareForEval(ctx, rego.Query(module.Package.Path.String()), rego.Package(module.Package.Path.String()), rego.Input(input), + ) + if err != nil { + return nil, err } - regoArgs = append(regoArgs, pm.loaderOptions...) - - subQuery := rego.New(regoArgs...) evaluation, err := subQuery.Eval(ctx) if err != nil { @@ -141,6 +222,7 @@ type PolicyProcessor struct { inventoryItems []*proto.InventoryItem actors []*proto.OriginActor activities []*proto.Activity + policyData map[string]interface{} } func NewPolicyProcessor( @@ -151,6 +233,7 @@ func NewPolicyProcessor( inventoryItems []*proto.InventoryItem, actors []*proto.OriginActor, activities []*proto.Activity, + policyData map[string]interface{}, ) *PolicyProcessor { return &PolicyProcessor{ logger: logger, @@ -160,6 +243,7 @@ func NewPolicyProcessor( inventoryItems: inventoryItems, actors: actors, activities: activities, + policyData: policyData, } } @@ -183,7 +267,7 @@ func (p *PolicyProcessor) GenerateResults(ctx context.Context, policyPath string }, }, }) - results, err := New(ctx, p.logger, policyPath).Execute(ctx, data) + results, err := New(ctx, p.logger, policyPath, p.policyData).Execute(ctx, data) if err != nil { p.logger.Error("Failed to evaluate against policy bundle", "error", err) resultErr = errors.Join(resultErr, err) @@ -304,14 +388,10 @@ func (p *PolicyProcessor) newEvidence(result Result, activities []*proto.Activit } func (pm *PolicyManager) GetRiskTemplates(ctx context.Context) (map[string][]*proto.RiskTemplate, error) { - regoArgs := []func(r *rego.Rego){ + query, err := pm.prepareForEval(ctx, rego.Query("data.compliance_framework"), rego.Package("compliance_framework"), - } - regoArgs = append(regoArgs, pm.loaderOptions...) - r := rego.New(regoArgs...) - - query, err := r.PrepareForEval(ctx) + ) if err != nil { return nil, err } @@ -366,12 +446,14 @@ func (pm *PolicyManager) GetRiskTemplates(ctx context.Context) (map[string][]*pr } func (pm *PolicyManager) evaluateRiskTemplates(ctx context.Context, policy Policy) ([]interface{}, error) { - regoArgs := []func(r *rego.Rego){ + query, err := pm.prepareForEval(ctx, rego.Query(fmt.Sprintf("%s.risk_templates", policy.Package)), + ) + if err != nil { + return nil, fmt.Errorf("prepare %q in %s: %w", "risk_templates", policy.File, err) } - regoArgs = append(regoArgs, pm.loaderOptions...) - evaluation, err := rego.New(regoArgs...).Eval(ctx) + evaluation, err := query.Eval(ctx) if err != nil { return nil, fmt.Errorf("evaluate %q in %s: %w", "risk_templates", policy.File, err) } diff --git a/policy-manager/policy-manager_test.go b/policy-manager/policy-manager_test.go index 9932bf8..e9394bf 100644 --- a/policy-manager/policy-manager_test.go +++ b/policy-manager/policy-manager_test.go @@ -52,7 +52,7 @@ func TestPolicyManager(t *testing.T) { policyManager := New(ctx, hclog.New(&hclog.LoggerOptions{ Level: hclog.Debug, JSONFormat: true, - }), "testdata/001/") + }), "testdata/001/", nil) results, err := policyManager.Execute(ctx, data) @@ -87,6 +87,49 @@ func TestPolicyManager(t *testing.T) { }, result.Violations[0]) }) + t.Run("Policy Manager injects dynamic policy data as OPA data", func(t *testing.T) { + ctx := context.Background() + policyDir := t.TempDir() + regoContents := []byte(`package compliance_framework.dynamic_policy_data + +title := "Wget version is safe" +description := sprintf("Minimum wget version is %s", [data.allowed_versions.wget]) + +violation[{ + "id": "wget_version", + "remarks": sprintf("Required wget version is %s", [data.allowed_versions.wget]), +}] if { + input.wget != data.allowed_versions.wget +} +`) + + err := os.WriteFile(filepath.Join(policyDir, "dynamic_policy_data.rego"), regoContents, 0o644) + assert.NoError(t, err) + + policyManager := New(ctx, hclog.New(&hclog.LoggerOptions{ + Level: hclog.Debug, + JSONFormat: true, + }), policyDir, map[string]interface{}{ + "allowed_versions": map[string]interface{}{ + "wget": "1.20.3", + }, + }) + + results, err := policyManager.Execute(ctx, map[string]interface{}{ + "wget": "1.19.0", + }) + + assert.NoError(t, err) + if assert.Len(t, results, 1) { + result := results[0] + assert.Equal(t, Pointer("Minimum wget version is 1.20.3"), result.Description) + if assert.Len(t, result.Violations, 1) { + assert.Equal(t, Pointer("wget_version"), result.Violations[0].ID) + assert.Equal(t, Pointer("Required wget version is 1.20.3"), result.Violations[0].Remarks) + } + } + }) + // Removed as we are unmarshalling a map, and any extra keys will just be ignored //t.Run("Policy Manager handles errors in specification", func(t *testing.T) { // ctx := context.Background() @@ -313,6 +356,7 @@ description := "Evidence was generated without a title" nil, nil, nil, + nil, ) evidences, err := processor.GenerateResults(ctx, policyDir, map[string]interface{}{}) @@ -428,6 +472,7 @@ violation[{ nil, nil, nil, + nil, ) evidences, err := processorSkip.GenerateResults(ctx, policyDirSkip, map[string]interface{}{}) @@ -463,6 +508,7 @@ skip_reason := "Invalid payload - missing required field" nil, nil, nil, + nil, ) evidences, err := processor.GenerateResults(ctx, policyDir, map[string]interface{}{}) diff --git a/runner/helpers.go b/runner/helpers.go index ec8e2ce..0faea47 100644 --- a/runner/helpers.go +++ b/runner/helpers.go @@ -26,7 +26,7 @@ func InitWithSubjectsAndRisksFromPolicies( } for _, path := range req.PolicyPaths { - pm := policyManager.New(ctx, logger, path) + pm := policyManager.New(ctx, logger, path, nil) temps, err := pm.GetRiskTemplates(ctx) if err != nil { logger.Error("Error getting risk templates for policy path", "path", path, "error", err) diff --git a/runner/proto/runner.pb.go b/runner/proto/runner.pb.go index bb34207..c0480a3 100644 --- a/runner/proto/runner.pb.go +++ b/runner/proto/runner.pb.go @@ -9,6 +9,7 @@ package proto import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + structpb "google.golang.org/protobuf/types/known/structpb" reflect "reflect" sync "sync" unsafe "unsafe" @@ -70,6 +71,7 @@ func (ExecutionStatus) EnumDescriptor() ([]byte, []int) { type ConfigureRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Config map[string]string `protobuf:"bytes,1,rep,name=config,proto3" json:"config,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + PolicyData *structpb.Struct `protobuf:"bytes,2,opt,name=policy_data,json=policyData,proto3" json:"policy_data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -111,6 +113,13 @@ func (x *ConfigureRequest) GetConfig() map[string]string { return nil } +func (x *ConfigureRequest) GetPolicyData() *structpb.Struct { + if x != nil { + return x.PolicyData + } + return nil +} + type ConfigureResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` @@ -347,9 +356,11 @@ var File_runner_proto_runner_proto protoreflect.FileDescriptor const file_runner_proto_runner_proto_rawDesc = "" + "\n" + - "\x19runner/proto/runner.proto\x12\x05proto\"\x8a\x01\n" + + "\x19runner/proto/runner.proto\x12\x05proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc4\x01\n" + "\x10ConfigureRequest\x12;\n" + - "\x06config\x18\x01 \x03(\v2#.proto.ConfigureRequest.ConfigEntryR\x06config\x1a9\n" + + "\x06config\x18\x01 \x03(\v2#.proto.ConfigureRequest.ConfigEntryR\x06config\x128\n" + + "\vpolicy_data\x18\x02 \x01(\v2\x17.google.protobuf.StructR\n" + + "policyData\x1a9\n" + "\vConfigEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\")\n" + @@ -395,21 +406,23 @@ var file_runner_proto_runner_proto_goTypes = []any{ (*EvalRequest)(nil), // 5: proto.EvalRequest (*EvalResponse)(nil), // 6: proto.EvalResponse nil, // 7: proto.ConfigureRequest.ConfigEntry + (*structpb.Struct)(nil), // 8: google.protobuf.Struct } var file_runner_proto_runner_proto_depIdxs = []int32{ 7, // 0: proto.ConfigureRequest.config:type_name -> proto.ConfigureRequest.ConfigEntry - 0, // 1: proto.EvalResponse.status:type_name -> proto.ExecutionStatus - 1, // 2: proto.Runner.Configure:input_type -> proto.ConfigureRequest - 5, // 3: proto.Runner.Eval:input_type -> proto.EvalRequest - 3, // 4: proto.Runner.Init:input_type -> proto.InitRequest - 2, // 5: proto.Runner.Configure:output_type -> proto.ConfigureResponse - 6, // 6: proto.Runner.Eval:output_type -> proto.EvalResponse - 4, // 7: proto.Runner.Init:output_type -> proto.InitResponse - 5, // [5:8] is the sub-list for method output_type - 2, // [2:5] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 8, // 1: proto.ConfigureRequest.policy_data:type_name -> google.protobuf.Struct + 0, // 2: proto.EvalResponse.status:type_name -> proto.ExecutionStatus + 1, // 3: proto.Runner.Configure:input_type -> proto.ConfigureRequest + 5, // 4: proto.Runner.Eval:input_type -> proto.EvalRequest + 3, // 5: proto.Runner.Init:input_type -> proto.InitRequest + 2, // 6: proto.Runner.Configure:output_type -> proto.ConfigureResponse + 6, // 7: proto.Runner.Eval:output_type -> proto.EvalResponse + 4, // 8: proto.Runner.Init:output_type -> proto.InitResponse + 6, // [6:9] is the sub-list for method output_type + 3, // [3:6] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name } func init() { file_runner_proto_runner_proto_init() } diff --git a/runner/proto/runner.proto b/runner/proto/runner.proto index 9d25118..9a72bc6 100644 --- a/runner/proto/runner.proto +++ b/runner/proto/runner.proto @@ -3,6 +3,8 @@ package proto; option go_package = "./proto"; +import "google/protobuf/struct.proto"; + enum ExecutionStatus { SUCCESS = 0; FAILURE = 1; @@ -10,6 +12,7 @@ enum ExecutionStatus { message ConfigureRequest { map config = 1; + google.protobuf.Struct policy_data = 2; } message ConfigureResponse {