From a5fbd2b1b7d42b85d268fc239d81af6809c92945 Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Thu, 21 May 2026 14:14:46 +0100 Subject: [PATCH 1/3] feat: add support for policy definition configuration --- cmd/agent.go | 40 +++-- cmd/agent_test.go | 10 +- runner/proto/policy_behavior.go | 108 +++++++++++++ runner/proto/policy_behavior_test.go | 216 ++++++++++++++++++++++++++ runner/proto/runner.pb.go | 217 +++++++++++++++++++-------- runner/proto/runner.proto | 7 + 6 files changed, 524 insertions(+), 74 deletions(-) create mode 100644 runner/proto/policy_behavior.go create mode 100644 runner/proto/policy_behavior_test.go diff --git a/cmd/agent.go b/cmd/agent.go index 22098b4..7e44e73 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -68,6 +68,7 @@ type agentPlugin struct { Config agentPluginConfig `mapstructure:"config"` Labels map[string]string `mapstructure:"labels"` PolicyData map[string]interface{} `mapstructure:"policy_data,omitempty"` + PolicyBehavior map[string][]string `mapstructure:"policy_behavior,omitempty"` protocolSet bool } @@ -377,13 +378,14 @@ func runnerDispenseName(protocolVersion int32) (string, error) { } } -func initRunner(name string, protocolVersion int32, runnerInstance runner.RunnerV2, policyPaths []string, resultsHelper runner.ApiHelper) error { +func initRunner(name string, protocolVersion int32, runnerInstance runner.RunnerV2, policyPaths []string, policyBehavior map[string]*proto.StringList, resultsHelper runner.ApiHelper) error { if protocolVersion <= DefaultProtocolVersion { return nil } _, err := runnerInstance.Init(&proto.InitRequest{ - PolicyPaths: policyPaths, + PolicyPaths: policyPaths, + PolicyBehavior: policyBehavior, }, resultsHelper) if err == nil { return nil @@ -396,15 +398,16 @@ func initRunner(name string, protocolVersion int32, runnerInstance runner.Runner return err } -func configureRunner(name string, runnerInstance runner.RunnerV2, config agentPluginConfig, policyData map[string]interface{}) error { +func configureRunner(name string, runnerInstance runner.RunnerV2, config agentPluginConfig, policyData map[string]interface{}, policyBehavior map[string][]string) error { policyDataStruct, err := mapToStruct(policyData) if err != nil { return fmt.Errorf("invalid policy_data for plugin %s: %w", name, err) } _, err = runnerInstance.Configure(&proto.ConfigureRequest{ - Config: config, - PolicyData: policyDataStruct, + Config: config, + PolicyData: policyDataStruct, + PolicyBehavior: policyBehaviorToProto(policyBehavior), }) return err } @@ -968,6 +971,17 @@ func mapToStruct(m map[string]interface{}) (*structpb.Struct, error) { return structpb.NewStruct(m) } +func policyBehaviorToProto(policyBehavior map[string][]string) map[string]*proto.StringList { + if policyBehavior == nil { + return nil + } + result := make(map[string]*proto.StringList) + for key, values := range policyBehavior { + result[key] = &proto.StringList{Values: values} + } + return result +} + func pluginEvidenceLabels(config *agentConfig, pluginName string, pluginConfig *agentPlugin) map[string]string { return pluginEvidenceLabelsWithHash(config, pluginName, pluginConfig, agentConfigurationHash(config)) } @@ -1373,7 +1387,7 @@ func (ar *AgentRunner) runAllPlugins(ctx context.Context) error { if err := func() error { defer cleanupRunner() - if err := configureRunner(pluginName, runnerInstance, pluginConfig.Config, pluginConfig.PolicyData); err != nil { + if err := configureRunner(pluginName, runnerInstance, pluginConfig.Config, pluginConfig.PolicyData, pluginConfig.PolicyBehavior); err != nil { // What do we do here ? //endTimer := time.Now() //_, err = client.Results.Create(&sdk.Result{ @@ -1402,13 +1416,15 @@ func (ar *AgentRunner) runAllPlugins(ctx context.Context) error { ) resultsHelper := runner.NewApiHelper(logger, client, labels, pluginName) - if err := initRunner(pluginName, pluginConfig.ProtocolVersion, runnerInstance, policyPaths, resultsHelper); err != nil { + policyBehaviorProto := policyBehaviorToProto(pluginConfig.PolicyBehavior) + if err := initRunner(pluginName, pluginConfig.ProtocolVersion, runnerInstance, policyPaths, policyBehaviorProto, resultsHelper); err != nil { return err } // TODO: Send failed results to the database? _, err = runnerInstance.Eval(&proto.EvalRequest{ - PolicyPaths: policyPaths, + PolicyPaths: policyPaths, + PolicyBehavior: policyBehaviorProto, }, resultsHelper) if err != nil { @@ -1519,7 +1535,7 @@ func (ar *AgentRunner) runPlugin(ctx context.Context, name string, plugin *agent } defer cleanupRunner() - if err := configureRunner(name, runnerInstance, plugin.Config, plugin.PolicyData); err != nil { + if err := configureRunner(name, runnerInstance, plugin.Config, plugin.PolicyData, plugin.PolicyBehavior); err != nil { return err } @@ -1531,13 +1547,15 @@ func (ar *AgentRunner) runPlugin(ctx context.Context, name string, plugin *agent ) resultsHelper := runner.NewApiHelper(pluginLogger, client, labels, name) - if err := initRunner(name, plugin.ProtocolVersion, runnerInstance, policyPaths, resultsHelper); err != nil { + policyBehaviorProto := policyBehaviorToProto(plugin.PolicyBehavior) + if err := initRunner(name, plugin.ProtocolVersion, runnerInstance, policyPaths, policyBehaviorProto, resultsHelper); err != nil { return err } // TODO: Send failed results to the database? _, err = runnerInstance.Eval(&proto.EvalRequest{ - PolicyPaths: policyPaths, + PolicyPaths: policyPaths, + PolicyBehavior: policyBehaviorProto, }, resultsHelper) if err != nil { diff --git a/cmd/agent_test.go b/cmd/agent_test.go index a436547..903de12 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "reflect" "strings" "sync/atomic" "testing" @@ -948,7 +949,7 @@ func activePluginClientCount(agentRunner *AgentRunner) int { func TestInitRunner(t *testing.T) { t.Run("skips init for v1", func(t *testing.T) { - err := initRunner("test-plugin", DefaultProtocolVersion, &initTestRunner{}, nil, nil) + err := initRunner("test-plugin", DefaultProtocolVersion, &initTestRunner{}, nil, nil, nil) if err != nil { t.Fatalf("initRunner() error = %v, expected nil", err) } @@ -961,6 +962,7 @@ func TestInitRunner(t *testing.T) { &initTestRunner{initErr: status.Error(codes.Unimplemented, "not implemented")}, nil, nil, + nil, ) if err == nil { t.Fatal("initRunner() error = nil, expected wrapped error") @@ -980,6 +982,7 @@ func TestInitRunner(t *testing.T) { &initTestRunner{initErr: expectedErr}, nil, nil, + nil, ) if !errors.Is(err, expectedErr) { t.Fatalf("initRunner() error = %v, expected %v", err, expectedErr) @@ -996,6 +999,7 @@ func TestConfigureRunner(t *testing.T) { testRunner, agentPluginConfig{"endpoint": "localhost"}, map[string]interface{}{"allowed_versions": map[string]interface{}{"wget": "1.20.3"}}, + map[string][]string{"policy-bundle": {"vpc", "sg"}}, ) if err != nil { t.Fatalf("configureRunner() error = %v, expected nil", err) @@ -1011,6 +1015,9 @@ func TestConfigureRunner(t *testing.T) { if got := allowedVersions.Fields["wget"].GetStringValue(); got != "1.20.3" { t.Fatalf("Configure policy_data allowed_versions.wget = %q, expected %q", got, "1.20.3") } + if got := testRunner.configureRequest.PolicyBehavior["policy-bundle"].Values; !reflect.DeepEqual(got, []string{"vpc", "sg"}) { + t.Fatalf("Configure policyBehavior policy-bundle = %#v, expected %#v", got, []string{"vpc", "sg"}) + } }) t.Run("rejects unsupported policy data before configuring runner", func(t *testing.T) { @@ -1021,6 +1028,7 @@ func TestConfigureRunner(t *testing.T) { testRunner, nil, map[string]interface{}{"unsupported": make(chan int)}, + nil, ) if err == nil { t.Fatal("configureRunner() error = nil, expected invalid policy_data error") diff --git a/runner/proto/policy_behavior.go b/runner/proto/policy_behavior.go new file mode 100644 index 0000000..970a4f0 --- /dev/null +++ b/runner/proto/policy_behavior.go @@ -0,0 +1,108 @@ +package proto + +import ( + "slices" + "strings" +) + +func (r *EvalRequest) WithDefaultPolicyBehavior(defaults map[string][]string) *EvalRequest { + if r == nil { + return nil + } + + return &EvalRequest{ + PolicyPaths: slices.Clone(r.PolicyPaths), + ApiServer: r.ApiServer, + PolicyBehavior: mergePolicyBehavior(defaults, r.PolicyBehavior), + } +} + +func (r *EvalRequest) WithUndefinedMappedTo(behavior []string) *EvalRequest { + if r == nil { + return nil + } + + copy := &EvalRequest{ + PolicyPaths: slices.Clone(r.PolicyPaths), + ApiServer: r.ApiServer, + PolicyBehavior: mergePolicyBehavior(nil, r.PolicyBehavior), + } + + for _, path := range copy.PolicyPaths { + if pathCoveredByPolicyBehavior(path, copy.PolicyBehavior) { + continue + } + if copy.PolicyBehavior == nil { + copy.PolicyBehavior = make(map[string]*StringList) + } + copy.PolicyBehavior[path] = &StringList{Values: slices.Clone(behavior)} + } + + return copy +} + +func (r *EvalRequest) PolicyPathsForBehavior(behavior string) []string { + if r == nil { + return nil + } + + if len(r.PolicyBehavior) == 0 { + return []string{} + } + + matchingKeys := make([]string, 0, len(r.PolicyBehavior)) + for key, list := range r.PolicyBehavior { + if list == nil || !slices.Contains(list.Values, behavior) { + continue + } + matchingKeys = append(matchingKeys, key) + } + + if len(matchingKeys) == 0 { + return []string{} + } + + filtered := make([]string, 0, len(r.PolicyPaths)) +outer: + for _, path := range r.PolicyPaths { + for _, key := range matchingKeys { + if pathCoveredByPolicyBehavior(path, map[string]*StringList{key: nil}) { + filtered = append(filtered, path) + continue outer + } + } + } + + return filtered +} + +func mergePolicyBehavior(defaults map[string][]string, configured map[string]*StringList) map[string]*StringList { + if len(defaults) == 0 && len(configured) == 0 { + return nil + } + + merged := make(map[string]*StringList, len(defaults)+len(configured)) + for key, values := range defaults { + merged[key] = &StringList{Values: slices.Clone(values)} + } + + for key, list := range configured { + if list == nil { + merged[key] = nil + continue + } + merged[key] = &StringList{Values: slices.Clone(list.Values)} + } + + return merged +} + +func pathCoveredByPolicyBehavior(path string, behavior map[string]*StringList) bool { + for key := range behavior { + if strings.Contains(path, key) { + return true + } + } + + return false +} diff --git a/runner/proto/policy_behavior_test.go b/runner/proto/policy_behavior_test.go new file mode 100644 index 0000000..e8f026c --- /dev/null +++ b/runner/proto/policy_behavior_test.go @@ -0,0 +1,216 @@ +package proto + +import ( + "reflect" + "testing" +) + +func TestEvalRequestWithUndefinedMappedTo(t *testing.T) { + t.Run("nil request returns nil", func(t *testing.T) { + var request *EvalRequest + if got := request.WithUndefinedMappedTo([]string{"vpc"}); got != nil { + t.Fatalf("WithUndefinedMappedTo() = %#v, want nil", got) + } + }) + + t.Run("adds full path mapping for uncovered policies", func(t *testing.T) { + request := &EvalRequest{ + PolicyPaths: []string{"/tmp/unmapped/vpc.rego"}, + } + + got := request.WithUndefinedMappedTo([]string{"vpc"}) + + wantBehavior := map[string]*StringList{ + "/tmp/unmapped/vpc.rego": {Values: []string{"vpc"}}, + } + if !reflect.DeepEqual(got.PolicyBehavior, wantBehavior) { + t.Fatalf("WithUndefinedMappedTo().PolicyBehavior = %#v, want %#v", got.PolicyBehavior, wantBehavior) + } + + if request.PolicyBehavior != nil { + t.Fatalf("original request.PolicyBehavior = %#v, want nil", request.PolicyBehavior) + } + }) + + t.Run("does not add full path mapping for covered policies", func(t *testing.T) { + request := &EvalRequest{ + PolicyPaths: []string{"/tmp/plugin-aws-networking-security-policies/vpc.rego"}, + PolicyBehavior: map[string]*StringList{ + "plugin-aws-networking-security-policies": {Values: []string{"vpc"}}, + }, + } + + got := request.WithUndefinedMappedTo([]string{"vpc"}) + + wantBehavior := map[string]*StringList{ + "plugin-aws-networking-security-policies": {Values: []string{"vpc"}}, + } + if !reflect.DeepEqual(got.PolicyBehavior, wantBehavior) { + t.Fatalf("WithUndefinedMappedTo().PolicyBehavior = %#v, want %#v", got.PolicyBehavior, wantBehavior) + } + }) + + t.Run("chains after defaults and fills only remaining uncovered paths", func(t *testing.T) { + request := &EvalRequest{ + PolicyPaths: []string{ + "/tmp/plugin-aws-networking-security-policies/vpc.rego", + "/tmp/custom/unmapped.rego", + }, + } + + got := request. + WithDefaultPolicyBehavior(map[string][]string{ + "plugin-aws-networking-security-policies": {"vpc"}, + }). + WithUndefinedMappedTo([]string{"vpc"}). + PolicyPathsForBehavior("vpc") + + want := []string{ + "/tmp/plugin-aws-networking-security-policies/vpc.rego", + "/tmp/custom/unmapped.rego", + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("defaults then undefined chain result = %#v, want %#v", got, want) + } + }) +} + +func TestEvalRequestWithDefaultPolicyBehavior(t *testing.T) { + t.Run("nil request returns nil", func(t *testing.T) { + var request *EvalRequest + if got := request.WithDefaultPolicyBehavior(map[string][]string{"default": {"vpc"}}); got != nil { + t.Fatalf("WithDefaultPolicyBehavior() = %#v, want nil", got) + } + }) + + t.Run("returns copied request with defaults when config is empty", func(t *testing.T) { + request := &EvalRequest{ + PolicyPaths: []string{"/tmp/default-policies/vpc.rego"}, + } + + got := request.WithDefaultPolicyBehavior(map[string][]string{ + "default-policies": {"vpc"}, + }) + + want := &EvalRequest{ + PolicyPaths: []string{"/tmp/default-policies/vpc.rego"}, + PolicyBehavior: map[string]*StringList{ + "default-policies": {Values: []string{"vpc"}}, + }, + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("WithDefaultPolicyBehavior() = %#v, want %#v", got, want) + } + + if request.PolicyBehavior != nil { + t.Fatalf("original request.PolicyBehavior = %#v, want nil", request.PolicyBehavior) + } + }) + + t.Run("configured values take precedence over defaults", func(t *testing.T) { + request := &EvalRequest{ + PolicyPaths: []string{ + "/tmp/default-policies/vpc.rego", + "/tmp/configured-policies/subnet.rego", + }, + PolicyBehavior: map[string]*StringList{ + "default-policies": {Values: []string{"subnet"}}, + "configured-policies": {Values: []string{"subnet"}}, + }, + } + + got := request.WithDefaultPolicyBehavior(map[string][]string{ + "default-policies": {"vpc"}, + "extra-policies": {"vpc"}, + }) + + wantBehavior := map[string]*StringList{ + "default-policies": {Values: []string{"subnet"}}, + "configured-policies": {Values: []string{"subnet"}}, + "extra-policies": {Values: []string{"vpc"}}, + } + + if !reflect.DeepEqual(got.PolicyBehavior, wantBehavior) { + t.Fatalf("WithDefaultPolicyBehavior().PolicyBehavior = %#v, want %#v", got.PolicyBehavior, wantBehavior) + } + + if !reflect.DeepEqual(request.PolicyBehavior["default-policies"].Values, []string{"subnet"}) { + t.Fatalf("original configured values were mutated: %#v", request.PolicyBehavior["default-policies"].Values) + } + }) + + t.Run("chains into PolicyPathsForBehavior", func(t *testing.T) { + request := &EvalRequest{ + PolicyPaths: []string{"/tmp/default-policies/vpc.rego", "/tmp/other-policies/general.rego"}, + } + + got := request.WithDefaultPolicyBehavior(map[string][]string{ + "default-policies": {"vpc"}, + }).PolicyPathsForBehavior("vpc") + + want := []string{"/tmp/default-policies/vpc.rego"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("WithDefaultPolicyBehavior(...).PolicyPathsForBehavior() = %#v, want %#v", got, want) + } + }) +} + +func TestEvalRequestPolicyPathsForBehavior(t *testing.T) { + tests := []struct { + name string + request *EvalRequest + behavior string + want []string + }{ + { + name: "nil request returns nil", + request: nil, + behavior: "vpc", + want: nil, + }, + { + name: "no policy behavior returns empty list", + request: &EvalRequest{ + PolicyPaths: []string{"/tmp/a", "/tmp/b"}, + }, + behavior: "vpc", + want: []string{}, + }, + { + name: "matching behavior filters paths by matching keys", + request: &EvalRequest{ + PolicyPaths: []string{ + "/tmp/plugin-aws-networking-security-policies/vpc.rego", + "/tmp/other-policies/general.rego", + }, + PolicyBehavior: map[string]*StringList{ + "plugin-aws-networking-security-policies": {Values: []string{"vpc"}}, + "other-policies": {Values: []string{"subnet"}}, + }, + }, + behavior: "vpc", + want: []string{"/tmp/plugin-aws-networking-security-policies/vpc.rego"}, + }, + { + name: "no matching behavior returns empty list", + request: &EvalRequest{ + PolicyPaths: []string{"/tmp/a", "/tmp/b"}, + PolicyBehavior: map[string]*StringList{ + "plugin-aws-networking-security-policies": {Values: []string{"vpc"}}, + }, + }, + behavior: "subnet", + want: []string{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := test.request.PolicyPathsForBehavior(test.behavior) + if !reflect.DeepEqual(got, test.want) { + t.Fatalf("PolicyPathsForBehavior() = %#v, want %#v", got, test.want) + } + }) + } +} diff --git a/runner/proto/runner.pb.go b/runner/proto/runner.pb.go index c0480a3..2d2cff1 100644 --- a/runner/proto/runner.pb.go +++ b/runner/proto/runner.pb.go @@ -68,17 +68,62 @@ func (ExecutionStatus) EnumDescriptor() ([]byte, []int) { return file_runner_proto_runner_proto_rawDescGZIP(), []int{0} } -type ConfigureRequest struct { +type StringList struct { state protoimpl.MessageState `protogen:"open.v1"` - Config map[string]string `protobuf:"bytes,1,rep,name=config,proto3" json:"config,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` - PolicyData *structpb.Struct `protobuf:"bytes,2,opt,name=policy_data,json=policyData,proto3" json:"policy_data,omitempty"` + Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } +func (x *StringList) Reset() { + *x = StringList{} + mi := &file_runner_proto_runner_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StringList) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StringList) ProtoMessage() {} + +func (x *StringList) ProtoReflect() protoreflect.Message { + mi := &file_runner_proto_runner_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StringList.ProtoReflect.Descriptor instead. +func (*StringList) Descriptor() ([]byte, []int) { + return file_runner_proto_runner_proto_rawDescGZIP(), []int{0} +} + +func (x *StringList) GetValues() []string { + if x != nil { + return x.Values + } + return nil +} + +type ConfigureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Config map[string]string `protobuf:"bytes,1,rep,name=config,proto3" json:"config,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + PolicyData *structpb.Struct `protobuf:"bytes,2,opt,name=policy_data,json=policyData,proto3" json:"policy_data,omitempty"` + PolicyBehavior map[string]*StringList `protobuf:"bytes,3,rep,name=policyBehavior,proto3" json:"policyBehavior,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + func (x *ConfigureRequest) Reset() { *x = ConfigureRequest{} - mi := &file_runner_proto_runner_proto_msgTypes[0] + mi := &file_runner_proto_runner_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -90,7 +135,7 @@ func (x *ConfigureRequest) String() string { func (*ConfigureRequest) ProtoMessage() {} func (x *ConfigureRequest) ProtoReflect() protoreflect.Message { - mi := &file_runner_proto_runner_proto_msgTypes[0] + mi := &file_runner_proto_runner_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -103,7 +148,7 @@ func (x *ConfigureRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ConfigureRequest.ProtoReflect.Descriptor instead. func (*ConfigureRequest) Descriptor() ([]byte, []int) { - return file_runner_proto_runner_proto_rawDescGZIP(), []int{0} + return file_runner_proto_runner_proto_rawDescGZIP(), []int{1} } func (x *ConfigureRequest) GetConfig() map[string]string { @@ -120,6 +165,13 @@ func (x *ConfigureRequest) GetPolicyData() *structpb.Struct { return nil } +func (x *ConfigureRequest) GetPolicyBehavior() map[string]*StringList { + if x != nil { + return x.PolicyBehavior + } + return nil +} + type ConfigureResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` @@ -129,7 +181,7 @@ type ConfigureResponse struct { func (x *ConfigureResponse) Reset() { *x = ConfigureResponse{} - mi := &file_runner_proto_runner_proto_msgTypes[1] + mi := &file_runner_proto_runner_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -141,7 +193,7 @@ func (x *ConfigureResponse) String() string { func (*ConfigureResponse) ProtoMessage() {} func (x *ConfigureResponse) ProtoReflect() protoreflect.Message { - mi := &file_runner_proto_runner_proto_msgTypes[1] + mi := &file_runner_proto_runner_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -154,7 +206,7 @@ func (x *ConfigureResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ConfigureResponse.ProtoReflect.Descriptor instead. func (*ConfigureResponse) Descriptor() ([]byte, []int) { - return file_runner_proto_runner_proto_rawDescGZIP(), []int{1} + return file_runner_proto_runner_proto_rawDescGZIP(), []int{2} } func (x *ConfigureResponse) GetValue() []byte { @@ -165,16 +217,17 @@ func (x *ConfigureResponse) GetValue() []byte { } type InitRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - PolicyPaths []string `protobuf:"bytes,1,rep,name=policyPaths,proto3" json:"policyPaths,omitempty"` - ApiServer uint32 `protobuf:"varint,2,opt,name=apiServer,proto3" json:"apiServer,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + PolicyPaths []string `protobuf:"bytes,1,rep,name=policyPaths,proto3" json:"policyPaths,omitempty"` + ApiServer uint32 `protobuf:"varint,2,opt,name=apiServer,proto3" json:"apiServer,omitempty"` + PolicyBehavior map[string]*StringList `protobuf:"bytes,3,rep,name=policyBehavior,proto3" json:"policyBehavior,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *InitRequest) Reset() { *x = InitRequest{} - mi := &file_runner_proto_runner_proto_msgTypes[2] + mi := &file_runner_proto_runner_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -186,7 +239,7 @@ func (x *InitRequest) String() string { func (*InitRequest) ProtoMessage() {} func (x *InitRequest) ProtoReflect() protoreflect.Message { - mi := &file_runner_proto_runner_proto_msgTypes[2] + mi := &file_runner_proto_runner_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -199,7 +252,7 @@ func (x *InitRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use InitRequest.ProtoReflect.Descriptor instead. func (*InitRequest) Descriptor() ([]byte, []int) { - return file_runner_proto_runner_proto_rawDescGZIP(), []int{2} + return file_runner_proto_runner_proto_rawDescGZIP(), []int{3} } func (x *InitRequest) GetPolicyPaths() []string { @@ -216,6 +269,13 @@ func (x *InitRequest) GetApiServer() uint32 { return 0 } +func (x *InitRequest) GetPolicyBehavior() map[string]*StringList { + if x != nil { + return x.PolicyBehavior + } + return nil +} + type InitResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -224,7 +284,7 @@ type InitResponse struct { func (x *InitResponse) Reset() { *x = InitResponse{} - mi := &file_runner_proto_runner_proto_msgTypes[3] + mi := &file_runner_proto_runner_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -236,7 +296,7 @@ func (x *InitResponse) String() string { func (*InitResponse) ProtoMessage() {} func (x *InitResponse) ProtoReflect() protoreflect.Message { - mi := &file_runner_proto_runner_proto_msgTypes[3] + mi := &file_runner_proto_runner_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -249,20 +309,21 @@ func (x *InitResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use InitResponse.ProtoReflect.Descriptor instead. func (*InitResponse) Descriptor() ([]byte, []int) { - return file_runner_proto_runner_proto_rawDescGZIP(), []int{3} + return file_runner_proto_runner_proto_rawDescGZIP(), []int{4} } type EvalRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - PolicyPaths []string `protobuf:"bytes,1,rep,name=policyPaths,proto3" json:"policyPaths,omitempty"` - ApiServer uint32 `protobuf:"varint,2,opt,name=apiServer,proto3" json:"apiServer,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + PolicyPaths []string `protobuf:"bytes,1,rep,name=policyPaths,proto3" json:"policyPaths,omitempty"` + ApiServer uint32 `protobuf:"varint,2,opt,name=apiServer,proto3" json:"apiServer,omitempty"` + PolicyBehavior map[string]*StringList `protobuf:"bytes,3,rep,name=policyBehavior,proto3" json:"policyBehavior,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *EvalRequest) Reset() { *x = EvalRequest{} - mi := &file_runner_proto_runner_proto_msgTypes[4] + mi := &file_runner_proto_runner_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -274,7 +335,7 @@ func (x *EvalRequest) String() string { func (*EvalRequest) ProtoMessage() {} func (x *EvalRequest) ProtoReflect() protoreflect.Message { - mi := &file_runner_proto_runner_proto_msgTypes[4] + mi := &file_runner_proto_runner_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -287,7 +348,7 @@ func (x *EvalRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use EvalRequest.ProtoReflect.Descriptor instead. func (*EvalRequest) Descriptor() ([]byte, []int) { - return file_runner_proto_runner_proto_rawDescGZIP(), []int{4} + return file_runner_proto_runner_proto_rawDescGZIP(), []int{5} } func (x *EvalRequest) GetPolicyPaths() []string { @@ -304,6 +365,13 @@ func (x *EvalRequest) GetApiServer() uint32 { return 0 } +func (x *EvalRequest) GetPolicyBehavior() map[string]*StringList { + if x != nil { + return x.PolicyBehavior + } + return nil +} + // * // EvalResponse is the result of an assessment check // Results are sent back by the plugins using the Result service defined @@ -317,7 +385,7 @@ type EvalResponse struct { func (x *EvalResponse) Reset() { *x = EvalResponse{} - mi := &file_runner_proto_runner_proto_msgTypes[5] + mi := &file_runner_proto_runner_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -329,7 +397,7 @@ func (x *EvalResponse) String() string { func (*EvalResponse) ProtoMessage() {} func (x *EvalResponse) ProtoReflect() protoreflect.Message { - mi := &file_runner_proto_runner_proto_msgTypes[5] + mi := &file_runner_proto_runner_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -342,7 +410,7 @@ func (x *EvalResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use EvalResponse.ProtoReflect.Descriptor instead. func (*EvalResponse) Descriptor() ([]byte, []int) { - return file_runner_proto_runner_proto_rawDescGZIP(), []int{5} + return file_runner_proto_runner_proto_rawDescGZIP(), []int{6} } func (x *EvalResponse) GetStatus() ExecutionStatus { @@ -356,23 +424,38 @@ var File_runner_proto_runner_proto protoreflect.FileDescriptor const file_runner_proto_runner_proto_rawDesc = "" + "\n" + - "\x19runner/proto/runner.proto\x12\x05proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc4\x01\n" + + "\x19runner/proto/runner.proto\x12\x05proto\x1a\x1cgoogle/protobuf/struct.proto\"$\n" + + "\n" + + "StringList\x12\x16\n" + + "\x06values\x18\x01 \x03(\tR\x06values\"\xef\x02\n" + "\x10ConfigureRequest\x12;\n" + "\x06config\x18\x01 \x03(\v2#.proto.ConfigureRequest.ConfigEntryR\x06config\x128\n" + "\vpolicy_data\x18\x02 \x01(\v2\x17.google.protobuf.StructR\n" + - "policyData\x1a9\n" + + "policyData\x12S\n" + + "\x0epolicyBehavior\x18\x03 \x03(\v2+.proto.ConfigureRequest.PolicyBehaviorEntryR\x0epolicyBehavior\x1a9\n" + "\vConfigEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + - "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\")\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\x1aT\n" + + "\x13PolicyBehaviorEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12'\n" + + "\x05value\x18\x02 \x01(\v2\x11.proto.StringListR\x05value:\x028\x01\")\n" + "\x11ConfigureResponse\x12\x14\n" + - "\x05value\x18\x01 \x01(\fR\x05value\"M\n" + + "\x05value\x18\x01 \x01(\fR\x05value\"\xf3\x01\n" + "\vInitRequest\x12 \n" + "\vpolicyPaths\x18\x01 \x03(\tR\vpolicyPaths\x12\x1c\n" + - "\tapiServer\x18\x02 \x01(\rR\tapiServer\"\x0e\n" + - "\fInitResponse\"M\n" + + "\tapiServer\x18\x02 \x01(\rR\tapiServer\x12N\n" + + "\x0epolicyBehavior\x18\x03 \x03(\v2&.proto.InitRequest.PolicyBehaviorEntryR\x0epolicyBehavior\x1aT\n" + + "\x13PolicyBehaviorEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12'\n" + + "\x05value\x18\x02 \x01(\v2\x11.proto.StringListR\x05value:\x028\x01\"\x0e\n" + + "\fInitResponse\"\xf3\x01\n" + "\vEvalRequest\x12 \n" + "\vpolicyPaths\x18\x01 \x03(\tR\vpolicyPaths\x12\x1c\n" + - "\tapiServer\x18\x02 \x01(\rR\tapiServer\">\n" + + "\tapiServer\x18\x02 \x01(\rR\tapiServer\x12N\n" + + "\x0epolicyBehavior\x18\x03 \x03(\v2&.proto.EvalRequest.PolicyBehaviorEntryR\x0epolicyBehavior\x1aT\n" + + "\x13PolicyBehaviorEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12'\n" + + "\x05value\x18\x02 \x01(\v2\x11.proto.StringListR\x05value:\x028\x01\">\n" + "\fEvalResponse\x12.\n" + "\x06status\x18\x01 \x01(\x0e2\x16.proto.ExecutionStatusR\x06status*+\n" + "\x0fExecutionStatus\x12\v\n" + @@ -396,33 +479,43 @@ func file_runner_proto_runner_proto_rawDescGZIP() []byte { } var file_runner_proto_runner_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_runner_proto_runner_proto_msgTypes = make([]protoimpl.MessageInfo, 7) +var file_runner_proto_runner_proto_msgTypes = make([]protoimpl.MessageInfo, 11) var file_runner_proto_runner_proto_goTypes = []any{ (ExecutionStatus)(0), // 0: proto.ExecutionStatus - (*ConfigureRequest)(nil), // 1: proto.ConfigureRequest - (*ConfigureResponse)(nil), // 2: proto.ConfigureResponse - (*InitRequest)(nil), // 3: proto.InitRequest - (*InitResponse)(nil), // 4: proto.InitResponse - (*EvalRequest)(nil), // 5: proto.EvalRequest - (*EvalResponse)(nil), // 6: proto.EvalResponse - nil, // 7: proto.ConfigureRequest.ConfigEntry - (*structpb.Struct)(nil), // 8: google.protobuf.Struct + (*StringList)(nil), // 1: proto.StringList + (*ConfigureRequest)(nil), // 2: proto.ConfigureRequest + (*ConfigureResponse)(nil), // 3: proto.ConfigureResponse + (*InitRequest)(nil), // 4: proto.InitRequest + (*InitResponse)(nil), // 5: proto.InitResponse + (*EvalRequest)(nil), // 6: proto.EvalRequest + (*EvalResponse)(nil), // 7: proto.EvalResponse + nil, // 8: proto.ConfigureRequest.ConfigEntry + nil, // 9: proto.ConfigureRequest.PolicyBehaviorEntry + nil, // 10: proto.InitRequest.PolicyBehaviorEntry + nil, // 11: proto.EvalRequest.PolicyBehaviorEntry + (*structpb.Struct)(nil), // 12: google.protobuf.Struct } var file_runner_proto_runner_proto_depIdxs = []int32{ - 7, // 0: proto.ConfigureRequest.config:type_name -> proto.ConfigureRequest.ConfigEntry - 8, // 1: proto.ConfigureRequest.policy_data:type_name -> google.protobuf.Struct - 0, // 2: proto.EvalResponse.status:type_name -> proto.ExecutionStatus - 1, // 3: proto.Runner.Configure:input_type -> proto.ConfigureRequest - 5, // 4: proto.Runner.Eval:input_type -> proto.EvalRequest - 3, // 5: proto.Runner.Init:input_type -> proto.InitRequest - 2, // 6: proto.Runner.Configure:output_type -> proto.ConfigureResponse - 6, // 7: proto.Runner.Eval:output_type -> proto.EvalResponse - 4, // 8: proto.Runner.Init:output_type -> proto.InitResponse - 6, // [6:9] is the sub-list for method output_type - 3, // [3:6] is the sub-list for method input_type - 3, // [3:3] is the sub-list for extension type_name - 3, // [3:3] is the sub-list for extension extendee - 0, // [0:3] is the sub-list for field type_name + 8, // 0: proto.ConfigureRequest.config:type_name -> proto.ConfigureRequest.ConfigEntry + 12, // 1: proto.ConfigureRequest.policy_data:type_name -> google.protobuf.Struct + 9, // 2: proto.ConfigureRequest.policyBehavior:type_name -> proto.ConfigureRequest.PolicyBehaviorEntry + 10, // 3: proto.InitRequest.policyBehavior:type_name -> proto.InitRequest.PolicyBehaviorEntry + 11, // 4: proto.EvalRequest.policyBehavior:type_name -> proto.EvalRequest.PolicyBehaviorEntry + 0, // 5: proto.EvalResponse.status:type_name -> proto.ExecutionStatus + 1, // 6: proto.ConfigureRequest.PolicyBehaviorEntry.value:type_name -> proto.StringList + 1, // 7: proto.InitRequest.PolicyBehaviorEntry.value:type_name -> proto.StringList + 1, // 8: proto.EvalRequest.PolicyBehaviorEntry.value:type_name -> proto.StringList + 2, // 9: proto.Runner.Configure:input_type -> proto.ConfigureRequest + 6, // 10: proto.Runner.Eval:input_type -> proto.EvalRequest + 4, // 11: proto.Runner.Init:input_type -> proto.InitRequest + 3, // 12: proto.Runner.Configure:output_type -> proto.ConfigureResponse + 7, // 13: proto.Runner.Eval:output_type -> proto.EvalResponse + 5, // 14: proto.Runner.Init:output_type -> proto.InitResponse + 12, // [12:15] is the sub-list for method output_type + 9, // [9:12] is the sub-list for method input_type + 9, // [9:9] is the sub-list for extension type_name + 9, // [9:9] is the sub-list for extension extendee + 0, // [0:9] is the sub-list for field type_name } func init() { file_runner_proto_runner_proto_init() } @@ -436,7 +529,7 @@ func file_runner_proto_runner_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_runner_proto_runner_proto_rawDesc), len(file_runner_proto_runner_proto_rawDesc)), NumEnums: 1, - NumMessages: 7, + NumMessages: 11, NumExtensions: 0, NumServices: 1, }, diff --git a/runner/proto/runner.proto b/runner/proto/runner.proto index 9a72bc6..be450c0 100644 --- a/runner/proto/runner.proto +++ b/runner/proto/runner.proto @@ -5,6 +5,10 @@ option go_package = "./proto"; import "google/protobuf/struct.proto"; +message StringList { + repeated string values = 1; +} + enum ExecutionStatus { SUCCESS = 0; FAILURE = 1; @@ -13,6 +17,7 @@ enum ExecutionStatus { message ConfigureRequest { map config = 1; google.protobuf.Struct policy_data = 2; + map policyBehavior = 3; } message ConfigureResponse { @@ -22,6 +27,7 @@ message ConfigureResponse { message InitRequest { repeated string policyPaths = 1; uint32 apiServer = 2; + map policyBehavior = 3; } message InitResponse { @@ -30,6 +36,7 @@ message InitResponse { message EvalRequest { repeated string policyPaths = 1; uint32 apiServer = 2; + map policyBehavior = 3; } /** From 88f2d8faa88d855a18aac22f84ca750fe2d8de62 Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Thu, 21 May 2026 14:34:12 +0100 Subject: [PATCH 2/3] fix: copilot issues --- cmd/agent.go | 6 +++--- cmd/agent_test.go | 28 +++++++++++++++++++++++++- runner/proto/policy_behavior.go | 17 +++++++++------- runner/proto/policy_behavior_test.go | 30 ++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 11 deletions(-) diff --git a/cmd/agent.go b/cmd/agent.go index 7e44e73..42c4e2d 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -972,12 +972,12 @@ func mapToStruct(m map[string]interface{}) (*structpb.Struct, error) { } func policyBehaviorToProto(policyBehavior map[string][]string) map[string]*proto.StringList { - if policyBehavior == nil { + if len(policyBehavior) == 0 { return nil } - result := make(map[string]*proto.StringList) + result := make(map[string]*proto.StringList, len(policyBehavior)) for key, values := range policyBehavior { - result[key] = &proto.StringList{Values: values} + result[key] = &proto.StringList{Values: append([]string(nil), values...)} } return result } diff --git a/cmd/agent_test.go b/cmd/agent_test.go index 903de12..c33bcad 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -1015,7 +1015,7 @@ func TestConfigureRunner(t *testing.T) { if got := allowedVersions.Fields["wget"].GetStringValue(); got != "1.20.3" { t.Fatalf("Configure policy_data allowed_versions.wget = %q, expected %q", got, "1.20.3") } - if got := testRunner.configureRequest.PolicyBehavior["policy-bundle"].Values; !reflect.DeepEqual(got, []string{"vpc", "sg"}) { + if got := testRunner.configureRequest.GetPolicyBehavior()["policy-bundle"].GetValues(); !reflect.DeepEqual(got, []string{"vpc", "sg"}) { t.Fatalf("Configure policyBehavior policy-bundle = %#v, expected %#v", got, []string{"vpc", "sg"}) } }) @@ -1042,6 +1042,32 @@ func TestConfigureRunner(t *testing.T) { }) } +func TestPolicyBehaviorToProto(t *testing.T) { + t.Run("nil and empty maps return nil", func(t *testing.T) { + if got := policyBehaviorToProto(nil); got != nil { + t.Fatalf("policyBehaviorToProto(nil) = %#v, want nil", got) + } + + if got := policyBehaviorToProto(map[string][]string{}); got != nil { + t.Fatalf("policyBehaviorToProto(empty) = %#v, want nil", got) + } + }) + + t.Run("clones value slices", func(t *testing.T) { + input := map[string][]string{ + "bundle": {"vpc", "sg"}, + } + + got := policyBehaviorToProto(input) + input["bundle"][0] = "mutated" + + want := []string{"vpc", "sg"} + if !reflect.DeepEqual(got["bundle"].Values, want) { + t.Fatalf("policyBehaviorToProto() values = %#v, want %#v", got["bundle"].Values, want) + } + }) +} + func TestAgentRunnerBuildsAuthenticatedSDKClient(t *testing.T) { var ( tokenRequests int diff --git a/runner/proto/policy_behavior.go b/runner/proto/policy_behavior.go index 970a4f0..2adbec2 100644 --- a/runner/proto/policy_behavior.go +++ b/runner/proto/policy_behavior.go @@ -22,23 +22,23 @@ func (r *EvalRequest) WithUndefinedMappedTo(behavior []string) *EvalRequest { return nil } - copy := &EvalRequest{ + cloned := &EvalRequest{ PolicyPaths: slices.Clone(r.PolicyPaths), ApiServer: r.ApiServer, PolicyBehavior: mergePolicyBehavior(nil, r.PolicyBehavior), } - for _, path := range copy.PolicyPaths { - if pathCoveredByPolicyBehavior(path, copy.PolicyBehavior) { + for _, path := range cloned.PolicyPaths { + if pathCoveredByPolicyBehavior(path, cloned.PolicyBehavior) { continue } - if copy.PolicyBehavior == nil { - copy.PolicyBehavior = make(map[string]*StringList) + if cloned.PolicyBehavior == nil { + cloned.PolicyBehavior = make(map[string]*StringList) } - copy.PolicyBehavior[path] = &StringList{Values: slices.Clone(behavior)} + cloned.PolicyBehavior[path] = &StringList{Values: slices.Clone(behavior)} } - return copy + return cloned } func (r *EvalRequest) PolicyPathsForBehavior(behavior string) []string { @@ -99,6 +99,9 @@ func mergePolicyBehavior(defaults map[string][]string, configured map[string]*St func pathCoveredByPolicyBehavior(path string, behavior map[string]*StringList) bool { for key := range behavior { + if key == "" { + continue + } if strings.Contains(path, key) { return true } diff --git a/runner/proto/policy_behavior_test.go b/runner/proto/policy_behavior_test.go index e8f026c..23001cc 100644 --- a/runner/proto/policy_behavior_test.go +++ b/runner/proto/policy_behavior_test.go @@ -50,6 +50,25 @@ func TestEvalRequestWithUndefinedMappedTo(t *testing.T) { } }) + t.Run("empty key does not count as covered", func(t *testing.T) { + request := &EvalRequest{ + PolicyPaths: []string{"/tmp/custom/unmapped.rego"}, + PolicyBehavior: map[string]*StringList{ + "": {Values: []string{"vpc"}}, + }, + } + + got := request.WithUndefinedMappedTo([]string{"vpc"}) + + wantBehavior := map[string]*StringList{ + "": {Values: []string{"vpc"}}, + "/tmp/custom/unmapped.rego": {Values: []string{"vpc"}}, + } + if !reflect.DeepEqual(got.PolicyBehavior, wantBehavior) { + t.Fatalf("WithUndefinedMappedTo() with empty key = %#v, want %#v", got.PolicyBehavior, wantBehavior) + } + }) + t.Run("chains after defaults and fills only remaining uncovered paths", func(t *testing.T) { request := &EvalRequest{ PolicyPaths: []string{ @@ -203,6 +222,17 @@ func TestEvalRequestPolicyPathsForBehavior(t *testing.T) { behavior: "subnet", want: []string{}, }, + { + name: "empty key is ignored for matching", + request: &EvalRequest{ + PolicyPaths: []string{"/tmp/a", "/tmp/b"}, + PolicyBehavior: map[string]*StringList{ + "": {Values: []string{"vpc"}}, + }, + }, + behavior: "vpc", + want: []string{}, + }, } for _, test := range tests { From cdf6f1de16ab8d8ee41ca94c39e388dfae9fb2a1 Mon Sep 17 00:00:00 2001 From: Reece Bedding Date: Thu, 21 May 2026 14:45:37 +0100 Subject: [PATCH 3/3] fix: copilot issues --- README.md | 4 ++++ cmd/agent.go | 4 ++++ cmd/agent_test.go | 42 +++++++++++++++++++++++++++++++++ runner/proto/policy_behavior.go | 2 +- 4 files changed, 51 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 41cebc8..5af85be 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,10 @@ plugins: - policy_data: # Optional: Mapping for supported policies. Can be any data structure : + policy_behavior: # Optional: Used with supported plugins to filter specific policies for different inputs + : + - label1 + - label2 config: : : diff --git a/cmd/agent.go b/cmd/agent.go index 42c4e2d..2767a1b 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -977,6 +977,10 @@ func policyBehaviorToProto(policyBehavior map[string][]string) map[string]*proto } result := make(map[string]*proto.StringList, len(policyBehavior)) for key, values := range policyBehavior { + if values == nil { + result[key] = nil + continue + } result[key] = &proto.StringList{Values: append([]string(nil), values...)} } return result diff --git a/cmd/agent_test.go b/cmd/agent_test.go index c33bcad..5d72fd6 100644 --- a/cmd/agent_test.go +++ b/cmd/agent_test.go @@ -31,6 +31,8 @@ type initTestRunner struct { configureCalls int configureErr error configureRequest *proto.ConfigureRequest + initCalls int + initRequest *proto.InitRequest initErr error } @@ -51,6 +53,8 @@ func (r *initTestRunner) Eval(request *proto.EvalRequest, a runner.ApiHelper) (* } func (r *initTestRunner) Init(request *proto.InitRequest, a runner.ApiHelper) (*proto.InitResponse, error) { + r.initCalls++ + r.initRequest = request return &proto.InitResponse{}, r.initErr } @@ -955,6 +959,31 @@ func TestInitRunner(t *testing.T) { } }) + t.Run("passes policy behavior to init request", func(t *testing.T) { + testRunner := &initTestRunner{} + policyBehavior := map[string]*proto.StringList{ + "policy-bundle": {Values: []string{"vpc", "sg"}}, + } + + err := initRunner( + "test-plugin", + RunnerV2ProtocolVersion, + testRunner, + []string{"/tmp/policies/vpc.rego"}, + policyBehavior, + nil, + ) + if err != nil { + t.Fatalf("initRunner() error = %v, expected nil", err) + } + if testRunner.initCalls != 1 { + t.Fatalf("Init called %d times, expected 1", testRunner.initCalls) + } + if got := testRunner.initRequest.GetPolicyBehavior()["policy-bundle"].GetValues(); !reflect.DeepEqual(got, []string{"vpc", "sg"}) { + t.Fatalf("Init policyBehavior policy-bundle = %#v, expected %#v", got, []string{"vpc", "sg"}) + } + }) + t.Run("wraps unimplemented init for configured v2 plugin", func(t *testing.T) { err := initRunner( "test-plugin", @@ -1066,6 +1095,19 @@ func TestPolicyBehaviorToProto(t *testing.T) { t.Fatalf("policyBehaviorToProto() values = %#v, want %#v", got["bundle"].Values, want) } }) + + t.Run("preserves nil slices as nil StringList values", func(t *testing.T) { + got := policyBehaviorToProto(map[string][]string{ + "bundle": nil, + }) + + if _, ok := got["bundle"]; !ok { + t.Fatalf("policyBehaviorToProto() missing bundle key: %#v", got) + } + if got["bundle"] != nil { + t.Fatalf("policyBehaviorToProto() bundle = %#v, want nil", got["bundle"]) + } + }) } func TestAgentRunnerBuildsAuthenticatedSDKClient(t *testing.T) { diff --git a/runner/proto/policy_behavior.go b/runner/proto/policy_behavior.go index 2adbec2..4dabeb6 100644 --- a/runner/proto/policy_behavior.go +++ b/runner/proto/policy_behavior.go @@ -66,7 +66,7 @@ func (r *EvalRequest) PolicyPathsForBehavior(behavior string) []string { outer: for _, path := range r.PolicyPaths { for _, key := range matchingKeys { - if pathCoveredByPolicyBehavior(path, map[string]*StringList{key: nil}) { + if key != "" && strings.Contains(path, key) { filtered = append(filtered, path) continue outer }