diff --git a/agent/build.gradle b/agent/build.gradle index d6656bf92..c2a06b7a0 100644 --- a/agent/build.gradle +++ b/agent/build.gradle @@ -12,6 +12,7 @@ dependencies { compileOnly 'io.projectreactor.netty:reactor-netty-http:1.2.1' // For Spring Webflux compileOnly 'io.javalin:javalin:6.4.0' compileOnly 'org.springframework:spring-web:5.3.20' + compileOnly 'org.springframework:spring-webflux:5.3.20' // For Spring WebClient } shadowJar { diff --git a/agent/src/main/java/dev/aikido/agent/Wrappers.java b/agent/src/main/java/dev/aikido/agent/Wrappers.java index a71c13b26..66cc7a94c 100644 --- a/agent/src/main/java/dev/aikido/agent/Wrappers.java +++ b/agent/src/main/java/dev/aikido/agent/Wrappers.java @@ -9,6 +9,8 @@ import dev.aikido.agent.wrappers.spring.SpringWebfluxWrapper; import dev.aikido.agent.wrappers.spring.SpringControllerWrapper; import dev.aikido.agent.wrappers.spring.SpringMVCJakartaWrapper; +import dev.aikido.agent.wrappers.spring.SpringWebClientWrapper; +import dev.aikido.agent.wrappers.spring.SpringWebClientRedirectWrapper; import java.util.Arrays; import java.util.List; @@ -30,11 +32,14 @@ private Wrappers() {} // SSRF/HTTP wrappers new HttpURLConnectionWrapper(), new InetAddressWrapper(), + new SocketChannelWrapper(), new HttpClientWrapper(), new HttpConnectionRedirectWrapper(), new HttpClientSendWrapper(), new OkHttpWrapper(), new ApacheHttpClientWrapper(), + new SpringWebClientWrapper(), + new SpringWebClientRedirectWrapper(), new PathWrapper(), new PathsWrapper(), diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/SocketChannelWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/SocketChannelWrapper.java new file mode 100644 index 000000000..d238e4698 --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/SocketChannelWrapper.java @@ -0,0 +1,87 @@ +package dev.aikido.agent.wrappers; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; +import net.bytebuddy.matcher.ElementMatchers; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.MalformedURLException; +import java.net.SocketAddress; +import java.net.URL; +import java.net.URLClassLoader; +import java.nio.channels.SocketChannel; + +public class SocketChannelWrapper implements Wrapper { + public String getName() { + // Wrap connect(SocketAddress) on SocketChannel. Clients that resolve hostnames with + // their own DNS resolver instead of InetAddress.getAllByName() (e.g. Reactor Netty's + // async resolver, used by default by Spring's WebClient) never trigger + // InetAddressWrapper, so this is the only point where we see the resolved address + // before the connection is made. Also catches literal IP targets, which never go + // through any resolver at all. + // https://docs.oracle.com/javase/8/docs/api/java/nio/channels/SocketChannel.html#connect-java.net.SocketAddress- + return SocketChannelAdvice.class.getName(); + } + public ElementMatcher getMatcher() { + return ElementMatchers.named("connect"); + } + @Override + public ElementMatcher getTypeMatcher() { + return ElementMatchers.isSubTypeOf(SocketChannel.class); + } + public static class SocketChannelAdvice { + // Since we have to wrap a native Java Class stuff gets more complicated + // The classpath is not the same anymore, and we can't import our modules directly. + // To bypass this issue we load collectors from a .jar file + @Advice.OnMethodEnter + public static void before( + @Advice.Argument(0) SocketAddress remoteAddress + ) throws Throwable { + if (!(remoteAddress instanceof InetSocketAddress)) { + return; + } + InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteAddress; + InetAddress resolvedAddress = inetSocketAddress.getAddress(); + if (resolvedAddress == null) { + // Unresolved: nothing to report yet, connect() will throw on its own. + return; + } + String hostname = inetSocketAddress.getHostString(); + + String jarFilePath = System.getProperty("AIK_agent_api_jar"); + URLClassLoader classLoader = null; + try { + URL[] urls = { new URL(jarFilePath) }; + classLoader = new URLClassLoader(urls); + } catch (MalformedURLException ignored) {} + if (classLoader == null) { + return; + } + + try { + // Load the class from the JAR + Class clazz = classLoader.loadClass("dev.aikido.agent_api.collectors.DNSRecordCollector"); + + // Run reportConnect with "argument" + for (Method method2: clazz.getMethods()) { + if(method2.getName().equals("reportConnect")) { + method2.invoke(null, hostname, resolvedAddress); + break; + } + } + classLoader.close(); // Close the class loader + } catch (InvocationTargetException invocationTargetException) { + if(invocationTargetException.getCause().toString().startsWith("dev.aikido.agent_api.vulnerabilities")) { + throw invocationTargetException.getCause(); + } + // Ignore non-aikido throwables. + } catch(Throwable e) { + System.out.println("AIKIDO: " + e.getMessage()); + } + } + } +} diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientRedirectWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientRedirectWrapper.java new file mode 100644 index 000000000..2e598d970 --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientRedirectWrapper.java @@ -0,0 +1,66 @@ +package dev.aikido.agent.wrappers.spring; + +import dev.aikido.agent.wrappers.Wrapper; +import dev.aikido.agent_api.collectors.RedirectCollector; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; +import net.bytebuddy.matcher.ElementMatchers; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.net.URL; + +public class SpringWebClientRedirectWrapper implements Wrapper { + // Package-private in Reactor Netty, referenced by name only. This is the internal method + // that runs once per redirect hop, for both WebClient and the Netty-backed RestClient - + // Spring's own request-adaptation layer (ExchangeFunction/ReactorClientHttpRequest) is + // only invoked once per top-level call and never sees redirect targets for bodiless (e.g. + // GET) requests, since Reactor Netty resends internally without going back through it. + // Mirrors HttpConnectionRedirectWrapper, which hooks the JDK's equally-internal + // followRedirect0 for the same reason. + private static final String HTTP_CLIENT_HANDLER_CLASS_NAME = + "reactor.netty.http.client.HttpClientConnect$HttpClientHandler"; + + public String getName() { + return RedirectAdvice.class.getName(); + } + public ElementMatcher getMatcher() { + return ElementMatchers.isDeclaredBy(getTypeMatcher()) + .and(ElementMatchers.named("redirect")) + .and(ElementMatchers.takesArguments(1)); + } + @Override + public ElementMatcher getTypeMatcher() { + return ElementMatchers.named(HTTP_CLIENT_HANDLER_CLASS_NAME); + } + public static class RedirectAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + public static void after(@Advice.This Object handler) throws Exception { + // fromURI/toURI are UriEndpoint (also package-private), both reassigned by + // redirect() before this advice runs: fromURI is the hostname that redirected, + // toURI is where it redirected to. + String origin = externalForm(handler, "fromURI"); + String dest = externalForm(handler, "toURI"); + if (origin == null || dest == null) { + return; + } + RedirectCollector.report(new URL(origin), new URL(dest)); + } + + // Must be public: after weaving, this is called as a real cross-class invocation from + // inside the target class's own bytecode, so a private method would raise IllegalAccessError. + public static String externalForm(Object handler, String fieldName) throws Exception { + Field field = handler.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + Object uriEndpoint = field.get(handler); + if (uriEndpoint == null) { + return null; + } + Method toExternalForm = uriEndpoint.getClass().getDeclaredMethod("toExternalForm"); + toExternalForm.setAccessible(true); + return (String) toExternalForm.invoke(uriEndpoint); + } + } +} diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientWrapper.java new file mode 100644 index 000000000..91645852e --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientWrapper.java @@ -0,0 +1,48 @@ +package dev.aikido.agent.wrappers.spring; + +import dev.aikido.agent.wrappers.Wrapper; +import dev.aikido.agent_api.collectors.URLCollector; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; +import net.bytebuddy.matcher.ElementMatchers; +import org.springframework.web.reactive.function.client.ClientRequest; + +import java.net.MalformedURLException; + +public class SpringWebClientWrapper implements Wrapper { + // Referenced by name (not by .class) in the matchers below: ExchangeFunction is only on + // the target application's classloader (spring-webflux is compileOnly here), not on the + // agent's own classloader, so a .class literal would throw NoClassDefFoundError at premain. + private static final String EXCHANGE_FUNCTION_CLASS_NAME = + "org.springframework.web.reactive.function.client.ExchangeFunction"; + + public String getName() { + // Wrap exchange(ClientRequest) on ExchangeFunction, the interface every WebClient + // request goes through before Reactor Netty resolves/connects. + // https://docs.spring.io/spring-framework/docs/5.3.20/javadoc-api/org/springframework/web/reactive/function/client/ExchangeFunction.html + return SpringWebClientAdvice.class.getName(); + } + public ElementMatcher getMatcher() { + return ElementMatchers.isDeclaredBy(getTypeMatcher()) + .and(ElementMatchers.named("exchange")); + } + @Override + public ElementMatcher getTypeMatcher() { + return ElementMatchers.hasSuperType(ElementMatchers.named(EXCHANGE_FUNCTION_CLASS_NAME)); + } + public static class SpringWebClientAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + public static void before( + @Advice.Argument(0) ClientRequest request + ) throws MalformedURLException { + if (request == null || request.url() == null) { + return; + } + // Report the URL before the request is sent, so DNSRecordCollector can match the + // DNS lookup that follows to this outgoing request. + URLCollector.report(request.url().toURL()); + } + } +} diff --git a/agent_api/build.gradle b/agent_api/build.gradle index e051f811d..3d09c3827 100644 --- a/agent_api/build.gradle +++ b/agent_api/build.gradle @@ -39,6 +39,9 @@ dependencies { testImplementation 'org.springframework:spring-web:5.3.20' testImplementation 'org.springframework:spring-webmvc:5.3.20' testImplementation 'org.springframework:spring-test:5.3.20' + // Spring WebFlux for WebClient + testImplementation 'org.springframework:spring-webflux:5.3.20' + testImplementation 'io.projectreactor.netty:reactor-netty-http:1.2.1' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.9.2' } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java index d33c165c9..dbd05c64d 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java @@ -12,6 +12,7 @@ import dev.aikido.agent_api.vulnerabilities.ssrf.SSRFException; import dev.aikido.agent_api.helpers.logging.LogManager; import dev.aikido.agent_api.helpers.logging.Logger; +import dev.aikido.agent_api.vulnerabilities.ssrf.IsPrivateIP; import dev.aikido.agent_api.vulnerabilities.ssrf.StoredSSRFDetector; import dev.aikido.agent_api.vulnerabilities.ssrf.StoredSSRFException; @@ -26,30 +27,46 @@ public final class DNSRecordCollector { private DNSRecordCollector() {} private static final Logger logger = LogManager.getLogger(DNSRecordCollector.class); - private static final String OPERATION_NAME = "java.net.InetAddress.getAllByName"; + private static final String INET_ADDRESS_OPERATION_NAME = "java.net.InetAddress.getAllByName"; + private static final String SOCKET_CHANNEL_OPERATION_NAME = "java.nio.channels.SocketChannel.connect"; + public static void report(String hostname, InetAddress[] inetAddresses) { + // InetAddress.getAllByName() resolves everything in one call, so it's safe to consume. + process(hostname, inetAddresses, PendingHostnamesStore.getAndRemove(hostname), INET_ADDRESS_OPERATION_NAME); + } + + // For clients that resolve their own DNS (e.g. Reactor Netty, used by Spring's WebClient) or + // connect straight to an IP literal. A single request can trigger multiple connect() calls to + // the same hostname (IPv4 then IPv6), so unlike report(), this peeks the pending port instead + // of consuming it - consuming on the first attempt would let a later attempt bypass SSRF. + public static void reportConnect(String hostname, InetAddress resolvedAddress) { + process(hostname, new InetAddress[]{resolvedAddress}, PendingHostnamesStore.getPorts(hostname), SOCKET_CHANNEL_OPERATION_NAME); + } + + private static void process(String hostname, InetAddress[] inetAddresses, Set ports, String operationName) { try { logger.trace("DNSRecordCollector called with %s & inet addresses: %s", hostname, List.of(inetAddresses)); // store stats - StatisticsStore.registerCall("java.net.InetAddress.getAllByName", OperationKind.OUTGOING_HTTP_OP); - - // Consume pending ports recorded by URLCollector for this hostname. - // Removing them here ensures each (hostname, port) pair is counted exactly once. - Set ports = PendingHostnamesStore.getAndRemove(hostname); - if (!ports.isEmpty()) { - for (int port : ports) { - HostnamesStore.incrementHits(hostname, port); + StatisticsStore.registerCall(operationName, OperationKind.OUTGOING_HTTP_OP); + + // No pending port + private IP literal = infrastructure noise (Netty resolver bootstrap + // etc), not a real request - skip recording/blocking. SSRF checks below still run regardless. + if (!ports.isEmpty() || !IsPrivateIP.isPrivateIp(hostname)) { + if (!ports.isEmpty()) { + for (int port : ports) { + HostnamesStore.incrementHits(hostname, port); + } + } else { + // We still need to report a hit to the hostname for outbound domain blocking + HostnamesStore.incrementHits(hostname, 0); } - } else { - // We still need to report a hit to the hostname for outbound domain blocking - HostnamesStore.incrementHits(hostname, 0); - } - // Block if the hostname is in the blocked domains list - if (ServiceConfigStore.shouldBlockOutgoingRequest(hostname)) { - logger.debug("Blocking DNS lookup for domain: %s", hostname); - throw BlockedOutboundException.get(); + // Block if the hostname is in the blocked domains list + if (ServiceConfigStore.shouldBlockOutgoingRequest(hostname)) { + logger.debug("Blocking DNS lookup for domain: %s", hostname); + throw BlockedOutboundException.get(); + } } // Convert inetAddresses array to a List of IP strings : @@ -62,7 +79,7 @@ public static void report(String hostname, InetAddress[] inetAddresses) { for (int port : ports) { logger.debug("Hostname: %s, Port: %s, IPs: %s", hostname, port, ipAddresses); - Attack attack = SSRFDetector.run(hostname, port, ipAddresses, OPERATION_NAME); + Attack attack = SSRFDetector.run(hostname, port, ipAddresses, operationName); if (attack == null) { continue; } @@ -81,7 +98,7 @@ public static void report(String hostname, InetAddress[] inetAddresses) { // We don't need the context object to check for stored ssrf, but we do want to run this after our other // SSRF checks, making sure if it's a normal ssrf attack it gets reported like that. - Attack storedSsrfAttack = new StoredSSRFDetector().run(hostname, ipAddresses, OPERATION_NAME); + Attack storedSsrfAttack = new StoredSSRFDetector().run(hostname, ipAddresses, operationName); if (storedSsrfAttack != null) { attackDetected(storedSsrfAttack, Context.get()); diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java index 2efd5ecf1..2644f5036 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java @@ -4,14 +4,30 @@ /** * Thread-local bridge between URLCollector and DNSRecordCollector. - * URLCollector records hostname+port here; DNSRecordCollector reads and removes the entry - * so each (hostname, port) pair is processed exactly once per DNS lookup. + * URLCollector records hostname+port here; DNSRecordCollector.report() (fed by + * InetAddress.getAllByName(), which resolves everything in one call) reads and removes the + * entry so each (hostname, port) pair is processed exactly once per DNS lookup. + * DNSRecordCollector.reportConnect() (fed by SocketChannel.connect(), which fires once per + * connect attempt) instead peeks the entry, since a single outbound request can trigger + * multiple connect attempts to the same hostname (e.g. IPv4 then IPv6 for a dual-stack host). + * + * Entries are normally cleared per incoming request by WebRequestCollector, but a peeked + * entry added outside any incoming-request context (e.g. a WebClient call from a @Scheduled + * task) would never be cleared that way. Capped at MAX_ENTRIES per thread, evicting the least + * recently used entry once exceeded, same bounded-LRU pattern as Hostnames. */ public final class PendingHostnamesStore { private PendingHostnamesStore() {} + private static final int MAX_ENTRIES = 1000; + private static final ThreadLocal>> store = - ThreadLocal.withInitial(LinkedHashMap::new); + ThreadLocal.withInitial(() -> new LinkedHashMap<>(16, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry> eldest) { + return size() > MAX_ENTRIES; + } + }); public static void add(String hostname, int port) { Map> map = store.get(); diff --git a/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java b/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java index c7cdd4b3b..0893e24a7 100644 --- a/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java +++ b/agent_api/src/test/java/collectors/DNSRecordCollectorTest.java @@ -3,6 +3,7 @@ import dev.aikido.agent_api.background.cloud.api.APIResponse; import dev.aikido.agent_api.background.cloud.api.events.DetectedAttack; import dev.aikido.agent_api.collectors.DNSRecordCollector; +import dev.aikido.agent_api.collectors.RedirectCollector; import dev.aikido.agent_api.context.Context; import dev.aikido.agent_api.context.ContextObject; import dev.aikido.agent_api.storage.AttackQueue; @@ -18,6 +19,7 @@ import utils.EmptySampleContextObject; import java.net.InetAddress; +import java.net.URL; import java.net.UnknownHostException; import java.util.List; @@ -223,4 +225,166 @@ public void testStoredSSRFWithNoContext() throws InterruptedException { DNSRecordCollector.report("metadata.google.internal", new InetAddress[]{imdsAddress1, inetAddress2}); }); } + + @Test + public void testPrivateIpLiteralWithNoPendingPortNotRecorded() { + // No pending port and the hostname is a private IP literal: infrastructure noise + // (e.g. Reactor Netty's resolver bootstrap resolving nameserver/bind addresses). + // Must not be recorded as an outbound hostname. + Context.set(null); + DNSRecordCollector.report("10.20.11.143", new InetAddress[]{inetAddress2}); + assertEquals(0, HostnamesStore.getHostnamesAsList().length); + } + + @Test + public void testPrivateIpLiteralWithNoPendingPortNotBlockedInLockdown() { + // Lockdown mode (blockNewOutgoingRequests=true) must not block a private IP literal + // that has no pending port, otherwise it would break internal/infra resolutions. + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, true, List.of(), true, true, List.of() + )); + Context.set(null); + assertDoesNotThrow(() -> + DNSRecordCollector.report("10.20.11.143", new InetAddress[]{inetAddress2}) + ); + assertEquals(0, HostnamesStore.getHostnamesAsList().length); + } + + @Test + public void testPrivateIpLiteralWithPendingPortStillRecordedAndBlockedInLockdown() { + // A private IP literal that DOES have a pending port came from a real outgoing + // request made through an instrumented client, not from infrastructure noise. It + // must still be recorded and still be subject to outbound blocking in lockdown mode. + ServiceConfigStore.updateFromAPIResponse(new APIResponse( + true, null, 0L, null, null, null, true, List.of(), true, true, List.of() + )); + PendingHostnamesStore.add("10.20.11.143", 443); + Context.set(mock(ContextObject.class)); + + assertThrows(BlockedOutboundException.class, () -> + DNSRecordCollector.report("10.20.11.143", new InetAddress[]{inetAddress2}) + ); + } + + @Test + public void testSsrfStillDetectedForPrivateIpLiteralWithPendingPort() { + // Regression test: an attacker-supplied private IP literal (e.g. a webhook URL field + // pointing straight at 169.254.169.254) reaching a real outgoing request through an + // instrumented client must still be caught as SSRF. Earlier attempts at filtering + // private IP literals used an early return that accidentally skipped this check. + ServiceConfigStore.updateBlocking(true); + PendingHostnamesStore.add("169.254.169.254", 80); + Context.set(new EmptySampleContextObject("http://169.254.169.254:80/latest/meta-data/")); + + Exception exception = assertThrows(SSRFException.class, () -> + DNSRecordCollector.report("169.254.169.254", new InetAddress[]{imdsAddress1}) + ); + assertEquals("Aikido Zen has blocked a server-side request forgery", exception.getMessage()); + } + + // reportConnect(): used by SocketChannelWrapper for clients that resolve their own DNS + // (e.g. Reactor Netty, used by Spring's WebClient) instead of InetAddress.getAllByName(), + // reporting one resolved address per connect() attempt. + + @Test + public void testReportConnectRecordsHostnameWithPendingPort() { + PendingHostnamesStore.add("example.com", 443); + Context.set(mock(ContextObject.class)); + + DNSRecordCollector.reportConnect("example.com", inetAddress1); + Hostnames.HostnameEntry[] entries = HostnamesStore.getHostnamesAsList(); + assertEquals(1, entries.length); + assertEquals("example.com", entries[0].getHostname()); + assertEquals(443, entries[0].getPort()); + } + + @Test + public void testReportConnectDoesNotConsumePendingPort() { + // Unlike report(), reportConnect() must peek instead of consume: a single outbound + // request can trigger multiple connect() calls to the same hostname (e.g. trying the + // IPv4 then the IPv6 address of a dual-stack host), and each one must still see the + // pending port to be checked correctly. + PendingHostnamesStore.add("example.com", 443); + Context.set(mock(ContextObject.class)); + + DNSRecordCollector.reportConnect("example.com", inetAddress1); + assertFalse(PendingHostnamesStore.getPorts("example.com").isEmpty()); + + // A second connect attempt (e.g. the IPv6 address) still sees the same pending port + // and records another hit, instead of falling back to port 0 or being skipped. + DNSRecordCollector.reportConnect("example.com", inetAddress2); + Hostnames.HostnameEntry[] entries = HostnamesStore.getHostnamesAsList(); + assertEquals(1, entries.length); + assertEquals("example.com", entries[0].getHostname()); + assertEquals(443, entries[0].getPort()); + assertEquals(2, entries[0].getHits()); + } + + @Test + public void testSsrfDetectedOnEveryConnectAttemptForDualStackHostname() throws UnknownHostException { + // Regression test for a real bug found via e2e testing: "localhost" resolves to both + // 127.0.0.1 and ::1, and Reactor Netty tries both addresses via separate connect() + // calls. With a naive getAndRemove() the first attempt would consume the pending port + // and the second attempt would silently skip the SSRF check, letting the request + // through despite the first attempt having been blocked. + InetAddress loopbackIPv6 = InetAddress.getByName("::1"); + ServiceConfigStore.updateBlocking(true); + PendingHostnamesStore.add("localhost", 5000); + Context.set(new EmptySampleContextObject("http://localhost:5000")); + + assertThrows(SSRFException.class, () -> + DNSRecordCollector.reportConnect("localhost", inetAddress2) // 127.0.0.1 + ); + assertThrows(SSRFException.class, () -> + DNSRecordCollector.reportConnect("localhost", loopbackIPv6) // ::1 + ); + } + + @Test + public void testReportConnectPrivateIpLiteralWithNoPendingPortNotRecorded() { + // Same private-IP-literal infrastructure-noise filtering as report(), but for the + // reportConnect() path: a literal IP with no pending port (e.g. a raw socket connect + // Reactor Netty makes that we never asked for) must not be recorded. + Context.set(null); + DNSRecordCollector.reportConnect("10.20.11.143", inetAddress2); + assertEquals(0, HostnamesStore.getHostnamesAsList().length); + } + + @Test + public void testReportConnectStoredSsrfStillRunsUnconditionally() { + ServiceConfigStore.updateBlocking(true); + Context.set(null); + + assertThrows(StoredSSRFException.class, () -> + DNSRecordCollector.reportConnect("dev.aikido", imdsAddress1) + ); + } + + @Test + public void testSsrfDetectedForRedirectToPrivateIp() throws Exception { + // Regression test: a WebClient call to a user-supplied, safe-looking URL that redirects + // to a private IP must still be caught, even though the redirect target itself never + // has a pending port (SpringWebClientWrapper only sees the original request). + // RedirectCollector.report() (called by SpringWebClientRedirectWrapper for each redirect + // hop) records the chain so SSRFDetector's PrivateIPRedirectFinder fallback can trace the + // private-IP target back to the tainted origin. + // Uses attacker-supplied.test rather than example.com since EmptySampleContextObject's + // own server URL defaults to example.com, which would collide with the origin hostname. + ServiceConfigStore.updateBlocking(true); + PendingHostnamesStore.add("attacker-supplied.test", 80); + Context.set(new EmptySampleContextObject("http://attacker-supplied.test/redirect-me")); + + RedirectCollector.report( + new URL("http://attacker-supplied.test/redirect-me"), + new URL("http://169.254.169.254/latest/meta-data/") + ); + + InetAddress imdsResolved = InetAddress.getByAddress( + "169.254.169.254", new byte[]{(byte) 169, (byte) 254, (byte) 169, (byte) 254}); + + Exception exception = assertThrows(SSRFException.class, () -> + DNSRecordCollector.reportConnect("169.254.169.254", imdsResolved) + ); + assertEquals("Aikido Zen has blocked a server-side request forgery", exception.getMessage()); + } } diff --git a/agent_api/src/test/java/storage/PendingHostnamesStoreTest.java b/agent_api/src/test/java/storage/PendingHostnamesStoreTest.java new file mode 100644 index 000000000..b7f88a87d --- /dev/null +++ b/agent_api/src/test/java/storage/PendingHostnamesStoreTest.java @@ -0,0 +1,81 @@ +package storage; + +import dev.aikido.agent_api.storage.PendingHostnamesStore; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.*; + +public class PendingHostnamesStoreTest { + @AfterEach + public void cleanup() { + PendingHostnamesStore.clear(); + } + + @Test + public void testGetPortsDoesNotRemoveEntry() { + PendingHostnamesStore.add("dev.aikido", 443); + + assertEquals(Set.of(443), PendingHostnamesStore.getPorts("dev.aikido")); + // Reading again still sees it: getPorts() peeks, it doesn't consume. + assertEquals(Set.of(443), PendingHostnamesStore.getPorts("dev.aikido")); + } + + @Test + public void testGetAndRemoveConsumesEntry() { + PendingHostnamesStore.add("dev.aikido", 443); + + assertEquals(Set.of(443), PendingHostnamesStore.getAndRemove("dev.aikido")); + assertTrue(PendingHostnamesStore.getAndRemove("dev.aikido").isEmpty()); + } + + @Test + public void testUnboundedHostnamesDoNotGrowThreadLocalMapForever() { + // Regression test: entries added outside any incoming-request context (e.g. a + // WebClient call from a @Scheduled task) never get cleared by WebRequestCollector's + // per-request clear(). Adding well over the internal cap of distinct hostnames must + // not let the store grow unboundedly - the oldest, untouched entries get evicted. + for (int i = 0; i < 2000; i++) { + PendingHostnamesStore.add("host-" + i + ".example.com", 443); + } + + // The very first hostnames added, never read again, must have been evicted. + assertTrue(PendingHostnamesStore.getPorts("host-0.example.com").isEmpty()); + assertTrue(PendingHostnamesStore.getPorts("host-1.example.com").isEmpty()); + + // The most recently added hostnames must still be present. + assertEquals(Set.of(443), PendingHostnamesStore.getPorts("host-1999.example.com")); + } + + @Test + public void testReadingAnEntryProtectsItFromEvictionWhileStillInUse() { + // A dual-stack connect sequence peeks the same hostname's entry more than once (e.g. + // IPv4 then IPv6 attempt), realistically with only a handful of unrelated hostnames + // registered on the same thread in between (well under the eviction cap) - not + // thousands. Each read counts as "recently used", so the entry survives that window. + PendingHostnamesStore.add("dual-stack.example.com", 443); + + for (int i = 0; i < 10; i++) { + PendingHostnamesStore.add("host-" + i + ".example.com", 443); + } + assertEquals(Set.of(443), PendingHostnamesStore.getPorts("dual-stack.example.com")); + + for (int i = 10; i < 20; i++) { + PendingHostnamesStore.add("host-" + i + ".example.com", 443); + } + assertEquals(Set.of(443), PendingHostnamesStore.getPorts("dual-stack.example.com")); + } + + @Test + public void testClearRemovesEverything() { + PendingHostnamesStore.add("dev.aikido", 443); + PendingHostnamesStore.add("dev.aikido.not", 80); + + PendingHostnamesStore.clear(); + + assertTrue(PendingHostnamesStore.getPorts("dev.aikido").isEmpty()); + assertTrue(PendingHostnamesStore.getPorts("dev.aikido.not").isEmpty()); + } +} diff --git a/agent_api/src/test/java/wrappers/WebClientTest.java b/agent_api/src/test/java/wrappers/WebClientTest.java new file mode 100644 index 000000000..4a4e21d32 --- /dev/null +++ b/agent_api/src/test/java/wrappers/WebClientTest.java @@ -0,0 +1,80 @@ +package wrappers; + +import dev.aikido.agent_api.context.Context; +import dev.aikido.agent_api.storage.HostnamesStore; +import dev.aikido.agent_api.storage.PendingHostnamesStore; +import dev.aikido.agent_api.storage.ServiceConfigStore; +import dev.aikido.agent_api.vulnerabilities.ssrf.SSRFException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.web.reactive.function.client.WebClient; +import utils.EmptySampleContextObject; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.channels.SocketChannel; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * SpringWebClientWrapper (URLCollector.report on ExchangeFunction.exchange) and + * SocketChannelWrapper (DNSRecordCollector.reportConnect on SocketChannel.connect) run on + * different threads for a real WebClient call: the former on the subscribing thread, the + * latter on Reactor Netty's own event-loop thread. PendingHostnamesStore/Context are + * ThreadLocal, so a plain "Context.set() then webClient.block()" test can't observe both + * halves together the way HttpURLConnectionTest can for a same-thread blocking client - that + * only works in production because a real WebFlux request stays on one reactor-http-nio + * thread throughout. So this file tests each wrapper's own contribution separately. + */ +public class WebClientTest { + private static final WebClient webClient = WebClient.create(); + + @AfterEach + void cleanup() { + Context.set(null); + HostnamesStore.clear(); + PendingHostnamesStore.clear(); + } + + @BeforeEach + void beforeEach() { + cleanup(); + ServiceConfigStore.updateBlocking(true); + PendingHostnamesStore.clear(); + } + + @Test + public void testSpringWebClientWrapperRegistersPendingPort() { + // ExchangeFunction.exchange() runs on the subscribing thread, same as this test - + // confirms SpringWebClientWrapper fires and calls URLCollector.report() correctly. + assertTrue(PendingHostnamesStore.getPorts("aikido.dev").isEmpty()); + + webClient.get().uri("https://aikido.dev").retrieve().bodyToMono(String.class).block(); + + assertEquals(Set.of(443), PendingHostnamesStore.getPorts("aikido.dev")); + } + + @Test + public void testSocketChannelWrapperBlocksSsrf() throws Exception { + // SocketChannel.connect() is synchronous and single-threaded regardless of caller, so + // this exercises SocketChannelWrapper + DNSRecordCollector.reportConnect's real SSRF + // logic deterministically, without Reactor's thread-hopping. + ServiceConfigStore.updateBlocking(true); + PendingHostnamesStore.add("localhost", 5000); + Context.set(new EmptySampleContextObject("http://localhost:5000")); + + // Built via getByAddress (no lookup, no InetAddressWrapper interception) so the + // resolved address reaches SocketChannel.connect() untouched, isolating this test to + // SocketChannelWrapper's own behavior. + InetAddress resolved = InetAddress.getByAddress("localhost", new byte[]{127, 0, 0, 1}); + InetSocketAddress address = new InetSocketAddress(resolved, 5000); + try (SocketChannel channel = SocketChannel.open()) { + SSRFException exception = assertThrows(SSRFException.class, () -> channel.connect(address)); + assertEquals("Aikido Zen has blocked a server-side request forgery", exception.getMessage()); + } + } +} diff --git a/end2end/server/mock_aikido_core.py b/end2end/server/mock_aikido_core.py index 6238f5741..f744dd273 100644 --- a/end2end/server/mock_aikido_core.py +++ b/end2end/server/mock_aikido_core.py @@ -1,6 +1,6 @@ import gzip -from flask import Flask, request, jsonify, Response +from flask import Flask, request, jsonify, Response, redirect import sys import os import json @@ -144,6 +144,11 @@ def mock_reset(): events = [] # Reset events return jsonify({}) +@app.route('/mock/redirect-to-metadata', methods=['GET']) +def mock_redirect_to_metadata(): + # Used to test redirect-based SSRF: a safe-looking URL that redirects to a private IP. + return redirect('http://169.254.169.254/latest/meta-data/', code=302) + @app.route('/mock/set_protection', methods=['POST']) def mock_set_protection(): req = request.get_json() diff --git a/end2end/spring_webflux_postgres.py b/end2end/spring_webflux_postgres.py index 0095830b8..d937f4704 100644 --- a/end2end/spring_webflux_postgres.py +++ b/end2end/spring_webflux_postgres.py @@ -30,6 +30,13 @@ unsafe_request=Request("/api/commands/executeFromCookie", method='GET', headers={'Cookie': 'command=|sleep;command=|sleep'}), ) +# WebClient SSRF: query params are the taint source tracked for Spring WebFlux (the request +# body isn't, see agent_api's SpringWebfluxContextObject). +spring_webflux_postgres_app.add_payload("ssrf", + safe_request=Request("/api/request?url=https://aikido.dev/", method='GET'), + unsafe_request=Request("/api/request?url=http://localhost:5000", method='GET') +) + spring_webflux_postgres_app.test_all_payloads() spring_webflux_postgres_app.test_blocking() spring_webflux_postgres_app.test_rate_limiting() diff --git a/sample-apps/SpringWebfluxSampleApp/src/main/java/dev/aikido/SpringWebfluxSampleApp/RequestController.java b/sample-apps/SpringWebfluxSampleApp/src/main/java/dev/aikido/SpringWebfluxSampleApp/RequestController.java new file mode 100644 index 000000000..bf9ee90eb --- /dev/null +++ b/sample-apps/SpringWebfluxSampleApp/src/main/java/dev/aikido/SpringWebfluxSampleApp/RequestController.java @@ -0,0 +1,70 @@ +package dev.aikido.SpringWebfluxSampleApp; + +import org.springframework.http.client.reactive.ReactorClientHttpConnector; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; +import reactor.netty.http.client.HttpClient; + +@RestController +@RequestMapping("/api/request") +public class RequestController { + private record UrlRequest(String url) {} + + private static final WebClient webClient = WebClient.create(); + + // A separate client with followRedirect enabled, to exercise SSRF detection across + // redirects (see SpringWebClientRedirectWrapper). + private static final WebClient webClientFollowingRedirects = WebClient.builder() + .clientConnector(new ReactorClientHttpConnector(HttpClient.create().followRedirect(true))) + .build(); + + @PostMapping + public Mono makeRequest(@RequestBody UrlRequest urlRequest) { + return makeRequestInternal(urlRequest.url()); + } + + // Query params are a tracked taint source for Spring WebFlux (unlike the request body), + // so this variant is used to exercise SSRF detection end to end. + @GetMapping + public Mono makeRequestFromQuery(@RequestParam String url) { + return makeRequestInternal(url); + } + + @GetMapping("/follow-redirects") + public Mono makeRequestFollowingRedirects(@RequestParam String url) { + return webClientFollowingRedirects.get() + .uri(url) + .retrieve() + .bodyToMono(String.class) + .onErrorResume(e -> isAikidoBlock(e) + ? Mono.error(e) + : Mono.just("Error: " + e.getMessage())); + } + + private Mono makeRequestInternal(String url) { + return webClient.get() + .uri(url) + .retrieve() + .bodyToMono(String.class) + .onErrorResume(e -> isAikidoBlock(e) + ? Mono.error(e) + : Mono.just("Error: " + e.getMessage())); + } + + // Aikido Zen blocks (SSRF, outbound blocking, ...) must propagate as a server error + // instead of being swallowed into a 200 response, same as any other Aikido block. + private static boolean isAikidoBlock(Throwable e) { + for (Throwable cause = e; cause != null; cause = cause.getCause()) { + if (cause.getClass().getName().startsWith("dev.aikido.agent_api.vulnerabilities")) { + return true; + } + } + return false; + } +}