diff --git a/.github/workflows/multi-agent-integration.yml b/.github/workflows/multi-agent-integration.yml new file mode 100644 index 0000000..784f359 --- /dev/null +++ b/.github/workflows/multi-agent-integration.yml @@ -0,0 +1,53 @@ +name: Multi-Agent Integration + +# End-to-end test of the examples/multi-agent-setup stack: build the images, +# send a message to orchestrator-api, and verify the orchestrator → worker → +# callback round-trip. +# +# Runs post-merge on main (and on manual dispatch) — not as a PR gate. It hits +# the real Claude API, so it needs the CLAUDE_API_KEY repo secret; if the secret +# is unset the job soft-skips rather than failing main. + +on: + push: + branches: [main] + paths: + - examples/multi-agent-setup/** + - api/** + - frontend/** + - processor-apps/processing/** + - think/think-consumer/** + - .github/workflows/multi-agent-integration.yml + workflow_dispatch: + +jobs: + e2e: + name: Orchestrator → Worker round-trip + runs-on: ubuntu-latest + timeout-minutes: 30 + + steps: + - uses: actions/checkout@v4 + + - name: Check for CLAUDE_API_KEY secret + id: guard + env: + CLAUDE_API_KEY: ${{ secrets.CLAUDE_API_KEY }} + run: | + if [ -z "$CLAUDE_API_KEY" ]; then + echo "::warning::CLAUDE_API_KEY secret not set — skipping integration test" + echo "skip=true" >> "$GITHUB_OUTPUT" + fi + + - name: Run integration test + if: steps.guard.outputs.skip != 'true' + working-directory: examples/multi-agent-setup + env: + CLAUDE_API_KEY: ${{ secrets.CLAUDE_API_KEY }} + TIMEOUT: "300" + run: ./integration-test.sh + + - name: Tear down (backstop) + if: always() + working-directory: examples/multi-agent-setup + run: docker compose -p multiagent-it down -v || true diff --git a/api/chat-api/Dockerfile b/api/chat-api/Dockerfile index 3c40d4e..2a7f9b9 100644 --- a/api/chat-api/Dockerfile +++ b/api/chat-api/Dockerfile @@ -3,7 +3,9 @@ WORKDIR /app COPY pom.xml . RUN mvn dependency:go-offline -q COPY src ./src -RUN mvn package -q -DskipTests +# Skip test compilation+execution for the image build (tests run in CI via +# `mvn test`); keeps the image build decoupled from test sources. +RUN mvn package -q -Dmaven.test.skip=true FROM eclipse-temurin:17-jre WORKDIR /app diff --git a/api/chat-api/pom.xml b/api/chat-api/pom.xml index ee6a26c..c3f6209 100644 --- a/api/chat-api/pom.xml +++ b/api/chat-api/pom.xml @@ -53,10 +53,29 @@ Java-WebSocket 1.5.7 + + + + org.junit.jupiter + junit-jupiter + 5.10.2 + test + + + org.assertj + assertj-core + 3.25.3 + test + + + org.apache.maven.plugins + maven-surefire-plugin + 3.2.5 + org.apache.maven.plugins maven-shade-plugin diff --git a/api/chat-api/src/main/java/io/flightdeck/api/CallbackRegistry.java b/api/chat-api/src/main/java/io/flightdeck/api/CallbackRegistry.java new file mode 100644 index 0000000..973c4ec --- /dev/null +++ b/api/chat-api/src/main/java/io/flightdeck/api/CallbackRegistry.java @@ -0,0 +1,81 @@ +package io.flightdeck.api; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Resolves a logical callback-service name to the trusted base URL configured for + * it, then appends the fixed callback path ({@value #CALLBACK_PATH}). + * + *

The mapping is loaded once from the {@code ALLOWED_HOST_MAPPING} environment + * variable: a comma-separated list of {@code name:baseUrl} entries. Only the first + * colon of each entry separates the name from the URL, so the base URL keeps its + * own {@code scheme://host[:port]} colons: + * + *

+ *   ALLOWED_HOST_MAPPING=my-agent-a:https://hosta.local,my-agent-c:http://hostc.local
+ * 
+ * + *

Why this exists. A caller (a peer agent) supplies only the service + * name in its {@code reply} descriptor; the destination URL is chosen + * here from operator-controlled config, never from caller input. An untrusted + * caller therefore cannot steer the server-side callback at an arbitrary host — + * the SSRF primitive that an attacker-controlled {@code endpoint} would create is + * structurally removed. Unknown names fail closed. + */ +final class CallbackRegistry { + + /** Fixed path appended to every resolved base URL. */ + static final String CALLBACK_PATH = "/api/tools/response"; + + private static final Map MAPPING = + parse(ChatApiApp.env("ALLOWED_HOST_MAPPING", "")); + + private CallbackRegistry() {} + + /** True if {@code service} resolves to a configured base URL. */ + static boolean isKnown(String service) { + return service != null && MAPPING.containsKey(service); + } + + /** + * Resolves the full callback URL for a service name. + * + * @throws IllegalArgumentException if the name is not configured (fail closed) + */ + static String resolve(String service) { + String base = service == null ? null : MAPPING.get(service); + if (base == null) { + throw new IllegalArgumentException("unknown callbackService: " + service); + } + return toCallbackUrl(base); + } + + /** Strips any trailing slash from the base URL and appends the fixed callback path. */ + static String toCallbackUrl(String base) { + String trimmed = base.endsWith("/") ? base.substring(0, base.length() - 1) : base; + return trimmed + CALLBACK_PATH; + } + + /** Parses {@code name:baseUrl} entries, splitting each on its FIRST colon only. */ + static Map parse(String raw) { + Map mapping = new LinkedHashMap<>(); + if (raw == null || raw.isBlank()) { + return Collections.unmodifiableMap(mapping); + } + for (String entry : raw.split(",")) { + String e = entry.trim(); + if (e.isEmpty()) continue; + int sep = e.indexOf(':'); + String name = sep > 0 ? e.substring(0, sep).trim() : ""; + String url = sep > 0 ? e.substring(sep + 1).trim() : ""; + if (name.isEmpty() || url.isEmpty()) { + throw new IllegalArgumentException( + "Malformed ALLOWED_HOST_MAPPING entry (expected name:baseUrl): " + entry); + } + mapping.put(name, url); + } + return Collections.unmodifiableMap(mapping); + } +} diff --git a/api/chat-api/src/main/java/io/flightdeck/api/ChatApiApp.java b/api/chat-api/src/main/java/io/flightdeck/api/ChatApiApp.java index 93c97e4..580da7f 100644 --- a/api/chat-api/src/main/java/io/flightdeck/api/ChatApiApp.java +++ b/api/chat-api/src/main/java/io/flightdeck/api/ChatApiApp.java @@ -21,12 +21,24 @@ public class ChatApiApp { private static final int WS_PORT = Integer.parseInt(env("WS_PORT", "8001")); public static void main(String[] args) throws Exception { - // 1. Kafka producer (chat → message-input) + // 1. Kafka producers (chat → message-input, async callbacks → tool-use-result, + // multi-agent reply routes → reply-to) KafkaMessageProducer producer = new KafkaMessageProducer(); + ToolResultProducer toolResultProducer = new ToolResultProducer(); + ReplyToProducer replyToProducer = new ReplyToProducer(); + + // Shared secret for verifying async tool callback tokens. Optional — + // if unset, /api/tools/response rejects every callback. + String callbackSecret = env("TOOL_CALLBACK_SECRET", ""); + if (callbackSecret.isBlank()) { + log.warn("TOOL_CALLBACK_SECRET is not set — /api/tools/response will reject all callbacks"); + } // 2. HTTP server for REST API HttpServer httpServer = HttpServer.create(new InetSocketAddress(PORT), 0); - httpServer.createContext("/api/chat", new ChatHandler(producer)); + httpServer.createContext("/api/chat", new ChatHandler(producer, replyToProducer)); + httpServer.createContext("/api/tools/response", + new ToolResponseHandler(toolResultProducer, callbackSecret)); httpServer.setExecutor(null); httpServer.start(); log.info("HTTP server started on port {}", PORT); @@ -35,8 +47,9 @@ public static void main(String[] args) throws Exception { ChatWebSocketServer wsServer = new ChatWebSocketServer(WS_PORT); wsServer.start(); - // 4. Kafka consumer (message-output → WebSocket chat response) - OutputConsumer outputConsumer = new OutputConsumer(wsServer); + // 4. Kafka consumer (message-output → WebSocket chat response, or → HTTP + // callback for sessions that carry a reply-to descriptor) + OutputConsumer outputConsumer = new OutputConsumer(wsServer, replyToProducer); Thread outputThread = new Thread(outputConsumer, "output-consumer"); outputThread.setDaemon(true); outputThread.start(); @@ -55,6 +68,8 @@ public static void main(String[] args) throws Exception { try { wsServer.stop(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } httpServer.stop(2); producer.close(); + toolResultProducer.close(); + replyToProducer.close(); })); log.info("Chat API ready — HTTP={} WS={}", PORT, WS_PORT); diff --git a/api/chat-api/src/main/java/io/flightdeck/api/ChatHandler.java b/api/chat-api/src/main/java/io/flightdeck/api/ChatHandler.java index 3d712ec..6c6b04c 100644 --- a/api/chat-api/src/main/java/io/flightdeck/api/ChatHandler.java +++ b/api/chat-api/src/main/java/io/flightdeck/api/ChatHandler.java @@ -24,6 +24,21 @@ * { "session_id": "...", "user_id": "user_42", "role": "user", * "content": "hello", "timestamp": "...", * "metadata": { "locale": "en-US", "client": "web" } } + * + *

For multi-agent calls the request may also carry a transport-level + * {@code reply} descriptor naming where this session's terminal response should + * be delivered, e.g.: + *

+ *   { "session_id": "...", "content": "...",
+ *     "reply": { "callbackService": "my-agent-a", "bearerToken": "<HMAC>" } }
+ * 
+ * {@code callbackService} is a logical name resolved server-side against + * {@code ALLOWED_HOST_MAPPING} ({@link CallbackRegistry}); the caller never + * supplies a URL, so the descriptor cannot steer the callback at an arbitrary + * host. Unknown names are rejected here with a 400. + * The {@code reply} object is NOT placed into the message content/metadata — it + * is written to the reply-to topic (keyed by session_id) so it never reaches the + * LLM. The agent processes the request as an ordinary chat. */ public class ChatHandler implements HttpHandler { @@ -31,9 +46,11 @@ public class ChatHandler implements HttpHandler { private static final ObjectMapper mapper = new ObjectMapper(); private final KafkaMessageProducer producer; + private final ReplyToProducer replyToProducer; - public ChatHandler(KafkaMessageProducer producer) { + public ChatHandler(KafkaMessageProducer producer, ReplyToProducer replyToProducer) { this.producer = producer; + this.replyToProducer = replyToProducer; } @Override @@ -59,6 +76,28 @@ public void handle(HttpExchange exchange) throws IOException { String sessionId = requireField(body, "session_id"); String content = requireField(body, "content"); + // Transport-level reply routing (multi-agent). Written to the reply-to + // topic keyed by session_id; never placed into the message content. + if (body.hasNonNull("reply")) { + JsonNode reply = body.get("reply"); + if (!reply.isObject()) { + throw new IllegalArgumentException("'reply' must be an object"); + } + String callbackService = reply.path("callbackService").asText(""); + if (callbackService.isBlank()) { + throw new IllegalArgumentException("'reply' must include 'callbackService'"); + } + // Fail closed: reject a descriptor naming a service this agent is not + // configured to call back, so unroutable routes never reach the + // reply-to topic and the caller gets an immediate 400. + if (!CallbackRegistry.isKnown(callbackService)) { + throw new IllegalArgumentException("unknown callbackService: " + callbackService); + } + replyToProducer.send(sessionId, mapper.writeValueAsString(reply)); + log.info("[{}] Stored reply-to descriptor (callbackService={})", + sessionId, callbackService); + } + // Build the full message-input payload ObjectNode message = mapper.createObjectNode(); message.put("session_id", sessionId); diff --git a/api/chat-api/src/main/java/io/flightdeck/api/OutputConsumer.java b/api/chat-api/src/main/java/io/flightdeck/api/OutputConsumer.java index 31c3b8d..175a0d8 100644 --- a/api/chat-api/src/main/java/io/flightdeck/api/OutputConsumer.java +++ b/api/chat-api/src/main/java/io/flightdeck/api/OutputConsumer.java @@ -1,5 +1,7 @@ package io.flightdeck.api; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; @@ -8,17 +10,41 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.time.Duration; import java.util.List; import java.util.Properties; /** - * Kafka consumer that reads UserResponse records from the message-output topic - * and forwards them to connected WebSocket clients via {@link ChatWebSocketServer}. + * Kafka consumer that reads {@code UserResponse} records from the message-output + * topic and delivers each one to its destination: + * + *
    + *
  • WebSocket — the default. Forwarded to connected browser clients + * via {@link ChatWebSocketServer}.
  • + *
  • HTTP callback — when the record carries a {@code reply_to} + * descriptor (a multi-agent call). The descriptor's {@code callbackService} + * is resolved to a trusted base URL via {@link CallbackRegistry} + * ({@code ALLOWED_HOST_MAPPING}) and the response is POSTed to that host's + * fixed callback path with the HMAC token as a bearer credential. On success + * the reply-to route is tombstoned so the one-shot call cannot be + * double-delivered.
  • + *
+ * + *

Retry policy for HTTP callbacks

+ * Retriable failures (connection errors, timeouts, HTTP 5xx/408/429) are retried + * with exponential backoff up to {@code REPLY_RETRY_MAX_MS} (default 120s, kept + * below the consumer's {@code max.poll.interval.ms} to avoid a rebalance). + * Non-retriable failures (other 4xx, e.g. an expired/invalid token) are logged + * and skipped. */ public class OutputConsumer implements Runnable { private static final Logger log = LoggerFactory.getLogger(OutputConsumer.class); + private static final ObjectMapper mapper = new ObjectMapper(); private static final String AGENT_NAME = ChatApiApp.requireEnv("AGENT_NAME"); private static final String TOPIC = AGENT_NAME + "-message-output"; @@ -27,11 +53,25 @@ public class OutputConsumer implements Runnable { private static final String CONSUMER_GROUP = ChatApiApp.env("OUTPUT_CONSUMER_GROUP", "chat-api-output-group"); + /** Total budget for retrying a single HTTP callback before giving up. */ + private static final long REPLY_RETRY_MAX_MS = + Long.parseLong(ChatApiApp.env("REPLY_RETRY_MAX_MS", "120000")); + private static final long RETRY_INITIAL_BACKOFF_MS = 1_000L; + private static final long RETRY_MAX_BACKOFF_MS = 15_000L; + + private static final String DEFAULT_RESPONSE_FIELD = "result"; + private final ChatWebSocketServer wsServer; + private final ReplyToProducer replyToProducer; + private final HttpClient httpClient; private volatile boolean running = true; - public OutputConsumer(ChatWebSocketServer wsServer) { + public OutputConsumer(ChatWebSocketServer wsServer, ReplyToProducer replyToProducer) { this.wsServer = wsServer; + this.replyToProducer = replyToProducer; + this.httpClient = HttpClient.newBuilder() + .connectTimeout(Duration.ofSeconds(10)) + .build(); } @Override @@ -53,8 +93,7 @@ public void run() { ConsumerRecords records = consumer.poll(Duration.ofMillis(200)); for (ConsumerRecord record : records) { String sessionId = record.key() != null ? record.key() : "unknown"; - log.info("[{}] Received response from message-output", sessionId); - wsServer.broadcastResponse(sessionId, record.value()); + deliver(sessionId, record.value()); } } } catch (Exception e) { @@ -65,6 +104,132 @@ public void run() { log.info("Output consumer stopped"); } + /** Routes a message-output record either to an HTTP callback or to WebSocket clients. */ + private void deliver(String sessionId, String value) { + JsonNode replyTo = null; + if (value != null) { + try { + JsonNode root = mapper.readTree(value); + JsonNode r = root.get("reply_to"); + if (r != null && r.isObject()) { + replyTo = r; + } + } catch (Exception e) { + log.warn("[{}] Failed to parse message-output — falling back to WebSocket: {}", + sessionId, e.getMessage()); + } + } + + if (replyTo != null && !replyTo.path("callbackService").asText("").isBlank()) { + log.info("[{}] Delivering response via HTTP callback", sessionId); + deliverHttp(sessionId, value, replyTo); + } else { + log.info("[{}] Received response from message-output", sessionId); + wsServer.broadcastResponse(sessionId, value); + } + } + + /** POSTs the response back to the calling agent, with bounded retries. */ + private void deliverHttp(String sessionId, String value, JsonNode replyTo) { + final HttpRequest request; + try { + request = buildRequest(value, replyTo); + } catch (Exception e) { + log.error("[{}] Invalid reply-to descriptor — skipping delivery: {}", sessionId, e.getMessage()); + return; + } + + long deadline = System.currentTimeMillis() + REPLY_RETRY_MAX_MS; + long backoff = RETRY_INITIAL_BACKOFF_MS; + int attempt = 0; + + while (running) { + attempt++; + try { + HttpResponse resp = + httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + int code = resp.statusCode(); + if (code >= 200 && code < 300) { + log.info("[{}] Reply delivered (HTTP {}, attempt {}) — tombstoning route", + sessionId, code, attempt); + replyToProducer.tombstone(sessionId); + return; + } + if (!isRetriable(code)) { + log.error("[{}] Reply delivery failed with non-retriable HTTP {} — skipping. Body: {}", + sessionId, code, truncate(resp.body())); + return; + } + log.warn("[{}] Reply delivery got retriable HTTP {} (attempt {})", sessionId, code, attempt); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + return; + } catch (Exception e) { + // Connection/timeout errors are retriable. + log.warn("[{}] Reply delivery error (attempt {}): {}", sessionId, attempt, e.getMessage()); + } + + if (System.currentTimeMillis() + backoff >= deadline) { + log.error("[{}] Reply delivery exhausted retry budget ({}ms) after {} attempts — skipping", + sessionId, REPLY_RETRY_MAX_MS, attempt); + return; + } + try { + Thread.sleep(backoff); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + return; + } + backoff = Math.min(backoff * 2, RETRY_MAX_BACKOFF_MS); + } + } + + /** Builds the callback HTTP request from the reply-to descriptor and the response body. */ + private HttpRequest buildRequest(String value, JsonNode replyTo) throws Exception { + String callbackService = replyTo.path("callbackService").asText(""); + String bearerToken = replyTo.path("bearerToken").asText(""); + + if (callbackService.isBlank()) { + throw new IllegalArgumentException("reply-to descriptor missing 'callbackService'"); + } + + // The caller supplies only a logical service name; the destination URL is + // resolved from operator-controlled config (ALLOWED_HOST_MAPPING) and the + // path is fixed, so a caller can never steer this request at an arbitrary + // host. Throws if the name is not configured — fail closed. + String url = CallbackRegistry.resolve(callbackService); + + // The agent's free-form answer goes under the default response field. + String content = ""; + try { + JsonNode root = mapper.readTree(value); + content = root.path("content").asText(""); + } catch (Exception ignore) { + // value already validated upstream; fall through with empty content + } + String body = mapper.writeValueAsString( + mapper.createObjectNode().put(DEFAULT_RESPONSE_FIELD, content)); + + HttpRequest.Builder b = HttpRequest.newBuilder() + .uri(URI.create(url)) + .timeout(Duration.ofSeconds(30)) + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(body)); + if (!bearerToken.isBlank()) { + b.header("Authorization", "Bearer " + bearerToken); + } + return b.build(); + } + + private static boolean isRetriable(int code) { + return code >= 500 || code == 408 || code == 429; + } + + private static String truncate(String s) { + if (s == null) return ""; + return s.length() <= 200 ? s : s.substring(0, 200) + "..."; + } + public void stop() { running = false; } diff --git a/api/chat-api/src/main/java/io/flightdeck/api/ReplyToProducer.java b/api/chat-api/src/main/java/io/flightdeck/api/ReplyToProducer.java new file mode 100644 index 0000000..782ea5e --- /dev/null +++ b/api/chat-api/src/main/java/io/flightdeck/api/ReplyToProducer.java @@ -0,0 +1,90 @@ +package io.flightdeck.api; + +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.serialization.StringSerializer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Properties; + +/** + * Produces reply-routing descriptors to the {@code {AGENT_NAME}-reply-to} topic, + * keyed by session_id. + * + *

Used in two places: + *

    + *
  • {@link ChatHandler} writes a descriptor when an inbound {@code /api/chat} + * request carries a {@code reply} object — establishing where this session's + * terminal response should be delivered.
  • + *
  • {@link OutputConsumer} writes a tombstone (null value) after it has + * successfully delivered a session's response, so a one-shot sub-agent call + * cannot be double-delivered.
  • + *
+ * + *

The topic is compacted (latest-per-key) with a time-based retention, so the + * descriptor lives only until it is tombstoned or until {@code REPLY_TO_STATE_TTL_MS} + * elapses, whichever comes first. + */ +public class ReplyToProducer { + + private static final Logger log = LoggerFactory.getLogger(ReplyToProducer.class); + + private static final String AGENT_NAME = ChatApiApp.requireEnv("AGENT_NAME"); + private static final String TOPIC = AGENT_NAME + "-reply-to"; + + private static final String BOOTSTRAP_SERVERS = + ChatApiApp.env("KAFKA_BOOTSTRAP_SERVERS", "localhost:9092"); + + private final KafkaProducer producer; + + public ReplyToProducer() { + Properties props = new Properties(); + KafkaEnvProps.apply(props); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, BOOTSTRAP_SERVERS); + props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + props.put(ProducerConfig.ACKS_CONFIG, "all"); + props.put(ProducerConfig.RETRIES_CONFIG, 3); + + this.producer = new KafkaProducer<>(props); + log.info("Reply-to producer initialized — bootstrap={} topic={}", BOOTSTRAP_SERVERS, TOPIC); + } + + /** + * Writes (or overwrites) the reply-routing descriptor for a session. + * + *

TODO(multi-agent): keyed by sessionId, so only one outstanding reply + * route per session is supported (a second write overwrites the first). For + * multiple concurrent callbacks within one session, key by + * {@code sessionId:requestId}. Deferred — see Topics.REPLY_TO. + */ + public void send(String sessionId, String descriptorJson) { + ProducerRecord record = new ProducerRecord<>(TOPIC, sessionId, descriptorJson); + producer.send(record, (metadata, exception) -> { + if (exception != null) { + log.error("[{}] Failed to produce reply-to descriptor to {}", sessionId, TOPIC, exception); + } else { + log.debug("[{}] Wrote reply-to descriptor to {}", sessionId, TOPIC); + } + }); + } + + /** Tombstones (null value) a session's reply-routing descriptor after delivery. */ + public void tombstone(String sessionId) { + ProducerRecord record = new ProducerRecord<>(TOPIC, sessionId, null); + producer.send(record, (metadata, exception) -> { + if (exception != null) { + log.error("[{}] Failed to tombstone reply-to descriptor on {}", sessionId, TOPIC, exception); + } else { + log.debug("[{}] Tombstoned reply-to descriptor on {}", sessionId, TOPIC); + } + }); + } + + public void close() { + producer.close(); + log.info("Reply-to producer closed"); + } +} diff --git a/api/chat-api/src/main/java/io/flightdeck/api/ToolCallbackToken.java b/api/chat-api/src/main/java/io/flightdeck/api/ToolCallbackToken.java new file mode 100644 index 0000000..6a68ca0 --- /dev/null +++ b/api/chat-api/src/main/java/io/flightdeck/api/ToolCallbackToken.java @@ -0,0 +1,111 @@ +package io.flightdeck.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.time.Instant; +import java.util.Base64; + +/** + * Verifies and decodes the HMAC-signed callback token minted by the tool + * consumer SDK when it dispatches an async tool. + * + *

Token layout (must match the SDK's {@code signCallbackToken}): + *

+ *   token = base64url(payloadJson) "." base64url(HMAC_SHA256(secret, payloadJson))
+ * 
+ * + *

The HMAC is recomputed over the exact decoded payload bytes, so JSON key + * ordering is irrelevant — we never re-serialise during verification. + */ +public final class ToolCallbackToken { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final Base64.Decoder URL_DECODER = Base64.getUrlDecoder(); + + private ToolCallbackToken() {} + + /** Decoded token payload. */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record Payload( + @JsonProperty("session_id") String sessionId, + @JsonProperty("tool_use_id") String toolUseId, + @JsonProperty("tool_id") String toolId, + @JsonProperty("name") String name, + @JsonProperty("total_tools") int totalTools, + @JsonProperty("agent") String agent, + @JsonProperty("iat") long iat, + @JsonProperty("exp") long exp + ) {} + + /** Raised when a token is malformed, has a bad signature, or has expired. */ + public static class InvalidTokenException extends Exception { + public InvalidTokenException(String message) { super(message); } + } + + /** + * Verifies the token signature against {@code secret} and its expiry against + * the current time, then returns the decoded payload. + * + * @throws InvalidTokenException if the token is malformed, the signature + * does not match, or the token has expired. + */ + public static Payload verify(String token, String secret) throws InvalidTokenException { + if (token == null || token.isBlank()) { + throw new InvalidTokenException("Missing token"); + } + if (secret == null || secret.isBlank()) { + throw new InvalidTokenException("Server is not configured to verify callback tokens"); + } + + int dot = token.indexOf('.'); + if (dot <= 0 || dot == token.length() - 1) { + throw new InvalidTokenException("Malformed token"); + } + + byte[] payloadBytes; + byte[] providedSig; + try { + payloadBytes = URL_DECODER.decode(token.substring(0, dot)); + providedSig = URL_DECODER.decode(token.substring(dot + 1)); + } catch (IllegalArgumentException e) { + throw new InvalidTokenException("Token is not valid base64url"); + } + + byte[] expectedSig = hmacSha256(secret, payloadBytes); + // Constant-time comparison to avoid signature timing oracles. + if (!MessageDigest.isEqual(expectedSig, providedSig)) { + throw new InvalidTokenException("Bad token signature"); + } + + Payload payload; + try { + payload = MAPPER.readValue(payloadBytes, Payload.class); + } catch (Exception e) { + throw new InvalidTokenException("Token payload is not valid JSON"); + } + + if (payload.exp() > 0 && Instant.now().getEpochSecond() > payload.exp()) { + throw new InvalidTokenException("Token has expired"); + } + if (payload.sessionId() == null || payload.toolUseId() == null) { + throw new InvalidTokenException("Token missing session_id or tool_use_id"); + } + return payload; + } + + private static byte[] hmacSha256(String secret, byte[] data) { + try { + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(secret.getBytes(StandardCharsets.UTF_8), "HmacSHA256")); + return mac.doFinal(data); + } catch (Exception e) { + throw new RuntimeException("HMAC computation failed", e); + } + } +} diff --git a/api/chat-api/src/main/java/io/flightdeck/api/ToolResponseHandler.java b/api/chat-api/src/main/java/io/flightdeck/api/ToolResponseHandler.java new file mode 100644 index 0000000..65c58e1 --- /dev/null +++ b/api/chat-api/src/main/java/io/flightdeck/api/ToolResponseHandler.java @@ -0,0 +1,136 @@ +package io.flightdeck.api; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.time.Instant; + +/** + * Handles {@code POST /api/tools/response} — the callback endpoint for + * asynchronous tools, including sub-agent (multi-agent) replies. + * + *

When a tool consumer dispatches an async tool it acks the {@code tool-use} + * message immediately (via {@code ctx.pending()}) and hands the external system + * a signed callback token. When the external work finishes, the external system + * calls this endpoint with the token as a bearer credential and the result in + * the body: + * + *

+ *   POST /api/tools/response
+ *   Authorization: Bearer <signed token>
+ *   { "result": <any JSON> }
+ *   → 202 Accepted
+ * 
+ * + *

The token carries session_id, tool_use_id, tool_id, name and total_tools, + * so the body only needs the result payload. The handler verifies the token and + * writes a complete {@code ToolUseResult} into {@code tool-use-result} — the + * exact same topic and shape a synchronous tool would have produced. The + * aggregator cannot tell the two apart and dedupes by tool_use_id, so duplicate + * callbacks are safe. + * + *

This is a server-to-server endpoint authenticated by the bearer token, so + * no CORS headers are emitted. + */ +public class ToolResponseHandler implements HttpHandler { + + private static final Logger log = LoggerFactory.getLogger(ToolResponseHandler.class); + private static final ObjectMapper mapper = new ObjectMapper(); + private static final String BEARER_PREFIX = "Bearer "; + + private final ToolResultProducer producer; + private final String callbackSecret; + + public ToolResponseHandler(ToolResultProducer producer, String callbackSecret) { + this.producer = producer; + this.callbackSecret = callbackSecret; + } + + @Override + public void handle(HttpExchange exchange) throws IOException { + if (!"POST".equalsIgnoreCase(exchange.getRequestMethod())) { + sendJson(exchange, 405, "{\"error\":\"Method not allowed\"}"); + return; + } + + try (InputStream is = exchange.getRequestBody()) { + // Auth: HMAC callback token carried as `Authorization: Bearer `. + String authHeader = exchange.getRequestHeaders().getFirst("Authorization"); + String token = (authHeader != null && authHeader.startsWith(BEARER_PREFIX)) + ? authHeader.substring(BEARER_PREFIX.length()).trim() + : null; + + ToolCallbackToken.Payload payload; + try { + payload = ToolCallbackToken.verify(token, callbackSecret); + } catch (ToolCallbackToken.InvalidTokenException e) { + log.warn("Rejected tool callback: {}", e.getMessage()); + sendJson(exchange, 401, "{\"error\":\"" + e.getMessage() + "\"}"); + return; + } + + JsonNode body = mapper.readTree(is); + + // The body IS the tool result payload (e.g. { "result": "" }). + // ToolUseResult.result is an object, so use the body directly when it + // is one and wrap primitives/arrays under a "result" key otherwise. + ObjectNode resultObj; + if (body != null && body.isObject()) { + resultObj = (ObjectNode) body; + } else { + resultObj = mapper.createObjectNode(); + if (body != null && !body.isNull()) { + resultObj.set("result", body); + } + } + + // Latency measured from when the token was minted (iat, epoch seconds). + long latencyMs = payload.iat() > 0 + ? Math.max(0, System.currentTimeMillis() - payload.iat() * 1000L) + : 0L; + + ObjectNode result = mapper.createObjectNode(); + result.put("session_id", payload.sessionId()); + result.put("tool_use_id", payload.toolUseId()); + if (payload.toolId() != null) { + result.put("tool_id", payload.toolId()); + } + result.put("name", payload.name()); + result.set("result", resultObj); + result.put("latency_ms", latencyMs); + result.put("status", "success"); + result.put("total_tools", payload.totalTools()); + result.put("timestamp", Instant.now().toString()); + + producer.send(payload.sessionId(), mapper.writeValueAsString(result)); + + log.info("[{}] Async tool result accepted via callback — tool_use_id={} tool_id={}", + payload.sessionId(), payload.toolUseId(), payload.toolId()); + + sendJson(exchange, 202, + "{\"status\":\"accepted\",\"tool_use_id\":\"" + payload.toolUseId() + "\"}"); + + } catch (Exception e) { + log.error("Failed to handle tool callback", e); + sendJson(exchange, 500, "{\"error\":\"Internal server error\"}"); + } + } + + private static void sendJson(HttpExchange exchange, int status, String json) throws IOException { + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders(status, bytes.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + } + } +} diff --git a/api/chat-api/src/main/java/io/flightdeck/api/ToolResultProducer.java b/api/chat-api/src/main/java/io/flightdeck/api/ToolResultProducer.java new file mode 100644 index 0000000..62be041 --- /dev/null +++ b/api/chat-api/src/main/java/io/flightdeck/api/ToolResultProducer.java @@ -0,0 +1,61 @@ +package io.flightdeck.api; + +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.serialization.StringSerializer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Properties; + +/** + * Produces async tool results to the Kafka {@code tool-use-result} topic — the + * same topic the synchronous tool consumers write to. The record key is the + * session_id so results co-partition with the aggregator's per-session store. + *

Fed by {@link ToolResponseHandler} when an external system calls back to + * {@code POST /api/tools/response} with the result of an async tool. + */ +public class ToolResultProducer { + + private static final Logger log = LoggerFactory.getLogger(ToolResultProducer.class); + + private static final String AGENT_NAME = ChatApiApp.requireEnv("AGENT_NAME"); + private static final String TOPIC = AGENT_NAME + "-tool-use-result"; + + private static final String BOOTSTRAP_SERVERS = + ChatApiApp.env("KAFKA_BOOTSTRAP_SERVERS", "localhost:9092"); + + private final KafkaProducer producer; + + public ToolResultProducer() { + Properties props = new Properties(); + KafkaEnvProps.apply(props); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, BOOTSTRAP_SERVERS); + props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + props.put(ProducerConfig.ACKS_CONFIG, "all"); + props.put(ProducerConfig.RETRIES_CONFIG, 3); + + this.producer = new KafkaProducer<>(props); + log.info("Tool result producer initialized — bootstrap={} topic={}", BOOTSTRAP_SERVERS, TOPIC); + } + + /** Sends a tool-use-result record to {@code tool-use-result}, keyed by sessionId. */ + public void send(String sessionId, String resultJson) { + ProducerRecord record = new ProducerRecord<>(TOPIC, sessionId, resultJson); + producer.send(record, (metadata, exception) -> { + if (exception != null) { + log.error("[{}] Failed to produce to {}", sessionId, TOPIC, exception); + } else { + log.debug("[{}] Produced to {} partition={} offset={}", + sessionId, TOPIC, metadata.partition(), metadata.offset()); + } + }); + } + + public void close() { + producer.close(); + log.info("Tool result producer closed"); + } +} diff --git a/api/chat-api/src/test/java/io/flightdeck/api/CallbackRegistryTest.java b/api/chat-api/src/test/java/io/flightdeck/api/CallbackRegistryTest.java new file mode 100644 index 0000000..3b9ac08 --- /dev/null +++ b/api/chat-api/src/test/java/io/flightdeck/api/CallbackRegistryTest.java @@ -0,0 +1,69 @@ +package io.flightdeck.api; + +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Unit-tests the pure mapping logic of {@link CallbackRegistry} — the SSRF-relevant + * parts: parsing {@code ALLOWED_HOST_MAPPING}, splitting each entry on its first + * colon only (so {@code scheme://host:port} survives), and composing the fixed + * callback URL. The env-backed {@code MAPPING}/{@code resolve}/{@code isKnown} + * surface is exercised against the (empty) default config to assert fail-closed. + */ +class CallbackRegistryTest { + + @Test + void parsesSingleEntry() { + Map m = CallbackRegistry.parse("my-agent-a:https://hosta.local"); + assertEquals("https://hosta.local", m.get("my-agent-a")); + assertEquals(1, m.size()); + } + + @Test + void parsesMultipleEntriesAndPreservesSchemeAndPort() { + Map m = CallbackRegistry.parse( + "my-agent-a:https://hosta.local,orchestrator:http://orchestrator-api:8000"); + assertEquals("https://hosta.local", m.get("my-agent-a")); + // First-colon-only split must keep the http:// and the :8000 port intact. + assertEquals("http://orchestrator-api:8000", m.get("orchestrator")); + } + + @Test + void parseToleratesWhitespaceAndEmptyInput() { + assertEquals("https://x.local", + CallbackRegistry.parse(" a : https://x.local ").get("a")); + assertEquals(0, CallbackRegistry.parse("").size()); + assertEquals(0, CallbackRegistry.parse(" ").size()); + } + + @Test + void parseRejectsMalformedEntry() { + assertThrows(IllegalArgumentException.class, () -> CallbackRegistry.parse("no-url-here")); + assertThrows(IllegalArgumentException.class, () -> CallbackRegistry.parse(":https://x")); + assertThrows(IllegalArgumentException.class, () -> CallbackRegistry.parse("name:")); + } + + @Test + void toCallbackUrlAppendsFixedPathAndNormalizesTrailingSlash() { + assertEquals("https://hosta.local/api/tools/response", + CallbackRegistry.toCallbackUrl("https://hosta.local")); + assertEquals("https://hosta.local/api/tools/response", + CallbackRegistry.toCallbackUrl("https://hosta.local/")); + assertEquals("http://orchestrator-api:8000/api/tools/response", + CallbackRegistry.toCallbackUrl("http://orchestrator-api:8000")); + } + + @Test + void resolveFailsClosedForUnknownService() { + // Default config (no ALLOWED_HOST_MAPPING in the test env) → nothing is known. + assertFalse(CallbackRegistry.isKnown("anything")); + assertFalse(CallbackRegistry.isKnown(null)); + assertThrows(IllegalArgumentException.class, () -> CallbackRegistry.resolve("anything")); + assertThrows(IllegalArgumentException.class, () -> CallbackRegistry.resolve(null)); + } +} diff --git a/api/chat-api/src/test/java/io/flightdeck/api/ChatHandlerReplyValidationTest.java b/api/chat-api/src/test/java/io/flightdeck/api/ChatHandlerReplyValidationTest.java new file mode 100644 index 0000000..f7a2b9d --- /dev/null +++ b/api/chat-api/src/test/java/io/flightdeck/api/ChatHandlerReplyValidationTest.java @@ -0,0 +1,46 @@ +package io.flightdeck.api; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Validates the {@code reply} descriptor handling in {@link ChatHandler} — that a + * reply object carrying a non-blank {@code callbackService} is accepted and a + * malformed one is rejected before it reaches Kafka. Membership of the name in + * {@code ALLOWED_HOST_MAPPING} is enforced separately (see {@link CallbackRegistry} + * and {@link CallbackRegistryTest}). + */ +class ChatHandlerReplyValidationTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private boolean isWellFormedReply(JsonNode reply) { + // Mirror ChatHandler: require an object with a non-blank callbackService. + if (reply == null || !reply.isObject()) return false; + return !reply.path("callbackService").asText("").isBlank(); + } + + @Test + void acceptsReplyWithCallbackService() { + JsonNode reply = MAPPER.createObjectNode() + .put("callbackService", "my-agent-a") + .put("bearerToken", "tok-abc"); + assertTrue(isWellFormedReply(reply)); + } + + @Test + void rejectsNonObjectReply() { + JsonNode reply = MAPPER.createObjectNode().put("x", 1).path("x"); + assertFalse(isWellFormedReply(reply)); + } + + @Test + void rejectsReplyWithoutCallbackService() { + JsonNode reply = MAPPER.createObjectNode().put("bearerToken", "tok-abc"); + assertFalse(isWellFormedReply(reply)); + } +} diff --git a/api/chat-api/src/test/java/io/flightdeck/api/OutputConsumerReplyRoutingTest.java b/api/chat-api/src/test/java/io/flightdeck/api/OutputConsumerReplyRoutingTest.java new file mode 100644 index 0000000..7b9dae7 --- /dev/null +++ b/api/chat-api/src/test/java/io/flightdeck/api/OutputConsumerReplyRoutingTest.java @@ -0,0 +1,58 @@ +package io.flightdeck.api; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit-tests the routing decision in {@link OutputConsumer#deliver} — specifically + * the predicate that decides HTTP callback vs WebSocket — without standing up + * Kafka or a WebSocket server. The decision hinges on the presence of a + * {@code reply_to} descriptor carrying a non-blank {@code callbackService}. + */ +class OutputConsumerReplyRoutingTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private boolean isHttpCallback(String value) throws Exception { + JsonNode root = MAPPER.readTree(value); + JsonNode replyTo = root.get("reply_to"); + // Mirror the predicate in deliver(): replyTo present, object, callbackService set. + if (replyTo == null || !replyTo.isObject()) return false; + return !replyTo.path("callbackService").asText("").isBlank(); + } + + @Test + void routesToHttpWhenCallbackServicePresent() throws Exception { + String value = MAPPER.writeValueAsString(MAPPER.createObjectNode() + .set("reply_to", MAPPER.createObjectNode() + .put("callbackService", "my-agent-a") + .put("bearerToken", "tok-abc"))); + assertTrue(isHttpCallback(value)); + } + + @Test + void routesToWebSocketWhenNoReplyTo() throws Exception { + String value = MAPPER.writeValueAsString(MAPPER.createObjectNode() + .put("content", "hello")); + assertFalse(isHttpCallback(value)); + } + + @Test + void routesToWebSocketWhenReplyToNotObject() throws Exception { + String value = MAPPER.writeValueAsString(MAPPER.createObjectNode() + .put("reply_to", "not-an-object")); + assertFalse(isHttpCallback(value)); + } + + @Test + void routesToWebSocketWhenCallbackServiceBlank() throws Exception { + String value = MAPPER.writeValueAsString(MAPPER.createObjectNode() + .set("reply_to", MAPPER.createObjectNode() + .put("bearerToken", "tok-abc"))); + assertFalse(isHttpCallback(value)); + } +} diff --git a/examples/multi-agent-setup/README.md b/examples/multi-agent-setup/README.md new file mode 100644 index 0000000..2ac4ee7 --- /dev/null +++ b/examples/multi-agent-setup/README.md @@ -0,0 +1,159 @@ +# Multi-Agent Setup — Async Tool Delegation with Reply Routing + +Two FlightDeck agents on one Kafka cluster. The **orchestrator** delegates tasks +to the **worker** by calling a tool — but instead of blocking, the tool hands the +work off and the worker's answer flows back as the tool result *asynchronously*. + +The whole point: **the orchestrator treats the worker as a generic async tool, +and the worker treats the request as a generic chat. Neither knows the other is +an agent.** + +## How it works + +``` + User ──▶ Orchestrator (Agent A) + │ LLM calls delegate_task(task) + ▼ + orchestrator-tool-use + │ + ▼ + orchestrator-dispatcher ┌─ mint HMAC token + • ctx "pending": ack offset, NO result ────┤ (session_id, tool_use_id, + • POST task to worker /api/chat │ tool_id, name, total_tools) + with a `reply` descriptor └─ token = capability, opaque to B + │ + ▼ + Worker (Agent B) — an ordinary agent + /api/chat stores reply-to (keyed by session) ──▶ worker-reply-to topic + think → (its own tools, if any) → end turn + EndTurnProcessor left-joins reply-to ──▶ worker-message-output (carries reply_to) + │ + ▼ + worker-api OutputConsumer + • sees reply_to → POST answer back to A's /api/tools/response + Authorization: Bearer , body { "result": "" } + • on success: tombstone the reply-to route (one-shot) + │ + ▼ + Orchestrator /api/tools/response + • verify token → write tool-use-result (keyed by A's session) + │ + ▼ + orchestrator aggregator + • matches tool_use_id → turn completes (or times out → error result) + │ + ▼ + User ◀── Orchestrator synthesizes the worker's answer +``` + +## Design choices illustrated + +- **Async, not blocking.** The dispatcher acks the `tool-use` offset *without* + producing a `tool-use-result` (the "pending" ack) and returns immediately. No + thread is held open waiting for the worker. Compare this to a synchronous + bridge that blocks the tool consumer for the whole sub-call. + +- **HMAC token = capability, secret stays with A.** The orchestrator mints a + signed token carrying the correlation fields and hands it to the worker. The + worker cannot read, forge, or alter it — it only echoes it back. Only the + orchestrator holds `TOOL_CALLBACK_SECRET`, on the two components that need it + (dispatcher to sign, chat-api to verify). + +- **Reply routing is transport, not prompt.** The `reply` descriptor + (`{ "callbackService": "...", "bearerToken": "..." }`) travels as a transport + field on `/api/chat`, is stored on the worker's `reply-to` topic keyed by + `session_id`, and is re-attached to the worker's output at end-turn by a + left-join. The worker's LLM never sees it, and the worker has no "call back to + the orchestrator" tool. Routing lives in the delivery layer, keeping the worker + a vanilla agent. + +- **SSRF-safe callbacks: the caller names a service, never a URL.** The descriptor + carries only a logical `callbackService` name. The worker resolves it to a base + URL from its own operator-controlled `ALLOWED_HOST_MAPPING` + (`orchestrator:http://orchestrator-api:8000`) and POSTs to the fixed path + `/api/tools/response`. Because the destination host comes from worker + config and never from caller input, an untrusted caller cannot steer the + server-side callback at an internal host — the request is rejected (400 at + `/api/chat`, and fail-closed at delivery) if the name is not in the mapping. + +- **Free-form answer, A shapes it.** The worker returns ordinary prose. The + worker's OutputConsumer wraps it under the fixed field name + (`{ "result": "" }`) and the orchestrator turns it into a canonical tool + result. No JSON contract is forced on the worker. + +- **Failure is a timeout, not an error channel.** The worker never reports + failure. If it crashes or never answers, the orchestrator's aggregator hits its + deadline (`ASYNC_TOOL_TIMEOUT_MS`, default 5 min) and synthesizes an error + result from the expected tool set, so the turn always completes. + +- **One-shot, no double-delivery.** On a successful callback the worker tombstones + its reply-to route; the topic is also compacted with a TTL + (`REPLY_TO_STATE_TTL_MS`, default 24h) so stale routes are dropped. + +## Services + +| Service | Agent | Description | +|---------|-------|-------------| +| `orchestrator-api` | orchestrator | Chat API (user-facing) + `/api/tools/response` callback | +| `orchestrator-processing` | orchestrator | Kafka Streams pipeline | +| `orchestrator-think` | orchestrator | Claude API — decides when to `delegate_task` | +| `orchestrator-dispatcher` | orchestrator | The async tool: mints token, POSTs to worker, acks pending | +| `frontend` | orchestrator | Web UI | +| `worker-api` | worker | Chat API — accepts the task, delivers the reply via HTTP callback | +| `worker-processing` | worker | Kafka Streams pipeline | +| `worker-think` | worker | Claude API — answers the delegated task | + +> The worker has an empty `tools.json` here, so it answers directly from the LLM. +> It is still a full agent — give it its own tools and tool service and nothing +> about the reply routing changes. + +## Run + +```bash +cp .env.example .env +# Edit .env: add CLAUDE_API_KEY and set TOOL_CALLBACK_SECRET (e.g. `openssl rand -hex 32`) + +docker compose up --build +``` + +`chat-api`, `processing`, and `think-consumer` build from source in this repo +(this example demonstrates branch functionality not yet in the published images). +All three must share the same build — a source-built `processing` paired with a +published `think-consumer` mismatches the `ThinkResponse` schema and produces +empty content / lost history. The first build takes a few minutes. + +Open [http://localhost](http://localhost) and try: + +- "Ask the worker to write a haiku about Kafka." +- "Delegate this to the worker: summarize the tradeoffs of event-driven architecture in 3 bullets." +- "Have the worker explain what an HMAC is, simply." + +Watch `docker compose logs -f orchestrator-dispatcher worker-api` to see the +dispatch → pending ack → worker answer → callback → tool result round-trip. + +## Integration test + +`integration-test.sh` brings the whole stack up, posts a message to +`orchestrator-api`, and verifies the round-trip — that the worker handled a +delegated sub-session and the orchestrator emitted a non-empty final answer — +then tears everything down. It takes `CLAUDE_API_KEY` from the environment: + +```bash +CLAUDE_API_KEY=sk-ant-... ./integration-test.sh +``` + +The default prompt asks the worker to derive ∫₀^∞ x²/(eˣ−1) dx (= 2ζ(3) ≈ +2.404114), so the printed answer is easy to eyeball. Override `TIMEOUT` or +`CLAUDE_MODEL` via env if needed. + +## Troubleshooting + +- **`/api/tools/response` returns 401** — `TOOL_CALLBACK_SECRET` differs between + `orchestrator-api` and `orchestrator-dispatcher`, or the token expired. +- **Orchestrator replies "the tool failed / timed out"** — the worker didn't + answer within `ASYNC_TOOL_TIMEOUT_MS`. Check the `worker-think` logs (API key, + model) and `worker-api` logs (callback POST result). +- **Docker build fails compiling tests** — the service Dockerfiles compile the + module's `src/test`. Unrelated work-in-progress test files in your working tree + that don't compile will break the build; build from a clean checkout of the + branch. diff --git a/examples/multi-agent-setup/docker-compose.yml b/examples/multi-agent-setup/docker-compose.yml new file mode 100644 index 0000000..53bb498 --- /dev/null +++ b/examples/multi-agent-setup/docker-compose.yml @@ -0,0 +1,183 @@ +# Multi-agent setup — async tool delegation with reply routing. +# +# Two FlightDeck agents on one Kafka cluster: +# • orchestrator (user-facing) — delegates tasks to the worker via an async tool +# • worker (headless) — an ordinary agent that answers delegated tasks +# +# chat-api, processing AND think-consumer are BUILT FROM SOURCE in this repo +# because this example demonstrates functionality on the feat/multi-agent branch +# that is not in the published images yet. All three core services MUST share the +# same build — mixing a source-built processing with a published think-consumer +# causes a ThinkResponse schema mismatch (empty content / lost history). Only the +# frontend (a stable UI over chat-api's WebSocket) uses a published image. + +services: + + # ─── Shared Infrastructure ────────────────────────────────────────────────── + + kafka: + image: apache/kafka:4.1.1 + hostname: kafka + ports: + - "9092:9092" + volumes: + - kafka-data:/var/lib/kafka/data + environment: + KAFKA_NODE_ID: 1 + KAFKA_PROCESS_ROLES: broker,controller + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: CONTROLLER:PLAINTEXT,INTERNAL:PLAINTEXT,EXTERNAL:PLAINTEXT + KAFKA_LISTENERS: INTERNAL://kafka:29092,EXTERNAL://0.0.0.0:9092,CONTROLLER://kafka:9093 + KAFKA_ADVERTISED_LISTENERS: INTERNAL://kafka:29092,EXTERNAL://localhost:9092 + KAFKA_INTER_BROKER_LISTENER_NAME: INTERNAL + KAFKA_CONTROLLER_LISTENER_NAMES: CONTROLLER + KAFKA_CONTROLLER_QUORUM_VOTERS: 1@kafka:9093 + KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1 + KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR: 1 + KAFKA_TRANSACTION_STATE_LOG_MIN_ISR: 1 + KAFKA_GROUP_INITIAL_REBALANCE_DELAY_MS: 0 + KAFKA_AUTO_CREATE_TOPICS_ENABLE: "true" + KAFKA_LOG_DIRS: /var/lib/kafka/data + CLUSTER_ID: MkU3OEVBNTcwNTJENDM2Qk + healthcheck: + test: ["CMD", "/opt/kafka/bin/kafka-topics.sh", "--bootstrap-server", "localhost:9092", "--list"] + interval: 10s + timeout: 10s + retries: 10 + start_period: 30s + + # ═══════════════════════════════════════════════════════════════════════════ + # ORCHESTRATOR AGENT (user-facing, Agent A) + # ═══════════════════════════════════════════════════════════════════════════ + + orchestrator-api: + build: + context: ../../api/chat-api + image: flightdeck-multiagent/chat-api:local + ports: + - "8000:8000" + - "8001:8001" + networks: + default: + aliases: + - api + environment: + KAFKA_BOOTSTRAP_SERVERS: kafka:29092 + AGENT_NAME: orchestrator + PORT: 8000 + WS_PORT: 8001 + # Required to verify async tool callbacks on /api/tools/response. + TOOL_CALLBACK_SECRET: ${TOOL_CALLBACK_SECRET} + depends_on: + kafka: + condition: service_healthy + + orchestrator-processing: + build: + context: ../../processor-apps/processing + image: flightdeck-multiagent/processing:local + environment: + KAFKA_BOOTSTRAP_SERVERS: kafka:29092 + AGENT_NAME: orchestrator + MEMOIR_ENABLED: "false" + depends_on: + kafka: + condition: service_healthy + + orchestrator-think: + build: + context: ../../think/think-consumer + image: flightdeck-multiagent/think-consumer:local + volumes: + - ./orchestrator/tools.json:/app/tools.json:ro + - ./orchestrator/system-prompt.txt:/app/system-prompt.txt:ro + environment: + KAFKA_BOOTSTRAP_SERVERS: kafka:29092 + AGENT_NAME: orchestrator + CLAUDE_API_KEY: ${CLAUDE_API_KEY} + CLAUDE_MODEL: ${CLAUDE_MODEL:-claude-haiku-4-5-20251001} + CLAUDE_MAX_TOKENS: ${CLAUDE_MAX_TOKENS:-8096} + TOOLS_JSON_FILE: /app/tools.json + SYSTEM_PROMPT_FILE: /app/system-prompt.txt + depends_on: + kafka: + condition: service_healthy + + # The async multi-agent tool: dispatches delegate_task to the worker and acks + # "pending" — the result returns via the callback, not synchronously. + orchestrator-dispatcher: + build: + context: ../.. + dockerfile: examples/multi-agent-setup/orchestrator/dispatcher/Dockerfile + environment: + KAFKA_BOOTSTRAP_SERVERS: kafka:29092 + AGENT_NAME: orchestrator + TOOL_CALLBACK_SECRET: ${TOOL_CALLBACK_SECRET} + TARGET_CHAT_URL: http://worker-api:8000/api/chat + # Logical callback-service name; the worker resolves it to the URL below. + CALLBACK_SERVICE: orchestrator + depends_on: + kafka: + condition: service_healthy + + frontend: + image: ghcr.io/tsuz/flightdeck/frontend:${FLIGHTDECK_VERSION:-latest} + ports: + - "80:80" + depends_on: + - orchestrator-api + + # ═══════════════════════════════════════════════════════════════════════════ + # WORKER AGENT (headless, Agent B) + # ═══════════════════════════════════════════════════════════════════════════ + + # The worker runs its own chat-api: it accepts /api/chat (with the reply + # descriptor) and its OutputConsumer POSTs the answer back to the orchestrator. + # No host ports are exposed — it is reached internally at worker-api:8000. + worker-api: + image: flightdeck-multiagent/chat-api:local + environment: + KAFKA_BOOTSTRAP_SERVERS: kafka:29092 + AGENT_NAME: worker + PORT: 8000 + WS_PORT: 8001 + # SSRF-safe callback routing: a `reply` descriptor may only name one of + # these services; the worker resolves the name to the base URL here and + # POSTs to /api/tools/response. Callers never supply a URL. + ALLOWED_HOST_MAPPING: "orchestrator:http://orchestrator-api:8000" + depends_on: + kafka: + condition: service_healthy + orchestrator-api: + condition: service_started + + worker-processing: + image: flightdeck-multiagent/processing:local + environment: + KAFKA_BOOTSTRAP_SERVERS: kafka:29092 + AGENT_NAME: worker + MEMOIR_ENABLED: "false" + depends_on: + kafka: + condition: service_healthy + orchestrator-processing: + condition: service_started + + worker-think: + image: flightdeck-multiagent/think-consumer:local + volumes: + - ./worker/tools.json:/app/tools.json:ro + - ./worker/system-prompt.txt:/app/system-prompt.txt:ro + environment: + KAFKA_BOOTSTRAP_SERVERS: kafka:29092 + AGENT_NAME: worker + CLAUDE_API_KEY: ${CLAUDE_API_KEY} + CLAUDE_MODEL: ${CLAUDE_MODEL:-claude-haiku-4-5-20251001} + CLAUDE_MAX_TOKENS: ${CLAUDE_MAX_TOKENS:-8096} + TOOLS_JSON_FILE: /app/tools.json + SYSTEM_PROMPT_FILE: /app/system-prompt.txt + depends_on: + kafka: + condition: service_healthy + +volumes: + kafka-data: diff --git a/examples/multi-agent-setup/integration-test.sh b/examples/multi-agent-setup/integration-test.sh new file mode 100755 index 0000000..0833551 --- /dev/null +++ b/examples/multi-agent-setup/integration-test.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash +# +# Integration test for the multi-agent-setup example. +# +# Spins up the whole stack (orchestrator + worker + Kafka), sends a message to +# the orchestrator's /api/chat, and verifies the async multi-agent round-trip: +# 1. the worker handled a delegated sub-session, and +# 2. the orchestrator produced a non-empty final answer for our session. +# +# The result (the orchestrator's final answer) is printed on success. +# +# Usage: +# CLAUDE_API_KEY=sk-ant-... ./integration-test.sh +# +# Optional env: +# TOOL_CALLBACK_SECRET shared HMAC secret (default: random) +# TIMEOUT seconds to wait for the answer (default: 240) +# CLAUDE_MODEL model override (default: compose default) +# +# Requires: docker compose, curl, python3. + +set -euo pipefail + +: "${CLAUDE_API_KEY:?CLAUDE_API_KEY must be set}" +export CLAUDE_API_KEY +export TOOL_CALLBACK_SECRET="${TOOL_CALLBACK_SECRET:-it-secret-${RANDOM}${RANDOM}}" +[ -n "${CLAUDE_MODEL:-}" ] && export CLAUDE_MODEL + +cd "$(dirname "$0")" + +PROJECT="multiagent-it" +COMPOSE=(docker compose -p "$PROJECT") +ORCH_URL="http://localhost:8000" +SESSION="it-$(date +%s)" +TIMEOUT="${TIMEOUT:-300}" +PROMPT="Compute the exact closed-form value of the integral ∫₀^∞ (x² / (eˣ − 1)) dx. Show the full derivation: rewrite the integrand using the geometric series expansion of 1/(eˣ − 1), interchange sum and integral with justification, and reduce to a product of a Gamma function and a Riemann zeta value. State the final answer in terms of ζ(3). Then independently confirm the numerical value to 6 decimal places using a direct numerical integration, and report both numbers side by side so they can be compared." + +cleanup() { + echo "--- tearing down ---" + "${COMPOSE[@]}" down -v >/dev/null 2>&1 || true +} +trap cleanup EXIT + +# Read a topic from the beginning with a short idle timeout (existing records only). +consume() { + "${COMPOSE[@]}" exec -T kafka /opt/kafka/bin/kafka-console-consumer.sh \ + --bootstrap-server localhost:9092 --topic "$1" \ + --from-beginning --timeout-ms 4000 2>/dev/null || true +} + +echo "--- building & starting stack (first build can take several minutes) ---" +"${COMPOSE[@]}" up -d --build + +echo "--- waiting for orchestrator-api at $ORCH_URL ---" +ready="" +for _ in $(seq 1 60); do + if curl -sf -o /dev/null -X OPTIONS "$ORCH_URL/api/chat"; then ready=1; break; fi + sleep 3 +done +[ -n "$ready" ] || { echo "FAIL: orchestrator-api never became ready"; exit 1; } + +echo "--- sending message (session=$SESSION) ---" +# Build the JSON body with python3 so the Unicode prompt (∫, ζ, −, …) is encoded safely. +payload=$(SESSION="$SESSION" PROMPT="$PROMPT" python3 -c \ + 'import json,os; print(json.dumps({"session_id":os.environ["SESSION"],"content":os.environ["PROMPT"]}))') +curl -sf -X POST "$ORCH_URL/api/chat" \ + -H 'Content-Type: application/json' \ + -d "$payload" +echo + +echo "--- awaiting orchestrator final answer (timeout ${TIMEOUT}s) ---" +final="" +worker_seen="" +end=$((SECONDS + TIMEOUT)) +while [ "$SECONDS" -lt "$end" ]; do + # 1. Did the worker handle a delegated sub-session ({SESSION}--{tool_use_id})? + if [ -z "$worker_seen" ] && consume worker-message-output | grep -q "${SESSION}--"; then + worker_seen=1 + echo "[ok] worker handled a delegated sub-session" + fi + + # 2. Did the orchestrator emit a non-empty final answer for our session? + line=$(consume orchestrator-message-output | grep "\"session_id\":\"${SESSION}\"" | tail -1 || true) + if [ -n "$line" ]; then + content=$(printf '%s' "$line" | python3 -c 'import sys,json; print(json.load(sys.stdin).get("content",""))' 2>/dev/null || true) + if [ -n "$content" ]; then final="$content"; break; fi + fi + + sleep 3 +done + +if [ -z "$final" ]; then + echo "FAIL: no non-empty orchestrator response for ${SESSION} within ${TIMEOUT}s" + echo "--- recent logs ---" + "${COMPOSE[@]}" logs --tail=40 \ + orchestrator-api orchestrator-think orchestrator-dispatcher worker-api worker-think || true + exit 1 +fi + +echo "=================== ORCHESTRATOR ANSWER ===================" +echo "$final" +echo "==========================================================" + +# Informational: the closed form is Γ(3)·ζ(3) = 2ζ(3) ≈ 2.404114. LLM phrasing +# varies, so this is a soft signal, not a hard assertion. +if printf '%s' "$final" | grep -qiE 'ζ\(3\)|zeta\(3\)|2\.40411'; then + echo "[ok] answer references the expected result (ζ(3) / ≈2.404114)" +else + echo "[warn] answer did not obviously reference ζ(3) / 2.404114 — review the output above" +fi + +if [ -z "$worker_seen" ]; then + echo "FAIL: got an answer but the worker never received a delegated task" + exit 1 +fi + +echo "PASS: multi-agent round-trip verified (orchestrator → worker → callback → answer)" diff --git a/examples/multi-agent-setup/orchestrator/dispatcher/Dockerfile b/examples/multi-agent-setup/orchestrator/dispatcher/Dockerfile new file mode 100644 index 0000000..8c68310 --- /dev/null +++ b/examples/multi-agent-setup/orchestrator/dispatcher/Dockerfile @@ -0,0 +1,14 @@ +FROM python:3.11-slim +WORKDIR /app + +# Install the FlightDeck Python SDK (used here only for kafka_env_props). +COPY sdk/python /tmp/sdk +RUN pip install --no-cache-dir /tmp/sdk && rm -rf /tmp/sdk + +COPY examples/multi-agent-setup/orchestrator/dispatcher/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY examples/multi-agent-setup/orchestrator/dispatcher/app.py . + +ENV PYTHONUNBUFFERED=1 +CMD ["python", "app.py"] diff --git a/examples/multi-agent-setup/orchestrator/dispatcher/app.py b/examples/multi-agent-setup/orchestrator/dispatcher/app.py new file mode 100644 index 0000000..7cf811a --- /dev/null +++ b/examples/multi-agent-setup/orchestrator/dispatcher/app.py @@ -0,0 +1,158 @@ +""" +Orchestrator dispatcher — the asynchronous multi-agent tool. + +This is the Agent-A side of the multi-agent design. When the orchestrator LLM +calls `delegate_task`, the tool does NOT answer synchronously. Instead it: + + 1. Mints an HMAC callback token correlating this exact tool call + (session_id, tool_use_id, tool_id, name, total_tools). Only the + orchestrator holds the secret, so only the orchestrator can verify it — the + worker is a blind courier that just echoes the token back. + + 2. POSTs the task to the worker agent's /api/chat with a transport-level + `reply` descriptor naming the callback service to answer to (a logical name + the worker resolves to the orchestrator's URL via ALLOWED_HOST_MAPPING), + carrying the token as a bearer credential. The dispatcher never sends a URL, + and the reply descriptor never enters the worker's prompt. + + 3. Commits the tool-use offset WITHOUT producing a tool-use-result. This is the + "pending" ack: the dispatch is acknowledged so it is not redelivered, but no + result is published — the result will arrive later via the callback. + +We use a plain Kafka consumer here (rather than the SDK's ToolConsumerRunner) so +we have explicit control over the "ack-without-result" semantics. A future SDK +`ctx.pending()` would replace this loop. + +The worker processes the task as an ordinary chat. When it finishes, the worker's +chat-api OutputConsumer POSTs the answer back to the orchestrator's +/api/tools/response, which verifies the token and writes the tool-use-result — +completing the original tool call. If the worker never responds, the +orchestrator's aggregator times out and synthesizes an error result, so the turn +cannot hang. + +Env vars: + KAFKA_BOOTSTRAP_SERVERS Kafka broker + AGENT_NAME orchestrator agent name (prefixes its topics) + TOOL_CALLBACK_SECRET shared HMAC secret (same as orchestrator chat-api) + TARGET_CHAT_URL worker /api/chat URL, e.g. http://worker-api:8000/api/chat + CALLBACK_SERVICE logical callback-service name the worker maps to the + orchestrator's URL via ALLOWED_HOST_MAPPING, e.g. "orchestrator" + TOKEN_TTL_SEC callback token lifetime in seconds (default 86400) +""" + +import base64 +import hashlib +import hmac +import json +import os +import time + +import requests +from confluent_kafka import Consumer + +from flightdeck_sdk import kafka_env_props + +KAFKA_BOOTSTRAP = os.environ.get("KAFKA_BOOTSTRAP_SERVERS", "kafka:29092") +AGENT_NAME = os.environ["AGENT_NAME"] +SECRET = os.environ["TOOL_CALLBACK_SECRET"] +TARGET_CHAT_URL = os.environ["TARGET_CHAT_URL"] +CALLBACK_SERVICE = os.environ["CALLBACK_SERVICE"] +TOKEN_TTL_SEC = int(os.environ.get("TOKEN_TTL_SEC", "86400")) + +INPUT_TOPIC = f"{AGENT_NAME}-tool-use" +GROUP_ID = f"{AGENT_NAME}-tool-execution" + + +def _b64url(raw: bytes) -> str: + """base64url without padding — matches the FlightDeck token format.""" + return base64.urlsafe_b64encode(raw).rstrip(b"=").decode() + + +def mint_token(session_id, tool_use_id, tool_id, name, total_tools): + now = int(time.time()) + payload = { + "session_id": session_id, + "tool_use_id": tool_use_id, + "tool_id": tool_id, + "name": name, + "total_tools": total_tools, + "agent": AGENT_NAME, + "iat": now, + "exp": now + TOKEN_TTL_SEC, + } + # The HMAC is computed over the exact transmitted payload bytes, so the + # verifier never re-serializes — JSON key ordering is irrelevant. + payload_bytes = json.dumps(payload, separators=(",", ":")).encode() + sig = hmac.new(SECRET.encode(), payload_bytes, hashlib.sha256).digest() + return f"{_b64url(payload_bytes)}.{_b64url(sig)}" + + +def dispatch(tool_use): + session_id = tool_use.get("session_id", "") + tool_use_id = tool_use.get("tool_use_id", "") + tool_id = tool_use.get("tool_id", "") + name = tool_use.get("name", "") + total_tools = tool_use.get("total_tools", 1) + task = (tool_use.get("input") or {}).get("task", "") + + token = mint_token(session_id, tool_use_id, tool_id, name, total_tools) + + # A fresh session for the worker's sub-conversation. The worker keys its + # reply-to state by THIS id; the orchestrator-side correlation lives in the + # token (which carries the orchestrator session_id + tool_use_id), not here. + worker_session = f"{session_id}--{tool_use_id}" + + body = { + "session_id": worker_session, + "content": task, + "reply": { + "callbackService": CALLBACK_SERVICE, + "bearerToken": token, + }, + } + + print(f"[{session_id}] delegate tool_use_id={tool_use_id} → worker " + f"session={worker_session}: {task[:80]!r}", flush=True) + resp = requests.post(TARGET_CHAT_URL, json=body, timeout=15) + resp.raise_for_status() + print(f"[{session_id}] worker accepted task (HTTP {resp.status_code}); " + f"awaiting async callback", flush=True) + + +def main(): + consumer = Consumer({ + **kafka_env_props(), + "bootstrap.servers": KAFKA_BOOTSTRAP, + "group.id": GROUP_ID, + "auto.offset.reset": "earliest", + "enable.auto.commit": True, + "enable.auto.offset.store": False, # we store offsets explicitly after acking + }) + consumer.subscribe([INPUT_TOPIC]) + print(f"Orchestrator dispatcher started — listening on [{INPUT_TOPIC}]", flush=True) + print(f" target chat : {TARGET_CHAT_URL}", flush=True) + print(f" callback svc: {CALLBACK_SERVICE} (resolved by worker via ALLOWED_HOST_MAPPING)", flush=True) + + try: + while True: + msg = consumer.poll(1.0) + if msg is None: + continue + if msg.error(): + print(f"consumer error: {msg.error()}", flush=True) + continue + try: + dispatch(json.loads(msg.value())) + except Exception as e: + # We still ack on failure: the orchestrator aggregator will time + # out and synthesize an error result, so the turn cannot hang. + # (Prefer NOT acking if you'd rather redeliver on transient errors.) + print(f"dispatch failed: {e}", flush=True) + # "pending" ack: record the source offset, publish NO tool-use-result. + consumer.store_offsets(msg) + finally: + consumer.close() + + +if __name__ == "__main__": + main() diff --git a/examples/multi-agent-setup/orchestrator/dispatcher/requirements.txt b/examples/multi-agent-setup/orchestrator/dispatcher/requirements.txt new file mode 100644 index 0000000..891b2f1 --- /dev/null +++ b/examples/multi-agent-setup/orchestrator/dispatcher/requirements.txt @@ -0,0 +1,2 @@ +confluent-kafka>=2.6.0 +requests>=2.31.0 diff --git a/examples/multi-agent-setup/orchestrator/system-prompt.txt b/examples/multi-agent-setup/orchestrator/system-prompt.txt new file mode 100644 index 0000000..c626ff2 --- /dev/null +++ b/examples/multi-agent-setup/orchestrator/system-prompt.txt @@ -0,0 +1,13 @@ +You are the Orchestrator — a user-facing coordinator agent. + +You do not do specialist work yourself. When the user asks for something that +should be handled by the worker agent, call the `delegate_task` tool with a +clear, self-contained instruction. The worker is an independent agent that will +complete the task and return its answer to you. + +The worker runs asynchronously: when you call `delegate_task` the result is not +instant — it arrives once the worker has finished. You will then receive the +worker's answer as the tool result. Synthesize it into a clear, helpful reply +for the user, and feel free to delegate follow-up tasks if needed. + +Keep your own messages concise. Let the worker do the heavy lifting. diff --git a/examples/multi-agent-setup/orchestrator/tools.json b/examples/multi-agent-setup/orchestrator/tools.json new file mode 100644 index 0000000..fd21bd8 --- /dev/null +++ b/examples/multi-agent-setup/orchestrator/tools.json @@ -0,0 +1,16 @@ +[ + { + "name": "delegate_task", + "description": "Delegate a self-contained task to the worker agent and receive its completed result. Use this whenever the user asks for work that the specialist worker should handle. The worker is an independent agent; you hand it an instruction and its answer comes back to you. You may delegate more than one task.", + "input_schema": { + "type": "object", + "properties": { + "task": { + "type": "string", + "description": "A clear, self-contained instruction for the worker agent. Include everything the worker needs — it has no access to this conversation." + } + }, + "required": ["task"] + } + } +] diff --git a/examples/multi-agent-setup/worker/system-prompt.txt b/examples/multi-agent-setup/worker/system-prompt.txt new file mode 100644 index 0000000..f1bdcc3 --- /dev/null +++ b/examples/multi-agent-setup/worker/system-prompt.txt @@ -0,0 +1,15 @@ +You are the Worker — an independent task-handling agent. + +You receive a single, self-contained task and complete it to the best of your +ability, returning a clear and concise answer. You are unaware of any +orchestration around you: you simply answer the request you were given, as +helpfully as possible, and stop. + +If the task is ambiguous, make a reasonable assumption and state it briefly +rather than asking a follow-up question — there is no interactive user on the +other end, only another agent waiting for your result. + +This worker has no tools of its own in this example, but it is an ordinary +FlightDeck agent: you could give it its own `tools.json` and tool service and it +would use them exactly like any other agent, with no change to how its result is +routed back to the orchestrator. diff --git a/examples/multi-agent-setup/worker/tools.json b/examples/multi-agent-setup/worker/tools.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/examples/multi-agent-setup/worker/tools.json @@ -0,0 +1 @@ +[] diff --git a/processor-apps/processing/Dockerfile b/processor-apps/processing/Dockerfile index 7a8a342..a21b702 100644 --- a/processor-apps/processing/Dockerfile +++ b/processor-apps/processing/Dockerfile @@ -3,7 +3,9 @@ WORKDIR /app COPY pom.xml . RUN mvn dependency:go-offline -q COPY src ./src -RUN mvn package -q -DskipTests +# Skip test compilation+execution for the image build (tests run in CI via +# `mvn test`); keeps the image build decoupled from test sources. +RUN mvn package -q -Dmaven.test.skip=true FROM eclipse-temurin:17-jre WORKDIR /app diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/FlightDeckStreamsApp.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/FlightDeckStreamsApp.java index 836ed8a..e83a22a 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/FlightDeckStreamsApp.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/FlightDeckStreamsApp.java @@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory; import java.util.List; +import java.util.Map; import java.util.Properties; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -45,6 +46,11 @@ public class FlightDeckStreamsApp { static final String MEMOIR_CONTEXT_STORE = "memoir-context-store"; static final String THINK_RESPONSE_STORE = "think-response-store"; + static final String REPLY_TO_STORE = "reply-to-store"; + + /** Time-based expiry for reply-to routing state (default 24h). */ + static final long REPLY_TO_STATE_TTL_MS = Long.parseLong( + System.getenv().getOrDefault("REPLY_TO_STATE_TTL_MS", "86400000")); public static void main(String[] args) { @@ -118,11 +124,23 @@ static Topology buildTopology(boolean memoirEnabled) { .withValueSerde(JsonSerde.of(ThinkResponse.class)) ); + // ── Shared KTable: reply-to routing descriptors (multi-agent) ─────── + // session_id → reply descriptor JSON. Left-joined into message-output + // at end-turn so the OutputConsumer knows where to deliver a reply. + KTable replyToTable = builder.table( + Topics.REPLY_TO, + Consumed.with(Serdes.String(), Serdes.String()), + Materialized.as( + Stores.persistentKeyValueStore(REPLY_TO_STORE)) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + ); + // ── Register each processor fragment ────────────────────────────────── EnrichInputMessageProcessor.register(builder, memoirTable, thinkTable); ExtractToolUseItemsProcessor.register(builder, thinkStream); - EndTurnProcessor.register(builder, thinkStream); - AggregateToolExecutionResultProcessor.register(builder); + EndTurnProcessor.register(builder, thinkStream, replyToTable); + AggregateToolExecutionResultProcessor.register(builder, thinkStream); TransformToolUseDoneProcessor.register(builder); if (memoirEnabled) { @@ -158,7 +176,8 @@ private static void ensureTopicsExist(Properties streamsProps) { Topics.TOOL_USE_RESULT, Topics.TOOL_USE_ALL_COMPLETE, Topics.TOOL_USE_LATENCY, - Topics.MESSAGE_OUTPUT + Topics.MESSAGE_OUTPUT, + Topics.REPLY_TO )); if (MEMOIR_ENABLED) { @@ -174,7 +193,18 @@ private static void ensureTopicsExist(Properties streamsProps) { List toCreate = requiredTopics.stream() .filter(t -> !existing.contains(t)) - .map(t -> new NewTopic(t, 1, (short) 1)) + .map(t -> { + NewTopic topic = new NewTopic(t, 1, (short) 1); + // reply-to is a keyed-state topic: keep the latest descriptor + // per session_id, but also drop anything older than the TTL. + if (t.equals(Topics.REPLY_TO)) { + topic.configs(Map.of( + "cleanup.policy", "compact,delete", + "retention.ms", String.valueOf(REPLY_TO_STATE_TTL_MS) + )); + } + return topic; + }) .collect(Collectors.toList()); if (!toCreate.isEmpty()) { diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/config/Topics.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/config/Topics.java index 011f7be..e645a2a 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/config/Topics.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/config/Topics.java @@ -66,4 +66,23 @@ private static String requireEnv(String key) { // ── Outbound ────────────────────────────────────────────────────────────── /** Final responses sent back to the user-facing layer */ public static final String MESSAGE_OUTPUT = PREFIX + "message-output"; + + // ── Multi-agent reply routing ────────────────────────────────────────────── + /** + * Per-session reply-routing descriptors, keyed by session_id. Written by the + * chat-api {@code /api/chat} endpoint when an inbound request carries a + * {@code reply} object, joined into {@code message-output} at end-turn, and + * consumed by the chat-api OutputConsumer to deliver the response back to the + * caller. Compacted with a time-based retention ({@code REPLY_TO_STATE_TTL_MS}) + * so stale routes are eventually dropped. + * + *

TODO(multi-agent): the key is {@code session_id}, so a session can hold + * only ONE outstanding reply route at a time — a second {@code reply} + * descriptor for the same session overwrites the first. To support multiple + * concurrent callbacks within one session (e.g. parallel sub-agent calls that + * reuse the session), key by {@code sessionId:requestId} instead and carry the + * requestId through the join. Deferred — one route per session is sufficient + * for the current one-shot delegation flow. + */ + public static final String REPLY_TO = PREFIX + "reply-to"; } \ No newline at end of file diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ThinkResponse.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ThinkResponse.java index a33232f..d6751a9 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ThinkResponse.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ThinkResponse.java @@ -1,5 +1,6 @@ package io.flightdeck.streams.model; +import com.fasterxml.jackson.annotation.JsonAlias; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; @@ -20,14 +21,22 @@ public record ThinkResponse( @JsonProperty("session_id") String sessionId, @JsonProperty("user_id") String userId, - @JsonProperty("total_session_cost") Double totalSessionCost, - @JsonProperty("previous_session_cost") Double previousSessionCost, + // @JsonAlias accepts the older think-consumer wire schema too: published + // think-consumer images (pre the previous/last-input split) emit + // "cost"/"prev_session_cost"/"input_tokens"/"output_tokens" and a single + // combined "messages" list. Aliasing lets this model read both schemas, so + // a source-built processing stays compatible with published think-consumer + // images. assembleContent()/history filter by role, so a combined + // "messages" list (user + assistant) deserialized into lastInputResponse + // still yields the correct content. + @JsonProperty("total_session_cost") @JsonAlias("cost") Double totalSessionCost, + @JsonProperty("previous_session_cost") @JsonAlias("prev_session_cost") Double previousSessionCost, @JsonProperty("think_cost") Double thinkCost, - @JsonProperty("think_input_tokens") int thinkInputTokens, - @JsonProperty("think_output_tokens") int thinkOutputTokens, + @JsonProperty("think_input_tokens") @JsonAlias("input_tokens") int thinkInputTokens, + @JsonProperty("think_output_tokens") @JsonAlias("output_tokens") int thinkOutputTokens, @JsonProperty("previous_messages") List previousMessages, @JsonProperty("last_input_message") MessageInput lastInputMessage, - @JsonProperty("last_input_response") List lastInputResponse, + @JsonProperty("last_input_response") @JsonAlias("messages") List lastInputResponse, @JsonProperty("tool_uses") List toolUses, @JsonProperty("end_turn") boolean endTurn, @JsonProperty("compaction") boolean compaction, diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ToolResultAccumulator.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ToolResultAccumulator.java index 773ae4d..675515c 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ToolResultAccumulator.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ToolResultAccumulator.java @@ -4,15 +4,34 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; /** * Per-session accumulator held in RocksDB state store while tool results are - * arriving. Published to {@code tool-use-all-complete} once - * {@code results.size() == expectedCount}. + * arriving. Published to {@code tool-use-all-complete} once every expected tool + * result has been collected. * - *

The {@code emitted} flag acts as a one-shot guard — once the complete - * event has been forwarded downstream the flag is set to {@code true} so - * that any late-arriving duplicate results do not trigger a second emission. + *

Expected set vs count

+ * The accumulator is seeded from {@code think-request-response} with the full + * {@link ExpectedTool} set (every {@code tool_use_id} + {@code name} the LLM + * requested) plus a wall-clock {@code deadlineMs}. Completion is then defined as + * "every expected tool_use_id has a result". This holds whether a result was + * produced synchronously by a tool consumer or asynchronously via the + * {@code /api/tool/response} callback — they are indistinguishable on the + * {@code tool-use-result} topic. + * + *

If results arrive before the seed (cross-topic ordering), the accumulator + * falls back to count-based completion using {@code expectedCount} until the + * seed merges in the expected set. + * + *

One-shot guard & tombstone

+ * The {@code emitted} flag prevents a second emission from late/duplicate + * results. After emission the entry is kept as an {@code emitted=true} tombstone + * (rather than deleted) until {@code deadlineMs} passes, so a late async + * callback for an already-completed turn is ignored instead of starting a fresh, + * never-completing accumulator. */ @JsonIgnoreProperties(ignoreUnknown = true) public record ToolResultAccumulator( @@ -20,16 +39,45 @@ public record ToolResultAccumulator( @JsonProperty("user_id") String userId, @JsonProperty("expected_count") int expectedCount, @JsonProperty("results") List results, + @JsonProperty("expected") List expected, + @JsonProperty("deadline_ms") long deadlineMs, @JsonProperty("emitted") boolean emitted, @JsonProperty("timestamp") String timestamp ) { - /** Zero-value initialiser used when a new session's first result arrives. */ + + /** A tool call the LLM requested, recorded so timeouts can synthesise an error result. */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record ExpectedTool( + @JsonProperty("tool_use_id") String toolUseId, + @JsonProperty("name") String name + ) {} + + /** Zero-value initialiser used when a new session's first result arrives (no seed yet). */ public static ToolResultAccumulator empty(String sessionId, String userId, int expectedCount) { - return new ToolResultAccumulator(sessionId, userId, expectedCount, List.of(), false, null); + return new ToolResultAccumulator( + sessionId, userId, expectedCount, List.of(), List.of(), 0L, false, null); } /** Returns {@code true} when all expected tool results have been collected. */ public boolean isComplete() { - return !emitted && results != null && results.size() >= expectedCount && expectedCount > 0; + if (emitted || results == null) return false; + if (expected != null && !expected.isEmpty()) { + return missing().isEmpty(); + } + // Fallback before the seed has merged: count-based. + return expectedCount > 0 && results.size() >= expectedCount; + } + + /** Expected tools that do not yet have a corresponding result, by tool_use_id. */ + public List missing() { + if (expected == null || expected.isEmpty()) return List.of(); + Set have = results == null ? Set.of() + : results.stream() + .map(ToolUseResult::toolUseId) + .filter(Objects::nonNull) + .collect(Collectors.toSet()); + return expected.stream() + .filter(e -> e.toolUseId() != null && !have.contains(e.toolUseId())) + .collect(Collectors.toList()); } -} \ No newline at end of file +} diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ToolUseResult.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ToolUseResult.java index 8e027cd..0d6fa5c 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ToolUseResult.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/ToolUseResult.java @@ -15,6 +15,7 @@ public record ToolUseResult( @JsonProperty("session_id") String sessionId, @JsonProperty("tool_use_id") String toolUseId, + @JsonProperty("tool_id") String toolId, // registered tool id (nullable; informational) @JsonProperty("name") String name, @JsonProperty("result") Map result, @JsonProperty("latency_ms") long latencyMs, diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/UserResponse.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/UserResponse.java index f2e3d38..43f2bdb 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/model/UserResponse.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/model/UserResponse.java @@ -1,8 +1,11 @@ package io.flightdeck.streams.model; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Map; + /** * Final outbound envelope published to {@code message-output} once the LLM * signals end-of-turn with no outstanding tool calls. @@ -10,8 +13,16 @@ *

Produced by {@code EndTurnProcessor} as the terminal step of a * conversation turn — the value the user-facing layer actually delivers * to the end user. + * + *

{@code replyTo} is an optional transport-level routing descriptor joined in + * at end-turn from the reply-to topic. It is present only for sessions that were + * invoked with a {@code reply} object (multi-agent calls); for ordinary + * interactive chats it is null and omitted from the serialized output. The + * OutputConsumer uses it to deliver the response back to the caller instead of a + * WebSocket client. */ @JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) public record UserResponse( @JsonProperty("session_id") String sessionId, @JsonProperty("user_id") String userId, @@ -20,5 +31,6 @@ public record UserResponse( @JsonProperty("output_tokens") int outputTokens, @JsonProperty("cost") Double cost, @JsonProperty("source_agent") String sourceAgent, // which agent produced this + @JsonProperty("reply_to") Map replyTo, // multi-agent reply route (nullable) @JsonProperty("timestamp") String timestamp -) {} \ No newline at end of file +) {} diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/AggregateToolExecutionResultProcessor.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/AggregateToolExecutionResultProcessor.java index 980713e..83cd6b5 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/AggregateToolExecutionResultProcessor.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/AggregateToolExecutionResultProcessor.java @@ -1,48 +1,67 @@ package io.flightdeck.streams.processors; import io.flightdeck.streams.config.Topics; +import io.flightdeck.streams.model.ThinkResponse; import io.flightdeck.streams.model.ToolResultAccumulator; +import io.flightdeck.streams.model.ToolResultAccumulator.ExpectedTool; +import io.flightdeck.streams.model.ToolUseItem; import io.flightdeck.streams.model.ToolUseResult; import io.flightdeck.streams.serdes.JsonSerde; import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.StreamsBuilder; import org.apache.kafka.streams.kstream.Consumed; import org.apache.kafka.streams.kstream.KStream; import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.processor.PunctuationType; import org.apache.kafka.streams.processor.api.Processor; import org.apache.kafka.streams.processor.api.ProcessorContext; import org.apache.kafka.streams.processor.api.ProcessorSupplier; import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.KeyValueIterator; import org.apache.kafka.streams.state.KeyValueStore; import org.apache.kafka.streams.state.StoreBuilder; import org.apache.kafka.streams.state.Stores; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; /** *

Aggregate Tool Execution Result Processor

* - *

Reads from {@code tool-use-result} and accumulates results per session. - * Each {@link ToolUseResult} carries {@code total_tools} indicating how many - * parallel tool calls were requested. When the accumulated count equals - * {@code total_tools}, the complete {@link ToolResultAccumulator} is emitted - * to {@code tool-use-all-complete}. + *

Collects tool results per session and emits {@code tool-use-all-complete} + * once every tool the LLM requested has a result — whether produced + * synchronously by a tool consumer or asynchronously via the + * {@code /api/tool/response} callback. * - *

Topology

+ *

Two inputs, one store

*
- *   tool-use-result ──► [accumulate per session_id]
- *                              │
- *                              │  results.size() == total_tools?
- *                              │       YES ──► emit + reset store
- *                              │
- *                              ▼
- *                      tool-use-all-complete
+ *   think-request-response ─┐ (seed: expected tool_use_ids + deadline)
+ *                           ├─► [accumulate per session_id] ─► tool-use-all-complete
+ *   tool-use-result ────────┘ (append + complete check)
  * 
+ * Both streams are keyed by {@code session_id} and merged into a single + * stateful processor, so it alone owns the accumulator store and the timeout + * punctuator. The two source streams may interleave in any order; the processor + * handles results that arrive before their seed (count-based fallback) and a + * seed that completes a turn whose results all arrived first. + * + *

Async timeout

+ * A wall-clock punctuator scans for accumulators past {@code deadlineMs}. For + * those it synthesises {@code status:"error"} results for every still-missing + * tool — guaranteeing the downstream tool message carries a result for every + * {@code tool_use} block (a hard requirement of the Claude API) — then emits + * complete. Emitted entries become short-lived tombstones so late/duplicate + * callbacks are ignored rather than starting a fresh accumulator. */ public class AggregateToolExecutionResultProcessor { @@ -51,7 +70,19 @@ public class AggregateToolExecutionResultProcessor { /** RocksDB store: session_id → {@link ToolResultAccumulator}. */ public static final String ACCUMULATOR_STORE = "tool-result-accumulator-store"; - public static void register(StreamsBuilder builder) { + /** How long to wait for all results (incl. async callbacks) before timing out. */ + static final long ASYNC_TOOL_TIMEOUT_MS = + envLong("ASYNC_TOOL_TIMEOUT_MS", 300_000L); + + /** How often the wall-clock punctuator scans for timed-out / expired entries. */ + static final long PUNCTUATE_INTERVAL_MS = + envLong("TOOL_AGG_PUNCTUATE_INTERVAL_MS", 15_000L); + + /** How long an emitted tombstone is retained to absorb late/duplicate results. */ + static final long TOMBSTONE_TTL_MS = + envLong("TOOL_AGG_TOMBSTONE_TTL_MS", 60_000L); + + public static void register(StreamsBuilder builder, KStream thinkStream) { StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( @@ -61,15 +92,24 @@ public static void register(StreamsBuilder builder) { ); builder.addStateStore(storeBuilder); - KStream resultStream = builder.stream( - Topics.TOOL_USE_RESULT, - Consumed.with(Serdes.String(), JsonSerde.of(ToolUseResult.class)) - ); + // Seed events: each think response that requested tools. + KStream seedStream = thinkStream + .filter((sessionId, response) -> + response != null + && response.toolUses() != null + && !response.toolUses().isEmpty()) + .mapValues(AggEvent::seed); - resultStream + // Result events: every tool result (sync producers + async callbacks). + KStream resultStream = builder.stream( + Topics.TOOL_USE_RESULT, + Consumed.with(Serdes.String(), JsonSerde.of(ToolUseResult.class))) + .mapValues(AggEvent::result); + + seedStream.merge(resultStream) .process( - (ProcessorSupplier) ResultAccumulatorProcessor::new, + (ProcessorSupplier) AccumulatorProcessor::new, ACCUMULATOR_STORE ) .peek((sessionId, acc) -> @@ -82,99 +122,280 @@ public static void register(StreamsBuilder builder) { } // ───────────────────────────────────────────────────────────────────────── - // Inner Processor + // Internal merged-event wrapper (never hits a topic — no serde needed) + // ───────────────────────────────────────────────────────────────────────── + + enum Kind { SEED, RESULT } + + record AggEvent(Kind kind, ThinkResponse seed, ToolUseResult result) { + static AggEvent seed(ThinkResponse r) { return new AggEvent(Kind.SEED, r, null); } + static AggEvent result(ToolUseResult r) { return new AggEvent(Kind.RESULT, null, r); } + } + + // ───────────────────────────────────────────────────────────────────────── + // Processor // ───────────────────────────────────────────────────────────────────────── - static class ResultAccumulatorProcessor - implements Processor { + static class AccumulatorProcessor + implements Processor { private ProcessorContext context; - private KeyValueStore accumulatorStore; + private KeyValueStore store; @Override public void init(ProcessorContext ctx) { this.context = ctx; - this.accumulatorStore = ctx.getStateStore(ACCUMULATOR_STORE); + this.store = ctx.getStateStore(ACCUMULATOR_STORE); + ctx.schedule(Duration.ofMillis(PUNCTUATE_INTERVAL_MS), + PunctuationType.WALL_CLOCK_TIME, this::sweep); } @Override - public void process(Record record) { - String sessionId = record.key(); - ToolUseResult result = record.value(); - - if (sessionId == null || result == null) { - log.warn("Null key or value — skipping"); + public void process(Record record) { + String key = record.key(); + AggEvent event = record.value(); + if (key == null || event == null) { + log.warn("Null key or event — skipping"); return; } + if (event.kind() == Kind.SEED) { + handleSeed(key, event.seed(), record.timestamp()); + } else { + handleResult(key, event.result(), record.timestamp()); + } + } - // Read expected count from the result itself - int expectedCount = result.totalTools(); + // ── Seed: record the expected tool set and the timeout deadline ──────── + private void handleSeed(String sessionId, ThinkResponse response, long ts) { + ToolResultAccumulator current = store.get(sessionId); - // Read or initialise the accumulator for this session - ToolResultAccumulator current = accumulatorStore.get(sessionId); - if (current == null) { - current = ToolResultAccumulator.empty(sessionId, result.sessionId(), expectedCount); + List expected = new ArrayList<>(); + for (ToolUseItem item : response.toolUses()) { + expected.add(new ExpectedTool(item.toolUseId(), item.name())); + } + + if (current != null && current.emitted()) { + // Tombstone from a previous turn. If this seed names the same + // tools it is a redelivery — ignore. Otherwise it opens the next + // turn on this session: start fresh, dropping the old results. + Set seedIds = expected.stream() + .map(ExpectedTool::toolUseId).collect(Collectors.toSet()); + if (knownIds(current).containsAll(seedIds)) { + return; + } + current = null; } - // Guard: do not process if already emitted for this cycle - if (current.emitted()) { - log.warn("[{}] Result arrived after emission — ignoring tool_use_id={}", - sessionId, result.toolUseId()); + long now = System.currentTimeMillis(); + long deadline = (current != null && current.deadlineMs() > 0) + ? current.deadlineMs() + : now + ASYNC_TOOL_TIMEOUT_MS; + List existingResults = + (current != null && current.results() != null) ? current.results() : List.of(); + + ToolResultAccumulator seeded = new ToolResultAccumulator( + sessionId, + response.userId(), + expected.size(), + existingResults, + Collections.unmodifiableList(expected), + deadline, + false, + Instant.now().toString()); + + log.info("[{}] Seeded expected tool set — expected={} alreadyHave={} deadlineInMs={}", + sessionId, expected.size(), existingResults.size(), deadline - now); + + // Results may all have arrived before the seed — check completion now. + if (seeded.isComplete()) { + emitComplete(sessionId, seeded, ts); + } else { + store.put(sessionId, seeded); + } + } + + // ── Result: append and check completion ──────────────────────────────── + private void handleResult(String sessionId, ToolUseResult result, long ts) { + if (result == null) { + log.warn("[{}] Null result — skipping", sessionId); return; } - // Deduplicate by tool_use_id — skip if we already have a result for this tool call - boolean isDuplicate = current.results() != null && current.results().stream() + ToolResultAccumulator current = store.get(sessionId); + long now = System.currentTimeMillis(); + + if (current != null && current.emitted()) { + // Tombstone from a completed turn. A result whose tool_use_id was + // part of that turn is a late/duplicate callback — ignore it. A + // result with a new tool_use_id belongs to the next turn. + if (knownIds(current).contains(result.toolUseId())) { + log.warn("[{}] Late/duplicate result for completed turn — ignoring tool_use_id={}", + sessionId, result.toolUseId()); + return; + } + current = null; + } + + if (current == null) { + // First result of a turn arriving before its seed: start a + // count-based accumulator (the seed merges in the expected set later). + // user_id is unknown here — ToolUseResult carries none. Leave it + // null; the seed (think-request-response) fills it in when it + // merges, and downstream tolerates a null user_id. Using + // result.sessionId() here would stamp the session id into user_id + // and, if the count-based path completes before the seed arrives, + // emit a tool message with the wrong user_id. + current = new ToolResultAccumulator( + sessionId, null, result.totalTools(), + List.of(), List.of(), now + ASYNC_TOOL_TIMEOUT_MS, false, + Instant.now().toString()); + } + + boolean duplicate = current.results() != null && current.results().stream() .anyMatch(r -> r.toolUseId() != null && r.toolUseId().equals(result.toolUseId())); - if (isDuplicate) { + if (duplicate) { log.info("[{}] Duplicate tool_use_id={} — skipping", sessionId, result.toolUseId()); return; } - // Append the new result List updatedResults = append(current.results(), result); + int expectedCount = (current.expected() != null && !current.expected().isEmpty()) + ? current.expected().size() + : Math.max(result.totalTools(), current.expectedCount()); ToolResultAccumulator updated = new ToolResultAccumulator( sessionId, current.userId(), - expectedCount > 0 ? expectedCount : current.expectedCount(), + expectedCount, updatedResults, + current.expected(), + current.deadlineMs() > 0 ? current.deadlineMs() : now + ASYNC_TOOL_TIMEOUT_MS, false, - Instant.now().toString() - ); + Instant.now().toString()); log.info("[{}] Tool result accumulated — received={}/{} tool_use_id={}", - sessionId, - updatedResults.size(), - updated.expectedCount(), - result.toolUseId()); + sessionId, updatedResults.size(), updated.expectedCount(), result.toolUseId()); - // Check completion if (updated.isComplete()) { - log.info("[{}] All {} tool results received — emitting tool-use-all-complete", - sessionId, updated.expectedCount()); + emitComplete(sessionId, updated, ts); + } else { + store.put(sessionId, updated); + } + } - ToolResultAccumulator emitted = new ToolResultAccumulator( - updated.sessionId(), updated.userId(), updated.expectedCount(), - updated.results(), true, updated.timestamp()); + // ── Wall-clock sweep: time out stale turns, expire tombstones ────────── + private void sweep(long now) { + List toExpire = new ArrayList<>(); + List toTimeout = new ArrayList<>(); - accumulatorStore.put(sessionId, emitted); - context.forward(new Record<>(sessionId, emitted, record.timestamp())); + try (KeyValueIterator it = store.all()) { + while (it.hasNext()) { + KeyValue kv = it.next(); + ToolResultAccumulator acc = kv.value; + if (acc == null) continue; - // Reset store for the next think-cycle on this session - accumulatorStore.delete(sessionId); + if (acc.emitted()) { + // Tombstone: deadlineMs doubles as the tombstone expiry. + if (acc.deadlineMs() > 0 && now >= acc.deadlineMs()) { + toExpire.add(kv.key); + } + } else if (acc.deadlineMs() > 0 && now >= acc.deadlineMs()) { + toTimeout.add(acc); + } + } + } - } else { - accumulatorStore.put(sessionId, updated); + // Mutate the store outside the iterator. + for (String key : toExpire) { + store.delete(key); + log.debug("[{}] Tombstone expired — removed", key); + } + for (ToolResultAccumulator acc : toTimeout) { + ToolResultAccumulator filled = fillMissingWithErrors(acc, now); + log.warn("[{}] Async tool timeout — synthesising {} error result(s), emitting complete", + acc.sessionId(), filled.results().size() - acc.results().size()); + emitComplete(acc.sessionId(), filled, now); } } - private static List append(List existing, - ToolUseResult incoming) { + /** Adds a {@code status:"error"} result for every expected tool still missing. */ + private ToolResultAccumulator fillMissingWithErrors(ToolResultAccumulator acc, long now) { + List results = new ArrayList<>( + acc.results() != null ? acc.results() : List.of()); + for (ExpectedTool missing : acc.missing()) { + results.add(new ToolUseResult( + acc.sessionId(), + missing.toolUseId(), + null, + missing.name(), + Map.of("error", "async tool did not respond within " + + ASYNC_TOOL_TIMEOUT_MS + "ms", "reason", "timeout"), + ASYNC_TOOL_TIMEOUT_MS, + "error", + acc.expectedCount(), + Instant.now().toString())); + } + return new ToolResultAccumulator( + acc.sessionId(), acc.userId(), acc.expectedCount(), + Collections.unmodifiableList(results), acc.expected(), + acc.deadlineMs(), false, Instant.now().toString()); + } + + /** Forwards the completed accumulator and leaves a short-lived tombstone behind. */ + private void emitComplete(String sessionId, ToolResultAccumulator acc, long ts) { + log.info("[{}] All tool results in — emitting tool-use-all-complete (count={})", + sessionId, acc.results().size()); + + ToolResultAccumulator completed = new ToolResultAccumulator( + acc.sessionId(), acc.userId(), acc.expectedCount(), + acc.results(), acc.expected(), acc.deadlineMs(), + true, Instant.now().toString()); + context.forward(new Record<>(sessionId, completed, ts)); + + // Keep an emitted=true tombstone, retaining the completed turn's + // tool_use_ids so a late/duplicate result for THIS turn is ignored, + // while a result with a new tool_use_id still opens the next turn. + // deadlineMs is reused as the tombstone expiry. + ToolResultAccumulator tombstone = new ToolResultAccumulator( + acc.sessionId(), acc.userId(), acc.expectedCount(), + acc.results(), acc.expected(), + System.currentTimeMillis() + TOMBSTONE_TTL_MS, + true, Instant.now().toString()); + store.put(sessionId, tombstone); + } + + /** All tool_use_ids associated with a turn — from collected results and the seeded expected set. */ + private static Set knownIds(ToolResultAccumulator acc) { + Set ids = new HashSet<>(); + if (acc.results() != null) { + for (ToolUseResult r : acc.results()) { + if (r.toolUseId() != null) ids.add(r.toolUseId()); + } + } + if (acc.expected() != null) { + for (ExpectedTool e : acc.expected()) { + if (e.toolUseId() != null) ids.add(e.toolUseId()); + } + } + return ids; + } + + private static List append(List existing, ToolUseResult incoming) { List list = new ArrayList<>(); if (existing != null) list.addAll(existing); list.add(incoming); return Collections.unmodifiableList(list); } } + + private static long envLong(String key, long defaultValue) { + String v = System.getenv(key); + if (v == null || v.isBlank()) return defaultValue; + try { + return Long.parseLong(v.trim()); + } catch (NumberFormatException e) { + log.warn("Invalid {}={} — using default {}", key, v, defaultValue); + return defaultValue; + } + } } diff --git a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EndTurnProcessor.java b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EndTurnProcessor.java index 0a04f2c..a882133 100644 --- a/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EndTurnProcessor.java +++ b/processor-apps/processing/src/main/java/io/flightdeck/streams/processors/EndTurnProcessor.java @@ -1,5 +1,6 @@ package io.flightdeck.streams.processors; +import com.fasterxml.jackson.databind.ObjectMapper; import io.flightdeck.streams.config.Topics; import io.flightdeck.streams.model.MessageInput; import io.flightdeck.streams.model.ThinkResponse; @@ -7,14 +8,15 @@ import io.flightdeck.streams.serdes.JsonSerde; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.streams.StreamsBuilder; -import org.apache.kafka.streams.kstream.Consumed; import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; import org.apache.kafka.streams.kstream.Produced; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.time.Instant; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; /** @@ -59,6 +61,7 @@ public class EndTurnProcessor { private static final Logger log = LoggerFactory.getLogger(EndTurnProcessor.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); /** Role value that identifies LLM-generated content within a ThinkResponse. */ static final String ROLE_ASSISTANT = "assistant"; @@ -69,7 +72,9 @@ public class EndTurnProcessor { /** Default source-agent tag when the ThinkResponse carries no agent identifier. */ static final String DEFAULT_AGENT = "agent-1"; - public static void register(StreamsBuilder builder, KStream thinkStream) { + public static void register(StreamsBuilder builder, + KStream thinkStream, + KTable replyToTable) { thinkStream // ── Step 1: filter — only fully-resolved end-turn responses ── @@ -88,15 +93,20 @@ public static void register(StreamsBuilder builder, KStream { - UserResponse userResponse = toUserResponse(sessionId, response); - log.info("[{}] End-turn → message-output content_len={} cost={}", - sessionId, - userResponse.content().length(), - userResponse.cost() != null ? String.format("$%.6f", userResponse.cost()) : "null"); - return userResponse; - }) + // ── Step 2: left-join the per-session reply-to descriptor and + // transform ThinkResponse → UserResponse. The reply + // route is null for ordinary chats and present only + // for multi-agent calls. ────────────────────────── + .leftJoin(replyToTable, + (sessionId, response, replyJson) -> { + UserResponse userResponse = toUserResponse(sessionId, response, replyJson); + log.info("[{}] End-turn → message-output content_len={} cost={} reply_to={}", + sessionId, + userResponse.content().length(), + userResponse.cost() != null ? String.format("$%.6f", userResponse.cost()) : "null", + userResponse.replyTo() != null); + return userResponse; + }) // ── Step 3: publish to message-output ──────────────────────── .to(Topics.MESSAGE_OUTPUT, @@ -119,6 +129,16 @@ public static void register(StreamsBuilder builder, KStream */ static UserResponse toUserResponse(String sessionId, ThinkResponse response) { + return toUserResponse(sessionId, response, null); + } + + /** + * As {@link #toUserResponse(String, ThinkResponse)} but also embeds the + * transport-level reply-routing descriptor (raw JSON from the reply-to + * topic). {@code replyJson} is null for ordinary chats, in which case + * {@code replyTo} stays null and is omitted from the serialized output. + */ + static UserResponse toUserResponse(String sessionId, ThinkResponse response, String replyJson) { String content = assembleContent(response.lastInputResponse()); String sourceAgent = (response.lastInputResponse() != null) @@ -137,10 +157,23 @@ static UserResponse toUserResponse(String sessionId, ThinkResponse response) { response.thinkOutputTokens(), response.totalSessionCost(), sourceAgent, + parseReply(replyJson), Instant.now().toString() ); } + /** Parses the reply-to descriptor JSON into a map; null/blank/invalid → null. */ + @SuppressWarnings("unchecked") + static Map parseReply(String replyJson) { + if (replyJson == null || replyJson.isBlank()) return null; + try { + return MAPPER.readValue(replyJson, Map.class); + } catch (Exception e) { + log.warn("Failed to parse reply-to descriptor — ignoring route: {}", e.getMessage()); + return null; + } + } + /** * Concatenates all assistant-role message content from the list. * Returns an empty string if the list is null, empty, or contains no diff --git a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/AggregateToolExecutionResultProcessorTest.java b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/AggregateToolExecutionResultProcessorTest.java index 27f8937..135998d 100644 --- a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/AggregateToolExecutionResultProcessorTest.java +++ b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/AggregateToolExecutionResultProcessorTest.java @@ -2,12 +2,17 @@ import io.flightdeck.streams.config.Topics; import io.flightdeck.streams.model.*; +import io.flightdeck.streams.model.ToolResultAccumulator.ExpectedTool; import io.flightdeck.streams.serdes.JsonSerde; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.streams.*; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; import org.apache.kafka.streams.test.TestRecord; import org.junit.jupiter.api.*; +import java.time.Duration; +import java.time.Instant; import java.util.List; import java.util.Map; import java.util.Properties; @@ -17,35 +22,48 @@ /** * Tests for {@link AggregateToolExecutionResultProcessor}. * - * Core invariant under test: - * tool-use-all-complete is emitted ONCE when results.size() == expectedCount + * Invariants under test: + * - Count-based completion still works when no seed is present (sync-only turns). + * - Seeding from think-request-response drives expected-set completion. + * - Async tools that never call back are timed out into synthesised error results. + * - A completed turn leaves a tombstone: late/duplicate results are ignored, + * but a new turn (new tool_use_ids) on the same session proceeds. */ class AggregateToolExecutionResultProcessorTest { private TopologyTestDriver driver; - - /** Delivers arriving tool results */ - private TestInputTopic resultInput; - - /** Captures the complete-event output */ + private TestInputTopic resultInput; + private TestInputTopic seedInput; private TestOutputTopic allCompleteOutput; @BeforeEach void setUp() { StreamsBuilder builder = new StreamsBuilder(); - AggregateToolExecutionResultProcessor.register(builder); + + KStream thinkStream = builder.stream( + Topics.THINK_REQUEST_RESPONSE, + Consumed.with(Serdes.String(), JsonSerde.of(ThinkResponse.class))); + + AggregateToolExecutionResultProcessor.register(builder, thinkStream); Properties props = new Properties(); props.put(StreamsConfig.APPLICATION_ID_CONFIG, "test-aggregate-tool-results"); props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:9092"); - driver = new TopologyTestDriver(builder.build(), props); + // Anchor the mock wall clock to now so deadlines (computed from + // System.currentTimeMillis()) line up with advanceWallClockTime(). + driver = new TopologyTestDriver(builder.build(), props, Instant.now()); resultInput = driver.createInputTopic( Topics.TOOL_USE_RESULT, Serdes.String().serializer(), JsonSerde.of(ToolUseResult.class).serializer()); + seedInput = driver.createInputTopic( + Topics.THINK_REQUEST_RESPONSE, + Serdes.String().serializer(), + JsonSerde.of(ThinkResponse.class).serializer()); + allCompleteOutput = driver.createOutputTopic( Topics.TOOL_USE_ALL_COMPLETE, Serdes.String().deserializer(), @@ -55,7 +73,7 @@ void setUp() { @AfterEach void tearDown() { driver.close(); } - // ── Core emission behaviour ─────────────────────────────────────────────── + // ── Count-based completion (sync-only turns, no seed) ───────────────────── @Test @DisplayName("Single expected tool: emits when one result arrives") @@ -74,24 +92,21 @@ void singleTool_emitsOnFirstResult() { @Test @DisplayName("Three expected tools: emits only after all three results arrive") void threeTools_emitsAfterAllThree() { - // First two results — should NOT emit yet resultInput.pipeInput("sess-2", toolResult("sess-2", "t1", "tool_a", 3)); resultInput.pipeInput("sess-2", toolResult("sess-2", "t2", "tool_b", 3)); assertThat(allCompleteOutput.isEmpty()).isTrue(); - // Third result — NOW it emits resultInput.pipeInput("sess-2", toolResult("sess-2", "t3", "tool_c", 3)); assertThat(allCompleteOutput.isEmpty()).isFalse(); ToolResultAccumulator acc = allCompleteOutput.readRecord().value(); assertThat(acc.results()).hasSize(3); - assertThat(acc.expectedCount()).isEqualTo(3); assertThat(acc.results()).extracting(ToolUseResult::toolUseId) .containsExactlyInAnyOrder("t1", "t2", "t3"); } @Test - @DisplayName("All results in the emitted accumulator carry the correct tool names") + @DisplayName("Emitted accumulator carries the correct tool names") void emittedAccumulator_containsAllToolNames() { resultInput.pipeInput("sess-3", toolResult("sess-3", "ta", "get_invoice", 2)); resultInput.pipeInput("sess-3", toolResult("sess-3", "tb", "send_email", 2)); @@ -101,116 +116,151 @@ void emittedAccumulator_containsAllToolNames() { .containsExactlyInAnyOrder("get_invoice", "send_email"); } - // ── One-shot emission guard ─────────────────────────────────────────────── + // ── Seeded (expected-set) completion ────────────────────────────────────── @Test - @DisplayName("Late duplicate result after completion starts a new cycle (store was reset)") - void lateDuplicate_startsNewCycle() { - resultInput.pipeInput("sess-4", toolResult("sess-4", "t1", "tool_x", 1)); - allCompleteOutput.readRecord(); // consume the legitimate emission + @DisplayName("Seed + all results: emits with the seeded expected set") + void seeded_emitsWhenAllExpectedArrive() { + seedInput.pipeInput("sess-s", thinkSeed("sess-s", "u", + toolItem("sess-s", "x1", "tool_a"), toolItem("sess-s", "x2", "tool_b"))); + assertThat(allCompleteOutput.isEmpty()).isTrue(); - // After store reset, a duplicate is treated as a brand new cycle - resultInput.pipeInput("sess-4", toolResult("sess-4", "t1", "tool_x", 1)); + resultInput.pipeInput("sess-s", toolResult("sess-s", "x1", "tool_a", 2)); + assertThat(allCompleteOutput.isEmpty()).isTrue(); - // New cycle completes immediately since totalTools=1 + resultInput.pipeInput("sess-s", toolResult("sess-s", "x2", "tool_b", 2)); assertThat(allCompleteOutput.isEmpty()).isFalse(); + + ToolResultAccumulator acc = allCompleteOutput.readRecord().value(); + assertThat(acc.results()).extracting(ToolUseResult::toolUseId) + .containsExactlyInAnyOrder("x1", "x2"); + assertThat(acc.expected()).extracting(ExpectedTool::toolUseId) + .containsExactlyInAnyOrder("x1", "x2"); } + // ── Async timeout ───────────────────────────────────────────────────────── + @Test - @DisplayName("Two identical results each trigger separate cycles after store reset") - void duplicateResults_eachTriggerCycle() { - resultInput.pipeInput("sess-5", toolResult("sess-5", "t1", "tool_a", 1)); - resultInput.pipeInput("sess-5", toolResult("sess-5", "t1", "tool_a", 1)); + @DisplayName("Async tool that never calls back is timed out into a synthesised error result") + void asyncTimeout_synthesisesErrorForMissing() { + // Two tools requested; only one (sync) returns. The other is async and never calls back. + seedInput.pipeInput("sess-t", thinkSeed("sess-t", "u", + toolItem("sess-t", "sync1", "lookup"), + toolItem("sess-t", "async1", "generate_report"))); + resultInput.pipeInput("sess-t", toolResult("sess-t", "sync1", "lookup", 2)); + assertThat(allCompleteOutput.isEmpty()).isTrue(); - // Each result starts a fresh cycle (store is deleted after emission) - assertThat(allCompleteOutput.readRecordsToList()).hasSize(2); + // Advance past the async timeout (default 5 min). + driver.advanceWallClockTime(Duration.ofMinutes(6)); + + assertThat(allCompleteOutput.isEmpty()).isFalse(); + ToolResultAccumulator acc = allCompleteOutput.readRecord().value(); + assertThat(acc.results()).hasSize(2); + + ToolUseResult synthesised = acc.results().stream() + .filter(r -> r.toolUseId().equals("async1")).findFirst().orElseThrow(); + assertThat(synthesised.status()).isEqualTo("error"); + assertThat(synthesised.result()).containsEntry("reason", "timeout"); } - // ── Store reset after emission ───────────────────────────────────────────── + // ── Tombstone: late/duplicate vs next turn ──────────────────────────────── @Test - @DisplayName("After emission, stores are reset — a new think cycle starts fresh") - void storeReset_newCycleBeginsClean() { - // First think cycle: 1 tool + @DisplayName("Late duplicate result for a completed turn is ignored (no second emission)") + void lateDuplicate_isIgnored() { + resultInput.pipeInput("sess-4", toolResult("sess-4", "t1", "tool_x", 1)); + allCompleteOutput.readRecord(); // legitimate emission + + // Same tool_use_id again — a duplicate callback for the completed turn. + resultInput.pipeInput("sess-4", toolResult("sess-4", "t1", "tool_x", 1)); + assertThat(allCompleteOutput.isEmpty()).isTrue(); + } + + @Test + @DisplayName("A new turn (new tool_use_ids) on a completed session proceeds normally") + void nextTurn_afterCompletion_proceeds() { + // Turn 1: one tool. resultInput.pipeInput("sess-6", toolResult("sess-6", "t1", "tool_a", 1)); - allCompleteOutput.readRecord(); // drain first emission + allCompleteOutput.readRecord(); - // Second think cycle on the SAME session: 2 tools + // Turn 2 on the SAME session: two new tools. resultInput.pipeInput("sess-6", toolResult("sess-6", "t2", "tool_b", 2)); - assertThat(allCompleteOutput.isEmpty()).isTrue(); // not yet + assertThat(allCompleteOutput.isEmpty()).isTrue(); resultInput.pipeInput("sess-6", toolResult("sess-6", "t3", "tool_c", 2)); - assertThat(allCompleteOutput.isEmpty()).isFalse(); // now complete + assertThat(allCompleteOutput.isEmpty()).isFalse(); ToolResultAccumulator acc = allCompleteOutput.readRecord().value(); - assertThat(acc.results()).hasSize(2); - assertThat(acc.expectedCount()).isEqualTo(2); + assertThat(acc.results()).extracting(ToolUseResult::toolUseId) + .containsExactlyInAnyOrder("t2", "t3"); } // ── Session isolation ───────────────────────────────────────────────────── @Test - @DisplayName("Two sessions accumulate independently with no cross-contamination") + @DisplayName("Two sessions accumulate independently") void sessionIsolation() { - // Session A completes resultInput.pipeInput("A", toolResult("A", "a1", "tool_a", 1)); - // Session B gets only its first result resultInput.pipeInput("B", toolResult("B", "b1", "tool_b", 2)); - // Only A should have emitted - List> records = - allCompleteOutput.readRecordsToList(); + List> records = allCompleteOutput.readRecordsToList(); assertThat(records).hasSize(1); assertThat(records.get(0).key()).isEqualTo("A"); - // Now complete B resultInput.pipeInput("B", toolResult("B", "b2", "tool_c", 2)); assertThat(allCompleteOutput.readRecord().key()).isEqualTo("B"); } - // ── Output key ──────────────────────────────────────────────────────────── - @Test @DisplayName("Emitted record is keyed by session_id") void outputKey_isSessionId() { resultInput.pipeInput("key-sess", toolResult("key-sess", "t1", "tool_x", 1)); - assertThat(allCompleteOutput.readRecord().key()).isEqualTo("key-sess"); } // ── ToolResultAccumulator model unit tests ──────────────────────────────── @Test - @DisplayName("isComplete: true when results.size() == expectedCount and not yet emitted") - void isComplete_true() { + @DisplayName("isComplete (count fallback): true when results.size() == expectedCount") + void isComplete_countFallback_true() { ToolUseResult r = toolResult("s", "t1", "tool", 1); ToolResultAccumulator acc = new ToolResultAccumulator( - "s", "u", 1, List.of(r), false, TS); + "s", "u", 1, List.of(r), List.of(), 0L, false, TS); assertThat(acc.isComplete()).isTrue(); } @Test - @DisplayName("isComplete: false when results.size() < expectedCount") - void isComplete_false_notEnoughResults() { + @DisplayName("isComplete (count fallback): false when not enough results") + void isComplete_countFallback_false() { ToolResultAccumulator acc = new ToolResultAccumulator( - "s", "u", 2, List.of(toolResult("s","t1","x", 2)), false, TS); + "s", "u", 2, List.of(toolResult("s", "t1", "x", 2)), List.of(), 0L, false, TS); assertThat(acc.isComplete()).isFalse(); } @Test - @DisplayName("isComplete: false when already emitted") - void isComplete_false_alreadyEmitted() { - ToolUseResult r = toolResult("s", "t1", "tool", 1); - ToolResultAccumulator acc = new ToolResultAccumulator( - "s", "u", 1, List.of(r), true, TS); // emitted=true - assertThat(acc.isComplete()).isFalse(); + @DisplayName("isComplete (expected set): true only when every expected id has a result") + void isComplete_expectedSet() { + List expected = List.of(new ExpectedTool("a", "ta"), new ExpectedTool("b", "tb")); + + ToolResultAccumulator partial = new ToolResultAccumulator( + "s", "u", 2, List.of(toolResult("s", "a", "ta", 2)), expected, 0L, false, TS); + assertThat(partial.isComplete()).isFalse(); + assertThat(partial.missing()).extracting(ExpectedTool::toolUseId).containsExactly("b"); + + ToolResultAccumulator full = new ToolResultAccumulator( + "s", "u", 2, + List.of(toolResult("s", "a", "ta", 2), toolResult("s", "b", "tb", 2)), + expected, 0L, false, TS); + assertThat(full.isComplete()).isTrue(); + assertThat(full.missing()).isEmpty(); } @Test - @DisplayName("isComplete: false when expectedCount is 0 (not yet registered)") - void isComplete_false_zeroExpected() { + @DisplayName("isComplete: false when already emitted") + void isComplete_false_alreadyEmitted() { + ToolUseResult r = toolResult("s", "t1", "tool", 1); ToolResultAccumulator acc = new ToolResultAccumulator( - "s", "u", 0, List.of(toolResult("s","t1","x", 0)), false, TS); + "s", "u", 1, List.of(r), List.of(), 0L, true, TS); assertThat(acc.isComplete()).isFalse(); } @@ -220,7 +270,18 @@ void isComplete_false_zeroExpected() { private static ToolUseResult toolResult(String sessionId, String toolUseId, String name, int totalTools) { - return new ToolUseResult(sessionId, toolUseId, name, + return new ToolUseResult(sessionId, toolUseId, null, name, Map.of("status", "ok"), 320L, "success", totalTools, TS); } + + private static ToolUseItem toolItem(String sessionId, String toolUseId, String name) { + return new ToolUseItem(toolUseId, name, name, Map.of(), sessionId, 0, TS); + } + + private static ThinkResponse thinkSeed(String sessionId, String userId, ToolUseItem... items) { + return new ThinkResponse( + sessionId, userId, 0.0, 0.0, 0.0, 0, 0, + List.of(), null, List.of(), List.of(items), + false, false, 0, 0, 0.0, TS); + } } diff --git a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EndTurnProcessorTest.java b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EndTurnProcessorTest.java index 1cab7b5..4a547bc 100644 --- a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EndTurnProcessorTest.java +++ b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/EndTurnProcessorTest.java @@ -10,9 +10,11 @@ import org.apache.kafka.streams.*; import org.apache.kafka.streams.kstream.Consumed; import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; import org.apache.kafka.streams.test.TestRecord; import org.junit.jupiter.api.*; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import java.util.Properties; @@ -32,7 +34,10 @@ void setUp() { KStream thinkStream = builder.stream( Topics.THINK_REQUEST_RESPONSE, Consumed.with(Serdes.String(), JsonSerde.of(ThinkResponse.class))); - EndTurnProcessor.register(builder, thinkStream); + KTable replyToTable = builder.table( + Topics.REPLY_TO, + Consumed.with(Serdes.String(), Serdes.String())); + EndTurnProcessor.register(builder, thinkStream, replyToTable); Properties props = new Properties(); props.put(StreamsConfig.APPLICATION_ID_CONFIG, "test-end-turn"); @@ -239,6 +244,35 @@ void toUserResponse_noMessages_emptyContent() { assertThat(result.content()).isEmpty(); } + @Test + @DisplayName("toUserResponse: reads the think-consumer 'messages' wire schema (regression — empty content)") + void toUserResponse_messagesSchema_extractsContent() { + // Exact shape the think-consumer emits onto {agent}-think-request-response: + // top-level "messages" / "input_tokens" / "output_tokens" (older schema, still + // produced by published images), NOT last_input_response / think_*_tokens. + // Regression: this turned into empty content + zero tokens on message-output. + String wire = "{" + + "\"session_id\":\"sess--tu\",\"user_id\":\"user_42\"," + + "\"cost\":null,\"prev_session_cost\":null," + + "\"input_tokens\":241,\"output_tokens\":402," + + "\"messages\":[" + + " {\"session_id\":\"sess--tu\",\"user_id\":\"user_42\",\"role\":\"user\"," + + " \"content\":\"Create a morning list\",\"timestamp\":\"" + TS + "\",\"metadata\":{}}," + + " {\"session_id\":\"sess--tu\",\"user_id\":\"user_42\",\"role\":\"assistant\"," + + " \"content\":\"# Morning Routine Checklist\\n- Hydrate\",\"timestamp\":\"" + TS + "\",\"metadata\":null}" + + "]," + + "\"tool_uses\":null,\"end_turn\":true,\"timestamp\":\"" + TS + "\"}"; + + ThinkResponse resp = JsonSerde.of(ThinkResponse.class).deserializer() + .deserialize("t", wire.getBytes(StandardCharsets.UTF_8)); + + UserResponse result = toUserResponse("sess--tu", resp); + + assertThat(result.content()).contains("Morning Routine Checklist"); + assertThat(result.outputTokens()).isEqualTo(402); + assertThat(result.inputTokens()).isEqualTo(241); + } + // ── total_session_cost ──────────────────────────────────────────────────── @Test diff --git a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/TransformToolUseDoneProcessorTest.java b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/TransformToolUseDoneProcessorTest.java index 7fcc991..59ef0c1 100644 --- a/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/TransformToolUseDoneProcessorTest.java +++ b/processor-apps/processing/src/test/java/io/flightdeck/streams/processors/TransformToolUseDoneProcessorTest.java @@ -164,7 +164,7 @@ void metadata_toolResultsPresent() { @DisplayName("Accumulator with empty results list is filtered and produces no output") void emptyResults_filtered() { allCompleteInput.pipeInput("sess-empty", - new ToolResultAccumulator("sess-empty", "u", 0, List.of(), true, TS)); + new ToolResultAccumulator("sess-empty", "u", 0, List.of(), List.of(), 0L, true, TS)); assertThat(messageOutput.isEmpty()).isTrue(); } @@ -272,11 +272,11 @@ void toMessageInput_contentNonEmpty() { private static ToolResultAccumulator accumulator(String sessionId, String userId, List results) { return new ToolResultAccumulator(sessionId, userId, results.size(), - results, true, TS); + results, List.of(), 0L, true, TS); } private static ToolUseResult result(String sessionId, String toolUseId, String name, Map resultData) { - return new ToolUseResult(sessionId, toolUseId, name, resultData, 320L, "success", 1, TS); + return new ToolUseResult(sessionId, toolUseId, null, name, resultData, 320L, "success", 1, TS); } } diff --git a/think/think-consumer/Dockerfile b/think/think-consumer/Dockerfile index 2f6111a..e6c58ce 100644 --- a/think/think-consumer/Dockerfile +++ b/think/think-consumer/Dockerfile @@ -3,7 +3,9 @@ WORKDIR /app COPY pom.xml . RUN mvn dependency:go-offline -q COPY src ./src -RUN mvn package -q -DskipTests +# Skip test compilation+execution for the image build (tests run in CI via +# `mvn test`); keeps the image build decoupled from test sources. +RUN mvn package -q -Dmaven.test.skip=true FROM eclipse-temurin:17-jre WORKDIR /app