From d2fc84a0f5bc98e7c3b8bbf508f9c25de7a94bdb Mon Sep 17 00:00:00 2001 From: Mark Smith Date: Fri, 19 Jun 2026 16:47:58 +0200 Subject: [PATCH 1/7] Fix JSON keys must be strings --- lib/aikido/zen/actor.rb | 20 +- lib/aikido/zen/attack.rb | 28 +- lib/aikido/zen/attack_wave.rb | 16 +- lib/aikido/zen/collector/event.rb | 52 +-- lib/aikido/zen/collector/routes.rb | 8 +- lib/aikido/zen/collector/sink_stats.rb | 20 +- lib/aikido/zen/collector/stats.rb | 34 +- lib/aikido/zen/event.rb | 24 +- lib/aikido/zen/outbound_connection.rb | 6 +- lib/aikido/zen/payload.rb | 6 +- lib/aikido/zen/request.rb | 12 +- lib/aikido/zen/route.rb | 6 +- lib/aikido/zen/system_info.rb | 26 +- test/aikido/zen/actor_test.rb | 10 +- test/aikido/zen/attack_test.rb | 90 ++--- test/aikido/zen/attack_wave_test.rb | 6 +- test/aikido/zen/collector/event_test.rb | 38 +-- test/aikido/zen/collector/routes_test.rb | 18 +- test/aikido/zen/collector/stats_test.rb | 156 ++++----- test/aikido/zen/collector_test.rb | 348 ++++++++++---------- test/aikido/zen/event_test.rb | 88 ++--- test/aikido/zen/outbound_connection_test.rb | 2 +- test/aikido/zen/payload_test.rb | 16 +- test/aikido/zen/request_test.rb | 16 +- test/aikido/zen/route_test.rb | 2 +- test/aikido/zen/system_info_test.rb | 30 +- 26 files changed, 539 insertions(+), 539 deletions(-) diff --git a/lib/aikido/zen/actor.rb b/lib/aikido/zen/actor.rb index e2584fdd..d60ab7af 100644 --- a/lib/aikido/zen/actor.rb +++ b/lib/aikido/zen/actor.rb @@ -40,11 +40,11 @@ def self.Actor(data) class Actor def self.from_json(data) new( - id: data[:id], - name: data[:name], - ip: data[:lastIpAddress], - first_seen_at: Time.at(data[:firstSeenAt] / 1000), - last_seen_at: Time.at(data[:lastSeenAt] / 1000) + id: data["id"], + name: data["name"], + ip: data["lastIpAddress"], + first_seen_at: Time.at(data["firstSeenAt"] / 1000), + last_seen_at: Time.at(data["lastSeenAt"] / 1000) ) end @@ -135,11 +135,11 @@ def hash def as_json { - id: id, - name: name, - lastIpAddress: ip, - firstSeenAt: first_seen_at.to_i * 1000, - lastSeenAt: last_seen_at.to_i * 1000 + "id" => id, + "name" => name, + "lastIpAddress" => ip, + "firstSeenAt" => first_seen_at.to_i * 1000, + "lastSeenAt" => last_seen_at.to_i * 1000 }.compact end end diff --git a/lib/aikido/zen/attack.rb b/lib/aikido/zen/attack.rb index 79f19ac6..a9774a86 100644 --- a/lib/aikido/zen/attack.rb +++ b/lib/aikido/zen/attack.rb @@ -44,11 +44,11 @@ def metadata def as_json { - kind: kind, - blocked: blocked?, - metadata: metadata, - operation: @operation, - stack: @stack + "kind" => kind, + "blocked" => blocked?, + "metadata" => metadata, + "operation" => @operation, + "stack" => @stack }.compact.merge(input.as_json) end @@ -70,7 +70,7 @@ def initialize(input:, filepath:, **opts) def metadata { - filename: filepath + "filename" => filepath } end @@ -107,7 +107,7 @@ def kind def metadata { - command: @command + "command" => @command } end @@ -139,9 +139,9 @@ def kind def metadata { - sql: @query, - dialect: @dialect.name, - failedToTokenize: @failed_to_tokenize || nil + "sql" => @query, + "dialect" => @dialect.name, + "failedToTokenize" => @failed_to_tokenize || nil }.compact end @@ -174,8 +174,8 @@ def exception(*) def metadata { - hostname: @request.uri.hostname, - port: @request.uri.port.to_s + "hostname" => @request.uri.hostname, + "port" => @request.uri.port.to_s } end end @@ -212,8 +212,8 @@ def input def metadata { - hostname: @hostname, - privateIP: @address + "hostname" => @hostname, + "privateIP" => @address } end end diff --git a/lib/aikido/zen/attack_wave.rb b/lib/aikido/zen/attack_wave.rb index 640397c4..73fdb1da 100644 --- a/lib/aikido/zen/attack_wave.rb +++ b/lib/aikido/zen/attack_wave.rb @@ -69,9 +69,9 @@ def initialize(ip_address:, user_agent:, source:) def as_json { - ipAddress: @ip_address, - userAgent: @user_agent, - source: @source + "ipAddress" => @ip_address, + "userAgent" => @user_agent, + "source" => @source }.compact end @@ -101,10 +101,10 @@ def initialize(samples:, user:) def as_json { - metadata: { - samples: @samples.as_json.to_json # The API only accepts string values in metadata + "metadata" => { + "samples" => @samples.as_json.to_json # The API only accepts string values in metadata }, - user: @user.as_json + "user" => @user.as_json }.compact end @@ -130,8 +130,8 @@ def initialize(verb:, path:) def as_json { - method: @verb.as_json, - url: @path.as_json + "method" => @verb.as_json, + "url" => @path.as_json }.compact end diff --git a/lib/aikido/zen/collector/event.rb b/lib/aikido/zen/collector/event.rb index 7134fdd7..5ae95c6b 100644 --- a/lib/aikido/zen/collector/event.rb +++ b/lib/aikido/zen/collector/event.rb @@ -12,7 +12,7 @@ def self.register(type) end def self.from_json(data) - type = data[:type] + type = data["type"] subclass = @@registry[type] subclass.from_json(data) end @@ -25,7 +25,7 @@ def initialize def as_json { - type: @type + "type" => @type } end @@ -72,7 +72,7 @@ class TrackUserAgent < Event register "track_user_agent" def self.from_json(data) - new(data[:user_agent_keys]) + new(data["user_agent_keys"]) end def initialize(user_agent_keys) @@ -82,7 +82,7 @@ def initialize(user_agent_keys) def as_json super.update({ - user_agent_keys: @user_agent_keys + "user_agent_keys" => @user_agent_keys }) end @@ -99,7 +99,7 @@ class TrackIPList < Event register "track_ip_list" def self.from_json(data) - new(data[:ip_list_keys]) + new(data["ip_list_keys"]) end def initialize(ip_list_keys) @@ -109,7 +109,7 @@ def initialize(ip_list_keys) def as_json super.update({ - ip_list_keys: @ip_list_keys + "ip_list_keys" => @ip_list_keys }) end @@ -127,7 +127,7 @@ class TrackAttackWave < Event def self.from_json(data) new( - being_blocked: data[:being_blocked] + being_blocked: data["being_blocked"] ) end @@ -138,7 +138,7 @@ def initialize(being_blocked:) def as_json super.update({ - being_blocked: @being_blocked + "being_blocked" => @being_blocked }) end @@ -156,9 +156,9 @@ class TrackScan < Event def self.from_json(data) new( - data[:sink_name], - data[:duration], - has_errors: data[:has_errors] + data["sink_name"], + data["duration"], + has_errors: data["has_errors"] ) end @@ -171,9 +171,9 @@ def initialize(sink_name, duration, has_errors:) def as_json super.update({ - sink_name: @sink_name, - duration: @duration, - has_errors: @has_errors + "sink_name" => @sink_name, + "duration" => @duration, + "has_errors" => @has_errors }) end @@ -191,8 +191,8 @@ class TrackAttack < Event def self.from_json(data) new( - data[:sink_name], - being_blocked: data[:being_blocked] + data["sink_name"], + being_blocked: data["being_blocked"] ) end @@ -204,8 +204,8 @@ def initialize(sink_name, being_blocked:) def as_json super.update({ - sink_name: @sink_name, - being_blocked: @being_blocked + "sink_name" => @sink_name, + "being_blocked" => @being_blocked }) end @@ -222,7 +222,7 @@ class TrackUser < Event register "track_user" def self.from_json(data) - new(Aikido::Zen::Actor.from_json(data[:actor])) + new(Aikido::Zen::Actor.from_json(data["actor"])) end def initialize(actor) @@ -232,7 +232,7 @@ def initialize(actor) def as_json super.update({ - actor: @actor.as_json + "actor" => @actor.as_json }) end @@ -249,7 +249,7 @@ class TrackOutbound < Event register "track_outbound" def self.from_json(data) - new(OutboundConnection.from_json(data[:connection])) + new(OutboundConnection.from_json(data["connection"])) end def initialize(connection) @@ -259,7 +259,7 @@ def initialize(connection) def as_json super.update({ - connection: @connection.as_json + "connection" => @connection.as_json }) end @@ -277,8 +277,8 @@ class TrackRoute < Event def self.from_json(data) new( - Route.from_json(data[:route]), - Request::Schema.from_json(data[:schema]) + Route.from_json(data["route"]), + Request::Schema.from_json(data["schema"]) ) end @@ -290,8 +290,8 @@ def initialize(route, schema) def as_json super.update({ - route: @route.as_json, - schema: @schema.as_json + "route" => @route.as_json, + "schema" => @schema.as_json }) end diff --git a/lib/aikido/zen/collector/routes.rb b/lib/aikido/zen/collector/routes.rb index 74b64efa..018a0770 100644 --- a/lib/aikido/zen/collector/routes.rb +++ b/lib/aikido/zen/collector/routes.rb @@ -26,10 +26,10 @@ def add(route, schema) def as_json @visits.map do |route, record| { - method: route.verb, - path: route.path, - hits: record.hits, - apispec: record.schema.as_json + "method" => route.verb, + "path" => route.path, + "hits" => record.hits, + "apispec" => record.schema.as_json }.compact end end diff --git a/lib/aikido/zen/collector/sink_stats.rb b/lib/aikido/zen/collector/sink_stats.rb index 24caadbf..fff8c80b 100644 --- a/lib/aikido/zen/collector/sink_stats.rb +++ b/lib/aikido/zen/collector/sink_stats.rb @@ -62,14 +62,14 @@ def compress_timings(at: Time.now.utc) def as_json { - total: @scans, - interceptorThrewError: @errors, - withoutContext: 0, - attacksDetected: { - total: @attacks, - blocked: @blocked_attacks + "total" => @scans, + "interceptorThrewError" => @errors, + "withoutContext" => 0, + "attacksDetected" => { + "total" => @attacks, + "blocked" => @blocked_attacks }, - compressedTimings: @compressed_timings.as_json + "compressedTimings" => @compressed_timings.as_json } end @@ -85,9 +85,9 @@ def as_json CompressedTiming = Struct.new(:mean, :percentiles, :compressed_at) do def as_json { - averageInMs: mean * 1000, - percentiles: percentiles.transform_values { |t| t * 1000 }, - compressedAt: compressed_at.to_i * 1000 + "averageInMs" => mean * 1000, + "percentiles" => percentiles.transform_values { |t| t * 1000 }, + "compressedAt" => compressed_at.to_i * 1000 } end end diff --git a/lib/aikido/zen/collector/stats.rb b/lib/aikido/zen/collector/stats.rb index 789809b2..95a7ffce 100644 --- a/lib/aikido/zen/collector/stats.rb +++ b/lib/aikido/zen/collector/stats.rb @@ -116,27 +116,27 @@ def add_attack(sink_name, being_blocked:) def as_json total_attacks, total_blocked = aggregate_attacks_from_sinks { - startedAt: @started_at.to_i * 1000, - endedAt: (@ended_at.to_i * 1000 if @ended_at), - operations: @sinks.transform_values(&:as_json), - requests: { - total: @requests, - aborted: @aborted_requests, - rateLimited: @rate_limited_requests, - attacksDetected: { - total: total_attacks, - blocked: total_blocked + "startedAt" => @started_at.to_i * 1000, + "endedAt" => (@ended_at.to_i * 1000 if @ended_at), + "operations" => @sinks.transform_values(&:as_json), + "requests" => { + "total" => @requests, + "aborted" => @aborted_requests, + "rateLimited" => @rate_limited_requests, + "attacksDetected" => { + "total" => total_attacks, + "blocked" => total_blocked }, - attackWaves: { - total: @attack_waves, - blocked: @blocked_attack_waves + "attackWaves" => { + "total" => @attack_waves, + "blocked" => @blocked_attack_waves } }, - userAgents: { - breakdown: @user_agents + "userAgents" => { + "breakdown" => @user_agents }, - ipAddresses: { - breakdown: @ip_lists + "ipAddresses" => { + "breakdown" => @ip_lists } } end diff --git a/lib/aikido/zen/event.rb b/lib/aikido/zen/event.rb index 5e1bdd42..da6d4a44 100644 --- a/lib/aikido/zen/event.rb +++ b/lib/aikido/zen/event.rb @@ -16,9 +16,9 @@ def initialize(type:, system_info: Aikido::Zen.system_info, time: Time.now.utc) def as_json { - type: type, - time: time.to_i * 1000, - agent: system_info.as_json + "type" => type, + "time" => time.to_i * 1000, + "agent" => system_info.as_json } end end @@ -42,8 +42,8 @@ def initialize(attack:, **opts) def as_json super.update( { - attack: @attack.as_json, - request: @attack.context&.request&.as_json + "attack" => @attack.as_json, + "request" => @attack.context&.request&.as_json }.compact ) end @@ -61,11 +61,11 @@ def initialize(stats:, users:, hosts:, routes:, middleware_installed:, **opts) def as_json super.update( - stats: @stats.as_json, - users: @users.as_json, - routes: @routes.as_json, - hostnames: @hosts.as_json, - middlewareInstalled: @middleware_installed + "stats" => @stats.as_json, + "users" => @users.as_json, + "routes" => @routes.as_json, + "hostnames" => @hosts.as_json, + "middlewareInstalled" => @middleware_installed ) end end @@ -90,8 +90,8 @@ def initialize(request:, attack:, **opts) def as_json super.update( - request: @request.as_json, - attack: @attack.as_json + "request" => @request.as_json, + "attack" => @attack.as_json ) end end diff --git a/lib/aikido/zen/outbound_connection.rb b/lib/aikido/zen/outbound_connection.rb index cbad39aa..821f87ca 100644 --- a/lib/aikido/zen/outbound_connection.rb +++ b/lib/aikido/zen/outbound_connection.rb @@ -5,8 +5,8 @@ module Aikido::Zen class OutboundConnection def self.from_json(data) new( - host: data[:hostname], - port: data[:port] + host: data["hostname"], + port: data["port"] ) end @@ -41,7 +41,7 @@ def hit end def as_json - {hostname: host, port: port, hits: hits}.compact + {"hostname" => host, "port" => port, "hits" => hits}.compact end def ==(other) diff --git a/lib/aikido/zen/payload.rb b/lib/aikido/zen/payload.rb index 632f1055..797c8fcd 100644 --- a/lib/aikido/zen/payload.rb +++ b/lib/aikido/zen/payload.rb @@ -25,9 +25,9 @@ def ==(other) def as_json { - payload: value.to_s, - source: SOURCE_SERIALIZATIONS[source], - path: ".#{path}" + "payload" => value.to_s, + "source" => SOURCE_SERIALIZATIONS[source], + "path" => ".#{path}" } end diff --git a/lib/aikido/zen/request.rb b/lib/aikido/zen/request.rb index 1ecb94b8..75ff5add 100644 --- a/lib/aikido/zen/request.rb +++ b/lib/aikido/zen/request.rb @@ -82,12 +82,12 @@ def normalized_headers def as_json { - method: request_method.upcase, - url: url, - ipAddress: client_ip, - userAgent: user_agent, - source: framework, - route: route&.path + "method" => request_method.upcase, + "url" => url, + "ipAddress" => client_ip, + "userAgent" => user_agent, + "source" => framework, + "route" => route&.path } end diff --git a/lib/aikido/zen/route.rb b/lib/aikido/zen/route.rb index f55f0733..ab92bbde 100644 --- a/lib/aikido/zen/route.rb +++ b/lib/aikido/zen/route.rb @@ -7,8 +7,8 @@ module Aikido::Zen class Route def self.from_json(data) new( - verb: data[:method], - path: data[:path] + verb: data["method"], + path: data["path"] ) end @@ -25,7 +25,7 @@ def initialize(verb:, path:) end def as_json - {method: verb, path: path} + {"method" => verb, "path" => path} end def ==(other) diff --git a/lib/aikido/zen/system_info.rb b/lib/aikido/zen/system_info.rb index f57fbd3d..6543ae63 100644 --- a/lib/aikido/zen/system_info.rb +++ b/lib/aikido/zen/system_info.rb @@ -61,19 +61,19 @@ def os_version def as_json { - dryMode: attacks_are_only_reported?, - library: library_name, - version: library_version, - hostname: hostname, - ipAddress: ip_address, - platform: {version: platform_version}, - os: {name: os_name, version: os_version}, - packages: packages.reduce({}) { |all, package| all.update(package.as_json) }, - incompatiblePackages: {}, - stack: [], - serverless: false, - nodeEnv: "", - preventedPrototypePollution: false + "dryMode" => attacks_are_only_reported?, + "library" => library_name, + "version" => library_version, + "hostname" => hostname, + "ipAddress" => ip_address, + "platform" => {"version" => platform_version}, + "os" => {"name" => os_name, "version" => os_version}, + "packages" => packages.reduce({}) { |all, package| all.update(package.as_json) }, + "incompatiblePackages" => {}, + "stack" => [], + "serverless" => false, + "nodeEnv" => "", + "preventedPrototypePollution" => false } end end diff --git a/test/aikido/zen/actor_test.rb b/test/aikido/zen/actor_test.rb index b47a9255..ce4a796d 100644 --- a/test/aikido/zen/actor_test.rb +++ b/test/aikido/zen/actor_test.rb @@ -160,11 +160,11 @@ def to_aikido_actor actor.update(seen_at: Time.at(1234577890)) expected = { - id: "123", - name: "Jane Doe", - lastIpAddress: "1.2.3.4", - firstSeenAt: 1234567890000, - lastSeenAt: 1234577890000 + "id" => "123", + "name" => "Jane Doe", + "lastIpAddress" => "1.2.3.4", + "firstSeenAt" => 1234567890000, + "lastSeenAt" => 1234577890000 } assert_equal expected, actor.as_json diff --git a/test/aikido/zen/attack_test.rb b/test/aikido/zen/attack_test.rb index 3c658bd4..d0618074 100644 --- a/test/aikido/zen/attack_test.rb +++ b/test/aikido/zen/attack_test.rb @@ -54,16 +54,16 @@ class SQLInjectionTest < ActiveSupport::TestCase ) expected = { - kind: "sql_injection", - operation: @op, - blocked: false, - payload: @input.value, - metadata: { - sql: @query, - dialect: @dialect.name + "kind" => "sql_injection", + "operation" => @op, + "blocked" => false, + "payload" => @input.value, + "metadata" => { + "sql" => @query, + "dialect" => @dialect.name }, - source: "routeParams", - path: ".id" + "source" => "routeParams", + "path" => ".id" } assert_equal expected, attack.as_json @@ -77,16 +77,16 @@ class SQLInjectionTest < ActiveSupport::TestCase attack.will_be_blocked! expected = { - kind: "sql_injection", - operation: @op, - blocked: true, - payload: @input.value, - metadata: { - sql: @query, - dialect: @dialect.name + "kind" => "sql_injection", + "operation" => @op, + "blocked" => true, + "payload" => @input.value, + "metadata" => { + "sql" => @query, + "dialect" => @dialect.name }, - source: "routeParams", - path: ".id" + "source" => "routeParams", + "path" => ".id" } assert_equal expected, attack.as_json @@ -145,9 +145,9 @@ class SSRFAttackTest < ActiveSupport::TestCase metadata = attack.metadata - assert_equal "localhost", metadata[:hostname] - assert_equal "7000", metadata[:port] - assert_kind_of String, metadata[:port], "Port should be a string, not an integer" + assert_equal "localhost", metadata["hostname"] + assert_equal "7000", metadata["port"] + assert_kind_of String, metadata["port"], "Port should be a string, not an integer" end test "#as_json includes the expected fields with port as string" do @@ -156,16 +156,16 @@ class SSRFAttackTest < ActiveSupport::TestCase ) expected = { - kind: "ssrf", - operation: @op, - blocked: false, - payload: @input.value, - metadata: { - hostname: "localhost", - port: "7000" + "kind" => "ssrf", + "operation" => @op, + "blocked" => false, + "payload" => @input.value, + "metadata" => { + "hostname" => "localhost", + "port" => "7000" }, - source: "body", - path: ".url" + "source" => "body", + "path" => ".url" } assert_equal expected, attack.as_json @@ -179,16 +179,16 @@ class SSRFAttackTest < ActiveSupport::TestCase attack.will_be_blocked! expected = { - kind: "ssrf", - operation: @op, - blocked: true, - payload: @input.value, - metadata: { - hostname: "localhost", - port: "7000" + "kind" => "ssrf", + "operation" => @op, + "blocked" => true, + "payload" => @input.value, + "metadata" => { + "hostname" => "localhost", + "port" => "7000" }, - source: "body", - path: ".url" + "source" => "body", + "path" => ".url" } assert_equal expected, attack.as_json @@ -207,9 +207,9 @@ class SSRFAttackTest < ActiveSupport::TestCase metadata = attack.metadata - assert_equal "example.com", metadata[:hostname] - assert_equal "80", metadata[:port] - assert_kind_of String, metadata[:port] + assert_equal "example.com", metadata["hostname"] + assert_equal "80", metadata["port"] + assert_kind_of String, metadata["port"] end test "#metadata handles default HTTPS port 443" do @@ -225,9 +225,9 @@ class SSRFAttackTest < ActiveSupport::TestCase metadata = attack.metadata - assert_equal "example.com", metadata[:hostname] - assert_equal "443", metadata[:port] - assert_kind_of String, metadata[:port] + assert_equal "example.com", metadata["hostname"] + assert_equal "443", metadata["port"] + assert_kind_of String, metadata["port"] end end end diff --git a/test/aikido/zen/attack_wave_test.rb b/test/aikido/zen/attack_wave_test.rb index 8c194c44..c4543122 100644 --- a/test/aikido/zen/attack_wave_test.rb +++ b/test/aikido/zen/attack_wave_test.rb @@ -215,9 +215,9 @@ def setup assert 3, samples.size expected = [ - {method: "GET", url: "/.config"}, - {method: "GET", url: "/.git/config"}, - {method: "GET", url: "/.ssh/known_hosts"} + {"method" => "GET", "url" => "/.config"}, + {"method" => "GET", "url" => "/.git/config"}, + {"method" => "GET", "url" => "/.ssh/known_hosts"} ] assert_equal expected, samples.as_json diff --git a/test/aikido/zen/collector/event_test.rb b/test/aikido/zen/collector/event_test.rb index 5fec2672..6e50555d 100644 --- a/test/aikido/zen/collector/event_test.rb +++ b/test/aikido/zen/collector/event_test.rb @@ -130,7 +130,7 @@ def stub_event_from_json(data) event = stub_track_request_event assert_hash_subset_of event.as_json, { - type: "track_request" + "type" => "track_request" } end @@ -170,7 +170,7 @@ def stub_event_from_json(data) event = stub_track_rate_limited_request_event assert_hash_subset_of event.as_json, { - type: "track_rate_limited_request" + "type" => "track_rate_limited_request" } end @@ -210,7 +210,7 @@ def stub_event_from_json(data) event = stub_track_user_agent_event assert_hash_subset_of event.as_json, { - type: "track_user_agent" + "type" => "track_user_agent" } end @@ -250,7 +250,7 @@ def stub_event_from_json(data) event = stub_track_ip_list_event assert_hash_subset_of event.as_json, { - type: "track_ip_list" + "type" => "track_ip_list" } end @@ -290,7 +290,7 @@ def stub_event_from_json(data) event = stub_track_attack_wave_event assert_hash_subset_of event.as_json, { - type: "track_attack_wave" + "type" => "track_attack_wave" } end @@ -330,10 +330,10 @@ def stub_event_from_json(data) event = stub_track_scan_event assert_hash_subset_of event.as_json, { - type: "track_scan", - sink_name: "sink_name", - duration: 1.0, - has_errors: false + "type" => "track_scan", + "sink_name" => "sink_name", + "duration" => 1.0, + "has_errors" => false } end @@ -373,9 +373,9 @@ def stub_event_from_json(data) event = stub_track_attack_event assert_hash_subset_of event.as_json, { - type: "track_attack", - sink_name: "sink_name", - being_blocked: false + "type" => "track_attack", + "sink_name" => "sink_name", + "being_blocked" => false } end @@ -415,8 +415,8 @@ def stub_event_from_json(data) event_hash = event.as_json assert_hash_subset_of event_hash, { - type: "track_user", - actor: stub_actor.as_json + "type" => "track_user", + "actor" => stub_actor.as_json } end @@ -459,8 +459,8 @@ def stub_event_from_json(data) event_hash = event.as_json assert_hash_subset_of event_hash, { - type: "track_outbound", - connection: stub_outbound_connection.as_json + "type" => "track_outbound", + "connection" => stub_outbound_connection.as_json } end @@ -503,9 +503,9 @@ def stub_event_from_json(data) event_hash = event.as_json assert_hash_subset_of event_hash, { - type: "track_route", - route: stub_route.as_json, - schema: stub_schema.as_json + "type" => "track_route", + "route" => stub_route.as_json, + "schema" => stub_schema.as_json } end diff --git a/test/aikido/zen/collector/routes_test.rb b/test/aikido/zen/collector/routes_test.rb index 21e4e56c..2081d254 100644 --- a/test/aikido/zen/collector/routes_test.rb +++ b/test/aikido/zen/collector/routes_test.rb @@ -135,10 +135,10 @@ class Aikido::Zen::Collector::RoutesTest < ActiveSupport::TestCase assert_equal @routes.as_json, [ { - method: "GET", - path: "/", - hits: 2, - apispec: { + "method" => "GET", + "path" => "/", + "hits" => 2, + "apispec" => { "query" => { "type" => "object", "properties" => {"mode" => {"type" => "string"}} @@ -146,10 +146,10 @@ class Aikido::Zen::Collector::RoutesTest < ActiveSupport::TestCase } }, { - method: "POST", - path: "/users", - hits: 1, - apispec: { + "method" => "POST", + "path" => "/users", + "hits" => 1, + "apispec" => { "body" => { "type" => :json, "schema" => { @@ -205,7 +205,7 @@ class Aikido::Zen::Collector::RoutesTest < ActiveSupport::TestCase request = build_request(build_route("GET", "/")) @routes.add(request.route, request.schema) - assert_equal [{method: "GET", path: "/", hits: 1}], @routes.as_json + assert_equal [{"method" => "GET", "path" => "/", "hits" => 1}], @routes.as_json end def build_route(verb, path) diff --git a/test/aikido/zen/collector/stats_test.rb b/test/aikido/zen/collector/stats_test.rb index 3e54745a..d4bc89bf 100644 --- a/test/aikido/zen/collector/stats_test.rb +++ b/test/aikido/zen/collector/stats_test.rb @@ -205,8 +205,8 @@ def stub_outbound(**opts) @stats.ended_at = Time.at(1234577890) assert_hash_subset_of @stats.as_json, { - startedAt: 1234567890000, - endedAt: 1234577890000 + "startedAt" => 1234567890000, + "endedAt" => 1234577890000 } end @@ -214,17 +214,17 @@ def stub_outbound(**opts) 3.times { @stats.add_request } assert_hash_subset_of @stats.as_json, { - requests: { - total: 3, - aborted: 0, - rateLimited: 0, - attacksDetected: { - total: 0, - blocked: 0 + "requests" => { + "total" => 3, + "aborted" => 0, + "rateLimited" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - attackWaves: { - total: 0, - blocked: 0 + "attackWaves" => { + "total" => 0, + "blocked" => 0 } } } @@ -241,26 +241,26 @@ def stub_outbound(**opts) @stats.add_scan(scan.sink.name, scan.duration, has_errors: scan.errors?) assert_hash_subset_of @stats.as_json, { - operations: { + "operations" => { "test" => { - total: 2, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 0, - blocked: 0 + "total" => 2, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - compressedTimings: [] + "compressedTimings" => [] }, "another" => { - total: 1, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 0, - blocked: 0 + "total" => 1, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - compressedTimings: [] + "compressedTimings" => [] } } } @@ -277,26 +277,26 @@ def stub_outbound(**opts) @stats.add_scan(scan.sink.name, scan.duration, has_errors: scan.errors?) assert_hash_subset_of @stats.as_json, { - operations: { + "operations" => { "test" => { - total: 2, - interceptorThrewError: 1, - withoutContext: 0, - attacksDetected: { - total: 0, - blocked: 0 + "total" => 2, + "interceptorThrewError" => 1, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - compressedTimings: [] + "compressedTimings" => [] }, "another" => { - total: 1, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 0, - blocked: 0 + "total" => 1, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - compressedTimings: [] + "compressedTimings" => [] } } } @@ -316,26 +316,26 @@ def stub_outbound(**opts) @stats.add_attack("another", being_blocked: true) assert_hash_subset_of @stats.as_json, { - operations: { + "operations" => { "test" => { - total: 2, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 1, - blocked: 1 + "total" => 2, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 1, + "blocked" => 1 }, - compressedTimings: [] + "compressedTimings" => [] }, "another" => { - total: 1, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 1, - blocked: 1 + "total" => 1, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 1, + "blocked" => 1 }, - compressedTimings: [] + "compressedTimings" => [] } } } @@ -357,45 +357,45 @@ def stub_outbound(**opts) @stats.sinks.each_value { |s| s.compress_timings(at: Time.at(1234577890)) } assert_hash_subset_of @stats.as_json, { - operations: { + "operations" => { "test" => { - total: 3, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 0, - blocked: 0 + "total" => 3, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - compressedTimings: [{ - averageInMs: 2000, - percentiles: { + "compressedTimings" => [{ + "averageInMs" => 2000, + "percentiles" => { 50 => 2000, 75 => 3000, 90 => 3000, 95 => 3000, 99 => 3000 }, - compressedAt: 1234577890000 + "compressedAt" => 1234577890000 }] }, "another" => { - total: 1, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 0, - blocked: 0 + "total" => 1, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - compressedTimings: [{ - averageInMs: 1000, - percentiles: { + "compressedTimings" => [{ + "averageInMs" => 1000, + "percentiles" => { 50 => 1000, 75 => 1000, 90 => 1000, 95 => 1000, 99 => 1000 }, - compressedAt: 1234577890000 + "compressedAt" => 1234577890000 }] } } diff --git a/test/aikido/zen/collector_test.rb b/test/aikido/zen/collector_test.rb index 9ddf6746..4d62d8d4 100644 --- a/test/aikido/zen/collector_test.rb +++ b/test/aikido/zen/collector_test.rb @@ -203,34 +203,34 @@ class Aikido::Zen::CollectorTest < ActiveSupport::TestCase event = @collector.flush(at: Time.at(1234577890)) assert_hash_subset_of event.as_json, { - stats: { - startedAt: 1234567890000, - endedAt: 1234577890000, - operations: {}, - requests: { - total: 0, - aborted: 0, - rateLimited: 0, - attacksDetected: { - total: 0, - blocked: 0 + "stats" => { + "startedAt" => 1234567890000, + "endedAt" => 1234577890000, + "operations" => {}, + "requests" => { + "total" => 0, + "aborted" => 0, + "rateLimited" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - attackWaves: { - total: 0, - blocked: 0 + "attackWaves" => { + "total" => 0, + "blocked" => 0 } }, - userAgents: { - breakdown: {} + "userAgents" => { + "breakdown" => {} }, - ipAddresses: { - breakdown: {} + "ipAddresses" => { + "breakdown" => {} } }, - users: [], - routes: [], - hostnames: [], - middlewareInstalled: false + "users" => [], + "routes" => [], + "hostnames" => [], + "middlewareInstalled" => false } end @@ -249,36 +249,36 @@ class Aikido::Zen::CollectorTest < ActiveSupport::TestCase event = @collector.flush(at: Time.at(1234577890)) assert_hash_subset_of event.as_json, { - stats: { - startedAt: 1234567890000, - endedAt: 1234577890000, - operations: {}, - requests: { - total: 3, - aborted: 0, - rateLimited: 0, - attacksDetected: { - total: 0, - blocked: 0 + "stats" => { + "startedAt" => 1234567890000, + "endedAt" => 1234577890000, + "operations" => {}, + "requests" => { + "total" => 3, + "aborted" => 0, + "rateLimited" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - attackWaves: { - total: 5, - blocked: 2 + "attackWaves" => { + "total" => 5, + "blocked" => 2 } }, - userAgents: { - breakdown: {} + "userAgents" => { + "breakdown" => {} }, - ipAddresses: { - breakdown: {} + "ipAddresses" => { + "breakdown" => {} } }, - users: [], - routes: [ - {path: "/", method: "GET", hits: 3, apispec: {}} + "users" => [], + "routes" => [ + {"path" => "/", "method" => "GET", "hits" => 3, "apispec" => {}} ], - hostnames: [], - middlewareInstalled: false + "hostnames" => [], + "middlewareInstalled" => false } end @@ -297,69 +297,69 @@ class Aikido::Zen::CollectorTest < ActiveSupport::TestCase event = @collector.flush(at: Time.at(1234577890)) assert_hash_subset_of event.as_json, { - stats: { - startedAt: 1234567890000, - endedAt: 1234577890000, - requests: { - total: 2, - aborted: 0, - rateLimited: 0, - attacksDetected: { - total: 0, - blocked: 0 + "stats" => { + "startedAt" => 1234567890000, + "endedAt" => 1234577890000, + "requests" => { + "total" => 2, + "aborted" => 0, + "rateLimited" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - attackWaves: { - total: 0, - blocked: 0 + "attackWaves" => { + "total" => 0, + "blocked" => 0 } }, - userAgents: { - breakdown: {} + "userAgents" => { + "breakdown" => {} }, - ipAddresses: { - breakdown: {} + "ipAddresses" => { + "breakdown" => {} }, - operations: { + "operations" => { "test" => { - total: 2, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 0, - blocked: 0 + "total" => 2, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - compressedTimings: [ + "compressedTimings" => [ { - averageInMs: 1000, - percentiles: {50 => 1000, 75 => 1000, 90 => 1000, 95 => 1000, 99 => 1000}, - compressedAt: 1234577890000 + "averageInMs" => 1000, + "percentiles" => {50 => 1000, 75 => 1000, 90 => 1000, 95 => 1000, 99 => 1000}, + "compressedAt" => 1234577890000 } ] }, "another" => { - total: 1, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 0, - blocked: 0 + "total" => 1, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 0, + "blocked" => 0 }, - compressedTimings: [ + "compressedTimings" => [ { - averageInMs: 1000, - percentiles: {50 => 1000, 75 => 1000, 90 => 1000, 95 => 1000, 99 => 1000}, - compressedAt: 1234577890000 + "averageInMs" => 1000, + "percentiles" => {50 => 1000, 75 => 1000, 90 => 1000, 95 => 1000, 99 => 1000}, + "compressedAt" => 1234577890000 } ] } } }, - users: [], - routes: [ - {path: "/", method: "GET", hits: 2, apispec: {}} + "users" => [], + "routes" => [ + {"path" => "/", "method" => "GET", "hits" => 2, "apispec" => {}} ], - middlewareInstalled: false, - hostnames: [] + "middlewareInstalled" => false, + "hostnames" => [] } end @@ -382,69 +382,69 @@ class Aikido::Zen::CollectorTest < ActiveSupport::TestCase event = @collector.flush(at: Time.at(1234577890)) assert_hash_subset_of event.as_json, { - stats: { - startedAt: 1234567890000, - endedAt: 1234577890000, - operations: { + "stats" => { + "startedAt" => 1234567890000, + "endedAt" => 1234577890000, + "operations" => { "test" => { - total: 2, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 1, - blocked: 1 + "total" => 2, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 1, + "blocked" => 1 }, - compressedTimings: [ + "compressedTimings" => [ { - averageInMs: 1000, - percentiles: {50 => 1000, 75 => 1000, 90 => 1000, 95 => 1000, 99 => 1000}, - compressedAt: 1234577890000 + "averageInMs" => 1000, + "percentiles" => {50 => 1000, 75 => 1000, 90 => 1000, 95 => 1000, 99 => 1000}, + "compressedAt" => 1234577890000 } ] }, "another" => { - total: 1, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 1, - blocked: 1 + "total" => 1, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 1, + "blocked" => 1 }, - compressedTimings: [ + "compressedTimings" => [ { - averageInMs: 1000, - percentiles: {50 => 1000, 75 => 1000, 90 => 1000, 95 => 1000, 99 => 1000}, - compressedAt: 1234577890000 + "averageInMs" => 1000, + "percentiles" => {50 => 1000, 75 => 1000, 90 => 1000, 95 => 1000, 99 => 1000}, + "compressedAt" => 1234577890000 } ] } }, - requests: { - total: 2, - aborted: 0, - rateLimited: 0, - attacksDetected: { - total: 2, - blocked: 2 + "requests" => { + "total" => 2, + "aborted" => 0, + "rateLimited" => 0, + "attacksDetected" => { + "total" => 2, + "blocked" => 2 }, - attackWaves: { - total: 0, - blocked: 0 + "attackWaves" => { + "total" => 0, + "blocked" => 0 } }, - userAgents: { - breakdown: {} + "userAgents" => { + "breakdown" => {} }, - ipAddresses: { - breakdown: {} + "ipAddresses" => { + "breakdown" => {} } }, - users: [], - routes: [ - {path: "/", method: "GET", hits: 2, apispec: {}} + "users" => [], + "routes" => [ + {"path" => "/", "method" => "GET", "hits" => 2, "apispec" => {}} ], - middlewareInstalled: false, - hostnames: [] + "middlewareInstalled" => false, + "hostnames" => [] } end @@ -488,73 +488,73 @@ class Aikido::Zen::CollectorTest < ActiveSupport::TestCase event = @collector.flush(at: Time.at(1234577890)) assert_hash_subset_of event.as_json, { - stats: { - startedAt: 1234567890000, - endedAt: 1234577890000, - operations: { + "stats" => { + "startedAt" => 1234567890000, + "endedAt" => 1234577890000, + "operations" => { "test" => { - total: 3, - interceptorThrewError: 0, - withoutContext: 0, - attacksDetected: { - total: 1, - blocked: 1 + "total" => 3, + "interceptorThrewError" => 0, + "withoutContext" => 0, + "attacksDetected" => { + "total" => 1, + "blocked" => 1 }, - compressedTimings: [{ - averageInMs: 2000, - percentiles: { + "compressedTimings" => [{ + "averageInMs" => 2000, + "percentiles" => { 50 => 2000, 75 => 3000, 90 => 3000, 95 => 3000, 99 => 3000 }, - compressedAt: 1234577890000 + "compressedAt" => 1234577890000 }] } }, - requests: { - total: 2, - aborted: 0, - rateLimited: 5, - attacksDetected: { - total: 1, - blocked: 1 + "requests" => { + "total" => 2, + "aborted" => 0, + "rateLimited" => 5, + "attacksDetected" => { + "total" => 1, + "blocked" => 1 }, - attackWaves: { - total: 3, - blocked: 1 + "attackWaves" => { + "total" => 3, + "blocked" => 1 } }, - userAgents: { - breakdown: { + "userAgents" => { + "breakdown" => { "google_adwords" => 2 } }, - ipAddresses: { - breakdown: {} + "ipAddresses" => { + "breakdown" => {} } }, - middlewareInstalled: false, - routes: [{method: "GET", path: "/", hits: 2, apispec: {}}], - users: [ + "middlewareInstalled" => false, + "routes" => [{"method" => "GET", "path" => "/", "hits" => 2, "apispec" => {}}], + "users" => [ { - id: "123", - lastIpAddress: "1.2.3.4", - firstSeenAt: 12345567890000, - lastSeenAt: 12345567890000 + "id" => "123", + "lastIpAddress" => "1.2.3.4", + "firstSeenAt" => 12345567890000, + "lastSeenAt" => 12345567890000 }, { - id: "234", - lastIpAddress: "5.6.7.8", - firstSeenAt: 12334567890000, - lastSeenAt: 12334567890000 + "id" => "234", + "lastIpAddress" => "5.6.7.8", + "firstSeenAt" => 12334567890000, + "lastSeenAt" => 12334567890000 } ], - hostnames: [ - {hostname: "example.com", port: 2000, hits: 1}, - {hostname: "example.com", port: 2001, hits: 1}, - {hostname: "example.com", port: 2002, hits: 1} + "hostnames" => [ + {"hostname" => "example.com", "port" => 2000, "hits" => 1}, + {"hostname" => "example.com", "port" => 2001, "hits" => 1}, + {"hostname" => "example.com", "port" => 2002, "hits" => 1} ] } end diff --git a/test/aikido/zen/event_test.rb b/test/aikido/zen/event_test.rb index 630f3c4e..d6912204 100644 --- a/test/aikido/zen/event_test.rb +++ b/test/aikido/zen/event_test.rb @@ -27,20 +27,20 @@ class Aikido::Zen::EventTest < ActiveSupport::TestCase test "#as_json includes the type" do event = Aikido::Zen::Event.new(type: "test") - assert_equal "test", event.as_json[:type] + assert_equal "test", event.as_json["type"] end test "#as_json serializes the time in milliseconds" do event = Aikido::Zen::Event.new(type: "test", time: Time.at(123)) - assert_equal 123000, event.as_json[:time] + assert_equal 123000, event.as_json["time"] end test "#as_json serializes the system info" do event = Aikido::Zen::Event.new(type: "test") system_info = Aikido::Zen::SystemInfo.new - refute_nil system_info.as_json, event.as_json[:agent] + refute_nil system_info.as_json, event.as_json["agent"] end class StartedTest < ActiveSupport::TestCase @@ -69,7 +69,7 @@ class AttackTest < ActiveSupport::TestCase attack = TestAttack.new(context: stub_context) event = Aikido::Zen::Events::Attack.new(attack: attack) - assert_equal({some: "info"}, event.as_json[:attack]) + assert_equal({"some" => "info"}, event.as_json["attack"]) end test "includes the request's JSON representation" do @@ -78,14 +78,14 @@ class AttackTest < ActiveSupport::TestCase attack = TestAttack.new(context: context) event = Aikido::Zen::Events::Attack.new(attack: attack) - assert_equal context.request.as_json, event.as_json[:request] + assert_equal context.request.as_json, event.as_json["request"] end test "request key is absent when context is nil" do attack = TestAttack.new(context: nil) event = Aikido::Zen::Events::Attack.new(attack: attack) - refute event.as_json.key?(:request) + refute event.as_json.key?("request") end def stub_context(**options) @@ -103,7 +103,7 @@ def humanized_name end def as_json - {some: "info"} + {"some" => "info"} end def exception(*) @@ -137,28 +137,28 @@ class HeartbeatTest < ActiveSupport::TestCase ) assert_hash_subset_of event.as_json, { - stats: { - startedAt: 1234567890000, - endedAt: 1234577890000, - operations: {}, - requests: { - total: 0, - aborted: 0, - rateLimited: 0, - attacksDetected: {total: 0, blocked: 0}, - attackWaves: {total: 0, blocked: 0} + "stats" => { + "startedAt" => 1234567890000, + "endedAt" => 1234577890000, + "operations" => {}, + "requests" => { + "total" => 0, + "aborted" => 0, + "rateLimited" => 0, + "attacksDetected" => {"total" => 0, "blocked" => 0}, + "attackWaves" => {"total" => 0, "blocked" => 0} }, - userAgents: { - breakdown: {} + "userAgents" => { + "breakdown" => {} }, - ipAddresses: { - breakdown: {} + "ipAddresses" => { + "breakdown" => {} } }, - users: [], - routes: [], - hostnames: [], - middlewareInstalled: true + "users" => [], + "routes" => [], + "hostnames" => [], + "middlewareInstalled" => true } end @@ -188,14 +188,14 @@ class HeartbeatTest < ActiveSupport::TestCase ) serialized = event.as_json - assert_includes serialized[:routes], - {path: "/", method: "GET", hits: 1, apispec: {}} - assert_includes serialized[:routes], + assert_includes serialized["routes"], + {"path" => "/", "method" => "GET", "hits" => 1, "apispec" => {}} + assert_includes serialized["routes"], { - path: "/users(.:format)", - method: "GET", - hits: 1, - apispec: { + "path" => "/users(.:format)", + "method" => "GET", + "hits" => 1, + "apispec" => { "query" => { "type" => "object", "properties" => { @@ -205,12 +205,12 @@ class HeartbeatTest < ActiveSupport::TestCase } } } - assert_includes serialized[:routes], + assert_includes serialized["routes"], { - path: "/users(.:format)", - method: "POST", - hits: 2, - apispec: { + "path" => "/users(.:format)", + "method" => "POST", + "hits" => 2, + "apispec" => { "body" => { "type" => :json, "schema" => { @@ -307,19 +307,19 @@ def build_attack_wave(context, time:) attack_wave_data = attack_wave.as_json request = { - ipAddress: "1.2.3.4", - source: "rack" + "ipAddress" => "1.2.3.4", + "source" => "rack" } attack = { - metadata: { - samples: '[{"method":"GET","url":"/.config"}]' + "metadata" => { + "samples" => '[{"method":"GET","url":"/.config"}]' } } - assert_equal "detected_attack_wave", attack_wave_data[:type] - assert_equal request, attack_wave_data[:request] - assert_equal attack, attack_wave_data[:attack] + assert_equal "detected_attack_wave", attack_wave_data["type"] + assert_equal request, attack_wave_data["request"] + assert_equal attack, attack_wave_data["attack"] end end end diff --git a/test/aikido/zen/outbound_connection_test.rb b/test/aikido/zen/outbound_connection_test.rb index 634abe21..1f108928 100644 --- a/test/aikido/zen/outbound_connection_test.rb +++ b/test/aikido/zen/outbound_connection_test.rb @@ -54,6 +54,6 @@ class Aikido::Zen::OutboundConnectionTest < ActiveSupport::TestCase test "#as_json includes hostname and port" do conn = Aikido::Zen::OutboundConnection.new(host: "example.com", port: 443) - assert_equal({hostname: "example.com", port: 443}, conn.as_json) + assert_equal({"hostname" => "example.com", "port" => 443}, conn.as_json) end end diff --git a/test/aikido/zen/payload_test.rb b/test/aikido/zen/payload_test.rb index 1ff3b5df..b571a710 100644 --- a/test/aikido/zen/payload_test.rb +++ b/test/aikido/zen/payload_test.rb @@ -5,41 +5,41 @@ class Aikido::Zen::PayloadTest < ActiveSupport::TestCase test "query payloads have the proper JSON serialization" do payload = Aikido::Zen::Payload.new("value", :query, "path") - assert_equal({payload: "value", source: "query", path: ".path"}, payload.as_json) + assert_equal({"payload" => "value", "source" => "query", "path" => ".path"}, payload.as_json) end test "body payloads have the proper JSON serialization" do payload = Aikido::Zen::Payload.new("value", :body, "path") - assert_equal({payload: "value", source: "body", path: ".path"}, payload.as_json) + assert_equal({"payload" => "value", "source" => "body", "path" => ".path"}, payload.as_json) end test "header payloads have the proper JSON serialization" do payload = Aikido::Zen::Payload.new("value", :header, "path") - assert_equal({payload: "value", source: "headers", path: ".path"}, payload.as_json) + assert_equal({"payload" => "value", "source" => "headers", "path" => ".path"}, payload.as_json) end test "cookie payloads have the proper JSON serialization" do payload = Aikido::Zen::Payload.new("value", :cookie, "path") - assert_equal({payload: "value", source: "cookies", path: ".path"}, payload.as_json) + assert_equal({"payload" => "value", "source" => "cookies", "path" => ".path"}, payload.as_json) end test "route payloads have the proper JSON serialization" do payload = Aikido::Zen::Payload.new("value", :route, "path") - assert_equal({payload: "value", source: "routeParams", path: ".path"}, payload.as_json) + assert_equal({"payload" => "value", "source" => "routeParams", "path" => ".path"}, payload.as_json) end test "graphql payloads have the proper JSON serialization" do payload = Aikido::Zen::Payload.new("value", :graphql, "path") - assert_equal({payload: "value", source: "graphql", path: ".path"}, payload.as_json) + assert_equal({"payload" => "value", "source" => "graphql", "path" => ".path"}, payload.as_json) end test "xml payloads have the proper JSON serialization" do payload = Aikido::Zen::Payload.new("value", :xml, "path") - assert_equal({payload: "value", source: "xml", path: ".path"}, payload.as_json) + assert_equal({"payload" => "value", "source" => "xml", "path" => ".path"}, payload.as_json) end test "subdomain payloads have the proper JSON serialization" do payload = Aikido::Zen::Payload.new("value", :subdomain, "path") - assert_equal({payload: "value", source: "subdomains", path: ".path"}, payload.as_json) + assert_equal({"payload" => "value", "source" => "subdomains", "path" => ".path"}, payload.as_json) end end diff --git a/test/aikido/zen/request_test.rb b/test/aikido/zen/request_test.rb index da00991c..71fa17e5 100644 --- a/test/aikido/zen/request_test.rb +++ b/test/aikido/zen/request_test.rb @@ -50,8 +50,8 @@ module Tests req = build_request(env) assert_equal( - {method: "POST", url: "http://example.org/test"}, - req.as_json.slice(:method, :url) + {"method" => "POST", "url" => "http://example.org/test"}, + req.as_json.slice("method", "url") ) end @@ -59,7 +59,7 @@ module Tests env = Rack::MockRequest.env_for("/test", "REMOTE_ADDR" => "1.2.3.4") req = build_request(env) - assert_equal "1.2.3.4", req.as_json[:ipAddress] + assert_equal "1.2.3.4", req.as_json["ipAddress"] end test "#as_json includes the remote IP from the custom client IP header" do @@ -68,7 +68,7 @@ module Tests env = Rack::MockRequest.env_for("/test", "REMOTE_ADDR" => "1.2.3.4", "HTTP_CUSTOM_CLIENT_IP" => "4.3.2.1") req = build_request(env) - assert_equal "4.3.2.1", req.as_json[:ipAddress] + assert_equal "4.3.2.1", req.as_json["ipAddress"] Aikido::Zen.config.client_ip_header = nil end @@ -77,14 +77,14 @@ module Tests env = Rack::MockRequest.env_for("/test", "HTTP_USER_AGENT" => "Some/UA") req = build_request(env) - assert_equal "Some/UA", req.as_json[:userAgent] + assert_equal "Some/UA", req.as_json["userAgent"] end test "#as_json includes the framework handling the request as source" do env = Rack::MockRequest.env_for("/test") req = build_request(env) - assert_equal req.framework, req.as_json[:source] + assert_equal req.framework, req.as_json["source"] end test "#schema builds the request schema" do @@ -107,7 +107,7 @@ class RackRequestTest < ActiveSupport::TestCase env = Rack::MockRequest.env_for("/test/123") req = build_request(env) - assert_equal "/test/:number", req.as_json[:route] + assert_equal "/test/:number", req.as_json["route"] end def build_request(env) @@ -132,7 +132,7 @@ class ActionDispatchRequestTest < ActiveSupport::TestCase env = Rack::MockRequest.env_for("/cats/123") req = build_request(env) - assert_equal "/cats/:id(.:format)", req.as_json[:route] + assert_equal "/cats/:id(.:format)", req.as_json["route"] end test "#schema gets built from the request body" do diff --git a/test/aikido/zen/route_test.rb b/test/aikido/zen/route_test.rb index e1c2a156..f1c59c58 100644 --- a/test/aikido/zen/route_test.rb +++ b/test/aikido/zen/route_test.rb @@ -70,7 +70,7 @@ class Aikido::Zen::RouteTest < ActiveSupport::TestCase test "#as_json includes method and path" do route = Aikido::Zen::Route.new(verb: "GET", path: "/users/:id") - assert_equal({method: "GET", path: "/users/:id"}, route.as_json) + assert_equal({"method" => "GET", "path" => "/users/:id"}, route.as_json) end test "routes support wildcard matching on verbs" do diff --git a/test/aikido/zen/system_info_test.rb b/test/aikido/zen/system_info_test.rb index a05e2ba0..35e3e6cf 100644 --- a/test/aikido/zen/system_info_test.rb +++ b/test/aikido/zen/system_info_test.rb @@ -86,26 +86,26 @@ class Aikido::Zen::InfoTest < ActiveSupport::TestCase test "as_json includes the expected fields" do Aikido::Zen::Sinks.add("concurrent-ruby", scanners: [NOOP]) - assert_equal @info.attacks_are_only_reported?, @info.as_json[:dryMode] - assert_equal @info.library_name, @info.as_json[:library] - assert_equal @info.library_version, @info.as_json[:version] - assert_equal @info.hostname, @info.as_json[:hostname] - assert_equal @info.ip_address, @info.as_json[:ipAddress] - assert_equal @info.os_name, @info.as_json.dig(:os, :name) - assert_equal @info.os_version, @info.as_json.dig(:os, :version) - assert_equal @info.platform_version, @info.as_json.dig(:platform, :version) + assert_equal @info.attacks_are_only_reported?, @info.as_json["dryMode"] + assert_equal @info.library_name, @info.as_json["library"] + assert_equal @info.library_version, @info.as_json["version"] + assert_equal @info.hostname, @info.as_json["hostname"] + assert_equal @info.ip_address, @info.as_json["ipAddress"] + assert_equal @info.os_name, @info.as_json.dig("os", "name") + assert_equal @info.os_version, @info.as_json.dig("os", "version") + assert_equal @info.platform_version, @info.as_json.dig("platform", "version") # To keep the test scalable, only test one known dependency. - assert_kind_of Hash, @info.as_json[:packages] + assert_kind_of Hash, @info.as_json["packages"] assert_equal \ Gem.loaded_specs["concurrent-ruby"].version.to_s, - @info.as_json.dig(:packages, "concurrent-ruby") + @info.as_json.dig("packages", "concurrent-ruby") - assert_equal "", @info.as_json[:nodeEnv] - assert_equal false, @info.as_json[:preventedPrototypePollution] + assert_equal "", @info.as_json["nodeEnv"] + assert_equal false, @info.as_json["preventedPrototypePollution"] - assert_equal false, @info.as_json[:serverless] - assert_equal [], @info.as_json[:stack] - assert_equal({}, @info.as_json[:incompatiblePackages]) + assert_equal false, @info.as_json["serverless"] + assert_equal [], @info.as_json["stack"] + assert_equal({}, @info.as_json["incompatiblePackages"]) end end From 02299fb1c25e73a9d3e29a025441c0b0d36c0855 Mon Sep 17 00:00:00 2001 From: Mark Smith Date: Fri, 12 Jun 2026 10:35:10 +0200 Subject: [PATCH 2/7] Implement IPC and RPC --- lib/aikido/zen.rb | 1 + lib/aikido/zen/ipc.rb | 4 + lib/aikido/zen/ipc/ipc.rb | 372 ++++++++++++++++++++++++++++++++++++++ lib/aikido/zen/ipc/rpc.rb | 250 +++++++++++++++++++++++++ 4 files changed, 627 insertions(+) create mode 100644 lib/aikido/zen/ipc.rb create mode 100644 lib/aikido/zen/ipc/ipc.rb create mode 100644 lib/aikido/zen/ipc/rpc.rb diff --git a/lib/aikido/zen.rb b/lib/aikido/zen.rb index 491c4b25..adf37042 100644 --- a/lib/aikido/zen.rb +++ b/lib/aikido/zen.rb @@ -1,5 +1,6 @@ # frozen_string_literal: true +require_relative "zen/ipc" require_relative "zen/helpers" require_relative "zen/version" require_relative "zen/errors" diff --git a/lib/aikido/zen/ipc.rb b/lib/aikido/zen/ipc.rb new file mode 100644 index 00000000..c5621eff --- /dev/null +++ b/lib/aikido/zen/ipc.rb @@ -0,0 +1,4 @@ +# frozen_string_literal: true + +require_relative "ipc/ipc" +require_relative "ipc/rpc" diff --git a/lib/aikido/zen/ipc/ipc.rb b/lib/aikido/zen/ipc/ipc.rb new file mode 100644 index 00000000..bdc3634c --- /dev/null +++ b/lib/aikido/zen/ipc/ipc.rb @@ -0,0 +1,372 @@ +# frozen_string_literal: true + +require "openssl" +require "securerandom" +require "socket" +require "concurrent" + +# Code coverage is disabled here because `OpenSSL.fixed_length_secure_compare` +# is already defined in the normal case. +# :nocov: +unless OpenSSL.respond_to?(:fixed_length_secure_compare) + def OpenSSL.fixed_length_secure_compare(a, b) + l = a.unpack("C#{a.bytesize}") + + res = 0 + b.each_byte { |byte| res |= byte ^ l.shift } + res == 0 + end +end +# :nocov: + +module Aikido + module Zen + module IPC + CONNECT_TIMEOUT = 2.0 + HANDSHAKE_TIMEOUT = 3.0 + READ_TIMEOUT = 5.0 + WRITE_TIMEOUT = 5.0 + + module TimedIO + private + + def connect_with_deadline(host, port, deadline) + socket = ::Socket.new(:INET, :STREAM) + + addr = ::Socket.sockaddr_in(port, host) + + connected = false + + case socket.connect_nonblock(addr, exception: false) + # Code coverage is disabled here because this is hard to control. + # :nocov: + when 0 + connected = true + # :nocov: + when :wait_writable + remaining = deadline - Process.clock_gettime(Process::CLOCK_MONOTONIC) + + # Code coverage is disabled here because this is hard to control. + # :nocov: + unless remaining > 0 && ::IO.select(nil, [socket], nil, remaining) + raise Errno::ETIMEDOUT, "connect timed out" + end + # :nocov: + + errno = socket.getsockopt(::Socket::SOL_SOCKET, ::Socket::SO_ERROR).int + + unless errno == 0 + raise SystemCallError.new(errno) + end + + connected = true + + # Code coverage is disabled here because this code is unreachable. + # :nocov: + else + # empty + end + # :nocov: + + socket + ensure + socket.close unless connected + end + + def connect_with_timeout(host, port, timeout) + deadline = Process.clock_gettime(Process::CLOCK_MONOTONIC) + timeout + + connect_with_deadline(host, port, deadline) + end + + def read_with_deadline(socket, length, deadline) + buf = String.new(encoding: Encoding::BINARY) + + while buf.bytesize < length + remaining = deadline - Process.clock_gettime(Process::CLOCK_MONOTONIC) + + raise Errno::ETIMEDOUT, "read timed out" unless remaining > 0 + + case chunk = socket.read_nonblock(length - buf.bytesize, exception: false) + # Code coverage is disabled here because this is hard to control. + # :nocov: + when :wait_readable + raise Errno::ETIMEDOUT, "read timed out" unless ::IO.select([socket], nil, nil, remaining) + # :nocov: + when nil + raise EOFError + else + buf << chunk + end + end + + buf + end + + def write_with_deadline(socket, data, deadline) + written = 0 + + while written < data.bytesize + remaining = deadline - Process.clock_gettime(Process::CLOCK_MONOTONIC) + + raise Errno::ETIMEDOUT, "write timed out" unless remaining > 0 + + case n = socket.write_nonblock(data.byteslice(written..), exception: false) + # Code coverage is disabled here because this is hard to control. + # :nocov: + when :wait_writable + raise Errno::ETIMEDOUT, "write timed out" unless ::IO.select(nil, [socket], nil, remaining) + # :nocov: + else + written += n + end + end + end + end + + module FramedIO + include TimedIO + + class FrameTooLargeError < StandardError + def initialize(size, max_size) + super("frame too large: #{size} bytes (max: #{max_size})") + end + end + + private + + def read_frame_with_deadline(socket, max_size, deadline) + len = read_with_deadline(socket, 4, deadline).unpack1("N") + + if max_size && len > max_size + raise FrameTooLargeError.new(len, max_size) + end + + read_with_deadline(socket, len, deadline) + end + + def read_frame_with_timeout(socket, max_size, timeout) + deadline = Process.clock_gettime(Process::CLOCK_MONOTONIC) + timeout + + read_frame_with_deadline(socket, max_size, deadline) + end + + def write_frame_with_deadline(socket, data, max_size, deadline) + bytes = data.b + + if max_size && bytes.bytesize > max_size + raise FrameTooLargeError.new(bytes.bytesize, max_size) + end + + write_with_deadline(socket, [bytes.bytesize].pack("N"), deadline) + write_with_deadline(socket, bytes, deadline) + end + + def write_frame_with_timeout(socket, data, max_size, timeout) + deadline = Process.clock_gettime(Process::CLOCK_MONOTONIC) + timeout + + write_frame_with_deadline(socket, data, max_size, deadline) + end + end + + module Handshake + CHALLENGE_LEN = 32 + HMAC_LEN = 32 + + class Error < StandardError; end + + module Server + include TimedIO + + private + + def handshake(socket, secret, timeout = IPC::HANDSHAKE_TIMEOUT) + deadline = Process.clock_gettime(Process::CLOCK_MONOTONIC) + timeout + + server_challenge = SecureRandom.bytes(CHALLENGE_LEN) + + write_with_deadline(socket, server_challenge, deadline) + + buf = read_with_deadline(socket, HMAC_LEN + CHALLENGE_LEN, deadline) + + client_hmac = buf.byteslice(0, HMAC_LEN) + client_challenge = buf.byteslice(HMAC_LEN, CHALLENGE_LEN) + + expected = OpenSSL::HMAC.digest("SHA256", secret, "CLIENT-AUTH" + server_challenge) + + unless OpenSSL.fixed_length_secure_compare(client_hmac, expected) + socket.close + raise Error, "client authentication failed" + end + + server_hmac = OpenSSL::HMAC.digest("SHA256", secret, "SERVER-AUTH" + client_challenge) + + write_with_deadline(socket, server_hmac, deadline) + rescue Errno::ETIMEDOUT + socket.close + raise Error, "handshake timed out" + rescue EOFError, Errno::ECONNRESET, Errno::EPIPE + socket.close + raise Error, "connection closed" + end + end + + module Client + include TimedIO + + private + + def handshake(socket, secret, timeout = IPC::HANDSHAKE_TIMEOUT) + deadline = Process.clock_gettime(Process::CLOCK_MONOTONIC) + timeout + + server_challenge = read_with_deadline(socket, CHALLENGE_LEN, deadline) + + client_hmac = OpenSSL::HMAC.digest("SHA256", secret, "CLIENT-AUTH" + server_challenge) + client_challenge = SecureRandom.bytes(CHALLENGE_LEN) + + write_with_deadline(socket, client_hmac + client_challenge, deadline) + + server_hmac = read_with_deadline(socket, HMAC_LEN, deadline) + + expected = OpenSSL::HMAC.digest("SHA256", secret, "SERVER-AUTH" + client_challenge) + + unless OpenSSL.fixed_length_secure_compare(server_hmac, expected) + socket.close + raise Error, "server authentication failed" + end + rescue Errno::ETIMEDOUT + socket.close + raise Error, "handshake timed out" + rescue EOFError, Errno::ECONNRESET, Errno::EPIPE + socket.close + raise Error, "connection closed" + end + end + end + + class Server + include Handshake::Server + + attr_reader :host + attr_reader :port + + def initialize( + secret, + host = "127.0.0.1", + port = 0, + handshake_timeout: IPC::HANDSHAKE_TIMEOUT + ) + @secret = secret + @handshake_timeout = handshake_timeout + + @running = Concurrent::AtomicBoolean.new(false) + + @server = TCPServer.new(host, port) + @host = @server.addr[3] + @port = @server.addr[1] + end + + def accept(&block) + socket = @server.accept + + Thread.new do + begin + handshake(socket, @secret, @handshake_timeout) + rescue Handshake::Error + # rejected connection + else + # accepted connection + block.call(socket) + end + ensure + socket.close + end + end + + def close + @server.close + end + + def start(&block) + raise ArgumentError, "block required" unless block + + return false unless @running.make_true + + Thread.new do + loop do + accept do |socket| + # accepted connection + block.call(socket) + end + end + rescue IOError + # server stopped + ensure + @running.make_false + close + end + + true + end + + def stop(&block) + return false unless @running.make_false + + block&.call + + close + + true + end + end + + class Client + include Handshake::Client + + attr_reader :socket + + def initialize( + secret, + host = "127.0.0.1", + port = 0, + connect_timeout: IPC::CONNECT_TIMEOUT, + handshake_timeout: IPC::HANDSHAKE_TIMEOUT + ) + @running = Concurrent::AtomicBoolean.new(false) + + @socket = connect_with_timeout(host, port, connect_timeout) + handshake(@socket, secret, handshake_timeout) + end + + def close + @socket.close + end + + def start(&block) + raise ArgumentError, "block required" unless block + + return false unless @running.make_true + + Thread.new do + block.call(@socket) + ensure + @running.make_false + close + end + + true + end + + def stop(&block) + return false unless @running.make_false + + block&.call + + close + + true + end + end + end + end +end diff --git a/lib/aikido/zen/ipc/rpc.rb b/lib/aikido/zen/ipc/rpc.rb new file mode 100644 index 00000000..e57e56c5 --- /dev/null +++ b/lib/aikido/zen/ipc/rpc.rb @@ -0,0 +1,250 @@ +# frozen_string_literal: true + +require "json" +require "securerandom" +require "concurrent" + +module Aikido + module Zen + module RPC + class NoHandlerError < StandardError; end + + class Server + include IPC::FramedIO + + def initialize( + secret, + host = "127.0.0.1", + port = 0, + handshake_timeout: IPC::HANDSHAKE_TIMEOUT, + read_timeout: IPC::READ_TIMEOUT, + write_timeout: IPC::WRITE_TIMEOUT, + max_read_size: nil, + max_write_size: nil, + logger: Aikido::Zen.config.logger + ) + @read_timeout = read_timeout + @write_timeout = write_timeout + @max_read_size = max_read_size + @max_write_size = max_write_size + @logger = logger + + @handlers = {} + + @server = IPC::Server.new( + secret, + host, + port, + handshake_timeout: handshake_timeout + ) + end + + def host + @server.host + end + + def port + @server.port + end + + def start + @server.start do |socket| + @logger.info("RPC server: client connected") + + handle_messages(socket) + rescue EOFError, Errno::ECONNRESET, Errno::EPIPE + # disconnected + rescue IOError + # client stopped + rescue => err + @logger.error("RPC server error: #{err.class}: #{err.message}") + @logger.debug(err.backtrace.join("\n")) + ensure + @logger.info("RPC server: client disconnected") + end + end + + def stop + @server.stop + end + + def handle(name, &block) + @handlers[name] = block + end + + private + + def handle_messages(socket) + write_mutex = Mutex.new + + loop do + message = read_message(socket) + next unless valid_message?(message) + + handle_message(message, write_mutex, socket) + end + end + + def handle_message(message, write_mutex, socket) + id, name, args, kwargs = message + kwargs.transform_keys!(&:to_sym) + + respond_called = false + + respond = proc do |result| + write_mutex.synchronize do + next if respond_called + + write_message(socket, [id, result, nil]) + + respond_called = true + end + end + + handler = @handlers[name] + + raise NoHandlerError, "undefined handler '#{name}'" unless handler + + handler.call(respond, *args, **kwargs) + + respond.call(nil) + rescue => err + write_mutex.synchronize do + next if respond_called + + write_message(socket, [id, nil, err.message]) + end + end + + def valid_message?(message) + message.is_a?(Array) && message.length == 4 && + message[2].is_a?(Array) && message[3].is_a?(Hash) + end + + def read_message(socket) + JSON.parse(read_frame_with_timeout(socket, @max_read_size, @read_timeout)) + end + + def write_message(socket, message) + write_frame_with_timeout(socket, JSON.generate(message), @max_write_size, @write_timeout) + end + end + + class Client + include IPC::FramedIO + + def initialize( + secret, + host, + port, + connect_timeout: IPC::CONNECT_TIMEOUT, + handshake_timeout: IPC::HANDSHAKE_TIMEOUT, + read_timeout: IPC::READ_TIMEOUT, + write_timeout: IPC::WRITE_TIMEOUT, + max_read_size: nil, + max_write_size: nil, + logger: Aikido::Zen.config.logger + ) + @read_timeout = read_timeout + @write_timeout = write_timeout + @max_read_size = max_read_size + @max_write_size = max_write_size + @logger = logger + + @pending = Concurrent::Hash.new + @write_mutex = Mutex.new + + @client = IPC::Client.new( + secret, + host, + port, + connect_timeout: connect_timeout, + handshake_timeout: handshake_timeout + ) + end + + def start + @client.start do |socket| + @logger.info("RPC client connected") + + handle_messages(socket) + rescue EOFError, Errno::ECONNRESET, Errno::EPIPE => err + # disconnected + rescue IOError => err + # client stopped + rescue => err + @logger.error("RPC client error: #{err.class}: #{err.message}") + @logger.debug(err.backtrace.join("\n")) + ensure + @logger.info("RPC client disconnected") + + @client.stop + + @pending.each_value { |ivar| ivar.fail(err) } + @pending.clear + end + end + + def stop + @client.stop do + @client.socket.shutdown(Socket::SHUT_RDWR) + end + end + + def invoke(name, *args, timeout: nil, **kwargs) + id = SecureRandom.uuid + + ivar = Concurrent::IVar.new + + @pending[id] = ivar + + @write_mutex.synchronize { write_message(@client.socket, [id, name, args, kwargs]) } + + ivar.wait!(timeout) + + raise Errno::ETIMEDOUT, "invoke timed out" if ivar.incomplete? + + ivar.value + ensure + @pending.delete(id) + end + + private + + def handle_messages(socket) + loop do + message = read_message(socket) + next unless valid_message?(message) + + handle_message(message) + end + end + + def handle_message(message) + id, result, error = message + + ivar = @pending.delete(id) + return unless ivar + + if error + ivar.fail(RuntimeError.new(error)) + else + ivar.set(result) + end + end + + def valid_message?(message) + message.is_a?(Array) && message.length == 3 + end + + def read_message(socket) + JSON.parse(read_frame_with_timeout(socket, @max_read_size, @read_timeout)) + end + + def write_message(socket, message) + write_frame_with_timeout(socket, JSON.generate(message), @max_write_size, @write_timeout) + end + end + end + end +end From ba7b8db0824a2f2500e776faeca908af06a30e20 Mon Sep 17 00:00:00 2001 From: Mark Smith Date: Mon, 22 Jun 2026 18:18:35 +0200 Subject: [PATCH 3/7] Add tests --- test/aikido/zen/ipc/ipc_test.rb | 557 ++++++++++++++++++++++++++++++++ test/aikido/zen/ipc/rpc_test.rb | 482 +++++++++++++++++++++++++++ 2 files changed, 1039 insertions(+) create mode 100644 test/aikido/zen/ipc/ipc_test.rb create mode 100644 test/aikido/zen/ipc/rpc_test.rb diff --git a/test/aikido/zen/ipc/ipc_test.rb b/test/aikido/zen/ipc/ipc_test.rb new file mode 100644 index 00000000..da4ef5a4 --- /dev/null +++ b/test/aikido/zen/ipc/ipc_test.rb @@ -0,0 +1,557 @@ +# frozen_string_literal: true + +require "test_helper" +require "timeout" + +module IPCHelpers + def build_ipc_server(secret = Aikido::Zen.secret) + Aikido::Zen::IPC::Server.new(secret) + end + + def build_ipc_client(server, secret = Aikido::Zen.secret) + Aikido::Zen::IPC::Client.new( + secret, + server.host, + server.port, + connect_timeout: 1, + handshake_timeout: 1 + ) + end + + def start_ipc_server(secret = Aikido::Zen.secret, &block) + server = build_ipc_server(secret) + server.start(&block) + server + end + + def start_ipc_client(server, secret = Aikido::Zen.secret, &block) + client = build_ipc_client(server, secret) + client.start(&block) + client + end +end + +class Aikido::Zen::IPC::ServerTest < ActiveSupport::TestCase + include IPCHelpers + + test "#initialize binds to 127.0.0.1 on a free port" do + server = build_ipc_server + + assert_equal "127.0.0.1", server.host + assert_operator server.port, :>, 0 + ensure + server.close + end + + test "#start raises ArgumentError when called without a block" do + server = build_ipc_server + + err = assert_raises(ArgumentError) { server.start } + assert_equal "block required", err.message + ensure + server.close + end + + test "#start returns true when not running" do + server = build_ipc_server + + assert_equal true, server.start {} + ensure + server.stop + end + + test "#start returns false when running" do + server = start_ipc_server {} + + assert_equal false, server.start {} + ensure + server.stop + end + + test "#stop returns true when running" do + server = start_ipc_server {} + + assert_equal true, server.stop + end + + test "#stop returns false when not running" do + server = build_ipc_server + + assert_equal false, server.stop + ensure + server.close + end + + test "#stop yields the block when running" do + server = start_ipc_server {} + + yielded = false + server.stop { yielded = true } + + assert yielded + end + + test "#stop does not yield the block when not running" do + server = build_ipc_server + + yielded = false + server.stop { yielded = true } + + refute yielded + ensure + server.close + end +end + +class Aikido::Zen::IPC::ClientTest < ActiveSupport::TestCase + include IPCHelpers + + setup do + @server = start_ipc_server {} + end + + teardown do + @server.stop + end + + test "#initialize connects to the given host and port" do + client = build_ipc_client(@server) + + addr = client.socket.remote_address + + assert_equal @server.host, addr.ip_address + assert_equal @server.port, addr.ip_port + ensure + client.close + end + + test "#start raises ArgumentError when called without a block" do + client = build_ipc_client(@server) + + err = assert_raises(ArgumentError) { client.start } + assert_equal "block required", err.message + ensure + client.close + end + + test "#start returns true when not running" do + client = build_ipc_client(@server) + + assert_equal true, client.start {} + ensure + client.stop + end + + test "#start returns false when running" do + gate = Concurrent::CountDownLatch.new(1) + + client = start_ipc_client(@server) { gate.wait } + + assert_equal false, client.start {} + ensure + gate.count_down + end + + test "#stop returns true when running" do + gate = Concurrent::CountDownLatch.new(1) + + client = start_ipc_client(@server) { gate.wait } + + assert_equal true, client.stop + ensure + gate.count_down + end + + test "#stop returns false when not running" do + client = build_ipc_client(@server) + + assert_equal false, client.stop + ensure + client.close + end + + test "#stop yields the block when running" do + gate = Concurrent::CountDownLatch.new(1) + + client = start_ipc_client(@server) { gate.wait } + + yielded = false + client.stop { yielded = true } + + assert yielded + ensure + gate.count_down + end + + test "#stop does not yield the block when not running" do + client = build_ipc_client(@server) + + yielded = false + client.stop { yielded = true } + + refute yielded + ensure + client.close + end +end + +class Aikido::Zen::IPC::ConnectionTest < ActiveSupport::TestCase + include IPCHelpers + + test "client connects successfully when the shared secret matches" do + connected = Queue.new + + server = start_ipc_server { connected.push(:ok) } + + client = build_ipc_client(server) + + assert_equal :ok, Timeout.timeout(2) { connected.pop } + ensure + client.close + server.stop + end + + test "client raises Handshake::Error and the server block is never called when the secret is wrong" do + connected = Queue.new + + server = start_ipc_server { connected.push(:ok) } + + assert_raises(Aikido::Zen::IPC::Handshake::Error) do + build_ipc_client(server, "wrong...wrong...wrong...wrong...") + end + + assert_empty connected, "server block should not have been called" + ensure + server.stop + end + + test "server continues accepting connections after a handshake rejection" do + connected = Queue.new + + server = start_ipc_server { connected.push(:ok) } + + assert_raises(Aikido::Zen::IPC::Handshake::Error) do + build_ipc_client(server, "wrong...wrong...wrong...wrong...") + end + + client = build_ipc_client(server) + + assert_equal :ok, Timeout.timeout(2) { connected.pop } + ensure + client.close + server.stop + end + + test "server handles multiple sequential connections from different clients" do + connected = Queue.new + + server = start_ipc_server { connected.push(:ok) } + + 3.times do + client = build_ipc_client(server) + client.close + end + + results = Timeout.timeout(2) do + Array.new(3) { connected.pop } + end + + assert_equal [:ok, :ok, :ok], results + ensure + server.stop + end + + test "server handles multiple concurrent connections from different clients" do + connected = Queue.new + + gate = Concurrent::CountDownLatch.new(1) + + server = start_ipc_server do + connected.push(:ok) + gate.wait + end + + threads = 3.times.map do + Thread.new do + client = build_ipc_client(server) + client.close + end + end + + results = Timeout.timeout(2) do + Array.new(3) { connected.pop } + end + + assert_equal [:ok, :ok, :ok], results + ensure + gate.count_down + + threads.each { |thread| thread.join(1) } + + server.stop + end + + test "server continues accepting connections after a client disconnects abruptly" do + connected = Queue.new + + server = start_ipc_server { connected.push(:ok) } + + client1 = build_ipc_client(server) + Timeout.timeout(2) { connected.pop } + + client1.socket.close # disconnect abruptly + + client2 = build_ipc_client(server) + + assert_equal :ok, Timeout.timeout(2) { connected.pop } + ensure + client2.close + server.stop + end + + test "server continues accepting connections after the block raises" do + connected = Queue.new + + server = start_ipc_server do + Thread.current.report_on_exception = false + + connected.push(:attempted) + raise "something went wrong" + end + + client1 = build_ipc_client(server) + Timeout.timeout(2) { connected.pop } + + client2 = build_ipc_client(server) + + assert_equal :attempted, Timeout.timeout(2) { connected.pop } + ensure + client1.close + client2.close + server.stop + end + + test "#stop causes subsequent connection attempts to fail" do + server = start_ipc_server {} + server.stop + + assert_raises(Errno::ECONNREFUSED, Errno::ETIMEDOUT) do + build_ipc_client(server) + end + end + + test "the client can write data to the server" do + received = Queue.new + + server = start_ipc_server do |socket| + data = socket.read(5) + received.push(data) + end + + client = build_ipc_client(server) + + client.socket.write("hello") + + assert_equal "hello", Timeout.timeout(2) { received.pop } + ensure + client.close + server.stop + end + + test "the server can write data to the client" do + server = start_ipc_server { |socket| socket.write("hello") } + + client = build_ipc_client(server) + + assert_equal "hello", client.socket.read(5) + ensure + client.close + server.stop + end +end + +class Aikido::Zen::IPC::FramedIOTest < ActiveSupport::TestCase + include Aikido::Zen::IPC::FramedIO + + def socket_pair + Socket.pair(:UNIX, :STREAM, 0) + end + + test "#read_frame_with_timeout raises FrameTooLargeError when the frame exceeds max_size" do + reader, writer = socket_pair + + writer.write([10].pack("N")) # declare a 10-byte frame + + err = assert_raises(Aikido::Zen::IPC::FramedIO::FrameTooLargeError) do + read_frame_with_timeout(reader, 4, 1) + end + + assert_equal "frame too large: 10 bytes (max: 4)", err.message + ensure + reader.close + writer.close + end + + test "#write_frame_with_timeout raises FrameTooLargeError when the frame exceeds max_size" do + reader, writer = socket_pair + + err = assert_raises(Aikido::Zen::IPC::FramedIO::FrameTooLargeError) do + write_frame_with_timeout(writer, "hello world", 5, 1) + end + assert_equal "frame too large: 11 bytes (max: 5)", err.message + ensure + reader.close + writer.close + end + + test "handles large frames correctly" do + reader, writer = socket_pair + + data = SecureRandom.bytes(1 * 1024 * 1024) + + thread = Thread.new { write_frame_with_timeout(writer, data, nil, 5) } + result = read_frame_with_timeout(reader, nil, 5) + thread.join + + assert_equal data, result + ensure + reader.close + writer.close + end +end + +class Aikido::Zen::IPC::HandshakeTest < ActiveSupport::TestCase + CHALLENGE_LEN = Aikido::Zen::IPC::Handshake::CHALLENGE_LEN + HMAC_LEN = Aikido::Zen::IPC::Handshake::HMAC_LEN + + def socket_pair + Socket.pair(:UNIX, :STREAM, 0) + end + + def build_ipc_server + Object.new.extend(Aikido::Zen::IPC::Handshake::Server) + end + + def build_ipc_client + Object.new.extend(Aikido::Zen::IPC::Handshake::Client) + end + + def build_server_handshake_thread(socket, secret = Aikido::Zen.secret) + server = build_ipc_server + + thread = Thread.new { server.send(:handshake, socket, secret) } + thread.report_on_exception = false + thread + end + + def build_client_handshake_thread(socket, secret = Aikido::Zen.secret) + client = build_ipc_client + + thread = Thread.new { client.send(:handshake, socket, secret) } + thread.report_on_exception = false + thread + end + + test "Handshake completes successfully with a matching secret" do + server_socket, client_socket = socket_pair + + server_thread = build_server_handshake_thread(server_socket) + client_thread = build_client_handshake_thread(client_socket) + + assert_nothing_raised { server_thread.value } + assert_nothing_raised { client_thread.value } + ensure + server_socket.close + client_socket.close + end + + test "Handshake::Server#handshake raises Handshake::Error on timeout" do + server_socket, client_socket = socket_pair + + server = build_ipc_server + + err = assert_raises(Aikido::Zen::IPC::Handshake::Error) do + server.send(:handshake, server_socket, Aikido::Zen.secret, 0) + end + + assert_equal "handshake timed out", err.message + ensure + server_socket.close + client_socket.close + end + + test "Handshake::Client#handshake raises Handshake::Error on timeout" do + server_socket, client_socket = socket_pair + + client = build_ipc_client + + err = assert_raises(Aikido::Zen::IPC::Handshake::Error) do + client.send(:handshake, client_socket, Aikido::Zen.secret, 0) + end + + assert_equal "handshake timed out", err.message + ensure + server_socket.close + client_socket.close + end + + test "Handshake::Server#handshake raises Handshake::Error when the connection closes" do + server_socket, client_socket = socket_pair + + thread = build_server_handshake_thread(server_socket) + + client_socket.close + + err = assert_raises(Aikido::Zen::IPC::Handshake::Error) { thread.value } + assert_equal "connection closed", err.message + ensure + server_socket.close + client_socket.close + end + + test "Handshake::Client#handshake raises Handshake::Error when the connection closes" do + server_socket, client_socket = socket_pair + + thread = build_client_handshake_thread(client_socket) + + server_socket.close + + err = assert_raises(Aikido::Zen::IPC::Handshake::Error) { thread.value } + assert_equal "connection closed", err.message + ensure + server_socket.close + client_socket.close + end + + test "Handshake::Server#handshake raises Handshake::Error when the client sends a wrong HMAC" do + server_socket, client_socket = socket_pair + + thread = build_server_handshake_thread(server_socket) + + _server_challenge = client_socket.read(CHALLENGE_LEN) + client_socket.write(SecureRandom.bytes(HMAC_LEN + CHALLENGE_LEN)) # garbage client HMAC and challenge + + err = assert_raises(Aikido::Zen::IPC::Handshake::Error) { thread.value } + assert_equal "client authentication failed", err.message + ensure + server_socket.close + client_socket.close + end + + test "Handshake::Client#handshake raises Handshake::Error when the server sends a wrong HMAC" do + server_socket, client_socket = socket_pair + + thread = build_client_handshake_thread(client_socket) + + server_challenge = SecureRandom.bytes(CHALLENGE_LEN) + server_socket.write(server_challenge) + _client_hmac_and_challenge = server_socket.read(HMAC_LEN + CHALLENGE_LEN) + server_socket.write(SecureRandom.bytes(HMAC_LEN)) # garbage server HMAC + + err = assert_raises(Aikido::Zen::IPC::Handshake::Error) { thread.value } + assert_equal "server authentication failed", err.message + ensure + server_socket.close + client_socket.close + end +end diff --git a/test/aikido/zen/ipc/rpc_test.rb b/test/aikido/zen/ipc/rpc_test.rb new file mode 100644 index 00000000..d9804a27 --- /dev/null +++ b/test/aikido/zen/ipc/rpc_test.rb @@ -0,0 +1,482 @@ +# frozen_string_literal: true + +require "test_helper" + +module IPCHelpers + def build_ipc_server(secret = Aikido::Zen.secret) + Aikido::Zen::IPC::Server.new(secret) + end + + def build_ipc_client(server, secret = Aikido::Zen.secret) + Aikido::Zen::IPC::Client.new( + secret, + server.host, + server.port, + connect_timeout: 1, + handshake_timeout: 1 + ) + end + + def start_ipc_server(secret = Aikido::Zen.secret, &block) + server = build_ipc_server(secret) + server.start(&block) + server + end + + def start_ipc_client(server, secret = Aikido::Zen.secret, &block) + client = build_ipc_client(server, secret) + client.start(&block) + client + end +end + +module RPCHelpers + def build_rpc_server(secret = Aikido::Zen.secret, logger: Aikido::Zen.config.logger) + Aikido::Zen::RPC::Server.new(secret, logger: logger) + end + + def build_rpc_client(server, secret = Aikido::Zen.secret, logger: Aikido::Zen.config.logger) + Aikido::Zen::RPC::Client.new(secret, server.host, server.port, logger: logger) + end + + def start_rpc_server(secret = Aikido::Zen.secret, logger: Aikido::Zen.config.logger, &block) + server = build_rpc_server(secret, logger: logger) + block&.call(server) + server.start + server + end + + def start_rpc_client(server, secret = Aikido::Zen.secret, logger: Aikido::Zen.config.logger) + client = build_rpc_client(server, secret, logger: logger) + client.start + client + end +end + +class Aikido::Zen::RPC::ServerTest < ActiveSupport::TestCase + include IPCHelpers + include RPCHelpers + include Aikido::Zen::IPC::FramedIO + + test "skips messages with an invalid structure and continues processing" do + server = start_rpc_server do |server| + server.handle("echo") { |respond, value| respond.call(value) } + end + + client = build_ipc_client(server) + socket = client.socket + + write_frame_with_timeout(socket, JSON.generate(["bad", "hello"]), nil, 1) + write_frame_with_timeout(socket, JSON.generate(["abc", "echo", ["hello"], {}]), nil, 1) + response = JSON.parse(read_frame_with_timeout(socket, nil, 2)) + + assert_equal ["abc", "hello", nil], response + ensure + client.close + server.stop + end + + test "#respond only sends the first response when called multiple times" do + server = start_rpc_server do |server| + server.handle("repeat") do |respond| + respond.call("first") + respond.call("second") + end + end + + client = build_ipc_client(server) + socket = client.socket + + write_frame_with_timeout(socket, JSON.generate(["abc", "repeat", [], {}]), nil, 1) + response = JSON.parse(read_frame_with_timeout(socket, nil, 2)) + + assert_equal ["abc", "first", nil], response + ensure + client.close + server.stop + end + + test "logs unexpected errors and drops the connection" do + server = start_rpc_server + client = build_ipc_client(server) + socket = client.socket + + write_frame_with_timeout(socket, "invalid JSON", nil, 1) + assert_raises(EOFError) { read_frame_with_timeout(socket, nil, 1) } + + assert_logged :error, /invalid JSON/ + ensure + client.close + server.stop + end +end + +class Aikido::Zen::RPC::ClientTest < ActiveSupport::TestCase + include IPCHelpers + include RPCHelpers + include Aikido::Zen::IPC::FramedIO + + test "#invoke skips messages with an invalid structure and continues processing" do + server = start_ipc_server do |socket| + raw = read_frame_with_timeout(socket, nil, 1) + id, _name, _args, _kwargs = JSON.parse(raw) + write_frame_with_timeout(socket, JSON.generate(["bad"]), nil, 1) + write_frame_with_timeout(socket, JSON.generate([id, "hello", nil]), nil, 1) + end + + client = start_rpc_client(server) + + result = client.invoke("echo", timeout: 2) + assert_equal "hello", result + ensure + client.stop + server.stop + end + + test "#invoke ignores responses with an unknown ID" do + server = start_ipc_server do |socket| + raw = read_frame_with_timeout(socket, nil, 1) + id, _name, _args, _kwargs = JSON.parse(raw) + write_frame_with_timeout(socket, JSON.generate(["unknown-id", "ignored", nil]), nil, 1) + write_frame_with_timeout(socket, JSON.generate([id, "hello", nil]), nil, 1) + end + + client = start_rpc_client(server) + + result = client.invoke("echo", timeout: 2) + assert_equal "hello", result + ensure + client.stop + server.stop + end + + test "#invoke raises on server disconnect" do + server = start_ipc_server do |socket| + read_frame_with_timeout(socket, nil, 2) + socket.close + end + + client = start_rpc_client(server) + + assert_raises(EOFError, Errno::ECONNRESET) do + client.invoke("echo", timeout: 2) + end + ensure + client.stop + server.stop + end + + test "#invoke raises RuntimeError when the server responds with an error" do + server = start_ipc_server do |socket| + raw = read_frame_with_timeout(socket, nil, 2) + id, _name, _args, _kwargs = JSON.parse(raw) + write_frame_with_timeout(socket, JSON.generate([id, nil, "something went wrong"]), nil, 1) + end + + client = start_rpc_client(server) + + err = assert_raises(RuntimeError) do + client.invoke("echo", timeout: 2) + end + + assert_equal "something went wrong", err.message + ensure + client.stop + server.stop + end + + test "logs unexpected errors" do + gate = Concurrent::CountDownLatch.new(1) + + server = start_ipc_server do |socket| + write_frame_with_timeout(socket, "invalid JSON", nil, 1) + end + + logger = Aikido::Zen.config.logger + logger.stub(:error, ->(msg) { + logger.add(Logger::ERROR, msg) + gate.count_down + }) do + client = start_rpc_client(server) + gate.wait(2) + client.stop + end + + assert_logged :error, /JSON/ + ensure + server.stop + end +end + +class Aikido::Zen::RPC::ConnectionTest < ActiveSupport::TestCase + include RPCHelpers + + test "#invoke calls the handler and returns the result" do + server = start_rpc_server do |server| + server.handle("echo") { |respond, text| respond.call(text) } + end + + client = start_rpc_client(server) + + result = client.invoke("echo", "hello", timeout: 2) + + assert_equal "hello", result + ensure + client.stop + server.stop + end + + test "#invoke raises Errno::ETIMEDOUT on timeout" do + gate = Concurrent::CountDownLatch.new(1) + + server = start_rpc_server do |server| + server.handle("slow") { gate.wait } + end + + client = start_rpc_client(server) + + assert_raises(Errno::ETIMEDOUT) do + client.invoke("slow", timeout: 0.1) + end + ensure + gate.count_down + + client.stop + server.stop + end + + test "#invoke raises RuntimeError when the handler is not registered" do + server = start_rpc_server + + client = start_rpc_client(server) + + err = assert_raises(RuntimeError) do + client.invoke("nonexistent", timeout: 2) + end + + assert_match(/nonexistent/, err.message) + ensure + client.stop + server.stop + end + + test "#invoke passes positional arguments to the handler" do + server = start_rpc_server do |server| + server.handle("add") { |respond, a, b| respond.call(a + b) } + end + + client = start_rpc_client(server) + + result = client.invoke("add", 3, 4, timeout: 2) + + assert_equal 7, result + ensure + client.stop + server.stop + end + + test "#invoke passes keyword arguments to the handler" do + server = start_rpc_server do |server| + server.handle("greet") { |respond, name:, greeting: "Hello"| respond.call("#{greeting}, #{name}!") } + end + + client = start_rpc_client(server) + + result = client.invoke("greet", name: "Alice", timeout: 2) + + assert_equal "Hello, Alice!", result + ensure + client.stop + server.stop + end + + test "#invoke passes positional and keyword arguments to the handler" do + server = start_rpc_server do |server| + server.handle("greet") { |respond, greeting, name:| respond.call("#{greeting}, #{name}!") } + end + + client = start_rpc_client(server) + + result = client.invoke("greet", "Hello", name: "Alice", timeout: 2) + + assert_equal "Hello, Alice!", result + ensure + client.stop + server.stop + end + + test "#invoke returns nil when the handler explicitly responds with nil" do + server = start_rpc_server do |server| + server.handle("null") { |respond| respond.call(nil) } + end + + client = start_rpc_client(server) + + result = client.invoke("null", timeout: 2) + + assert_nil result + ensure + client.stop + server.stop + end + + test "#invoke returns nil when the handler does not respond explicitly" do + server = start_rpc_server do |server| + server.handle("noop") {} + end + + client = start_rpc_client(server) + + result = client.invoke("noop", timeout: 2) + + assert_nil result + ensure + client.stop + server.stop + end + + test "#invoke returns immediately when the handler responds" do + gate = Concurrent::CountDownLatch.new(1) + + server = start_rpc_server do |server| + server.handle("work") do |respond| + respond.call(nil) + gate.wait + end + end + + client = start_rpc_client(server) + + result = client.invoke("work", timeout: 2) + + assert_nil result + ensure + gate.count_down + + client.stop + server.stop + end + + test "#invoke returns a string when the handler responds with a symbol" do + server = start_rpc_server do |server| + server.handle("symbolize") { |respond| respond.call(:ok) } + end + + client = start_rpc_client(server) + + result = client.invoke("symbolize", timeout: 2) + + assert_equal "ok", result + ensure + client.stop + server.stop + end + + test "#invoke handles complex nested structures" do + data = {"users" => [{"id" => 1, "name" => "Alice"}, {"id" => 2, "name" => "Bob"}]} + + server = start_rpc_server do |server| + server.handle("echo") { |respond, value| respond.call(value) } + end + + client = start_rpc_client(server) + + result = client.invoke("echo", data, timeout: 2) + + assert_equal data, result + ensure + client.stop + server.stop + end + + test "#invoke returns the result when the handler raises after responding" do + server = start_rpc_server do |server| + server.handle("risky") do |respond| + respond.call("safe result") + raise "something went wrong" + end + end + + client = start_rpc_client(server) + + result = client.invoke("risky", timeout: 2) + + assert_equal "safe result", result + ensure + client.stop + server.stop + end + + test "#invoke propagates errors raised inside the handler" do + server = start_rpc_server do |server| + server.handle("boom") { raise "something went wrong" } + end + + client = start_rpc_client(server) + + err = assert_raises(RuntimeError) do + client.invoke("boom", timeout: 2) + end + + assert_equal "something went wrong", err.message + ensure + client.stop + server.stop + end + + test "#invoke routes concurrent responses to the correct callers with single handler" do + server = start_rpc_server do |server| + server.handle("echo") { |respond, value| respond.call(value) } + end + + client = start_rpc_client(server) + + threads = 5.times.map do |i| + Thread.new do + [i, client.invoke("echo", i, timeout: 2)] + end + end + + threads.each do |thread| + input, output = thread.value + assert_equal input, output + end + ensure + client.stop + server.stop + end + + test "#invoke routes concurrent responses to the correct callers with multiple handlers" do + server = start_rpc_server do |server| + server.handle("double") { |respond, n| respond.call(n * 2) } + server.handle("negate") { |respond, n| respond.call(-n) } + end + + client = start_rpc_client(server) + + doubles = 3.times.map do |i| + Thread.new do + [i, client.invoke("double", i, timeout: 2)] + end + end + + negates = 3.times.map do |i| + Thread.new do + [i, client.invoke("negate", i, timeout: 2)] + end + end + + doubles.each do |thread| + i, result = thread.value + assert_equal i * 2, result + end + + negates.each do |thread| + i, result = thread.value + assert_equal(-i, result) + end + ensure + client.stop + server.stop + end +end From 82fc0fe72290db11b8d4c6acff5f1b799600b7e0 Mon Sep 17 00:00:00 2001 From: Mark Smith Date: Mon, 15 Jun 2026 16:42:17 +0200 Subject: [PATCH 4/7] Reimplement multiprocess architecture --- lib/aikido/zen.rb | 63 +++-- lib/aikido/zen/agent.rb | 4 +- lib/aikido/zen/api_cache.rb | 8 + lib/aikido/zen/config.rb | 27 -- lib/aikido/zen/detached_agent.rb | 2 - lib/aikido/zen/detached_agent/agent.rb | 79 ------ lib/aikido/zen/detached_agent/front_object.rb | 41 --- lib/aikido/zen/detached_agent/server.rb | 78 ------ lib/aikido/zen/errors.rb | 8 - lib/aikido/zen/middleware/rack_throttler.rb | 6 +- lib/aikido/zen/rate_limiter/result.rb | 20 ++ lib/aikido/zen/sinks/action_controller.rb | 6 +- lib/aikido/zen/worker_process.rb | 3 + lib/aikido/zen/worker_process/agent.rb | 4 + lib/aikido/zen/worker_process/agent/client.rb | 97 +++++++ lib/aikido/zen/worker_process/agent/server.rb | 86 ++++++ test/aikido/zen/config_test.rb | 13 - test/aikido/zen/detached_agent/agent_test.rb | 90 ------ .../detached_agent_server_test.rb | 53 ---- .../zen/detached_agent/front_object_test.rb | 22 -- test/aikido/zen/rate_limiter/result_test.rb | 53 ++++ .../zen/worker_process/agent/client_test.rb | 257 ++++++++++++++++++ .../zen/worker_process/agent/server_test.rb | 78 ++++++ test/aikido/zen_test.rb | 24 ++ test/test_helper.rb | 28 +- 25 files changed, 685 insertions(+), 465 deletions(-) create mode 100644 lib/aikido/zen/api_cache.rb delete mode 100644 lib/aikido/zen/detached_agent.rb delete mode 100644 lib/aikido/zen/detached_agent/agent.rb delete mode 100644 lib/aikido/zen/detached_agent/front_object.rb delete mode 100644 lib/aikido/zen/detached_agent/server.rb create mode 100644 lib/aikido/zen/worker_process.rb create mode 100644 lib/aikido/zen/worker_process/agent.rb create mode 100644 lib/aikido/zen/worker_process/agent/client.rb create mode 100644 lib/aikido/zen/worker_process/agent/server.rb delete mode 100644 test/aikido/zen/detached_agent/agent_test.rb delete mode 100644 test/aikido/zen/detached_agent/detached_agent_server_test.rb delete mode 100644 test/aikido/zen/detached_agent/front_object_test.rb create mode 100644 test/aikido/zen/rate_limiter/result_test.rb create mode 100644 test/aikido/zen/worker_process/agent/client_test.rb create mode 100644 test/aikido/zen/worker_process/agent/server_test.rb diff --git a/lib/aikido/zen.rb b/lib/aikido/zen.rb index adf37042..69cba9f9 100644 --- a/lib/aikido/zen.rb +++ b/lib/aikido/zen.rb @@ -11,10 +11,11 @@ require_relative "zen/worker" require_relative "zen/agent" require_relative "zen/api_client" +require_relative "zen/api_cache" require_relative "zen/api_stream" require_relative "zen/context" require_relative "zen/current_context" -require_relative "zen/detached_agent" +require_relative "zen/worker_process" require_relative "zen/middleware/middleware" require_relative "zen/middleware/fork_detector" require_relative "zen/middleware/context_setter" @@ -87,6 +88,23 @@ def self.runtime_settings=(settings) @runtime_settings = settings end + def self.api_cache + @api_cache ||= APICache.new + end + + def self.rate_limiter + @rate_limiter ||= RateLimiter.new + end + + def self.calculate_rate_limits(request) + agent = @worker_process_client + agent ? agent.calculate_rate_limits(request) : rate_limiter.calculate_rate_limits(request) + end + + def self.secret + @secret ||= SecureRandom.bytes(32) + end + # @return [Boolean] whether the Aikido agent is currently blocking requests. # Blocking mode is configured at startup and can be controlled through the # Aikido dashboard at runtime. @@ -317,7 +335,8 @@ def self.load_sinks! # Stop any background threads. def self.stop! @agent&.stop! - @detached_agent_server&.stop! + @worker_process_client&.stop + @worker_process_server&.stop! end # @!visibility private @@ -326,28 +345,24 @@ def self.agent @agent ||= Agent.start end - def self.detached_agent - @detached_agent ||= DetachedAgent::Agent.new + def self.worker_process_server + @worker_process_server ||= WorkerProcess::Agent::Server.start end - def self.detached_agent_server - @detached_agent_server ||= DetachedAgent::Server.start - end + @pid = Process.pid + @has_started = Concurrent::AtomicBoolean.new(false) + @has_handled_fork = Concurrent::AtomicBoolean.new(false) class << self - # `agent` and `detached_agent` are started on the first method call. - # A mutex controls thread execution to prevent multiple attempts. - LOCK = Mutex.new - def start! + return if @has_started.true? + return unless start? - @pid = Process.pid + return unless @has_started.make_true - LOCK.synchronize do - agent - detached_agent_server - end + worker_process_server + agent end def start? @@ -357,17 +372,25 @@ def start? end def check_and_handle_fork + return unless @has_started.true? && @has_handled_fork.make_true + handle_fork if forked? end def forked? - pid_changed = Process.pid != @pid - @pid = Process.pid - pid_changed + current_pid = Process.pid + return false if current_pid == @pid + @pid = current_pid + true end def handle_fork - @detached_agent&.handle_fork + server = @worker_process_server + client = WorkerProcess::Agent::Client.new(server.host, server.port) + client.start + @worker_process_client = client + rescue => err + config.logger.error("Forked worker process #{Process.pid}: failed to start worker process client: #{err.message}") end end diff --git a/lib/aikido/zen/agent.rb b/lib/aikido/zen/agent.rb index ace75d83..5a4a3c4a 100644 --- a/lib/aikido/zen/agent.rb +++ b/lib/aikido/zen/agent.rb @@ -19,14 +19,12 @@ def self.start(**opts) def initialize( config: Aikido::Zen.config, collector: Aikido::Zen.collector, - detached_agent: Aikido::Zen.detached_agent, worker: Aikido::Zen::Worker.new(config: config), api_client: Aikido::Zen::APIClient.new(config: config), api_stream: Aikido::Zen::APIStream.new(config: config) ) @config = config @collector = collector - @detached_agent = detached_agent @worker = worker @api_client = api_client @api_stream = api_stream @@ -243,6 +241,7 @@ def heartbeats def update_settings_from_runtime_config!(data) return unless @runtime_config_update_mutex.try_lock begin + Aikido::Zen.api_cache.runtime_config = data Aikido::Zen.runtime_settings.update_from_runtime_config_json(data) ensure @runtime_config_update_mutex.unlock @@ -254,6 +253,7 @@ def update_settings_from_runtime_config!(data) def update_settings_from_runtime_firewall_lists!(data) return unless @runtime_firewall_lists_update_mutex.try_lock begin + Aikido::Zen.api_cache.runtime_firewall_lists = data Aikido::Zen.runtime_settings.update_from_runtime_firewall_lists_json(data) ensure @runtime_firewall_lists_update_mutex.unlock diff --git a/lib/aikido/zen/api_cache.rb b/lib/aikido/zen/api_cache.rb new file mode 100644 index 00000000..cf80bf8d --- /dev/null +++ b/lib/aikido/zen/api_cache.rb @@ -0,0 +1,8 @@ +# frozen_string_literal: true + +module Aikido::Zen + class APICache + attr_accessor :runtime_config + attr_accessor :runtime_firewall_lists + end +end diff --git a/lib/aikido/zen/config.rb b/lib/aikido/zen/config.rb index a7c6b067..3ded6b59 100644 --- a/lib/aikido/zen/config.rb +++ b/lib/aikido/zen/config.rb @@ -66,11 +66,6 @@ class Config # @return [Logger] attr_reader :logger - # @return [String] Path of the socket where the detached agent will listen. - # By default, the socket file is created in the current working directory. - # Defaults to `aikido-detached-agent.sock`. - attr_accessor :detached_agent_socket_path - # @return [String] environment specific HTTP header providing the client IP. attr_accessor :client_ip_header @@ -239,7 +234,6 @@ def initialize self.json_encoder = DEFAULT_JSON_ENCODER self.json_decoder = DEFAULT_JSON_DECODER self.logger = Logger.new($stdout, progname: "aikido", level: debugging ? Logger::DEBUG : Logger::INFO) - self.detached_agent_socket_path = ENV.fetch("AIKIDO_DETACHED_AGENT_SOCKET_PATH", DEFAULT_DETACHED_AGENT_SOCKET_PATH) self.client_ip_header = ENV.fetch("AIKIDO_CLIENT_IP_HEADER", nil) self.max_performance_samples = 5000 self.max_compressed_stats = 100 @@ -327,26 +321,8 @@ def api_token_hash @api_token_hash ||= Digest::SHA1.hexdigest(api_token)[0, 7] end - def detached_agent_socket_uri - "drbunix:" + @detached_agent_socket_path - end - - def expanded_detached_agent_socket_path - @exanded_detached_agent_path ||= expand_socket_path(detached_agent_socket_path) - end - - def expanded_detached_agent_socket_uri - @exanded_detached_agent_uri ||= expand_socket_path(detached_agent_socket_uri) - end - private - def expand_socket_path(socket_path) - socket_path = socket_path.dup - socket_path.gsub!("%h", api_token_hash) if api_token_hash - socket_path - end - def read_boolean_from_env(value) return value unless value.respond_to?(:to_str) @@ -373,9 +349,6 @@ def read_boolean_from_env(value) # @!visibility private DEFAULT_JSON_DECODER = JSON.method(:parse) - # @!visibility private - DEFAULT_DETACHED_AGENT_SOCKET_PATH = "aikido-detached-agent.%h.sock" - # @!visibility private DEFAULT_BLOCKED_RESPONDER = ->(request, blocking_type, reason = nil) do message = case blocking_type diff --git a/lib/aikido/zen/detached_agent.rb b/lib/aikido/zen/detached_agent.rb deleted file mode 100644 index 6a1e2648..00000000 --- a/lib/aikido/zen/detached_agent.rb +++ /dev/null @@ -1,2 +0,0 @@ -require_relative "detached_agent/agent" -require_relative "detached_agent/server" diff --git a/lib/aikido/zen/detached_agent/agent.rb b/lib/aikido/zen/detached_agent/agent.rb deleted file mode 100644 index 70c589e7..00000000 --- a/lib/aikido/zen/detached_agent/agent.rb +++ /dev/null @@ -1,79 +0,0 @@ -# frozen_string_literal: true - -require "drb/drb" -require "drb/unix" -require_relative "front_object" -require_relative "../background_worker" - -module Aikido::Zen::DetachedAgent - # Agent that runs in forked processes. It communicates with the parent process to dRB - # calls. It's in charge of schedule and send heartbeats to the *parent process*, to be - # later pushed. - # - # heartbeat & polling interval are configured to 10s , because they are connecting with - # parent process. We want to have the freshest data. - # - # It's possible to use `extend Forwardable` here for one-line forward calls to the - # @front_object object. Unfortunately, the methods to be called are - # created at runtime by `DRbObject`, which leads to an ugly warning about - # private methods after the delegator is bound. - class Agent - attr_reader :worker - - def initialize( - config: Aikido::Zen.config, - worker: Aikido::Zen::Worker.new(config: config), - heartbeat_interval: 10, - polling_interval: 10, - collector: Aikido::Zen.collector - ) - @config = config - @worker = worker - @heartbeat_interval = heartbeat_interval - @polling_interval = polling_interval - - @collector = collector - - @front_object = DRbObject.new_with_uri(config.expanded_detached_agent_socket_uri) - - @has_forked = false - schedule_tasks - end - - def send_collector_events - events_data = @collector.flush_events.map(&:as_json) - @front_object.send_collector_events(events_data) - end - - def calculate_rate_limits(request) - @front_object.calculate_rate_limits(request.route.as_json, request.client_ip, request.actor.as_json) - end - - # Every time a fork occurs (a new child process is created), we need to start - # a DRb service in a background thread within the child process. This service - # will manage the connection and handle resource cleanup. - def handle_fork - @has_forked = true - DRb.start_service - # we need to ensure that there are not more jobs in the queue, but - # we reuse the same object - @worker.restart - schedule_tasks - end - - private - - def schedule_tasks - @worker.every(@heartbeat_interval, run_now: false) { send_collector_events } - - # Runtime_settings fetch must happens only in the child processes, otherwise, due to - # we are updating the global runtime_settings, we could have an infinite recursion. - if @has_forked - @worker.every(@polling_interval) do - Aikido::Zen.runtime_settings = @front_object.updated_settings - @config.logger.debug "Updated runtime settings after polling from child process #{Process.pid}" - end - end - end - end -end diff --git a/lib/aikido/zen/detached_agent/front_object.rb b/lib/aikido/zen/detached_agent/front_object.rb deleted file mode 100644 index fb02b246..00000000 --- a/lib/aikido/zen/detached_agent/front_object.rb +++ /dev/null @@ -1,41 +0,0 @@ -# frozen_string_literal: true - -# dRB Front object that will work as a bridge communication between child & parent -# processes. -# Every method is called from the child but it runs in the parent process. -module Aikido::Zen::DetachedAgent - class FrontObject - def initialize( - config: Aikido::Zen.config, - runtime_settings: Aikido::Zen.runtime_settings, - collector: Aikido::Zen.collector, - rate_limiter: Aikido::Zen::RateLimiter.new - ) - @config = config - @runtime_settings = runtime_settings - @collector = collector - @rate_limiter = rate_limiter - end - - RequestKind = Struct.new(:route, :schema, :client_ip, :actor) - - def send_collector_events(events_data) - events_data.each do |event_data| - event = Aikido::Zen::Collector::Event.from_json(event_data) - @collector.add_event(event) - end - end - - # Method called by child processes to get an up-to-date version of the - # runtime_settings - def updated_settings - @runtime_settings - end - - def calculate_rate_limits(route_data, ip, actor_data) - actor = Aikido::Zen::Actor.from_json(actor_data) if actor_data - route = Aikido::Zen::Route.from_json(route_data) - @rate_limiter.calculate_rate_limits(RequestKind.new(route, nil, ip, actor)) - end - end -end diff --git a/lib/aikido/zen/detached_agent/server.rb b/lib/aikido/zen/detached_agent/server.rb deleted file mode 100644 index c2b31ec9..00000000 --- a/lib/aikido/zen/detached_agent/server.rb +++ /dev/null @@ -1,78 +0,0 @@ -# frozen_string_literal: true - -require "fileutils" - -module Aikido::Zen::DetachedAgent - class Server - # Initialize and start a detached agent server instance. - # - # @return [Aikido::Zen::DetachedAgent::Server] - def self.start(**opts) - new(**opts).tap(&:start!) - end - - def initialize(config: Aikido::Zen.config) - @started_at = nil - - @config = config - - @socket_path = config.expanded_detached_agent_socket_path - @socket_uri = config.expanded_detached_agent_socket_uri - end - - def started? - !!@started_at - end - - def start! - @config.logger.info("Starting DRb Server...") - - # Try to ensure that the DRb service can start if the DRb service did - # not stop cleanly. - begin - # Check whether the Unix domain socket is in use by another process. - UNIXSocket.new(@socket_path).close - rescue Errno::ECONNREFUSED - @config.logger.debug("Removing residual Unix domain socket...") - - # Remove the residual Unix domain socket. - FileUtils.rm_f(@socket_path) - rescue - # empty - end - - @front = FrontObject.new - - # If the Unix domain socket is in use by another process and/or the - # residual Unix domain socket could not be removed DRb will raise an - # appropriate error. - @drb_server = DRb.start_service(@socket_uri, @front) - - # Only show DRb output in debug mode. - @drb_server.verbose = @config.logger.debug? - - # Ensure that the DRb server is alive. - max_attempts = 10 - attempts = 0 - until @drb_server.alive? - @config.logger.info("DRb Server still not alive. #{max_attempts - attempts} attempts remaining") - sleep 0.1 - attempts += 1 - raise Aikido::Zen::DetachedAgentError.new("Impossible to start the dRB server (socket=#{Aikido::Zen.config.detached_agent_socket_path})") \ - if attempts == max_attempts - end - - @started_at = Time.now.utc - - at_exit { stop! if started? } - end - - def stop! - @config.logger.info("Stopping DRb Server...") - @started_at = nil - - @drb_server.stop_service if @drb_server.alive? - DRb.stop_service - end - end -end diff --git a/lib/aikido/zen/errors.rb b/lib/aikido/zen/errors.rb index 318532be..d98d2918 100644 --- a/lib/aikido/zen/errors.rb +++ b/lib/aikido/zen/errors.rb @@ -107,14 +107,6 @@ def initialize(attempt, problem, libname) end end - class DetachedAgentError < ZenError - extend Forwardable - - def initialize(msg) - super - end - end - class OutboundConnectionBlockedError < StandardError def initialize(connection) super("Zen blocked an outbound connection to #{connection.host}.") diff --git a/lib/aikido/zen/middleware/rack_throttler.rb b/lib/aikido/zen/middleware/rack_throttler.rb index 24552537..e5792382 100644 --- a/lib/aikido/zen/middleware/rack_throttler.rb +++ b/lib/aikido/zen/middleware/rack_throttler.rb @@ -12,14 +12,12 @@ def initialize( app, zen: Aikido::Zen, config: Aikido::Zen.config, - settings: Aikido::Zen.runtime_settings, - detached_agent: Aikido::Zen.detached_agent + settings: Aikido::Zen.runtime_settings ) @zen = zen @app = app @config = config @settings = settings - @detached_agent = detached_agent end def call(env) @@ -42,7 +40,7 @@ def should_throttle?(request) return false unless @settings.endpoints[request.route].rate_limiting.enabled? - result = @detached_agent.calculate_rate_limits(request) + result = @zen.calculate_rate_limits(request) return false unless result diff --git a/lib/aikido/zen/rate_limiter/result.rb b/lib/aikido/zen/rate_limiter/result.rb index a19842ee..ae0af05b 100644 --- a/lib/aikido/zen/rate_limiter/result.rb +++ b/lib/aikido/zen/rate_limiter/result.rb @@ -27,5 +27,25 @@ def initialize(throttled:, discriminator:, current_requests:, max_requests:, tim def throttled? @throttled end + + def as_json + { + "throttled" => @throttled, + "discriminator" => @discriminator, + "current_requests" => @current_requests, + "max_requests" => @max_requests, + "time_remaining" => @time_remaining + } + end + + def self.from_json(data) + new( + throttled: data["throttled"], + discriminator: data["discriminator"], + current_requests: data["current_requests"], + max_requests: data["max_requests"], + time_remaining: data["time_remaining"] + ) + end end end diff --git a/lib/aikido/zen/sinks/action_controller.rb b/lib/aikido/zen/sinks/action_controller.rb index b5087c91..ea77a767 100644 --- a/lib/aikido/zen/sinks/action_controller.rb +++ b/lib/aikido/zen/sinks/action_controller.rb @@ -12,13 +12,11 @@ class BlockRequestChecker def initialize( zen: Aikido::Zen, config: Aikido::Zen.config, - settings: Aikido::Zen.runtime_settings, - detached_agent: Aikido::Zen.detached_agent + settings: Aikido::Zen.runtime_settings ) @zen = zen @config = config @settings = settings - @detached_agent = detached_agent end def block?(controller) @@ -57,7 +55,7 @@ def block?(controller) return false unless @settings.endpoints[request.route].rate_limiting.enabled? - result = @detached_agent.calculate_rate_limits(request) + result = @zen.calculate_rate_limits(request) return false unless result request.env["aikido.rate_limiting"] = result diff --git a/lib/aikido/zen/worker_process.rb b/lib/aikido/zen/worker_process.rb new file mode 100644 index 00000000..29cb9716 --- /dev/null +++ b/lib/aikido/zen/worker_process.rb @@ -0,0 +1,3 @@ +# frozen_string_literal: true + +require_relative "worker_process/agent" diff --git a/lib/aikido/zen/worker_process/agent.rb b/lib/aikido/zen/worker_process/agent.rb new file mode 100644 index 00000000..d5b7bae3 --- /dev/null +++ b/lib/aikido/zen/worker_process/agent.rb @@ -0,0 +1,4 @@ +# frozen_string_literal: true + +require_relative "agent/client" +require_relative "agent/server" diff --git a/lib/aikido/zen/worker_process/agent/client.rb b/lib/aikido/zen/worker_process/agent/client.rb new file mode 100644 index 00000000..a06bfefb --- /dev/null +++ b/lib/aikido/zen/worker_process/agent/client.rb @@ -0,0 +1,97 @@ +# frozen_string_literal: true + +require_relative "../../background_worker" + +module Aikido::Zen::WorkerProcess + module Agent + class Client + attr_reader :worker + + def initialize( + host, + port, + secret: Aikido::Zen.secret, + config: Aikido::Zen.config, + worker: Aikido::Zen::Worker.new(config: config), + heartbeat_interval: 10, + polling_interval: 10, + collector: Aikido::Zen.collector + ) + @config = config + @worker = worker + @heartbeat_interval = heartbeat_interval + @polling_interval = polling_interval + @collector = collector + + @rpc_client = Aikido::Zen::RPC::Client.new(secret, host, port) + end + + def start + @rpc_client.start + + begin + update_settings(updated_settings) + rescue => err + @config.logger.error("Forked worker process #{Process.pid}: failed to get initial settings from parent: #{err.message}") + end + + schedule_tasks + end + + def stop + @worker.shutdown + @rpc_client.stop + end + + def send_collector_events + events_data = @collector.flush_events.map(&:as_json) + @rpc_client.invoke("send_collector_events", events_data) + rescue => err + @config.logger.error("Forked worker process #{Process.pid}: failed to send collector events to parent: #{err.message}") + end + + def calculate_rate_limits(request) + result = @rpc_client.invoke( + "calculate_rate_limits", + request.route.as_json, request.client_ip, request.actor&.as_json, + timeout: Aikido::Zen::IPC::READ_TIMEOUT + ) + + Aikido::Zen::RateLimiter::Result.from_json(result) if result + rescue => err + @config.logger.error("Forked worker process #{Process.pid}: failed to get rate limits from parent: #{err.message}") + nil + end + + private + + def updated_settings + @rpc_client.invoke("updated_settings", timeout: Aikido::Zen::IPC::READ_TIMEOUT) + end + + def update_settings(settings) + return unless settings + + if settings["config"] + Aikido::Zen.runtime_settings.update_from_runtime_config_json(settings["config"]) + end + + if settings["firewall_lists"] + Aikido::Zen.runtime_settings.update_from_runtime_firewall_lists_json(settings["firewall_lists"]) + end + end + + def schedule_tasks + @worker.every(@heartbeat_interval, run_now: false) { send_collector_events } + + @worker.every(@polling_interval, run_now: false) do + update_settings(updated_settings) + + @config.logger.debug("Forked worker process #{Process.pid}: updated runtime settings from parent") + rescue => err + @config.logger.error("Forked worker process #{Process.pid}: failed to get settings from parent: #{err.message}") + end + end + end + end +end diff --git a/lib/aikido/zen/worker_process/agent/server.rb b/lib/aikido/zen/worker_process/agent/server.rb new file mode 100644 index 00000000..6813c0ce --- /dev/null +++ b/lib/aikido/zen/worker_process/agent/server.rb @@ -0,0 +1,86 @@ +# frozen_string_literal: true + +module Aikido::Zen::WorkerProcess + module Agent + class Server + def self.start(**opts) + new(**opts).tap(&:start!) + end + + def initialize(config: Aikido::Zen.config) + @config = config + + @rpc_server = Aikido::Zen::RPC::Server.new(Aikido::Zen.secret) + + @rpc_server.handle("send_collector_events") do |respond, events_data| + respond.call(nil) + send_collector_events(events_data) + end + + @rpc_server.handle("updated_settings") do |respond| + respond.call(updated_settings) + end + + @rpc_server.handle("calculate_rate_limits") do |respond, route_data, ip, actor_data| + result = calculate_rate_limits(route_data, ip, actor_data) + respond.call(result&.as_json) + end + end + + def host + @rpc_server.host + end + + def port + @rpc_server.port + end + + def started? + !!@started_at + end + + def start! + @config.logger.info("Starting RPC Server...") + + @rpc_server.start + + @started_at = Time.now.utc + + at_exit { stop! if started? } + end + + def stop! + @config.logger.info("Stopping RPC Server...") + + @rpc_server.stop + + @started_at = nil + end + + # @api private + # + # Visible for testing. + RequestKind = Struct.new(:route, :schema, :client_ip, :actor) + + def send_collector_events(events_data) + events_data.each do |event_data| + event = Aikido::Zen::Collector::Event.from_json(event_data) + Aikido::Zen.collector.add_event(event) + end + end + + def updated_settings + { + "config" => Aikido::Zen.api_cache.runtime_config, + "firewall_lists" => Aikido::Zen.api_cache.runtime_firewall_lists + } + end + + def calculate_rate_limits(route_data, ip, actor_data) + actor = Aikido::Zen::Actor.from_json(actor_data) if actor_data + route = Aikido::Zen::Route.from_json(route_data) + Aikido::Zen.rate_limiter.calculate_rate_limits(RequestKind.new(route, nil, ip, actor)) + end + end + end +end diff --git a/test/aikido/zen/config_test.rb b/test/aikido/zen/config_test.rb index fc78c83c..b2d3abab 100644 --- a/test/aikido/zen/config_test.rb +++ b/test/aikido/zen/config_test.rb @@ -23,7 +23,6 @@ class Aikido::Zen::ConfigTest < ActiveSupport::TestCase assert_equal Aikido::Zen::Config::DEFAULT_JSON_ENCODER, @config.json_encoder assert_equal Aikido::Zen::Config::DEFAULT_JSON_DECODER, @config.json_decoder assert_kind_of ::Logger, @config.logger - assert_equal Aikido::Zen::Config::DEFAULT_DETACHED_AGENT_SOCKET_PATH, @config.detached_agent_socket_path assert_nil @config.client_ip_header assert_equal 5000, @config.max_performance_samples assert_equal 100, @config.max_compressed_stats @@ -248,16 +247,4 @@ def with_env(data = {}) ensure ENV.replace(env) end - - test "#expanded_detached_agent_socket_path includes the API token hash when set" do - @config.api_token = "S3CR3T" - - assert_equal "aikido-detached-agent.f3974fa.sock", @config.expanded_detached_agent_socket_path - end - - test "#expanded_detached_agent_socket_uri includes the API token hash when set" do - @config.api_token = "S3CR3T" - - assert_equal "drbunix:aikido-detached-agent.f3974fa.sock", @config.expanded_detached_agent_socket_uri - end end diff --git a/test/aikido/zen/detached_agent/agent_test.rb b/test/aikido/zen/detached_agent/agent_test.rb deleted file mode 100644 index bd3c4342..00000000 --- a/test/aikido/zen/detached_agent/agent_test.rb +++ /dev/null @@ -1,90 +0,0 @@ -# frozen_string_literal: true - -require "test_helper" - -class Aikido::Zen::DetachedAgent::AgentTest < ActiveSupport::TestCase - include WorkerHelpers - - def with_mocks(front_object, on_drb_start) - DRbObject.stub :new_with_uri, front_object do - DRb.stub :start_service, on_drb_start do - config = Aikido::Zen.config - collector = Minitest::Mock.new - worker = MockWorker.new - interval = 10 - - detached_agent_agent = Aikido::Zen::DetachedAgent::Agent.new( - heartbeat_interval: interval, - config: config, - worker: worker, - collector: collector - ) - - yield ({ - agent: detached_agent_agent, - interval: interval, - config: config, - worker: worker, - collector: collector - }) - end - end - end - - test "child to parent events are scheduled" do - drb_start_called = false - on_drb_start = -> { drb_start_called = true } - - with_mocks(Minitest::Mock.new, on_drb_start) do |mocks| - assert_equal 1, mocks[:worker].jobs.size - timer = mocks[:worker].jobs.first - assert_equal mocks[:interval], timer.execution_interval - end - - refute drb_start_called - end - - test "collector events are send to the front object" do - drb_start_called = false - on_drb_start = -> { drb_start_called = true } - front_object = Minitest::Mock.new - - with_mocks(front_object, on_drb_start) do |mocks| - events = Array.new(3) { Aikido::Zen::Collector::Events::TrackRequest.new } - - mocks[:collector].expect(:flush_events, events) - - front_object.expect(:send_collector_events, nil, [events.map(&:as_json)]) - - mocks[:agent].send_collector_events - - assert_mock mocks[:collector] - end - - assert_mock front_object - refute drb_start_called - end - - test "forks are properly handled" do - front_object = Minitest::Mock.new - - drb_start_called = false - on_drb_start = -> { drb_start_called = true } - - with_mocks(front_object, on_drb_start) do |mocks| - front_object.expect :updated_settings, {new: :settings} - - mocks[:agent].handle_fork - - assert_equal({new: :settings}, Aikido::Zen.runtime_settings) - - assert mocks[:worker].restarted - assert_same mocks[:worker], mocks[:agent].worker - assert_equal 2, mocks[:worker].jobs.size - - assert drb_start_called - - assert_mock front_object - end - end -end diff --git a/test/aikido/zen/detached_agent/detached_agent_server_test.rb b/test/aikido/zen/detached_agent/detached_agent_server_test.rb deleted file mode 100644 index ac0d095a..00000000 --- a/test/aikido/zen/detached_agent/detached_agent_server_test.rb +++ /dev/null @@ -1,53 +0,0 @@ -# frozen_string_literal: true - -require "test_helper" - -class Aikido::Zen::DetachedAgent::ServerTest < ActiveSupport::TestCase - FakeDrbServer = Struct.new(:stopped, :attempts) do - def initialize(max_attempts) - super - self[:stopped] = false - self[:attempts] = 0 - @max_attempts = max_attempts - end - - def alive? - self[:attempts] += 1 - self[:attempts] >= @max_attempts - end - - def stop_service - self[:stopped] = true - end - - def verbose=(v) - end - end - - test "server starts after a certain number of retries" do - @fake_drb_server = FakeDrbServer.new(4) - - DRb.stub(:start_service, @fake_drb_server) do - Aikido::Zen::DetachedAgent::Server.start - end - assert_equal 4, @fake_drb_server.attempts - end - - test "An exception is raised in case we exhaust the max number of attempts while starting the server" do - @fake_drb_server = FakeDrbServer.new(11) - DRb.stub(:start_service, @fake_drb_server) do - assert_raises Aikido::Zen::DetachedAgentError do - Aikido::Zen::DetachedAgent::Server.start - end - end - end - - test "Server is stopped " do - @fake_drb_server = FakeDrbServer.new(2) - DRb.stub(:start_service, @fake_drb_server) do - server = Aikido::Zen::DetachedAgent::Server.start - server.stop! - end - assert @fake_drb_server.stopped - end -end diff --git a/test/aikido/zen/detached_agent/front_object_test.rb b/test/aikido/zen/detached_agent/front_object_test.rb deleted file mode 100644 index 0d666fbd..00000000 --- a/test/aikido/zen/detached_agent/front_object_test.rb +++ /dev/null @@ -1,22 +0,0 @@ -# frozen_string_literal: true - -require "test_helper" - -class Aikido::Zen::DetachedAgent::FrontObjectTest < ActiveSupport::TestCase - setup do - @config = Aikido::Zen::Config.new - @collector = Aikido::Zen::Collector.new(config: @config) - @front_object = Aikido::Zen::DetachedAgent::FrontObject.new(config: @config, collector: @collector) - end - - test "it pushes collector events into collector" do - input_events = Array.new(3) { Aikido::Zen::Collector::Events::TrackRequest.new } - input_events_data = input_events.map(&:as_json) - - @front_object.send_collector_events(input_events_data) - - output_events = @collector.flush_events - - assert_equal output_events.map(&:as_json), input_events_data - end -end diff --git a/test/aikido/zen/rate_limiter/result_test.rb b/test/aikido/zen/rate_limiter/result_test.rb new file mode 100644 index 00000000..c9d85ddc --- /dev/null +++ b/test/aikido/zen/rate_limiter/result_test.rb @@ -0,0 +1,53 @@ +# frozen_string_literal: true + +require "test_helper" + +class Aikido::Zen::RateLimiter::ResultTest < ActiveSupport::TestCase + setup do + @result = Aikido::Zen::RateLimiter::Result.new( + throttled: true, + discriminator: "1.2.3.4", + current_requests: 5, + max_requests: 10, + time_remaining: 30 + ) + end + + test "#as_json serializes all fields" do + assert_equal({ + "throttled" => true, + "discriminator" => "1.2.3.4", + "current_requests" => 5, + "max_requests" => 10, + "time_remaining" => 30 + }, @result.as_json) + end + + test ".from_json deserializes all fields" do + data = { + "throttled" => true, + "discriminator" => "1.2.3.4", + "current_requests" => 5, + "max_requests" => 10, + "time_remaining" => 30 + } + + result = Aikido::Zen::RateLimiter::Result.from_json(data) + + assert result.throttled? + assert_equal "1.2.3.4", result.discriminator + assert_equal 5, result.current_requests + assert_equal 10, result.max_requests + assert_equal 30, result.time_remaining + end + + test "#as_json and #from_json round-trip" do + result = Aikido::Zen::RateLimiter::Result.from_json(@result.as_json) + + assert_equal @result.throttled?, result.throttled? + assert_equal @result.discriminator, result.discriminator + assert_equal @result.current_requests, result.current_requests + assert_equal @result.max_requests, result.max_requests + assert_equal @result.time_remaining, result.time_remaining + end +end diff --git a/test/aikido/zen/worker_process/agent/client_test.rb b/test/aikido/zen/worker_process/agent/client_test.rb new file mode 100644 index 00000000..3c060b20 --- /dev/null +++ b/test/aikido/zen/worker_process/agent/client_test.rb @@ -0,0 +1,257 @@ +# frozen_string_literal: true + +require "test_helper" + +class Aikido::Zen::WorkerProcess::Agent::ClientTest < ActiveSupport::TestCase + include WorkerHelpers + + class MockRPCClient + attr_reader :started, :stopped, :invoke_results + + def initialize + @started = false + @stopped = false + @invoke_results = {} + end + + def start + @started = true + end + + def stop + @stopped = true + end + + def invoke(name, *args, timeout: nil) + @invoke_results[name] + end + end + + def build_agent(invoke_results = {}) + client = MockRPCClient.new + client.invoke_results.merge!(invoke_results) + + Aikido::Zen::RPC::Client.stub(:new, client) do + collector = Minitest::Mock.new + worker = MockWorker.new + + agent = Aikido::Zen::WorkerProcess::Agent::Client.new( + "127.0.0.1", + 12345, + config: Aikido::Zen.config, + worker: worker, + collector: collector, + heartbeat_interval: 10, + polling_interval: 10 + ) + agent.start + + yield agent, worker, collector, client + end + end + + test "no tasks are scheduled before #start" do + client = MockRPCClient.new + Aikido::Zen::RPC::Client.stub(:new, client) do + worker = MockWorker.new + Aikido::Zen::WorkerProcess::Agent::Client.new("127.0.0.1", 12345, worker: worker) + assert_empty worker.jobs + end + end + + test "#start connects the RPC client and schedules two tasks" do + build_agent("updated_settings" => {}) do |agent, worker, collector, client| + assert client.started + assert_equal 2, worker.jobs.size + end + end + + test "#start handles nil settings from parent gracefully" do + build_agent("updated_settings" => nil) do |agent| + pass + end + end + + test "#start applies initial settings from parent" do + config_data = {"configUpdatedAt" => 0, "heartbeatIntervalInMS" => 60_000, + "endpoints" => [], "blockedUserIds" => [], "allowedIPAddresses" => [], + "receivedAnyStats" => false, "block" => false, + "blockNewOutgoingRequests" => false, "domains" => {}, + "excludedUserIdsFromRateLimiting" => []} + firewall_data = {"blockedUserAgents" => nil, "monitoredUserAgents" => nil, + "userAgentDetails" => [], "blockedIPAddresses" => [], + "allowedIPAddresses" => [], "monitoredIPAddresses" => []} + + build_agent("updated_settings" => {"config" => config_data, "firewall_lists" => firewall_data}) do + assert_equal 60, Aikido::Zen.runtime_settings.heartbeat_interval + end + end + + test "#send_collector_events flushes the collector and sends events to the parent" do + events = Array.new(3) { Aikido::Zen::Collector::Events::TrackRequest.new } + + build_agent("updated_settings" => {}) do |agent, worker, collector, client| + collector.expect(:flush_events, events) + agent.send_collector_events + + assert_mock collector + end + end + + MockRequest = Struct.new(:route, :client_ip, :actor) + + test "#calculate_rate_limits returns nil when the parent returns no result" do + build_agent("updated_settings" => {}, "calculate_rate_limits" => nil) do |agent| + request = MockRequest.new( + Aikido::Zen::Route.new(verb: "GET", path: "/test"), + "1.2.3.4", + nil + ) + + assert_nil agent.calculate_rate_limits(request) + end + end + + test "#calculate_rate_limits returns nil and logs when the RPC call raises" do + build_agent("updated_settings" => {}) do |agent, worker, collector, client| + request = MockRequest.new( + Aikido::Zen::Route.new(verb: "GET", path: "/test"), + "1.2.3.4", + nil + ) + + client.stub(:invoke, ->(*) { raise "RPC error" }) do + assert_nil agent.calculate_rate_limits(request) + assert_logged :error, /failed to get rate limits from parent/i + end + end + end + + test "#stop shuts down the worker and RPC client" do + build_agent("updated_settings" => {}) do |agent, worker, collector, client| + agent.stop + + assert client.stopped + assert worker.jobs.none?(&:running?) + end + end + + test "#calculate_rate_limits parses the result returned by the parent" do + result_data = { + "throttled" => false, + "discriminator" => "1.2.3.4", + "current_requests" => 1, + "max_requests" => 100, + "time_remaining" => 60 + } + + build_agent("updated_settings" => {}, "calculate_rate_limits" => result_data) do |agent| + request = MockRequest.new( + Aikido::Zen::Route.new(verb: "GET", path: "/test"), + "1.2.3.4", + nil + ) + + result = agent.calculate_rate_limits(request) + + assert_instance_of Aikido::Zen::RateLimiter::Result, result + refute result.throttled? + assert_equal "1.2.3.4", result.discriminator + end + end +end + +class Aikido::Zen::WorkerProcess::Agent::ClientIntegrationTest < ActiveSupport::TestCase + include WorkerHelpers + + setup do + @server = Aikido::Zen::WorkerProcess::Agent::Server.new + @server.start! + end + + teardown do + @server.stop! if @server.started? + end + + def in_forked_worker + reader, writer = IO.pipe + + pid = fork do + reader.close + begin + yield + writer.write("ok") + rescue => err + writer.write("#{err.class}: #{err.message}") + ensure + writer.close + exit! + end + end + + writer.close + result = reader.read + reader.close + Process.waitpid(pid) + + flunk(result) unless result == "ok" + end + + def build_client + Aikido::Zen::WorkerProcess::Agent::Client.new( + @server.host, @server.port, + worker: MockWorker.new, + collector: Aikido::Zen.collector + ) + end + + MockRequest = Struct.new(:route, :client_ip, :actor) + + test "client receives initial runtime settings from the server on startup" do + Aikido::Zen.api_cache.runtime_config = { + "configUpdatedAt" => 0, "heartbeatIntervalInMS" => 90_000, + "endpoints" => [], "blockedUserIds" => [], "allowedIPAddresses" => [], + "receivedAnyStats" => false, "block" => false, + "blockNewOutgoingRequests" => false, "domains" => {}, + "excludedUserIdsFromRateLimiting" => [] + } + + in_forked_worker do + client = build_client + client.start + + interval = Aikido::Zen.runtime_settings.heartbeat_interval + raise "expected heartbeat_interval=90, got #{interval}" unless interval == 90 + end + end + + test "client flushes collector events to the server" do + in_forked_worker do + 3.times { Aikido::Zen.collector.track_request } + client = build_client + client.start + client.send_collector_events + end + + captured = [] + wait_until(timeout: 2) do + captured.concat(Aikido::Zen.collector.flush_events) + captured.size >= 3 + end + assert_equal 3, captured.size, "Log: #{@log_output.string}" + end + + test "client delegates rate-limit calculation to the server" do + in_forked_worker do + client = build_client + client.start + + result = client.calculate_rate_limits(MockRequest.new( + Aikido::Zen::Route.new(verb: "GET", path: "/test"), + "1.2.3.4", + nil + )) + raise "expected nil with no rate limit rules, got #{result.inspect}" unless result.nil? + end + end +end diff --git a/test/aikido/zen/worker_process/agent/server_test.rb b/test/aikido/zen/worker_process/agent/server_test.rb new file mode 100644 index 00000000..91079f80 --- /dev/null +++ b/test/aikido/zen/worker_process/agent/server_test.rb @@ -0,0 +1,78 @@ +# frozen_string_literal: true + +require "test_helper" + +class Aikido::Zen::WorkerProcess::Agent::ServerTest < ActiveSupport::TestCase + setup do + @server = Aikido::Zen::WorkerProcess::Agent::Server.new + end + + teardown do + @server.stop! if @server.started? + end + + test "#started? returns false before #start! is called" do + refute @server.started? + end + + test "#start! marks the server as started and exposes host and port" do + @server.start! + + assert @server.started? + assert_equal "127.0.0.1", @server.host + end + + test "#stop! marks the server as stopped" do + @server.start! + @server.stop! + + refute @server.started? + end + + test "#send_collector_events pushes events into the collector" do + input_events = Array.new(3) { Aikido::Zen::Collector::Events::TrackRequest.new } + input_events_data = input_events.map(&:as_json) + + @server.send_collector_events(input_events_data) + + output_events = Aikido::Zen.collector.flush_events + assert_equal input_events_data, output_events.map(&:as_json) + end + + test "#updated_settings returns nil config and firewall_lists when cache is empty" do + Aikido::Zen.api_cache.runtime_config = nil + Aikido::Zen.api_cache.runtime_firewall_lists = nil + + settings = @server.updated_settings + + assert_nil settings["config"] + assert_nil settings["firewall_lists"] + end + + test "#updated_settings returns cached config and firewall_lists" do + config = {"configUpdatedAt" => 1234567890} + firewall_lists = {"blockedIPAddresses" => []} + + Aikido::Zen.api_cache.runtime_config = config + Aikido::Zen.api_cache.runtime_firewall_lists = firewall_lists + + settings = @server.updated_settings + + assert_equal config, settings["config"] + assert_equal firewall_lists, settings["firewall_lists"] + end + + test "#calculate_rate_limits works without an actor" do + route_data = {"method" => "GET", "path" => "/test"} + + assert_nil @server.calculate_rate_limits(route_data, "1.2.3.4", nil) + end + + test "#calculate_rate_limits works with an actor" do + route_data = {"method" => "GET", "path" => "/test"} + now_ms = Time.now.to_i * 1000 + actor_data = {"id" => "user1", "name" => "Test User", "firstSeenAt" => now_ms, "lastSeenAt" => now_ms} + + assert_nil @server.calculate_rate_limits(route_data, "1.2.3.4", actor_data) + end +end diff --git a/test/aikido/zen_test.rb b/test/aikido/zen_test.rb index e9c89d22..c79cd41d 100644 --- a/test/aikido/zen_test.rb +++ b/test/aikido/zen_test.rb @@ -129,4 +129,28 @@ class Aikido::ZenTest < ActiveSupport::TestCase assert_equal "block required", err.message end + + test ".calculate_rate_limits delegates to the rate limiter when there is no detached agent" do + mock = Minitest::Mock.new + mock.expect(:calculate_rate_limits, nil, [Object]) + + Aikido::Zen.stub(:rate_limiter, mock) do + Aikido::Zen.calculate_rate_limits(Object.new) + end + + assert_mock mock + end + + test ".calculate_rate_limits delegates to the detached agent when one is set" do + mock = Minitest::Mock.new + mock.expect(:calculate_rate_limits, nil, [Object]) + + Aikido::Zen.instance_variable_set(:@worker_process_client, mock) + + Aikido::Zen.calculate_rate_limits(Object.new) + + assert_mock mock + ensure + Aikido::Zen.instance_variable_set(:@worker_process_client, nil) + end end diff --git a/test/test_helper.rb b/test/test_helper.rb index f2937ddc..1c3144eb 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -19,26 +19,6 @@ require "debug" if Gem::Version.new(RUBY_VERSION) >= Gem::Version.new("3.0") require "support/capture_stream" -class FakeDetachedAgent - extend Forwardable - - def_delegators :@collector, :track_request, :track_route, :track_outbound, :track_scan, :track_user, :track_attack - def_delegator :@rate_limiter, :calculate_rate_limits - - def initialize(collector, rate_limiter) - @collector = collector - @rate_limiter = rate_limiter - end - - def handle_fork - end -end - -Aikido::Zen.instance_variable_set( - :@detached_agent, - FakeDetachedAgent.new(Aikido::Zen::Collector.new, Aikido::Zen::RateLimiter.new) -) - # Silence warnings that result from loading HTTPClient. ActiveSupport::Testing::Stream.quietly { require "webmock" } # For the HTTP adapters shipped with WebMock by default, requiring webmock first @@ -76,10 +56,14 @@ class ActiveSupport::TestCase collector = Aikido::Zen::Collector.new Aikido::Zen.instance_variable_set(:@collector, collector) - Aikido::Zen.detached_agent.instance_variable_set(:@collector, collector) Aikido::Zen.instance_variable_set(:@runtime_settings, nil) - Aikido::Zen.detached_agent.instance_variable_set(:@rate_limiter, Aikido::Zen::RateLimiter.new) + Aikido::Zen.instance_variable_set(:@rate_limiter, Aikido::Zen::RateLimiter.new) + Aikido::Zen.instance_variable_set(:@worker_process_client, nil) + Aikido::Zen.instance_variable_set(:@worker_process_server, nil) + Aikido::Zen.instance_variable_set(:@pid, Process.pid) + Aikido::Zen.instance_variable_set(:@has_started, Concurrent::AtomicBoolean.new(false)) + Aikido::Zen.instance_variable_set(:@has_handled_fork, Concurrent::AtomicBoolean.new(false)) Aikido::Zen.instance_variable_set(:@attack_wave_detector, nil) Aikido::Zen.instance_variable_set(:@idor_protector, nil) From 5ebef1b50d19739cb53b8bebcc2b9e27518e7f2e Mon Sep 17 00:00:00 2001 From: Mark Smith Date: Thu, 25 Jun 2026 13:33:54 +0200 Subject: [PATCH 5/7] WIP Fix synchronization in start and fork --- lib/aikido/zen.rb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/aikido/zen.rb b/lib/aikido/zen.rb index 69cba9f9..4c60c40d 100644 --- a/lib/aikido/zen.rb +++ b/lib/aikido/zen.rb @@ -349,7 +349,6 @@ def self.worker_process_server @worker_process_server ||= WorkerProcess::Agent::Server.start end - @pid = Process.pid @has_started = Concurrent::AtomicBoolean.new(false) @has_handled_fork = Concurrent::AtomicBoolean.new(false) @@ -361,6 +360,8 @@ def start! return unless @has_started.make_true + @pid = Process.pid + worker_process_server agent end @@ -372,7 +373,7 @@ def start? end def check_and_handle_fork - return unless @has_started.true? && @has_handled_fork.make_true + return unless @has_started.true? && @has_handled_fork.false? handle_fork if forked? end @@ -385,6 +386,8 @@ def forked? end def handle_fork + return unless @has_handled_fork.make_true + server = @worker_process_server client = WorkerProcess::Agent::Client.new(server.host, server.port) client.start From 31c552dbe563bf65939dc3149ba4c4c041c552de Mon Sep 17 00:00:00 2001 From: Mark Smith Date: Thu, 25 Jun 2026 13:38:30 +0200 Subject: [PATCH 6/7] WIP Purify predicate #forked? --- lib/aikido/zen.rb | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/aikido/zen.rb b/lib/aikido/zen.rb index 4c60c40d..0350afb7 100644 --- a/lib/aikido/zen.rb +++ b/lib/aikido/zen.rb @@ -379,15 +379,14 @@ def check_and_handle_fork end def forked? - current_pid = Process.pid - return false if current_pid == @pid - @pid = current_pid - true + Process.pid != @pid end def handle_fork return unless @has_handled_fork.make_true + @pid = Process.pid + server = @worker_process_server client = WorkerProcess::Agent::Client.new(server.host, server.port) client.start From bba0c1094173fb4c377e02e0f30e49f5a92eca0e Mon Sep 17 00:00:00 2001 From: Mark Smith Date: Fri, 26 Jun 2026 16:30:41 +0200 Subject: [PATCH 7/7] WIP Support worker refork --- lib/aikido/zen.rb | 38 +++++++------------ lib/aikido/zen/middleware/fork_detector.rb | 10 ++++- lib/aikido/zen/worker_process/agent/client.rb | 4 ++ test/test_helper.rb | 1 - 4 files changed, 25 insertions(+), 28 deletions(-) diff --git a/lib/aikido/zen.rb b/lib/aikido/zen.rb index 0350afb7..8f5fcfff 100644 --- a/lib/aikido/zen.rb +++ b/lib/aikido/zen.rb @@ -335,35 +335,29 @@ def self.load_sinks! # Stop any background threads. def self.stop! @agent&.stop! - @worker_process_client&.stop @worker_process_server&.stop! + + @worker_process_client&.stop end - # @!visibility private - # Starts the background agent if it has not been started yet. def self.agent - @agent ||= Agent.start + @agent end def self.worker_process_server - @worker_process_server ||= WorkerProcess::Agent::Server.start + @worker_process_server end @has_started = Concurrent::AtomicBoolean.new(false) - @has_handled_fork = Concurrent::AtomicBoolean.new(false) class << self def start! - return if @has_started.true? - return unless start? return unless @has_started.make_true - @pid = Process.pid - - worker_process_server - agent + @worker_process_server = WorkerProcess::Agent::Server.start + @agent = Agent.start end def start? @@ -372,22 +366,16 @@ def start? config.debugging? end - def check_and_handle_fork - return unless @has_started.true? && @has_handled_fork.false? - - handle_fork if forked? - end - - def forked? - Process.pid != @pid - end + def fork! + server = @worker_process_server + return unless server - def handle_fork - return unless @has_handled_fork.make_true + # TODO: Factor; stop the server and old client then start the new client - @pid = Process.pid + client = @worker_process_client + @worker_process_client = nil + client&.close - server = @worker_process_server client = WorkerProcess::Agent::Client.new(server.host, server.port) client.start @worker_process_client = client diff --git a/lib/aikido/zen/middleware/fork_detector.rb b/lib/aikido/zen/middleware/fork_detector.rb index 416cb0a3..1f807f16 100644 --- a/lib/aikido/zen/middleware/fork_detector.rb +++ b/lib/aikido/zen/middleware/fork_detector.rb @@ -9,11 +9,17 @@ module Middleware class ForkDetector def initialize(app) @app = app + + @pid = Concurrent::AtomicFixnum.new(Process.pid) end def call(env) - # This is the single, reliable trigger point for the fork check. - Aikido::Zen.check_and_handle_fork + new_pid = Process.pid + old_pid = @pid.value + + if new_pid != old_pid && @pid.compare_and_set(old_pid, new_pid) + Aikido::Zen.fork! + end @app.call(env) end diff --git a/lib/aikido/zen/worker_process/agent/client.rb b/lib/aikido/zen/worker_process/agent/client.rb index a06bfefb..dcfd8e16 100644 --- a/lib/aikido/zen/worker_process/agent/client.rb +++ b/lib/aikido/zen/worker_process/agent/client.rb @@ -43,6 +43,10 @@ def stop @rpc_client.stop end + def close + @rpc_client.stop + end + def send_collector_events events_data = @collector.flush_events.map(&:as_json) @rpc_client.invoke("send_collector_events", events_data) diff --git a/test/test_helper.rb b/test/test_helper.rb index 1c3144eb..0d958fa9 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -63,7 +63,6 @@ class ActiveSupport::TestCase Aikido::Zen.instance_variable_set(:@worker_process_server, nil) Aikido::Zen.instance_variable_set(:@pid, Process.pid) Aikido::Zen.instance_variable_set(:@has_started, Concurrent::AtomicBoolean.new(false)) - Aikido::Zen.instance_variable_set(:@has_handled_fork, Concurrent::AtomicBoolean.new(false)) Aikido::Zen.instance_variable_set(:@attack_wave_detector, nil) Aikido::Zen.instance_variable_set(:@idor_protector, nil)