From e4e62ed2e935635add4fb37dc3d2c1d68886bd75 Mon Sep 17 00:00:00 2001 From: purshotam shah Date: Mon, 15 Jun 2026 13:13:46 -0700 Subject: [PATCH 1/2] [api][routing] Pluggable in-chat LLM routing (ChatModelRouter + RoutingStrategy) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a drop-in chat model that selects which underlying model serves each request, then delegates to it. The router is a CHAT_MODEL resource, so an agent points at it by name with no runtime, event, or agent-definition change. Selection is a pluggable SPI (`RoutingStrategy`), decomposed into orthogonal concerns: - RoutingStrategy — pure selection (request -> candidate name). Returning null means "abstain / no opinion". - FallbackPolicy — optional: try remaining candidates on error. - CachingStrategy — optional bounded-LRU memoization of the decision per conversation, so an expensive strategy (e.g. an LLM judge) runs once per conversation rather than once per tool-call round. Built-in strategies: - RuleBasedRoutingStrategy — deterministic keyword/regex rules + default. - LlmRoutingStrategy — a small "judge" model picks the candidate from each candidate's name/description (RouteLLM-style). Bring-your-own strategies are first-class: implement RoutingStrategy and reference it by fully-qualified class name; loaded via the thread context classloader (cluster-safe). ML/learned routing is supported the same way. Routing-miss semantics: a strategy that abstains (null) or names a non-candidate degrades to the configured `default` candidate (validated at construction; defaults to the first candidate) rather than failing the request. The LLM judge distinguishes a transient failure (abstain -> not cached, retried next round) from an unparseable reply (deterministic default). Security: an LLM/ML routing decision is a hint, not an authority — the user's message is sent to the judge, so cost/privilege/safety must not be gated solely on it (prompt-injection risk). This is documented on the strategy. Includes an example (LlmRoutingAgentExample) and unit tests covering rule selection, judge parsing (whole-token match, no substring mis-routing), stickiness, fallback, caching (incl. abstain-not-cached), and bring-your-own. Also mirror the RULE_BASED/LLM ResourceName constants on the Python side (ResourceName.RoutingStrategy.Java) and register RoutingStrategy in the cross-language ResourceName parity check. --- .../routing/AbstractRoutingStrategy.java | 55 +++ .../chat/model/routing/CachingStrategy.java | 102 ++++ .../chat/model/routing/ChatModelRouter.java | 315 +++++++++++++ .../chat/model/routing/FallbackPolicy.java | 61 +++ .../model/routing/LlmRoutingStrategy.java | 211 +++++++++ .../chat/model/routing/RoutingCandidate.java | 131 ++++++ .../chat/model/routing/RoutingContext.java | 97 ++++ .../chat/model/routing/RoutingStrategy.java | 71 +++ .../routing/RuleBasedRoutingStrategy.java | 188 ++++++++ .../agents/api/resource/ResourceName.java | 14 + .../model/routing/CachingStrategyTest.java | 118 +++++ .../model/routing/ChatModelRouterTest.java | 444 ++++++++++++++++++ .../model/routing/LlmRoutingStrategyTest.java | 170 +++++++ .../model/routing/RoutingTestSupport.java | 146 ++++++ .../routing/RuleBasedRoutingStrategyTest.java | 169 +++++++ .../check_resource_consistency.py | 7 +- .../examples/LlmRoutingAgentExample.java | 186 ++++++++ python/flink_agents/api/resource.py | 12 + 18 files changed, 2496 insertions(+), 1 deletion(-) create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/routing/AbstractRoutingStrategy.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategy.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouter.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/routing/FallbackPolicy.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategy.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingCandidate.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingContext.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingStrategy.java create mode 100644 api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RuleBasedRoutingStrategy.java create mode 100644 api/src/test/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategyTest.java create mode 100644 api/src/test/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouterTest.java create mode 100644 api/src/test/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategyTest.java create mode 100644 api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingTestSupport.java create mode 100644 api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RuleBasedRoutingStrategyTest.java create mode 100644 examples/src/main/java/org/apache/flink/agents/examples/LlmRoutingAgentExample.java diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/AbstractRoutingStrategy.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/AbstractRoutingStrategy.java new file mode 100644 index 000000000..569929d88 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/AbstractRoutingStrategy.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; + +/** + * Convenience base class for {@link RoutingStrategy} implementations that are instantiated from a + * {@link ResourceDescriptor}. + * + *

{@link ChatModelRouter} instantiates the configured strategy reflectively, requiring a public + * constructor with the signature {@code (ResourceDescriptor, ResourceContext)} — the same + * convention used by {@link org.apache.flink.agents.api.resource.Resource}. Extending this class + * gives custom strategies that constructor for free and exposes the descriptor/context to + * subclasses. + */ +public abstract class AbstractRoutingStrategy implements RoutingStrategy { + + protected final ResourceDescriptor descriptor; + protected final ResourceContext resourceContext; + + protected AbstractRoutingStrategy( + ResourceDescriptor descriptor, ResourceContext resourceContext) { + this.descriptor = descriptor; + this.resourceContext = resourceContext; + } + + /** Read a strategy configuration argument from the backing descriptor. */ + protected T arg(String name) { + return descriptor != null ? descriptor.getArgument(name) : null; + } + + /** Read a strategy configuration argument, falling back to {@code defaultValue} when absent. */ + protected T arg(String name, T defaultValue) { + T value = arg(name); + return value != null ? value : defaultValue; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategy.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategy.java new file mode 100644 index 000000000..428d789ae --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategy.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * A {@link RoutingStrategy} decorator that memoizes the wrapped strategy's decision per + * conversation (keyed on {@link RoutingContext#firstUserMessage()}), so an expensive selection — + * e.g. an LLM judge — runs once per conversation rather than on every tool-call round. + * + *

The cache is a bounded LRU with real eviction (oldest entries are dropped past the + * capacity), so it never grows without bound and never silently stops caching. Empty keys (requests + * with no user message) are not cached, to avoid coupling unrelated empty-prompt conversations. A + * {@code null} decision (the strategy abstaining — e.g. a transient LLM-judge failure) is likewise + * not cached, so the strategy is re-consulted on the next round rather than pinned to a fallback. + * Thread-safe for the async execution pool. + */ +public final class CachingStrategy implements RoutingStrategy { + + /** Default cache capacity if none is specified. */ + public static final int DEFAULT_MAX_ENTRIES = 1024; + + private final RoutingStrategy delegate; + private final Map cache; + + public CachingStrategy(RoutingStrategy delegate) { + this(delegate, DEFAULT_MAX_ENTRIES); + } + + public CachingStrategy(RoutingStrategy delegate, int maxEntries) { + if (delegate == null) { + throw new IllegalArgumentException("delegate strategy must not be null"); + } + if (maxEntries <= 0) { + throw new IllegalArgumentException("maxEntries must be positive: " + maxEntries); + } + this.delegate = delegate; + this.cache = Collections.synchronizedMap(new LruMap(maxEntries)); + } + + @Override + public String route(RoutingContext context) throws Exception { + String key = context.firstUserMessage(); + if (key.isEmpty()) { + // Don't cache empty keys: every empty-prompt conversation would otherwise share one + // decision. Recompute each time instead. + return delegate.route(context); + } + String cached = cache.get(key); + if (cached != null) { + return cached; + } + String chosen = delegate.route(context); + if (chosen != null) { + // Only memoize a real decision. A null is the strategy abstaining ("no opinion", e.g. a + // transient LLM-judge failure); caching it would pin the whole conversation to the + // router's default and never re-consult the strategy. + cache.put(key, chosen); + } + return chosen; + } + + /** The strategy this caches. */ + public RoutingStrategy getDelegate() { + return delegate; + } + + /** Bounded access-order LRU map; evicts the eldest entry past {@code maxEntries}. */ + private static final class LruMap extends LinkedHashMap { + private static final long serialVersionUID = 1L; + private final int maxEntries; + + LruMap(int maxEntries) { + super(16, 0.75f, true); + this.maxEntries = maxEntries; + } + + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > maxEntries; + } + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouter.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouter.java new file mode 100644 index 000000000..338e9857e --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouter.java @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Constructor; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A pluggable LLM router that selects, per request, which underlying chat model should serve it. + * + *

The router is a drop-in {@link BaseChatModelSetup}: it reports {@link ResourceType#CHAT_MODEL} + * and is resolved by the built-in chat action exactly like any other model, so an agent points at + * the router by name and nothing else in the framework needs to change. Concerns are layered: + * selection ({@link RoutingStrategy}) decides the model, optional {@link CachingStrategy} memoizes + * that decision per conversation, and {@link FallbackPolicy} decides what to try if the chosen + * model errors. The router then delegates to the chosen model's own {@code chat(...)}, preserving + * its prompt, tools, parameters, and token metrics. + * + *

Configuration ({@link ResourceDescriptor} arguments)

+ * + * + * + *

Graceful degrade: if the strategy returns {@code null} ("no opinion", e.g. a transient + * LLM-judge failure) or a name that is not a configured candidate, the router treats it as a + * routing miss and serves the request from {@code default} (then the fallback order) rather than + * failing. + * + *

Bash/skill tool args (v1 scope): the built-in chat action injects bash allowlists and + * skill directories from the resource resolved by the agent's model name — i.e. this router — not + * the chosen backend. So configure {@code allowed_commands} / {@code allowed_script_dirs} / {@code + * skills} on the router; per-candidate skills/allowlists are not supported in v1. + * + *

Metrics note (v1): retry metrics recorded by the built-in chat action are grouped under + * this router's connection label ({@code "router"}), not the backend model actually used. + * Per-backend attribution is a documented follow-up. + */ +public class ChatModelRouter extends BaseChatModelSetup { + + private static final Logger LOG = LoggerFactory.getLogger(ChatModelRouter.class); + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + public static final String ARG_CANDIDATES = "candidates"; + public static final String ARG_STRATEGY = "strategy"; + public static final String ARG_FALLBACK = "fallback"; + public static final String ARG_CACHE = "cache"; + public static final String ARG_CACHE_SIZE = "cache_size"; + public static final String ARG_DEFAULT = "default"; + + private final List candidates; + private final Set candidateNames; + private final RoutingStrategy strategy; + private final FallbackPolicy fallbackPolicy; + private final boolean fallbackEnabled; + private final String defaultCandidate; + + @SuppressWarnings("unchecked") + public ChatModelRouter(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + + List rawCandidates = descriptor.getArgument(ARG_CANDIDATES); + if (rawCandidates == null || rawCandidates.isEmpty()) { + throw new IllegalArgumentException( + "ChatModelRouter requires a non-empty 'candidates' argument."); + } + List parsed = new ArrayList<>(rawCandidates.size()); + Set names = new LinkedHashSet<>(); + for (Object spec : rawCandidates) { + RoutingCandidate candidate = RoutingCandidate.from(spec); + parsed.add(candidate); + names.add(candidate.getName()); + } + this.candidates = Collections.unmodifiableList(parsed); + this.candidateNames = Collections.unmodifiableSet(names); + + ResourceDescriptor strategyDescriptor = + toResourceDescriptor(descriptor.getArgument(ARG_STRATEGY)); + RoutingStrategy base = instantiateStrategy(strategyDescriptor, resourceContext); + + boolean cache = descriptor.getArgument(ARG_CACHE, Boolean.TRUE); + if (cache) { + int cacheSize = + descriptor.getArgument(ARG_CACHE_SIZE, CachingStrategy.DEFAULT_MAX_ENTRIES); + this.strategy = new CachingStrategy(base, cacheSize); + } else { + this.strategy = base; + } + + this.fallbackEnabled = descriptor.getArgument(ARG_FALLBACK, Boolean.FALSE); + this.fallbackPolicy = + fallbackEnabled ? FallbackPolicy.remainingCandidates() : FallbackPolicy.none(); + + // Default candidate used on a routing miss (strategy abstains / names a non-candidate). + // Validated at construction so a typo is a clear config error, not a per-request failure. + String configuredDefault = descriptor.getArgument(ARG_DEFAULT); + if (configuredDefault != null && !candidateNames.contains(configuredDefault)) { + throw new IllegalArgumentException( + "ChatModelRouter 'default' '" + + configuredDefault + + "' is not one of the configured candidates " + + candidateNames); + } + this.defaultCandidate = + configuredDefault != null ? configuredDefault : candidates.get(0).getName(); + } + + /** + * The router has no connection of its own to resolve (it delegates to candidate models, each of + * which resolves its own). Override to skip the base class's connection resolution. + */ + @Override + public void open() { + // no-op + } + + /** + * Coerce the {@code strategy} argument into a {@link ResourceDescriptor}: a descriptor directly + * (in-memory construction) or its deserialized {@link Map} form after the {@code AgentPlan} + * round-trips through JSON. The {@link Map} form is converted with the canonical {@link + * ObjectMapper} via {@link ResourceDescriptor}'s own Jackson binding, rather than hand-reading + * field names. + */ + private static ResourceDescriptor toResourceDescriptor(Object strategyArg) { + if (strategyArg instanceof ResourceDescriptor) { + return (ResourceDescriptor) strategyArg; + } + if (strategyArg instanceof Map) { + return MAPPER.convertValue(strategyArg, ResourceDescriptor.class); + } + throw new IllegalArgumentException( + "ChatModelRouter requires a 'strategy' argument of type ResourceDescriptor (or its" + + " serialized map form), but got: " + + (strategyArg == null ? "null" : strategyArg.getClass().getName())); + } + + private static RoutingStrategy instantiateStrategy( + ResourceDescriptor descriptor, ResourceContext resourceContext) { + String clazz = descriptor.getClazz(); + if (clazz == null || clazz.isEmpty()) { + throw new IllegalArgumentException("Routing strategy descriptor is missing a class."); + } + try { + // Use the thread context classloader (the convention in JavaResourceProvider) so that + // user-supplied strategy classes resolve on a Flink cluster, not just the API jar. + Class strategyClass = + Class.forName(clazz, true, Thread.currentThread().getContextClassLoader()); + if (!RoutingStrategy.class.isAssignableFrom(strategyClass)) { + throw new IllegalArgumentException( + clazz + " does not implement " + RoutingStrategy.class.getName()); + } + Constructor constructor = + strategyClass.getConstructor(ResourceDescriptor.class, ResourceContext.class); + return (RoutingStrategy) constructor.newInstance(descriptor, resourceContext); + } catch (RuntimeException e) { + throw e; + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate routing strategy " + clazz, e); + } + } + + @Override + public ChatMessage chat( + List messages, + Map promptArgs, + Map modelParams) { + String primary = select(messages, promptArgs); + if (primary == null) { + // Routing miss (strategy abstained or chose a non-candidate): degrade to the default + // candidate instead of failing, then apply the fallback order. + primary = defaultCandidate; + } + List order = fallbackPolicy.attemptOrder(primary, candidates); + + Exception lastError = null; + for (String name : order) { + try { + return resolveCandidate(name).chat(messages, promptArgs, modelParams); + } catch (Exception e) { + lastError = e; + if (!fallbackEnabled) { + throw asRuntime(e, name); + } + LOG.warn( + "Routed model '{}' failed; falling back to the next candidate. Cause: {}", + name, + e.toString()); + } + } + throw new RuntimeException( + "All routed candidates failed for router. Tried: " + order, lastError); + } + + /** Run the strategy and validate its choice against the configured candidates. */ + private String select(List messages, Map promptArgs) { + String primary; + try { + primary = + strategy.route( + new RoutingContext(messages, promptArgs, candidates, resourceContext)); + } catch (Exception e) { + throw new RuntimeException("Routing strategy failed to select a model.", e); + } + if (primary == null || primary.isEmpty()) { + // Strategy abstained ("no opinion") -> routing miss; caller degrades to the default. + return null; + } + if (!candidateNames.contains(primary)) { + // A typo'd/unknown name must not hard-fail the request; treat it as a miss so the + // router can degrade gracefully to the default candidate. + LOG.warn( + "Routing strategy chose '{}', not a configured candidate {}; treating as a" + + " routing miss (using default '{}').", + primary, + candidateNames, + defaultCandidate); + return null; + } + return primary; + } + + private BaseChatModelSetup resolveCandidate(String name) throws Exception { + if (resourceContext == null) { + throw new IllegalStateException( + "Router has no ResourceContext; cannot resolve candidate '" + name + "'."); + } + Object resource = resourceContext.getResource(name, ResourceType.CHAT_MODEL); + if (!(resource instanceof BaseChatModelSetup)) { + throw new IllegalStateException( + "Routed resource '" + + name + + "' is not a chat model setup (CHAT_MODEL): " + + (resource == null ? "null" : resource.getClass().getName())); + } + return (BaseChatModelSetup) resource; + } + + private static RuntimeException asRuntime(Exception e, String name) { + if (e instanceof RuntimeException) { + return (RuntimeException) e; + } + return new RuntimeException("Routed model '" + name + "' failed.", e); + } + + @Override + public Map getParameters() { + return Collections.emptyMap(); + } + + /** + * The router has no connection of its own; return a stable label so retry-metric grouping in + * the built-in chat action never sees a null connection name. + */ + @Override + public String getConnectionName() { + return "router"; + } + + /** The candidate models this router may route to. */ + public List getCandidates() { + return candidates; + } + + /** Whether the router falls back to the next candidate when the chosen model errors. */ + public boolean isFallbackEnabled() { + return fallbackEnabled; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/FallbackPolicy.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/FallbackPolicy.java new file mode 100644 index 000000000..55f2856f0 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/FallbackPolicy.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; + +/** + * The resilience concern of a {@link ChatModelRouter}, kept separate from selection ({@link + * RoutingStrategy}). Given the strategy's primary choice and the configured candidates, it produces + * the ordered list of model names the router will try in turn until one succeeds. + */ +@FunctionalInterface +public interface FallbackPolicy { + + /** + * Ordered, de-duplicated model names to attempt, primary first. + * + * @param primary the model chosen by the {@link RoutingStrategy} + * @param candidates all configured candidates (for fallback ordering) + */ + List attemptOrder(String primary, List candidates); + + /** No fallback: only the chosen model is attempted; its failure surfaces to the caller. */ + static FallbackPolicy none() { + return (primary, candidates) -> Collections.singletonList(primary); + } + + /** + * On failure, fall back to the remaining configured candidates in declaration order (primary + * first, then the rest), de-duplicated. + */ + static FallbackPolicy remainingCandidates() { + return (primary, candidates) -> { + LinkedHashSet order = new LinkedHashSet<>(); + order.add(primary); + for (RoutingCandidate candidate : candidates) { + order.add(candidate.getName()); + } + return new ArrayList<>(order); + }; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategy.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategy.java new file mode 100644 index 000000000..034539777 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategy.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.regex.Pattern; + +/** + * An "LLM-as-router" selection strategy: a small, cheap chat model is asked to choose which + * candidate model should serve the request (the approach popularized by RouteLLM). + * + *

The strategy builds a classification prompt from each candidate's {@code name} and {@code + * description}, asks the configured judge model to reply with a single model name, and parses that + * reply back to a candidate (keyed on {@link RoutingContext#firstUserMessage()} for stickiness). + * + *

On a parse miss (the judge answered but named no candidate) it returns the configured + * {@code default} (or the first candidate) — a deterministic outcome that is safe to memoize. On a + * transient judge failure (resolve/call error) it instead abstains by returning + * {@code null}, so the router degrades to its own default and a wrapping {@link CachingStrategy} + * does not pin the conversation to a fallback; the judge is retried on the next round. It does not + * cache — wrap it in {@link CachingStrategy} (the router does this by default) so the judge runs + * once per conversation. + * + *

Security note. The user's message is sent to the judge model, so this routing decision + * is susceptible to prompt injection (a crafted message could steer the choice). Do not gate cost, + * privilege, or safety solely on the LLM router's decision; treat it as a best-effort hint and keep + * authoritative controls elsewhere. Parsing prefers an exact model-name reply. + * + *

Configuration ({@link ResourceDescriptor} arguments)

+ * + *
    + *
  • {@code judge_model} (required) — name of a registered {@code CHAT_MODEL} resource + * (typically small/cheap) that performs the routing decision. + *
  • {@code default} (optional) — candidate name used when the judge's answer cannot be mapped. + * Defaults to the first configured candidate. + *
  • {@code instruction} (optional) — extra guidance appended to the system prompt. + *
+ */ +public class LlmRoutingStrategy extends AbstractRoutingStrategy { + + private static final Logger LOG = LoggerFactory.getLogger(LlmRoutingStrategy.class); + + public static final String ARG_JUDGE_MODEL = "judge_model"; + public static final String ARG_DEFAULT = "default"; + public static final String ARG_INSTRUCTION = "instruction"; + + private final String judgeModel; + private final String defaultModel; + private final String instruction; + + public LlmRoutingStrategy(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + this.judgeModel = arg(ARG_JUDGE_MODEL); + if (judgeModel == null || judgeModel.isEmpty()) { + throw new IllegalArgumentException( + "LlmRoutingStrategy requires a 'judge_model' (a registered CHAT_MODEL name)."); + } + this.defaultModel = arg(ARG_DEFAULT); + this.instruction = arg(ARG_INSTRUCTION, ""); + } + + @Override + public String route(RoutingContext context) { + List candidates = context.getCandidates(); + String defaultName = resolveDefault(candidates); + if (defaultName == null) { + throw new IllegalStateException("LlmRoutingStrategy has no candidates to route to."); + } + if (resourceContext == null) { + // No way to consult the judge: abstain so the router degrades. Return null (not the + // default) so this non-deterministic miss is not memoized by a wrapping cache. + LOG.warn("No ResourceContext available; abstaining so the router can degrade."); + return null; + } + try { + Object resource = resourceContext.getResource(judgeModel, ResourceType.CHAT_MODEL); + if (!(resource instanceof BaseChatModelSetup)) { + throw new IllegalStateException( + "Judge model '" + judgeModel + "' is not a CHAT_MODEL setup."); + } + BaseChatModelSetup judge = (BaseChatModelSetup) resource; + + List messages = + Arrays.asList( + new ChatMessage(MessageRole.SYSTEM, buildSystemPrompt(candidates)), + new ChatMessage(MessageRole.USER, context.firstUserMessage())); + + ChatMessage response = + judge.chat(messages, Collections.emptyMap(), Collections.emptyMap()); + String chosen = parseChoice(response.getContent(), candidates); + if (chosen != null) { + return chosen; + } + // The judge answered, but named nothing recognizable. This is a deterministic outcome + // for this request, so it is safe to return (and cache) the default. + LOG.warn( + "Judge model '{}' returned an unrecognized choice; using default '{}'.", + judgeModel, + defaultName); + return defaultName; + } catch (Exception e) { + // Transient judge failure (resolve/call error): abstain with null so the router + // degrades and a wrapping cache does NOT pin this conversation to a fallback — the + // judge is retried on the next round. + LOG.warn( + "LLM routing via judge '{}' failed; abstaining. Cause: {}", + judgeModel, + e.toString()); + return null; + } + } + + private String resolveDefault(List candidates) { + if (defaultModel != null && !defaultModel.isEmpty()) { + return defaultModel; + } + return candidates.isEmpty() ? null : candidates.get(0).getName(); + } + + private String buildSystemPrompt(List candidates) { + StringBuilder sb = new StringBuilder(); + sb.append( + "You are a model router. Choose the single best model to answer the user's request. ") + .append("Reply with ONLY the model name, exactly as listed, and nothing else.\n\n") + .append("Available models:\n"); + for (RoutingCandidate candidate : candidates) { + sb.append("- ").append(candidate.getName()); + if (!candidate.getDescription().isEmpty()) { + sb.append(": ").append(candidate.getDescription()); + } + sb.append('\n'); + } + if (instruction != null && !instruction.isEmpty()) { + sb.append('\n').append(instruction).append('\n'); + } + return sb.toString(); + } + + /** + * Map the judge's free-text answer back to a candidate name. Prefers an exact (trimmed, + * case-insensitive) match; otherwise the longest candidate name that appears as a whole token + * (bounded by non-identifier characters) so that e.g. a "gpt-4o-mini" reply does not match a + * configured "gpt-4". + */ + private static String parseChoice(String answer, List candidates) { + if (answer == null) { + return null; + } + String normalized = answer.trim().toLowerCase(Locale.ROOT); + if (normalized.isEmpty()) { + return null; + } + + for (RoutingCandidate candidate : candidates) { + if (normalized.equals(candidate.getName().toLowerCase(Locale.ROOT))) { + return candidate.getName(); + } + } + + List byLengthDesc = new ArrayList<>(candidates); + byLengthDesc.sort((a, b) -> b.getName().length() - a.getName().length()); + for (RoutingCandidate candidate : byLengthDesc) { + // Whole-token match: the name must not be flanked by a word char or '-', so "gpt-4" + // won't match inside "gpt-4o" or "gpt-4-mini". '.' is treated as a boundary so a model + // name ending a sentence (e.g. "...use big.") still matches. + Pattern p = + Pattern.compile( + "(?A candidate names a chat-model setup that was registered as a {@link + * org.apache.flink.agents.api.resource.ResourceType#CHAT_MODEL} resource, together with an optional + * human-readable {@code description} (consumed by LLM-as-router strategies to reason about which + * model fits a request) and free-form {@code metadata} (e.g. {@code cost}, {@code tags}, + * capabilities) that rule-based or custom strategies can match against. + */ +public class RoutingCandidate { + + private final String name; + private final String description; + private final Map metadata; + + public RoutingCandidate(String name, String description, Map metadata) { + this.name = Objects.requireNonNull(name, "candidate name must not be null"); + this.description = description != null ? description : ""; + this.metadata = + metadata != null + ? Collections.unmodifiableMap(new HashMap<>(metadata)) + : Collections.emptyMap(); + } + + public RoutingCandidate(String name, String description) { + this(name, description, Collections.emptyMap()); + } + + public RoutingCandidate(String name) { + this(name, "", Collections.emptyMap()); + } + + /** The registered {@code CHAT_MODEL} resource name this candidate routes to. */ + public String getName() { + return name; + } + + /** Human-readable description of when this model should be used (may be empty). */ + public String getDescription() { + return description; + } + + /** Free-form metadata strategies may match against (never null). */ + public Map getMetadata() { + return metadata; + } + + /** + * Normalize a single user-supplied candidate spec into a {@link RoutingCandidate}. + * + *

Accepts an existing {@link RoutingCandidate}, a plain {@link String} (name only), or a + * {@link Map} with keys {@code name} (required), {@code description}, and {@code metadata}. + * This keeps the router descriptor easy to author from Java and tolerant of values that have + * round-tripped through serialization. + */ + @SuppressWarnings("unchecked") + public static RoutingCandidate from(Object spec) { + if (spec instanceof RoutingCandidate) { + return (RoutingCandidate) spec; + } + if (spec instanceof CharSequence) { + return new RoutingCandidate(spec.toString()); + } + if (spec instanceof Map) { + Map map = (Map) spec; + Object name = map.get("name"); + if (name == null) { + throw new IllegalArgumentException( + "Routing candidate map must contain a 'name' entry: " + map); + } + Object description = map.get("description"); + Object metadata = map.get("metadata"); + return new RoutingCandidate( + name.toString(), + description != null ? description.toString() : "", + metadata instanceof Map ? (Map) metadata : null); + } + throw new IllegalArgumentException( + "Unsupported routing candidate spec: " + + (spec == null ? "null" : spec.getClass().getName())); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RoutingCandidate)) { + return false; + } + RoutingCandidate that = (RoutingCandidate) o; + return name.equals(that.name) + && description.equals(that.description) + && metadata.equals(that.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(name, description, metadata); + } + + @Override + public String toString() { + return "RoutingCandidate{name='" + name + "', description='" + description + "'}"; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingContext.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingContext.java new file mode 100644 index 000000000..0460d0b26 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingContext.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.resource.ResourceContext; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * The information a {@link RoutingStrategy} sees when deciding which model to use for a single chat + * request. + * + *

Carries the request messages and prompt args, the configured {@link RoutingCandidate}s, and + * the {@link ResourceContext} so a strategy can resolve auxiliary resources it needs (for example + * an LLM-as-router strategy resolving its judge {@code CHAT_MODEL}, or a future semantic strategy + * resolving an embedding model). + */ +public class RoutingContext { + + private final List messages; + private final Map promptArgs; + private final List candidates; + private final ResourceContext resourceContext; + + public RoutingContext( + List messages, + Map promptArgs, + List candidates, + ResourceContext resourceContext) { + this.messages = + messages != null ? Collections.unmodifiableList(messages) : Collections.emptyList(); + this.promptArgs = promptArgs != null ? promptArgs : Collections.emptyMap(); + this.candidates = + candidates != null + ? Collections.unmodifiableList(candidates) + : Collections.emptyList(); + this.resourceContext = resourceContext; + } + + /** The full request message list (immutable). */ + public List getMessages() { + return messages; + } + + /** Variables supplied to fill the prompt template, if any (never null). */ + public Map getPromptArgs() { + return promptArgs; + } + + /** The models this router may route to (immutable, never null). */ + public List getCandidates() { + return candidates; + } + + /** Context for resolving other resources (e.g. a judge model). May be null in tests. */ + public ResourceContext getResourceContext() { + return resourceContext; + } + + /** + * The content of the first {@link MessageRole#USER} message, or an empty string if there is + * none. + * + *

Strategies should key their decision on this rather than the full (evolving) message list, + * so that the same model is chosen on every round of a multi-turn tool-calling conversation — + * the built-in chat action re-invokes the router with the accumulated messages on each tool + * response. See {@link ChatModelRouter} for the stickiness contract. + */ + public String firstUserMessage() { + for (ChatMessage message : messages) { + if (message.getRole() == MessageRole.USER) { + return message.getContent() != null ? message.getContent() : ""; + } + } + return ""; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingStrategy.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingStrategy.java new file mode 100644 index 000000000..ad36a388f --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingStrategy.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +/** + * The pluggable selection logic of a {@link ChatModelRouter}: given a request, pick which + * candidate model should serve it. This is intentionally a pure concern — it returns a + * single model name and does not deal with resilience or caching, which are layered separately + * ({@link FallbackPolicy} for fallback, {@link CachingStrategy} for memoization). + * + *

Selection can be driven by any approach: + * + *

    + *
  • Rule-based — deterministic keyword/regex/metadata rules ({@link + * RuleBasedRoutingStrategy}, built-in). + *
  • LLM-as-router — a small judge model picks the candidate ({@link LlmRoutingStrategy}, + * built-in). + *
  • ML / learned — a trained classifier or learned scorer (e.g. RouteLLM-style, or + * embedding-similarity over per-route examples) chooses the candidate. This is supported as a + * bring-your-own strategy: implement {@code route(...)} to run your model and return + * the chosen candidate name. No built-in ML strategy ships yet (it carries a model + * training/serving lifecycle); it is a planned follow-up. + *
  • Bring-your-own — any custom logic. + *
+ * + *

Built-ins and custom (incl. ML) strategies are equally first-class: provide your own by + * implementing this interface (typically via {@link AbstractRoutingStrategy}) and referencing the + * class from the router's {@code strategy} {@link + * org.apache.flink.agents.api.resource.ResourceDescriptor} — no framework change required. + * + *

Stickiness contract. The built-in chat action re-invokes the router on every round of a + * multi-turn tool-calling conversation, passing the accumulated messages. To keep the same model + * across the whole conversation, a strategy must be deterministic with respect to the original + * request. Built-in strategies key on {@link RoutingContext#firstUserMessage()}; custom strategies + * are encouraged to do the same (and to wrap with {@link CachingStrategy} when the decision is + * expensive, e.g. an LLM judge). + */ +@FunctionalInterface +public interface RoutingStrategy { + + /** + * Choose the candidate model that should handle this request. + * + *

Return one of the configured candidate names to select it. Return {@code null} to + * abstain ("no opinion" — e.g. a transient judge failure): the router then degrades to + * its configured {@code default} candidate, and a wrapping {@link CachingStrategy} will not + * memoize the abstention. A returned name that is not a configured candidate is treated by the + * router as a routing miss (same degrade-to-default behaviour), not a hard failure. + * + * @param context the request messages, prompt args, candidates, and resource context + * @return the chosen candidate model name, or {@code null} to abstain + * @throws Exception if the decision could not be made + */ + String route(RoutingContext context) throws Exception; +} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RuleBasedRoutingStrategy.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RuleBasedRoutingStrategy.java new file mode 100644 index 000000000..0955ef8c3 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RuleBasedRoutingStrategy.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * A deterministic, no-extra-call routing strategy that maps a request to a model using ordered + * rules. + * + *

This is the "pluggable rule logic" baseline: predictable, free, and side-effect-free. Rules + * are evaluated in order against the request text (by default {@link + * RoutingContext#firstUserMessage()}, which keeps routing sticky across tool-calling rounds), and + * the first matching rule's model wins. If no rule matches, the {@code default} model is used. + * + *

Configuration ({@link ResourceDescriptor} arguments)

+ * + *
    + *
  • {@code rules} (required) — an ordered list of rule maps. Each rule has: + *
      + *
    • {@code model} (required) — the candidate model name to route to. + *
    • {@code keywords} — a {@link String} or list of strings; matches when any keyword is a + * case-insensitive substring of the request text. + *
    • {@code regex} — a regular expression matched (find) against the request text. + *
    • {@code prompt_arg} + {@code equals} — matches when the named prompt arg equals the + * given value (string comparison). + *
    + * A rule matches when any of its present predicates match (OR semantics). + *
  • {@code default} (required) — the fallback model name when no rule matches. + *
  • {@code case_sensitive} (optional, default {@code false}) — keyword/text comparison mode. + *
+ */ +public class RuleBasedRoutingStrategy extends AbstractRoutingStrategy { + + public static final String ARG_RULES = "rules"; + public static final String ARG_DEFAULT = "default"; + public static final String ARG_CASE_SENSITIVE = "case_sensitive"; + + private final List rules; + private final String defaultModel; + private final boolean caseSensitive; + + @SuppressWarnings("unchecked") + public RuleBasedRoutingStrategy( + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + + this.defaultModel = arg(ARG_DEFAULT); + if (defaultModel == null || defaultModel.isEmpty()) { + throw new IllegalArgumentException( + "RuleBasedRoutingStrategy requires a 'default' model name."); + } + this.caseSensitive = arg(ARG_CASE_SENSITIVE, Boolean.FALSE); + + List rawRules = arg(ARG_RULES); + if (rawRules == null) { + rawRules = Collections.emptyList(); + } + List parsed = new ArrayList<>(rawRules.size()); + for (Object raw : rawRules) { + if (!(raw instanceof Map)) { + throw new IllegalArgumentException( + "Each rule must be a Map, but got: " + + (raw == null ? "null" : raw.getClass().getName())); + } + parsed.add(Rule.from((Map) raw, caseSensitive)); + } + this.rules = parsed; + } + + @Override + public String route(RoutingContext context) { + String text = context.firstUserMessage(); + String haystack = caseSensitive ? text : text.toLowerCase(Locale.ROOT); + for (Rule rule : rules) { + if (rule.matches(haystack, context.getPromptArgs())) { + return rule.model; + } + } + return defaultModel; + } + + /** A single compiled rule. */ + private static final class Rule { + final String model; + final List keywords; // already case-normalized to match the haystack + final Pattern regex; + final String promptArg; + final String promptArgEquals; + + Rule( + String model, + List keywords, + Pattern regex, + String promptArg, + String promptArgEquals) { + this.model = model; + this.keywords = keywords; + this.regex = regex; + this.promptArg = promptArg; + this.promptArgEquals = promptArgEquals; + } + + @SuppressWarnings("unchecked") + static Rule from(Map map, boolean caseSensitive) { + Object model = map.get("model"); + if (model == null) { + throw new IllegalArgumentException("Rule is missing a 'model': " + map); + } + + List keywords = new ArrayList<>(); + Object kw = map.get("keywords"); + if (kw instanceof CharSequence) { + keywords.add(normalize(kw.toString(), caseSensitive)); + } else if (kw instanceof List) { + for (Object item : (List) kw) { + if (item != null) { + keywords.add(normalize(item.toString(), caseSensitive)); + } + } + } + + Pattern regex = null; + Object re = map.get("regex"); + if (re instanceof CharSequence) { + int flags = caseSensitive ? 0 : Pattern.CASE_INSENSITIVE; + regex = Pattern.compile(re.toString(), flags); + } + + Object promptArg = map.get("prompt_arg"); + Object equals = map.get("equals"); + + return new Rule( + model.toString(), + keywords, + regex, + promptArg != null ? promptArg.toString() : null, + equals != null ? equals.toString() : null); + } + + private static String normalize(String s, boolean caseSensitive) { + return caseSensitive ? s : s.toLowerCase(Locale.ROOT); + } + + boolean matches(String haystack, Map promptArgs) { + for (String keyword : keywords) { + if (!keyword.isEmpty() && haystack.contains(keyword)) { + return true; + } + } + if (regex != null && regex.matcher(haystack).find()) { + return true; + } + if (promptArg != null && promptArgs != null) { + Object value = promptArgs.get(promptArg); + if (value != null + && (promptArgEquals == null || promptArgEquals.equals(value.toString()))) { + return true; + } + } + return false; + } + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java index 1798b996f..2a284022f 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java @@ -227,4 +227,18 @@ private VectorStore() {} // ==================== MCP ==================== public static final String MCP_SERVER = "DECIDE_IN_RUNTIME_MCPServer"; + + // ==================== RoutingStrategy ==================== + /** + * Built-in {@code RoutingStrategy} implementations, for use in a {@code ChatModelRouter}'s + * {@code strategy} descriptor. + */ + public static final class RoutingStrategy { + public static final String RULE_BASED = + "org.apache.flink.agents.api.chat.model.routing.RuleBasedRoutingStrategy"; + public static final String LLM = + "org.apache.flink.agents.api.chat.model.routing.LlmRoutingStrategy"; + + private RoutingStrategy() {} + } } diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategyTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategyTest.java new file mode 100644 index 000000000..cbb23fb7b --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategyTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +/** Tests for {@link CachingStrategy}: memoization, LRU eviction, and empty-key passthrough. */ +class CachingStrategyTest { + + /** A delegate that counts invocations and echoes the first user message as the choice. */ + private static final class CountingStrategy implements RoutingStrategy { + int calls = 0; + + @Override + public String route(RoutingContext context) { + calls++; + return "model:" + context.firstUserMessage(); + } + } + + private static RoutingContext ctxFor(String userText) { + return RoutingTestSupport.routingContext( + Collections.singletonList(RoutingTestSupport.user(userText)), + Collections.emptyList(), + null); + } + + @Test + @DisplayName("same first-user-message is computed once, then served from cache") + void testMemoizesPerKey() throws Exception { + CountingStrategy delegate = new CountingStrategy(); + CachingStrategy caching = new CachingStrategy(delegate); + + assertEquals("model:q1", caching.route(ctxFor("q1"))); + assertEquals("model:q1", caching.route(ctxFor("q1"))); + assertEquals("model:q1", caching.route(ctxFor("q1"))); + assertEquals(1, delegate.calls); + } + + @Test + @DisplayName("different keys are computed independently") + void testDifferentKeysRecompute() throws Exception { + CountingStrategy delegate = new CountingStrategy(); + CachingStrategy caching = new CachingStrategy(delegate); + + caching.route(ctxFor("a")); + caching.route(ctxFor("b")); + assertEquals(2, delegate.calls); + } + + @Test + @DisplayName("empty first-user-message is never cached") + void testEmptyKeyNotCached() throws Exception { + CountingStrategy delegate = new CountingStrategy(); + CachingStrategy caching = new CachingStrategy(delegate); + + caching.route(ctxFor("")); + caching.route(ctxFor("")); + assertEquals(2, delegate.calls); + } + + @Test + @DisplayName("a null (abstain) decision is not cached and is recomputed next round") + void testNullDecisionNotCached() throws Exception { + // Strategy abstains (returns null) on the first call, then commits to a real choice. + RoutingStrategy flaky = + new RoutingStrategy() { + int calls = 0; + + @Override + public String route(RoutingContext context) { + return calls++ == 0 ? null : "model:" + context.firstUserMessage(); + } + }; + CachingStrategy caching = new CachingStrategy(flaky); + + assertNull(caching.route(ctxFor("q1"))); // abstained -> must not be cached + assertEquals("model:q1", caching.route(ctxFor("q1"))); // re-consulted, now decides + assertEquals("model:q1", caching.route(ctxFor("q1"))); // now served from cache + } + + @Test + @DisplayName("bounded LRU evicts the eldest entry past capacity") + void testLruEviction() throws Exception { + CountingStrategy delegate = new CountingStrategy(); + CachingStrategy caching = new CachingStrategy(delegate, 2); + + caching.route(ctxFor("a")); // {a} + caching.route(ctxFor("b")); // {a,b} + caching.route(ctxFor("c")); // evicts a -> {b,c} + assertEquals(3, delegate.calls); + + caching.route(ctxFor("a")); // a was evicted -> recompute + assertEquals(4, delegate.calls); + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouterTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouterTest.java new file mode 100644 index 000000000..e05633b00 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouterTest.java @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class ChatModelRouterTest { + + /** A user-supplied strategy, used to verify the bring-your-own extension point. */ + public static class AlwaysBigStrategy extends AbstractRoutingStrategy { + public AlwaysBigStrategy(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + } + + @Override + public String route(RoutingContext context) { + return "big"; + } + } + + /** A strategy that always names a model that is not a configured candidate (a routing miss). */ + public static class NonCandidateStrategy extends AbstractRoutingStrategy { + public NonCandidateStrategy( + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + } + + @Override + public String route(RoutingContext context) { + return "does-not-exist"; + } + } + + /** A strategy that always abstains (returns null / "no opinion"). */ + public static class AbstainStrategy extends AbstractRoutingStrategy { + public AbstainStrategy(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + } + + @Override + public String route(RoutingContext context) { + return null; + } + } + + private static ResourceDescriptor ruleStrategy() { + return ResourceDescriptor.Builder.newBuilder(RuleBasedRoutingStrategy.class.getName()) + .addInitialArgument("default", "small") + .addInitialArgument( + "rules", + Collections.singletonList( + Map.of( + "model", + "big", + "keywords", + Collections.singletonList("code")))) + .build(); + } + + private static ChatModelRouter router( + ResourceDescriptor strategy, + boolean fallback, + List candidates, + ResourceContext ctx) + throws Exception { + ResourceDescriptor descriptor = + ResourceDescriptor.Builder.newBuilder(ChatModelRouter.class.getName()) + .addInitialArgument("candidates", candidates) + .addInitialArgument("strategy", strategy) + .addInitialArgument("fallback", fallback) + .build(); + ChatModelRouter router = new ChatModelRouter(descriptor, ctx); + router.open(); + return router; + } + + @Test + @DisplayName("router is a drop-in CHAT_MODEL resource") + void testResourceTypeIsChatModel() throws Exception { + ChatModelRouter router = + router( + ruleStrategy(), + false, + Arrays.asList("small", "big"), + RoutingTestSupport.context(new HashMap<>())); + assertEquals(ResourceType.CHAT_MODEL, router.getResourceType()); + } + + @Test + @DisplayName("rule-based router delegates to the selected model") + void testDelegatesToSelectedModel() throws Exception { + Map registry = new HashMap<>(); + RoutingTestSupport.RecordingModel small = new RoutingTestSupport.RecordingModel("small"); + RoutingTestSupport.RecordingModel big = new RoutingTestSupport.RecordingModel("big"); + registry.put("small", small); + registry.put("big", big); + ResourceContext ctx = RoutingTestSupport.context(registry); + + ChatModelRouter router = router(ruleStrategy(), false, Arrays.asList("small", "big"), ctx); + + ChatMessage response = + router.chat( + Collections.singletonList(RoutingTestSupport.user("please write code")), + Collections.emptyMap(), + Collections.emptyMap()); + + assertEquals("handled-by:big", response.getContent()); + assertEquals(1, big.callCount); + assertEquals(0, small.callCount); + } + + @Test + @DisplayName("non-matching request delegates to the default model") + void testDelegatesToDefault() throws Exception { + Map registry = new HashMap<>(); + RoutingTestSupport.RecordingModel small = new RoutingTestSupport.RecordingModel("small"); + RoutingTestSupport.RecordingModel big = new RoutingTestSupport.RecordingModel("big"); + registry.put("small", small); + registry.put("big", big); + ResourceContext ctx = RoutingTestSupport.context(registry); + + ChatModelRouter router = router(ruleStrategy(), false, Arrays.asList("small", "big"), ctx); + + ChatMessage response = + router.chat( + Collections.singletonList(RoutingTestSupport.user("how are you?")), + Collections.emptyMap(), + Collections.emptyMap()); + + assertEquals("handled-by:small", response.getContent()); + assertEquals(1, small.callCount); + assertEquals(0, big.callCount); + } + + @Test + @DisplayName("fallback enabled: a failing primary advances to the next candidate") + void testFallbackOnError() throws Exception { + Map registry = new HashMap<>(); + RoutingTestSupport.FailingModel small = new RoutingTestSupport.FailingModel(); + RoutingTestSupport.RecordingModel big = new RoutingTestSupport.RecordingModel("big"); + registry.put("small", small); + registry.put("big", big); + ResourceContext ctx = RoutingTestSupport.context(registry); + + // No rule matches "hello" -> strategy picks default "small" (the failing model). + ChatModelRouter router = router(ruleStrategy(), true, Arrays.asList("small", "big"), ctx); + + ChatMessage response = + router.chat( + Collections.singletonList(RoutingTestSupport.user("hello")), + Collections.emptyMap(), + Collections.emptyMap()); + + assertEquals("handled-by:big", response.getContent()); + assertEquals(1, small.callCount); + assertEquals(1, big.callCount); + } + + @Test + @DisplayName("fallback disabled: a failing primary surfaces the error") + void testNoFallbackRethrows() throws Exception { + Map registry = new HashMap<>(); + RoutingTestSupport.FailingModel small = new RoutingTestSupport.FailingModel(); + RoutingTestSupport.RecordingModel big = new RoutingTestSupport.RecordingModel("big"); + registry.put("small", small); + registry.put("big", big); + ResourceContext ctx = RoutingTestSupport.context(registry); + + ChatModelRouter router = router(ruleStrategy(), false, Arrays.asList("small", "big"), ctx); + + assertThrows( + RuntimeException.class, + () -> + router.chat( + Collections.singletonList(RoutingTestSupport.user("hello")), + Collections.emptyMap(), + Collections.emptyMap())); + assertEquals(1, small.callCount); + assertEquals(0, big.callCount); + } + + @Test + @DisplayName("routing is sticky across a simulated tool-calling round") + void testStickyAcrossToolRound() throws Exception { + Map registry = new HashMap<>(); + RoutingTestSupport.RecordingModel small = new RoutingTestSupport.RecordingModel("small"); + RoutingTestSupport.RecordingModel big = new RoutingTestSupport.RecordingModel("big"); + registry.put("small", small); + registry.put("big", big); + ResourceContext ctx = RoutingTestSupport.context(registry); + + ChatModelRouter router = router(ruleStrategy(), false, Arrays.asList("small", "big"), ctx); + + // First round: the original "code" request. + router.chat( + Collections.singletonList(RoutingTestSupport.user("write code")), + Collections.emptyMap(), + Collections.emptyMap()); + + // Second round: accumulated conversation (assistant tool call + tool result), as the chat + // action would re-invoke the router. Must still pick "big". + List conversation = + Arrays.asList( + RoutingTestSupport.user("write code"), + new ChatMessage(MessageRole.ASSISTANT, "calling tool"), + new ChatMessage(MessageRole.TOOL, "neutral tool output")); + router.chat(conversation, Collections.emptyMap(), Collections.emptyMap()); + + assertEquals(2, big.callCount); + assertEquals(0, small.callCount); + } + + @Test + @DisplayName("a user-supplied strategy plugs in by class name") + void testBringYourOwnStrategy() throws Exception { + Map registry = new HashMap<>(); + RoutingTestSupport.RecordingModel small = new RoutingTestSupport.RecordingModel("small"); + RoutingTestSupport.RecordingModel big = new RoutingTestSupport.RecordingModel("big"); + registry.put("small", small); + registry.put("big", big); + ResourceContext ctx = RoutingTestSupport.context(registry); + + ResourceDescriptor custom = + ResourceDescriptor.Builder.newBuilder(AlwaysBigStrategy.class.getName()).build(); + ChatModelRouter router = router(custom, false, Arrays.asList("small", "big"), ctx); + + ChatMessage response = + router.chat( + Collections.singletonList(RoutingTestSupport.user("anything")), + Collections.emptyMap(), + Collections.emptyMap()); + + assertEquals("handled-by:big", response.getContent()); + assertEquals(1, big.callCount); + } + + @Test + @DisplayName( + "a non-candidate selection is a routing miss: degrade to default (first candidate)") + void testNonCandidateDegradesToFirstCandidate() throws Exception { + Map registry = new HashMap<>(); + RoutingTestSupport.RecordingModel small = new RoutingTestSupport.RecordingModel("small"); + RoutingTestSupport.RecordingModel big = new RoutingTestSupport.RecordingModel("big"); + registry.put("small", small); + registry.put("big", big); + ResourceContext ctx = RoutingTestSupport.context(registry); + + ResourceDescriptor custom = + ResourceDescriptor.Builder.newBuilder(NonCandidateStrategy.class.getName()).build(); + // No "default" configured -> the first candidate ("small") is the default. + ChatModelRouter router = router(custom, false, Arrays.asList("small", "big"), ctx); + + ChatMessage response = + router.chat( + Collections.singletonList(RoutingTestSupport.user("anything")), + Collections.emptyMap(), + Collections.emptyMap()); + + assertEquals("handled-by:small", response.getContent()); + assertEquals(1, small.callCount); + assertEquals(0, big.callCount); + } + + @Test + @DisplayName("an abstaining (null) strategy degrades to the configured default candidate") + void testAbstainDegradesToConfiguredDefault() throws Exception { + Map registry = new HashMap<>(); + RoutingTestSupport.RecordingModel small = new RoutingTestSupport.RecordingModel("small"); + RoutingTestSupport.RecordingModel big = new RoutingTestSupport.RecordingModel("big"); + registry.put("small", small); + registry.put("big", big); + ResourceContext ctx = RoutingTestSupport.context(registry); + + ResourceDescriptor descriptor = + ResourceDescriptor.Builder.newBuilder(ChatModelRouter.class.getName()) + .addInitialArgument("candidates", Arrays.asList("small", "big")) + .addInitialArgument( + "strategy", + ResourceDescriptor.Builder.newBuilder( + AbstainStrategy.class.getName()) + .build()) + .addInitialArgument("default", "big") + .build(); + ChatModelRouter router = new ChatModelRouter(descriptor, ctx); + router.open(); + + ChatMessage response = + router.chat( + Collections.singletonList(RoutingTestSupport.user("anything")), + Collections.emptyMap(), + Collections.emptyMap()); + + assertEquals("handled-by:big", response.getContent()); + assertEquals(1, big.callCount); + assertEquals(0, small.callCount); + } + + @Test + @DisplayName("a 'default' that is not a configured candidate is rejected at construction") + void testInvalidDefaultRejected() { + ResourceContext ctx = RoutingTestSupport.context(new HashMap<>()); + ResourceDescriptor descriptor = + ResourceDescriptor.Builder.newBuilder(ChatModelRouter.class.getName()) + .addInitialArgument("candidates", Arrays.asList("small", "big")) + .addInitialArgument("strategy", ruleStrategy()) + .addInitialArgument("default", "nope") + .build(); + assertThrows(IllegalArgumentException.class, () -> new ChatModelRouter(descriptor, ctx)); + } + + @Test + @DisplayName("candidates accept rich {name, description} maps") + void testRichCandidateSpecs() throws Exception { + Map registry = new HashMap<>(); + registry.put("small", new RoutingTestSupport.RecordingModel("small")); + registry.put("big", new RoutingTestSupport.RecordingModel("big")); + ResourceContext ctx = RoutingTestSupport.context(registry); + + List candidates = + Arrays.asList( + Map.of("name", "small", "description", "cheap"), + Map.of("name", "big", "description", "strong")); + ChatModelRouter router = router(ruleStrategy(), false, candidates, ctx); + + assertEquals(2, router.getCandidates().size()); + assertEquals("cheap", router.getCandidates().get(0).getDescription()); + } + + @Test + @DisplayName("strategy supplied as a deserialized map (post-JSON round-trip) still works") + void testStrategyFromMapForm() throws Exception { + Map registry = new HashMap<>(); + RoutingTestSupport.RecordingModel small = new RoutingTestSupport.RecordingModel("small"); + RoutingTestSupport.RecordingModel big = new RoutingTestSupport.RecordingModel("big"); + registry.put("small", small); + registry.put("big", big); + ResourceContext ctx = RoutingTestSupport.context(registry); + + // Mimic what getArgument("strategy") returns after the AgentPlan round-trips through JSON: + // a plain map with the ResourceDescriptor's JSON field names rather than a typed object. + Map strategyMap = new HashMap<>(); + strategyMap.put("target_clazz", RuleBasedRoutingStrategy.class.getName()); + strategyMap.put("target_module", ""); + Map rule = new HashMap<>(); + rule.put("model", "big"); + rule.put("keywords", Collections.singletonList("code")); + Map strategyArgs = new HashMap<>(); + strategyArgs.put("default", "small"); + strategyArgs.put("rules", Collections.singletonList(rule)); + strategyMap.put("arguments", strategyArgs); + + ResourceDescriptor descriptor = + ResourceDescriptor.Builder.newBuilder(ChatModelRouter.class.getName()) + .addInitialArgument("candidates", Arrays.asList("small", "big")) + .addInitialArgument("strategy", strategyMap) + .build(); + ChatModelRouter router = new ChatModelRouter(descriptor, ctx); + router.open(); + + ChatMessage response = + router.chat( + Collections.singletonList(RoutingTestSupport.user("write code")), + Collections.emptyMap(), + Collections.emptyMap()); + assertEquals("handled-by:big", response.getContent()); + } + + @Test + @DisplayName("missing candidates or strategy are rejected at construction") + void testInvalidConfigRejected() { + ResourceContext ctx = RoutingTestSupport.context(new HashMap<>()); + + ResourceDescriptor noCandidates = + ResourceDescriptor.Builder.newBuilder(ChatModelRouter.class.getName()) + .addInitialArgument("strategy", ruleStrategy()) + .build(); + assertThrows(IllegalArgumentException.class, () -> new ChatModelRouter(noCandidates, ctx)); + + ResourceDescriptor noStrategy = + ResourceDescriptor.Builder.newBuilder(ChatModelRouter.class.getName()) + .addInitialArgument("candidates", Arrays.asList("small", "big")) + .build(); + assertThrows(IllegalArgumentException.class, () -> new ChatModelRouter(noStrategy, ctx)); + } + + @Test + @DisplayName("a null resource context surfaces a clear error at chat() time") + void testNullResourceContextRejected() throws Exception { + // Rule-based strategy needs no context; the router is built with a null ResourceContext. + ChatModelRouter router = router(ruleStrategy(), false, Arrays.asList("small", "big"), null); + assertThrows( + IllegalStateException.class, + () -> + router.chat( + Collections.singletonList(RoutingTestSupport.user("write code")), + Collections.emptyMap(), + Collections.emptyMap())); + } + + @Test + @DisplayName("connection name is non-null so retry metrics never NPE") + void testStableConnectionName() throws Exception { + ChatModelRouter router = + router( + ruleStrategy(), + false, + Arrays.asList("small", "big"), + RoutingTestSupport.context(new HashMap<>())); + assertTrue(router.getConnectionName() != null && !router.getConnectionName().isEmpty()); + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategyTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategyTest.java new file mode 100644 index 000000000..2173f8eb8 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategyTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class LlmRoutingStrategyTest { + + private static final List CANDIDATES = + Arrays.asList( + new RoutingCandidate("small", "cheap model for chit-chat"), + new RoutingCandidate("big", "strong model for code and reasoning")); + + private LlmRoutingStrategy strategy(Map args, ResourceContext ctx) { + return new LlmRoutingStrategy( + new ResourceDescriptor(LlmRoutingStrategy.class.getName(), args), ctx); + } + + private String route(LlmRoutingStrategy strategy, ResourceContext ctx, String userText) + throws Exception { + return strategy.route( + RoutingTestSupport.routingContext( + Collections.singletonList(RoutingTestSupport.user(userText)), + CANDIDATES, + ctx)); + } + + @Test + @DisplayName("judge model's choice is used when it names a candidate") + void testJudgeChoiceUsed() throws Exception { + Map registry = new HashMap<>(); + registry.put("judge", new RoutingTestSupport.ScriptedJudge("big")); + ResourceContext ctx = RoutingTestSupport.context(registry); + + assertEquals( + "big", route(strategy(Map.of("judge_model", "judge"), ctx), ctx, "write code")); + } + + @Test + @DisplayName("a verbose judge answer still resolves to the named candidate") + void testJudgeChoiceParsedFromProse() throws Exception { + Map registry = new HashMap<>(); + registry.put( + "judge", + new RoutingTestSupport.ScriptedJudge("I think the best choice here is big.")); + ResourceContext ctx = RoutingTestSupport.context(registry); + + assertEquals("big", route(strategy(Map.of("judge_model", "judge"), ctx), ctx, "hi")); + } + + @Test + @DisplayName("unrecognized judge answer falls back to the configured default") + void testFallbackToDefaultOnParseMiss() throws Exception { + Map registry = new HashMap<>(); + registry.put("judge", new RoutingTestSupport.ScriptedJudge("no idea")); + ResourceContext ctx = RoutingTestSupport.context(registry); + + assertEquals( + "small", + route( + strategy(Map.of("judge_model", "judge", "default", "small"), ctx), + ctx, + "hi")); + } + + @Test + @DisplayName("default is the first candidate when none is configured") + void testDefaultsToFirstCandidate() throws Exception { + Map registry = new HashMap<>(); + registry.put("judge", new RoutingTestSupport.ScriptedJudge("garbage")); + ResourceContext ctx = RoutingTestSupport.context(registry); + + assertEquals("small", route(strategy(Map.of("judge_model", "judge"), ctx), ctx, "hi")); + } + + @Test + @DisplayName("longest candidate name wins when names overlap as substrings") + void testLongestNameWins() throws Exception { + List candidates = + Arrays.asList(new RoutingCandidate("gpt-4"), new RoutingCandidate("gpt-4o")); + Map registry = new HashMap<>(); + registry.put("judge", new RoutingTestSupport.ScriptedJudge("use gpt-4o please")); + ResourceContext ctx = RoutingTestSupport.context(registry); + + LlmRoutingStrategy strategy = strategy(Map.of("judge_model", "judge"), ctx); + assertEquals( + "gpt-4o", + strategy.route( + RoutingTestSupport.routingContext( + Collections.singletonList(RoutingTestSupport.user("hi")), + candidates, + ctx))); + } + + @Test + @DisplayName( + "word-boundary parse does not mis-route a 'gpt-4o-mini' reply to a 'gpt-4' candidate") + void testSubstringNotMisRouted() throws Exception { + List candidates = + Arrays.asList(new RoutingCandidate("gpt-4"), new RoutingCandidate("claude")); + Map registry = new HashMap<>(); + registry.put("judge", new RoutingTestSupport.ScriptedJudge("gpt-4o-mini")); + ResourceContext ctx = RoutingTestSupport.context(registry); + + // "gpt-4" must NOT match inside "gpt-4o-mini"; with default=claude the reply is unparseable + // and we fall back to claude rather than mis-routing to gpt-4. + LlmRoutingStrategy strategy = + strategy(Map.of("judge_model", "judge", "default", "claude"), ctx); + assertEquals( + "claude", + strategy.route( + RoutingTestSupport.routingContext( + Collections.singletonList(RoutingTestSupport.user("hi")), + candidates, + ctx))); + } + + @Test + @DisplayName("a transient judge failure abstains (returns null), not the default") + void testTransientJudgeFailureAbstains() throws Exception { + // A judge whose chat() throws is a transient failure. The strategy must abstain with null + // (so the router degrades and a wrapping cache does not pin the conversation to a default), + // rather than returning the configured default as if it were a real decision. + Map registry = new HashMap<>(); + registry.put("judge", new RoutingTestSupport.FailingModel()); + ResourceContext ctx = RoutingTestSupport.context(registry); + + assertNull( + route( + strategy(Map.of("judge_model", "judge", "default", "small"), ctx), + ctx, + "hi")); + } + + @Test + @DisplayName("missing judge_model is rejected at construction") + void testMissingJudgeRejected() { + assertThrows(IllegalArgumentException.class, () -> strategy(Collections.emptyMap(), null)); + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingTestSupport.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingTestSupport.java new file mode 100644 index 000000000..2610238e2 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingTestSupport.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** Shared fakes for the routing tests. */ +final class RoutingTestSupport { + + private RoutingTestSupport() {} + + static ResourceDescriptor emptyDescriptor(Class clazz) { + return new ResourceDescriptor(clazz.getName(), Collections.emptyMap()); + } + + /** A chat model that records the messages it received and answers with its own tag. */ + static final class RecordingModel extends BaseChatModelSetup { + final String tag; + int callCount = 0; + List lastMessages; + + RecordingModel(String tag) { + super(emptyDescriptor(RecordingModel.class), null); + this.tag = tag; + } + + @Override + public Map getParameters() { + return Collections.emptyMap(); + } + + @Override + public ChatMessage chat( + List messages, + Map promptArgs, + Map modelParams) { + this.callCount++; + this.lastMessages = new ArrayList<>(messages); + return new ChatMessage(MessageRole.ASSISTANT, "handled-by:" + tag); + } + } + + /** A chat model that always fails — used to exercise fallback behavior. */ + static final class FailingModel extends BaseChatModelSetup { + int callCount = 0; + + FailingModel() { + super(emptyDescriptor(FailingModel.class), null); + } + + @Override + public Map getParameters() { + return Collections.emptyMap(); + } + + @Override + public ChatMessage chat( + List messages, + Map promptArgs, + Map modelParams) { + this.callCount++; + throw new RuntimeException("boom from failing model"); + } + } + + /** A chat model that returns a scripted reply — used as an LLM-as-router judge. */ + static final class ScriptedJudge extends BaseChatModelSetup { + final String reply; + int callCount = 0; + List lastMessages; + + ScriptedJudge(String reply) { + super(emptyDescriptor(ScriptedJudge.class), null); + this.reply = reply; + } + + @Override + public Map getParameters() { + return Collections.emptyMap(); + } + + @Override + public ChatMessage chat( + List messages, + Map promptArgs, + Map modelParams) { + this.callCount++; + this.lastMessages = new ArrayList<>(messages); + return new ChatMessage(MessageRole.ASSISTANT, reply); + } + } + + /** A {@link ResourceContext} backed by a fixed name → resource map. */ + static ResourceContext context(Map byName) { + return ResourceContext.fromGetResource( + (name, type) -> { + Resource resource = byName.get(name); + if (resource == null) { + throw new RuntimeException("No resource registered for name: " + name); + } + return resource; + }); + } + + static ChatMessage user(String content) { + return new ChatMessage(MessageRole.USER, content); + } + + static RoutingContext routingContext( + List messages, + List candidates, + ResourceContext resourceContext) { + return new RoutingContext(messages, Collections.emptyMap(), candidates, resourceContext); + } + + static ResourceType chatModelType() { + return ResourceType.CHAT_MODEL; + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RuleBasedRoutingStrategyTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RuleBasedRoutingStrategyTest.java new file mode 100644 index 000000000..e1fccbdb0 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RuleBasedRoutingStrategyTest.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class RuleBasedRoutingStrategyTest { + + private static final List CANDIDATES = + Arrays.asList(new RoutingCandidate("small"), new RoutingCandidate("big")); + + private RuleBasedRoutingStrategy strategy(Map args) { + return new RuleBasedRoutingStrategy( + new ResourceDescriptor(RuleBasedRoutingStrategy.class.getName(), args), null); + } + + private static String route(RuleBasedRoutingStrategy strategy, String userText) { + return strategy.route( + RoutingTestSupport.routingContext( + Collections.singletonList(RoutingTestSupport.user(userText)), + CANDIDATES, + null)); + } + + @Test + @DisplayName("keyword rule routes to the mapped model (case-insensitive by default)") + void testKeywordMatch() { + RuleBasedRoutingStrategy strategy = + strategy( + Map.of( + "default", + "small", + "rules", + Collections.singletonList( + Map.of( + "model", + "big", + "keywords", + Arrays.asList("code", "sql"))))); + + assertEquals("big", route(strategy, "Please write some CODE for me")); + } + + @Test + @DisplayName("no rule match falls back to the default model") + void testDefaultWhenNoMatch() { + RuleBasedRoutingStrategy strategy = + strategy( + Map.of( + "default", + "small", + "rules", + Collections.singletonList( + Map.of( + "model", + "big", + "keywords", + Collections.singletonList("code"))))); + + assertEquals("small", route(strategy, "just saying hello")); + } + + @Test + @DisplayName("regex rule matches the request text") + void testRegexMatch() { + RuleBasedRoutingStrategy strategy = + strategy( + Map.of( + "default", + "small", + "rules", + Collections.singletonList( + Map.of( + "model", + "big", + "regex", + "\\bSELECT\\b.*\\bFROM\\b")))); + + assertEquals("big", route(strategy, "select id from users")); + } + + @Test + @DisplayName("rules are evaluated in order; the first match wins") + void testFirstMatchWins() { + RuleBasedRoutingStrategy strategy = + strategy( + Map.of( + "default", + "small", + "rules", + Arrays.asList( + Map.of( + "model", + "big", + "keywords", + Collections.singletonList("urgent")), + Map.of( + "model", + "small", + "keywords", + Collections.singletonList("urgent"))))); + + assertEquals("big", route(strategy, "this is urgent")); + } + + @Test + @DisplayName("decision keys on the first user message for tool-call stickiness") + void testRoutesOnFirstUserMessage() { + RuleBasedRoutingStrategy strategy = + strategy( + Map.of( + "default", + "small", + "rules", + Collections.singletonList( + Map.of( + "model", + "big", + "keywords", + Collections.singletonList("code"))))); + + // Later tool-calling round: original "code" request plus assistant/tool messages. + List conversation = + Arrays.asList( + RoutingTestSupport.user("write code"), + new ChatMessage(MessageRole.ASSISTANT, "calling a tool"), + new ChatMessage(MessageRole.TOOL, "tool result with no keywords")); + + assertEquals( + "big", + strategy.route(RoutingTestSupport.routingContext(conversation, CANDIDATES, null))); + } + + @Test + @DisplayName("missing default is rejected at construction") + void testMissingDefaultRejected() { + assertThrows( + IllegalArgumentException.class, + () -> strategy(Map.of("rules", Collections.emptyList()))); + } +} diff --git a/e2e-test/test-scripts/check_resource_consistency.py b/e2e-test/test-scripts/check_resource_consistency.py index 82a23ab6f..bf8e17998 100644 --- a/e2e-test/test-scripts/check_resource_consistency.py +++ b/e2e-test/test-scripts/check_resource_consistency.py @@ -118,7 +118,12 @@ def get_python_resource_name_map(python_path: Path) -> dict: from flink_agents.api.resource import ResourceName python_map = {} - for resource_name in ["ChatModel", "EmbeddingModel", "VectorStore"]: + for resource_name in [ + "ChatModel", + "EmbeddingModel", + "VectorStore", + "RoutingStrategy", + ]: if not hasattr(ResourceName, resource_name): continue resource_cls = getattr(ResourceName, resource_name) diff --git a/examples/src/main/java/org/apache/flink/agents/examples/LlmRoutingAgentExample.java b/examples/src/main/java/org/apache/flink/agents/examples/LlmRoutingAgentExample.java new file mode 100644 index 000000000..0b3974027 --- /dev/null +++ b/examples/src/main/java/org/apache/flink/agents/examples/LlmRoutingAgentExample.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.examples; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.agents.api.agents.AgentExecutionOptions; +import org.apache.flink.agents.api.agents.ReActAgent; +import org.apache.flink.agents.api.chat.model.routing.ChatModelRouter; +import org.apache.flink.agents.api.chat.model.routing.LlmRoutingStrategy; +import org.apache.flink.agents.api.chat.model.routing.RuleBasedRoutingStrategy; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceName; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.examples.agents.CustomTypesAndResources; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.connector.file.src.reader.TextLineInputFormat; +import org.apache.flink.core.fs.Path; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.types.Row; + +import java.io.File; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.examples.WorkflowSingleAgentExample.copyResource; + +/** + * Java example demonstrating pluggable LLM routing. + * + *

Two Ollama models are registered as routable candidates — a small/cheap model and a larger + * model — together with a {@link ChatModelRouter} that is itself registered as a {@code + * CHAT_MODEL}. The router picks, per request, which candidate should serve it. Because the router + * is a drop-in chat model, an ordinary {@link ReActAgent} simply points at it; nothing else in the + * pipeline changes. + * + *

This example uses the built-in {@link RuleBasedRoutingStrategy} (deterministic, no extra model + * call): requests mentioning code/SQL/errors go to the larger model, everything else to the small + * one. {@link #llmRoutingStrategy()} shows how to swap in the {@link LlmRoutingStrategy} + * (LLM-as-router) instead, and any user-supplied {@code RoutingStrategy} class can be plugged in + * the same way. + * + *

Prerequisite: a local Ollama server with the two models pulled (see {@code model} values + * below). + */ +public class LlmRoutingAgentExample { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private static final String SMALL_MODEL = "smallModel"; + private static final String BIG_MODEL = "bigModel"; + + public static void main(String[] args) throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + agentsEnv.getConfig().set(AgentExecutionOptions.NUM_ASYNC_THREADS, 2); + + // 1) One shared connection, and two candidate chat-model setups registered by name. + // 2) The router, registered as a CHAT_MODEL, that routes between the two candidates. + agentsEnv + .addResource( + "ollamaChatModelConnection", + ResourceType.CHAT_MODEL_CONNECTION, + CustomTypesAndResources.OLLAMA_SERVER_DESCRIPTOR) + .addResource(SMALL_MODEL, ResourceType.CHAT_MODEL, candidate("qwen3:1.7b")) + .addResource(BIG_MODEL, ResourceType.CHAT_MODEL, candidate("qwen3:8b")); + + File inputDataFile = copyResource("input_data.txt"); + + DataStream productReviewStream = + env.fromSource( + FileSource.forRecordStreamFormat( + new TextLineInputFormat(), + new Path(inputDataFile.getAbsolutePath())) + .monitorContinuously(Duration.ofMinutes(1)) + .build(), + WatermarkStrategy.noWatermarks(), + "llm-routing-example") + .map( + inputStr -> { + Row row = Row.withNames(); + CustomTypesAndResources.ProductReview productReview = + MAPPER.readValue( + inputStr, + CustomTypesAndResources.ProductReview.class); + row.setField("id", productReview.getId()); + row.setField("review", productReview.getReview()); + return row; + }); + + // The agent uses the router as its chat model — routing is fully transparent to the agent. + // Swap ruleBasedStrategy() for llmRoutingStrategy() to route with an LLM-as-router instead. + ReActAgent routedAgent = + new ReActAgent( + routerDescriptor(ruleBasedStrategy()), + CustomTypesAndResources.REVIEW_ANALYSIS_REACT_PROMPT, + CustomTypesAndResources.ProductReviewAnalysisRes.class); + + DataStream resultStream = + agentsEnv.fromDataStream(productReviewStream).apply(routedAgent).toDataStream(); + + resultStream.print(); + + agentsEnv.execute(); + } + + /** A candidate chat-model setup sharing the registered Ollama connection. */ + static ResourceDescriptor candidate(String model) { + return ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.OLLAMA_SETUP) + .addInitialArgument("connection", "ollamaChatModelConnection") + .addInitialArgument("model", model) + .build(); + } + + /** Candidate specs the router and strategies reason about (name + description). */ + static List candidateSpecs() { + List candidates = new ArrayList<>(); + Map small = new HashMap<>(); + small.put("name", SMALL_MODEL); + small.put("description", "Fast, cheap model for simple chit-chat and short answers."); + Map big = new HashMap<>(); + big.put("name", BIG_MODEL); + big.put("description", "Stronger model for code, SQL, and complex reasoning."); + candidates.add(small); + candidates.add(big); + return candidates; + } + + /** A router resource descriptor wrapping the given strategy. */ + static ResourceDescriptor routerDescriptor(ResourceDescriptor strategy) { + return ResourceDescriptor.Builder.newBuilder(ChatModelRouter.class.getName()) + .addInitialArgument(ChatModelRouter.ARG_CANDIDATES, candidateSpecs()) + .addInitialArgument(ChatModelRouter.ARG_STRATEGY, strategy) + .addInitialArgument(ChatModelRouter.ARG_FALLBACK, true) + .build(); + } + + /** + * Built-in rule-based strategy: keywords route to the larger model; otherwise the small one. + */ + static ResourceDescriptor ruleBasedStrategy() { + Map codeRule = new HashMap<>(); + codeRule.put("model", BIG_MODEL); + codeRule.put("keywords", Arrays.asList("code", "sql", "error", "exception", "stacktrace")); + + return ResourceDescriptor.Builder.newBuilder(RuleBasedRoutingStrategy.class.getName()) + .addInitialArgument("default", SMALL_MODEL) + .addInitialArgument("rules", Collections.singletonList(codeRule)) + .build(); + } + + /** + * Built-in LLM-as-router strategy: a small judge model chooses the candidate. Swap this into + * {@link #routerDescriptor(ResourceDescriptor)} to use it instead of the rule-based strategy. + */ + static ResourceDescriptor llmRoutingStrategy() { + return ResourceDescriptor.Builder.newBuilder(LlmRoutingStrategy.class.getName()) + .addInitialArgument("judge_model", SMALL_MODEL) + .addInitialArgument("default", SMALL_MODEL) + .build(); + } +} diff --git a/python/flink_agents/api/resource.py b/python/flink_agents/api/resource.py index 17a4c96d9..800cbfc9b 100644 --- a/python/flink_agents/api/resource.py +++ b/python/flink_agents/api/resource.py @@ -356,5 +356,17 @@ class Java: # Milvus MILVUS_VECTOR_STORE = "org.apache.flink.agents.integrations.vectorstores.milvus.MilvusVectorStore" + class RoutingStrategy: + """RoutingStrategy resource names (for a ChatModelRouter's ``strategy``).""" + + class Java: + """Java implementations of RoutingStrategy.""" + + # Rule-based (deterministic keyword/regex rules) + RULE_BASED = "org.apache.flink.agents.api.chat.model.routing.RuleBasedRoutingStrategy" + + # LLM-as-router (a judge model picks the candidate) + LLM = "org.apache.flink.agents.api.chat.model.routing.LlmRoutingStrategy" + # MCP resource names MCP_SERVER = "flink_agents.integrations.mcp.mcp.MCPServer" From 3f773d7f05107bafddac5e6787c73cd9ef4ecf24 Mon Sep 17 00:00:00 2001 From: purshotam shah Date: Tue, 16 Jun 2026 16:34:14 -0700 Subject: [PATCH 2/2] [api][routing] Address review: docs, empty-name guard, open-before-chat test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Review follow-ups from @weiqingy on the routing PR: - ChatModelRouter.open(): document the load-bearing invariant the no-op relies on — a routed candidate is lazily open()-ed by ResourceCache.getResource() on first resolution, so its connection is non-null before chat() runs. - CachingStrategy / LlmRoutingStrategy: soften "runs once per conversation" to "typically once" and document that memoization is best-effort (a concurrent first-touch on the same key can double-compute; synchronized map, last-writer- wins, benign — so no locking). - RoutingCandidate: reject an empty name (not just null) — an empty name has no resolvable resource and would make LlmRoutingStrategy.parseChoice's whole-token match over-match arbitrary boundaries (mis-route). - Tests: add ChatModelRouterTest cases pinning the open-before-chat invariant (candidate resolved through an opening ResourceContext, mirroring ResourceCache; plus the negative case proving it is load-bearing), and RoutingCandidateTest for the null/empty name guards. 39 routing tests pass; spotless:check clean under JDK 17. --- .../chat/model/routing/CachingStrategy.java | 8 ++- .../chat/model/routing/ChatModelRouter.java | 11 +++- .../model/routing/LlmRoutingStrategy.java | 5 +- .../chat/model/routing/RoutingCandidate.java | 8 ++- .../model/routing/ChatModelRouterTest.java | 48 ++++++++++++++ .../model/routing/RoutingCandidateTest.java | 47 ++++++++++++++ .../model/routing/RoutingTestSupport.java | 62 +++++++++++++++++++ 7 files changed, 184 insertions(+), 5 deletions(-) create mode 100644 api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingCandidateTest.java diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategy.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategy.java index 428d789ae..704ba0c48 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategy.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/CachingStrategy.java @@ -25,7 +25,13 @@ /** * A {@link RoutingStrategy} decorator that memoizes the wrapped strategy's decision per * conversation (keyed on {@link RoutingContext#firstUserMessage()}), so an expensive selection — - * e.g. an LLM judge — runs once per conversation rather than on every tool-call round. + * e.g. an LLM judge — typically runs once per conversation rather than on every tool-call round. + * + *

This is best-effort memoization, not a hard guarantee: the lookup and compute are not + * atomic, so two async-pool threads racing on a key's first touch may both miss and both + * invoke the delegate (last-writer-wins on the same key). The backing map is synchronized, so there + * is no corruption, and the redundant compute is benign — hence no locking. Once a value is cached, + * subsequent rounds are served from it. * *

The cache is a bounded LRU with real eviction (oldest entries are dropped past the * capacity), so it never grows without bound and never silently stops caching. Empty keys (requests diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouter.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouter.java index 338e9857e..962d36c8f 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouter.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouter.java @@ -155,10 +155,19 @@ public ChatModelRouter(ResourceDescriptor descriptor, ResourceContext resourceCo /** * The router has no connection of its own to resolve (it delegates to candidate models, each of * which resolves its own). Override to skip the base class's connection resolution. + * + *

Invariant this relies on: a routed candidate is resolved through {@code + * ResourceContext.getResource(name, CHAT_MODEL)} at {@link #chat} time, and the runtime {@code + * ResourceCache} lazily {@code open()}s a resource when it is first resolved — so a candidate's + * connection is non-null before its {@code chat()} runs. The router therefore does not need to + * open anything here; opening is the resolved candidate's responsibility, performed for it on + * first resolution. (Do not eagerly open candidates here: that would defeat the lazy, per-use + * resolution the cache provides.) */ @Override public void open() { - // no-op + // no-op; see the invariant in the Javadoc above (candidates are lazily opened on + // resolution) } /** diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategy.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategy.java index 034539777..67fe44180 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategy.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/LlmRoutingStrategy.java @@ -47,8 +47,9 @@ * transient judge failure (resolve/call error) it instead abstains by returning * {@code null}, so the router degrades to its own default and a wrapping {@link CachingStrategy} * does not pin the conversation to a fallback; the judge is retried on the next round. It does not - * cache — wrap it in {@link CachingStrategy} (the router does this by default) so the judge runs - * once per conversation. + * cache — wrap it in {@link CachingStrategy} (the router does this by default) so the judge + * typically runs once per conversation (best-effort; see {@link CachingStrategy} for the concurrent + * first-touch caveat). * *

Security note. The user's message is sent to the judge model, so this routing decision * is susceptible to prompt injection (a crafted message could steer the choice). Do not gate cost, diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingCandidate.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingCandidate.java index ea724ed80..d788ba489 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingCandidate.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/routing/RoutingCandidate.java @@ -39,7 +39,13 @@ public class RoutingCandidate { private final Map metadata; public RoutingCandidate(String name, String description, Map metadata) { - this.name = Objects.requireNonNull(name, "candidate name must not be null"); + Objects.requireNonNull(name, "candidate name must not be null"); + if (name.isEmpty()) { + // An empty name has no valid resource to resolve, and would make the whole-token match + // in LlmRoutingStrategy.parseChoice over-match arbitrary boundaries (mis-routing). + throw new IllegalArgumentException("candidate name must not be empty"); + } + this.name = name; this.description = description != null ? description : ""; this.metadata = metadata != null diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouterTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouterTest.java index e05633b00..1340f81a4 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouterTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/ChatModelRouterTest.java @@ -430,6 +430,54 @@ void testNullResourceContextRejected() throws Exception { Collections.emptyMap())); } + @Test + @DisplayName("a routed candidate is open()-ed before chat() (open-before-chat invariant)") + void testCandidateOpenedBeforeChat() throws Exception { + Map registry = new HashMap<>(); + registry.put("small", new RoutingTestSupport.OpenRequiringModel("small")); + RoutingTestSupport.OpenRequiringModel big = + new RoutingTestSupport.OpenRequiringModel("big"); + registry.put("big", big); + // A context that opens a resource on resolution, like the runtime + // ResourceCache.getResource() + // (ResourceCache.java calls resource.open() before returning it). This pins the invariant + // the + // router's no-op open() relies on: the chosen candidate is opened before its chat() runs. + ResourceContext ctx = RoutingTestSupport.openingContext(registry); + + ChatModelRouter router = router(ruleStrategy(), false, Arrays.asList("small", "big"), ctx); + ChatMessage response = + router.chat( + Collections.singletonList(RoutingTestSupport.user("please write code")), + Collections.emptyMap(), + Collections.emptyMap()); + + assertEquals("handled-by:big", response.getContent()); + assertTrue(big.opened, "the routed candidate must have been opened before chat()"); + } + + @Test + @DisplayName("the open-before-chat invariant is load-bearing: a non-opened candidate fails") + void testCandidateNotOpenedFails() throws Exception { + // Same candidates, but a plain context that does NOT open on resolution -> chat() must + // fail, + // proving the invariant above is real and not vacuously satisfied. + Map registry = new HashMap<>(); + registry.put("small", new RoutingTestSupport.OpenRequiringModel("small")); + registry.put("big", new RoutingTestSupport.OpenRequiringModel("big")); + ResourceContext ctx = RoutingTestSupport.context(registry); + + ChatModelRouter router = router(ruleStrategy(), false, Arrays.asList("small", "big"), ctx); + assertThrows( + RuntimeException.class, + () -> + router.chat( + Collections.singletonList( + RoutingTestSupport.user("please write code")), + Collections.emptyMap(), + Collections.emptyMap())); + } + @Test @DisplayName("connection name is non-null so retry metrics never NPE") void testStableConnectionName() throws Exception { diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingCandidateTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingCandidateTest.java new file mode 100644 index 000000000..ce77aa39b --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingCandidateTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.chat.model.routing; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** Tests for {@link RoutingCandidate} construction guards. */ +class RoutingCandidateTest { + + @Test + @DisplayName("a null candidate name is rejected") + void testRejectsNullName() { + assertThrows(NullPointerException.class, () -> new RoutingCandidate(null)); + } + + @Test + @DisplayName("an empty candidate name is rejected (would over-match in whole-token parsing)") + void testRejectsEmptyName() { + assertThrows(IllegalArgumentException.class, () -> new RoutingCandidate("")); + } + + @Test + @DisplayName("a non-empty name is accepted") + void testAcceptsName() { + assertEquals("gpt-4o", new RoutingCandidate("gpt-4o").getName()); + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingTestSupport.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingTestSupport.java index 2610238e2..90a7a70ed 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingTestSupport.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/routing/RoutingTestSupport.java @@ -117,6 +117,47 @@ public ChatMessage chat( } } + /** + * A chat model that must be {@code open()}-ed before {@code chat()} — a stand-in for a real + * {@code BaseChatModelSetup} whose backend connection is resolved in {@code open()} and would + * be {@code null} (NPE on {@code chat()}) otherwise. Used to pin the router's open-before-chat + * invariant. + */ + static final class OpenRequiringModel extends BaseChatModelSetup { + final String tag; + boolean opened = false; + int callCount = 0; + + OpenRequiringModel(String tag) { + super(emptyDescriptor(OpenRequiringModel.class), null); + this.tag = tag; + } + + @Override + public void open() { + // Stand-in for resolving the backend connection; must run before chat(). + this.opened = true; + } + + @Override + public Map getParameters() { + return Collections.emptyMap(); + } + + @Override + public ChatMessage chat( + List messages, + Map promptArgs, + Map modelParams) { + if (!opened) { + throw new IllegalStateException( + "chat() called before open(): the backend connection would be null"); + } + this.callCount++; + return new ChatMessage(MessageRole.ASSISTANT, "handled-by:" + tag); + } + } + /** A {@link ResourceContext} backed by a fixed name → resource map. */ static ResourceContext context(Map byName) { return ResourceContext.fromGetResource( @@ -129,6 +170,27 @@ static ResourceContext context(Map byName) { }); } + /** + * A {@link ResourceContext} that {@code open()}s each resolved resource before returning it, + * mirroring the runtime {@code ResourceCache.getResource()} contract (lazy open on first + * resolution). Use this to exercise the router's open-before-chat invariant. + */ + static ResourceContext openingContext(Map byName) { + return ResourceContext.fromGetResource( + (name, type) -> { + Resource resource = byName.get(name); + if (resource == null) { + throw new RuntimeException("No resource registered for name: " + name); + } + try { + resource.open(); + } catch (Exception e) { + throw new RuntimeException("Failed to open resource: " + name, e); + } + return resource; + }); + } + static ChatMessage user(String content) { return new ChatMessage(MessageRole.USER, content); }