diff --git a/splunklib/ai/base_agent.py b/splunklib/ai/base_agent.py index 3e9de535..596d452e 100644 --- a/splunklib/ai/base_agent.py +++ b/splunklib/ai/base_agent.py @@ -94,9 +94,9 @@ def __init__( ] self._middleware = ( - *{m for m in predefined_before if type(m) not in user_middleware_types}, + *[m for m in predefined_before if type(m) not in user_middleware_types], *user_middleware, - *{m for m in predefined_after if type(m) not in user_middleware_types}, + *[m for m in predefined_after if type(m) not in user_middleware_types], ) self._trace_id = secrets.token_hex(16) # 32 Hex characters diff --git a/tests/unit/ai/test_default_limits.py b/tests/unit/ai/test_default_limits.py index 66957f8d..714a58b2 100644 --- a/tests/unit/ai/test_default_limits.py +++ b/tests/unit/ai/test_default_limits.py @@ -26,6 +26,7 @@ TimeoutLimitMiddleware, TokenLimitExceededException, TokenLimitMiddleware, + StructuredOutputRetryLimitMiddleware, ) from splunklib.ai.messages import AIMessage, AgentResponse from splunklib.ai.middleware import AgentMiddleware, AgentRequest, AgentState, ModelRequest, ModelResponse @@ -173,3 +174,12 @@ async def test_raises_when_steps_in_request_reach_limit(self) -> None: await mw.model_middleware(_make_model_request(total_steps=2), _noop_model_handler) with self.assertRaises(StepsLimitExceededException): await mw.model_middleware(_make_model_request(total_steps=3), _noop_model_handler) + + +def test_default_middleware() -> None: + agent = _make_agent() + mw = list(agent.middleware or []) + assert isinstance(mw[0], StructuredOutputRetryLimitMiddleware) + assert isinstance(mw[1], TokenLimitMiddleware) + assert isinstance(mw[2], StepLimitMiddleware) + assert isinstance(mw[3], TimeoutLimitMiddleware)