Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ plugins:
- <policy>
policy_data: # Optional: Mapping for supported policies. Can be any data structure
<data key>: <data value>
policy_behavior: # Optional: Used with supported plugins to filter specific policies for different inputs
<substring>:
- label1
- label2
config:
<config1>: <value>
<config2>: <value>
Expand Down
44 changes: 33 additions & 11 deletions cmd/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type agentPlugin struct {
Config agentPluginConfig `mapstructure:"config"`
Labels map[string]string `mapstructure:"labels"`
PolicyData map[string]interface{} `mapstructure:"policy_data,omitempty"`
PolicyBehavior map[string][]string `mapstructure:"policy_behavior,omitempty"`
protocolSet bool
}

Expand Down Expand Up @@ -377,13 +378,14 @@ func runnerDispenseName(protocolVersion int32) (string, error) {
}
}

func initRunner(name string, protocolVersion int32, runnerInstance runner.RunnerV2, policyPaths []string, resultsHelper runner.ApiHelper) error {
func initRunner(name string, protocolVersion int32, runnerInstance runner.RunnerV2, policyPaths []string, policyBehavior map[string]*proto.StringList, resultsHelper runner.ApiHelper) error {
if protocolVersion <= DefaultProtocolVersion {
return nil
}

_, err := runnerInstance.Init(&proto.InitRequest{
PolicyPaths: policyPaths,
PolicyPaths: policyPaths,
PolicyBehavior: policyBehavior,
}, resultsHelper)
if err == nil {
return nil
Expand All @@ -396,15 +398,16 @@ func initRunner(name string, protocolVersion int32, runnerInstance runner.Runner
return err
}

func configureRunner(name string, runnerInstance runner.RunnerV2, config agentPluginConfig, policyData map[string]interface{}) error {
func configureRunner(name string, runnerInstance runner.RunnerV2, config agentPluginConfig, policyData map[string]interface{}, policyBehavior map[string][]string) error {
policyDataStruct, err := mapToStruct(policyData)
if err != nil {
return fmt.Errorf("invalid policy_data for plugin %s: %w", name, err)
}

_, err = runnerInstance.Configure(&proto.ConfigureRequest{
Config: config,
PolicyData: policyDataStruct,
Config: config,
PolicyData: policyDataStruct,
PolicyBehavior: policyBehaviorToProto(policyBehavior),
})
return err
}
Expand Down Expand Up @@ -968,6 +971,21 @@ func mapToStruct(m map[string]interface{}) (*structpb.Struct, error) {
return structpb.NewStruct(m)
}

func policyBehaviorToProto(policyBehavior map[string][]string) map[string]*proto.StringList {
if len(policyBehavior) == 0 {
return nil
}
result := make(map[string]*proto.StringList, len(policyBehavior))
for key, values := range policyBehavior {
if values == nil {
result[key] = nil
continue
}
result[key] = &proto.StringList{Values: append([]string(nil), values...)}
}
return result
}

func pluginEvidenceLabels(config *agentConfig, pluginName string, pluginConfig *agentPlugin) map[string]string {
return pluginEvidenceLabelsWithHash(config, pluginName, pluginConfig, agentConfigurationHash(config))
}
Expand Down Expand Up @@ -1373,7 +1391,7 @@ func (ar *AgentRunner) runAllPlugins(ctx context.Context) error {
if err := func() error {
defer cleanupRunner()

if err := configureRunner(pluginName, runnerInstance, pluginConfig.Config, pluginConfig.PolicyData); err != nil {
if err := configureRunner(pluginName, runnerInstance, pluginConfig.Config, pluginConfig.PolicyData, pluginConfig.PolicyBehavior); err != nil {
// What do we do here ?
//endTimer := time.Now()
//_, err = client.Results.Create(&sdk.Result{
Expand Down Expand Up @@ -1402,13 +1420,15 @@ func (ar *AgentRunner) runAllPlugins(ctx context.Context) error {
)
resultsHelper := runner.NewApiHelper(logger, client, labels, pluginName)

if err := initRunner(pluginName, pluginConfig.ProtocolVersion, runnerInstance, policyPaths, resultsHelper); err != nil {
policyBehaviorProto := policyBehaviorToProto(pluginConfig.PolicyBehavior)
if err := initRunner(pluginName, pluginConfig.ProtocolVersion, runnerInstance, policyPaths, policyBehaviorProto, resultsHelper); err != nil {
return err
}

// TODO: Send failed results to the database?
_, err = runnerInstance.Eval(&proto.EvalRequest{
PolicyPaths: policyPaths,
PolicyPaths: policyPaths,
PolicyBehavior: policyBehaviorProto,
}, resultsHelper)

if err != nil {
Expand Down Expand Up @@ -1519,7 +1539,7 @@ func (ar *AgentRunner) runPlugin(ctx context.Context, name string, plugin *agent
}
defer cleanupRunner()

if err := configureRunner(name, runnerInstance, plugin.Config, plugin.PolicyData); err != nil {
if err := configureRunner(name, runnerInstance, plugin.Config, plugin.PolicyData, plugin.PolicyBehavior); err != nil {
return err
}

Expand All @@ -1531,13 +1551,15 @@ func (ar *AgentRunner) runPlugin(ctx context.Context, name string, plugin *agent
)
resultsHelper := runner.NewApiHelper(pluginLogger, client, labels, name)

if err := initRunner(name, plugin.ProtocolVersion, runnerInstance, policyPaths, resultsHelper); err != nil {
policyBehaviorProto := policyBehaviorToProto(plugin.PolicyBehavior)
if err := initRunner(name, plugin.ProtocolVersion, runnerInstance, policyPaths, policyBehaviorProto, resultsHelper); err != nil {
return err
}

// TODO: Send failed results to the database?
_, err = runnerInstance.Eval(&proto.EvalRequest{
PolicyPaths: policyPaths,
PolicyPaths: policyPaths,
PolicyBehavior: policyBehaviorProto,
}, resultsHelper)

if err != nil {
Expand Down
78 changes: 77 additions & 1 deletion cmd/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net/http"
"reflect"
"strings"
"sync/atomic"
"testing"
Expand All @@ -30,6 +31,8 @@ type initTestRunner struct {
configureCalls int
configureErr error
configureRequest *proto.ConfigureRequest
initCalls int
initRequest *proto.InitRequest
initErr error
}

Expand All @@ -50,6 +53,8 @@ func (r *initTestRunner) Eval(request *proto.EvalRequest, a runner.ApiHelper) (*
}

func (r *initTestRunner) Init(request *proto.InitRequest, a runner.ApiHelper) (*proto.InitResponse, error) {
r.initCalls++
r.initRequest = request
return &proto.InitResponse{}, r.initErr
}

Expand Down Expand Up @@ -948,19 +953,45 @@ func activePluginClientCount(agentRunner *AgentRunner) int {

func TestInitRunner(t *testing.T) {
t.Run("skips init for v1", func(t *testing.T) {
err := initRunner("test-plugin", DefaultProtocolVersion, &initTestRunner{}, nil, nil)
err := initRunner("test-plugin", DefaultProtocolVersion, &initTestRunner{}, nil, nil, nil)
if err != nil {
t.Fatalf("initRunner() error = %v, expected nil", err)
}
})

t.Run("passes policy behavior to init request", func(t *testing.T) {
testRunner := &initTestRunner{}
policyBehavior := map[string]*proto.StringList{
"policy-bundle": {Values: []string{"vpc", "sg"}},
}

err := initRunner(
"test-plugin",
RunnerV2ProtocolVersion,
testRunner,
[]string{"/tmp/policies/vpc.rego"},
policyBehavior,
nil,
)
if err != nil {
t.Fatalf("initRunner() error = %v, expected nil", err)
}
if testRunner.initCalls != 1 {
t.Fatalf("Init called %d times, expected 1", testRunner.initCalls)
}
if got := testRunner.initRequest.GetPolicyBehavior()["policy-bundle"].GetValues(); !reflect.DeepEqual(got, []string{"vpc", "sg"}) {
t.Fatalf("Init policyBehavior policy-bundle = %#v, expected %#v", got, []string{"vpc", "sg"})
}
})

t.Run("wraps unimplemented init for configured v2 plugin", func(t *testing.T) {
err := initRunner(
"test-plugin",
RunnerV2ProtocolVersion,
&initTestRunner{initErr: status.Error(codes.Unimplemented, "not implemented")},
nil,
nil,
nil,
)
Comment on lines 954 to 995
if err == nil {
t.Fatal("initRunner() error = nil, expected wrapped error")
Expand All @@ -980,6 +1011,7 @@ func TestInitRunner(t *testing.T) {
&initTestRunner{initErr: expectedErr},
nil,
nil,
nil,
)
if !errors.Is(err, expectedErr) {
t.Fatalf("initRunner() error = %v, expected %v", err, expectedErr)
Expand All @@ -996,6 +1028,7 @@ func TestConfigureRunner(t *testing.T) {
testRunner,
agentPluginConfig{"endpoint": "localhost"},
map[string]interface{}{"allowed_versions": map[string]interface{}{"wget": "1.20.3"}},
map[string][]string{"policy-bundle": {"vpc", "sg"}},
)
if err != nil {
t.Fatalf("configureRunner() error = %v, expected nil", err)
Expand All @@ -1011,6 +1044,9 @@ func TestConfigureRunner(t *testing.T) {
if got := allowedVersions.Fields["wget"].GetStringValue(); got != "1.20.3" {
t.Fatalf("Configure policy_data allowed_versions.wget = %q, expected %q", got, "1.20.3")
}
if got := testRunner.configureRequest.GetPolicyBehavior()["policy-bundle"].GetValues(); !reflect.DeepEqual(got, []string{"vpc", "sg"}) {
t.Fatalf("Configure policyBehavior policy-bundle = %#v, expected %#v", got, []string{"vpc", "sg"})
}
})

t.Run("rejects unsupported policy data before configuring runner", func(t *testing.T) {
Expand All @@ -1021,6 +1057,7 @@ func TestConfigureRunner(t *testing.T) {
testRunner,
nil,
map[string]interface{}{"unsupported": make(chan int)},
nil,
)
if err == nil {
t.Fatal("configureRunner() error = nil, expected invalid policy_data error")
Expand All @@ -1034,6 +1071,45 @@ func TestConfigureRunner(t *testing.T) {
})
}

func TestPolicyBehaviorToProto(t *testing.T) {
t.Run("nil and empty maps return nil", func(t *testing.T) {
if got := policyBehaviorToProto(nil); got != nil {
t.Fatalf("policyBehaviorToProto(nil) = %#v, want nil", got)
}

if got := policyBehaviorToProto(map[string][]string{}); got != nil {
t.Fatalf("policyBehaviorToProto(empty) = %#v, want nil", got)
}
})

t.Run("clones value slices", func(t *testing.T) {
input := map[string][]string{
"bundle": {"vpc", "sg"},
}

got := policyBehaviorToProto(input)
input["bundle"][0] = "mutated"

want := []string{"vpc", "sg"}
if !reflect.DeepEqual(got["bundle"].Values, want) {
t.Fatalf("policyBehaviorToProto() values = %#v, want %#v", got["bundle"].Values, want)
}
})

t.Run("preserves nil slices as nil StringList values", func(t *testing.T) {
got := policyBehaviorToProto(map[string][]string{
"bundle": nil,
})

if _, ok := got["bundle"]; !ok {
t.Fatalf("policyBehaviorToProto() missing bundle key: %#v", got)
}
if got["bundle"] != nil {
t.Fatalf("policyBehaviorToProto() bundle = %#v, want nil", got["bundle"])
}
})
}

func TestAgentRunnerBuildsAuthenticatedSDKClient(t *testing.T) {
var (
tokenRequests int
Expand Down
111 changes: 111 additions & 0 deletions runner/proto/policy_behavior.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package proto

import (
"slices"
"strings"
)

func (r *EvalRequest) WithDefaultPolicyBehavior(defaults map[string][]string) *EvalRequest {
if r == nil {
return nil
}

return &EvalRequest{
PolicyPaths: slices.Clone(r.PolicyPaths),
ApiServer: r.ApiServer,
PolicyBehavior: mergePolicyBehavior(defaults, r.PolicyBehavior),
}
}

func (r *EvalRequest) WithUndefinedMappedTo(behavior []string) *EvalRequest {
if r == nil {
return nil
}

cloned := &EvalRequest{
PolicyPaths: slices.Clone(r.PolicyPaths),
ApiServer: r.ApiServer,
PolicyBehavior: mergePolicyBehavior(nil, r.PolicyBehavior),
}

for _, path := range cloned.PolicyPaths {
if pathCoveredByPolicyBehavior(path, cloned.PolicyBehavior) {
continue
}
if cloned.PolicyBehavior == nil {
cloned.PolicyBehavior = make(map[string]*StringList)
}
cloned.PolicyBehavior[path] = &StringList{Values: slices.Clone(behavior)}
}

return cloned
}

func (r *EvalRequest) PolicyPathsForBehavior(behavior string) []string {
if r == nil {
return nil
}

if len(r.PolicyBehavior) == 0 {
return []string{}
}

matchingKeys := make([]string, 0, len(r.PolicyBehavior))
for key, list := range r.PolicyBehavior {
if list == nil || !slices.Contains(list.Values, behavior) {
continue
}
matchingKeys = append(matchingKeys, key)
}

if len(matchingKeys) == 0 {
return []string{}
}

filtered := make([]string, 0, len(r.PolicyPaths))
outer:
for _, path := range r.PolicyPaths {
for _, key := range matchingKeys {
if key != "" && strings.Contains(path, key) {
filtered = append(filtered, path)
continue outer
}
}
Comment on lines +65 to +73
}

return filtered
}

func mergePolicyBehavior(defaults map[string][]string, configured map[string]*StringList) map[string]*StringList {
if len(defaults) == 0 && len(configured) == 0 {
return nil
}

merged := make(map[string]*StringList, len(defaults)+len(configured))
for key, values := range defaults {
merged[key] = &StringList{Values: slices.Clone(values)}
}

for key, list := range configured {
if list == nil {
merged[key] = nil
continue
}
merged[key] = &StringList{Values: slices.Clone(list.Values)}
}

return merged
}

func pathCoveredByPolicyBehavior(path string, behavior map[string]*StringList) bool {
for key := range behavior {
if key == "" {
continue
}
if strings.Contains(path, key) {
return true
}
}
Comment on lines +100 to +108

return false
}
Loading
Loading