Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions agent_assembly/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"ToolExecutionBlockedError",
"MCPToolBlockedError",
"PolicyViolationError",
"OpTerminatedError",
]


Expand Down Expand Up @@ -60,3 +61,17 @@ def __init__(

class PolicyViolationError(ToolExecutionBlockedError):
"""Exception raised when policy blocks tool execution."""


class OpTerminatedError(AssemblyError):
"""Raised when the gateway terminates an in-flight op (AAASM-1422 PR-E).

Carries the originating `op_id` so callers can correlate the failure
against the operation they were awaiting. Surfaced by
`OpControlSubscriber.await_op` when an `OP_CONTROL_SIGNAL_TERMINATE`
arrives for the awaited op.
"""

def __init__(self, message: str, *, op_id: str) -> None:
super().__init__(message)
self.op_id = op_id
216 changes: 216 additions & 0 deletions agent_assembly/op_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""Gateway → SDK op-control consumer (AAASM-1422 PR-E / AAASM-1654).

Subscribes to ``PolicyService.OpControlStream`` and exposes a per-``op_id``
cooperative-pause / fast-fail-terminate state machine through ``await_op``.

The subscriber runs on a daemon background thread that reads the gRPC stream
and dispatches each ``OpControlMessage`` to a per-op state slot. Application
code awaits the slot via :meth:`OpControlSubscriber.await_op`:

* ``OP_CONTROL_SIGNAL_PAUSE`` → ``await_op`` blocks until ``RESUME`` arrives.
* ``OP_CONTROL_SIGNAL_RESUME`` → ``await_op`` returns immediately.
* ``OP_CONTROL_SIGNAL_TERMINATE`` → ``await_op`` raises
:class:`agent_assembly.exceptions.OpTerminatedError`.

If a signal arrives for an ``op_id`` no one is currently awaiting, it's
buffered into the per-op slot so the next ``await_op`` call sees it.

Out of scope for PR-E (deferred):
- Reconnection / heartbeat on stream close (caller observes via
``stream_alive`` and re-instantiates if desired).
- Auto-wiring into ``init_assembly`` / adapter ``check_action`` hooks
(separate sub-task once the adapter surface is stable).
"""

from __future__ import annotations

import threading
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Protocol

import grpc

from agent_assembly.exceptions import OpTerminatedError
from agent_assembly.proto import common_pb2, policy_pb2, policy_pb2_grpc

__all__ = ["OpControlSubscriber", "OpControlState"]


class _OpControlStub(Protocol):
"""Structural type for the gRPC stub method this module needs.

Lets tests inject a hand-rolled stub without standing up a gRPC server.
"""

def OpControlStream( # noqa: N802 — gRPC method name
self,
request: policy_pb2.OpControlSubscribeRequest,
) -> Iterator[policy_pb2.OpControlMessage]: ...


@dataclass
class OpControlState:
"""Per-op state slot used by the cooperative-pause machine.

Each ``op_id`` the subscriber observes gets one slot. ``await_op`` blocks
on ``event`` whenever ``paused`` is set; on terminate the slot's
``terminated`` flag is latched and subsequent ``await_op`` calls raise
immediately without blocking.
"""

event: threading.Event = field(default_factory=threading.Event)
paused: bool = False
terminated: bool = False


class OpControlSubscriber:
"""Subscribe to OpControlStream and serve per-op pause/terminate signals.

Construct via :meth:`connect`, never directly — the constructor takes a
pre-wired stub so tests can mock the gRPC layer without touching the
network.

Thread-safe: ``await_op`` may be called from any thread; the underlying
state is guarded by an internal ``threading.Lock``.
"""

def __init__(self, stub: _OpControlStub, agent_id: common_pb2.AgentId) -> None:
self._stub = stub
self._agent_id = agent_id
self._lock = threading.Lock()
self._ops: dict[str, OpControlState] = {}
self._stream_alive = threading.Event()
self._stream_alive.set()
self._reader: threading.Thread | None = None
self._call: grpc.RpcContext | None = None

@classmethod
def connect(
cls,
gateway_url: str,
*,
org_id: str,
team_id: str,
agent_id: str,
channel_factory: grpc.Channel | None = None,
) -> OpControlSubscriber:
"""Open the gRPC channel + subscription stream and start the reader.

``gateway_url`` is the ``host:port`` of the gateway's gRPC endpoint
(no scheme; gRPC uses its own). When ``channel_factory`` is supplied
(tests), it's used instead of opening a fresh insecure channel.
"""
channel = channel_factory or grpc.insecure_channel(gateway_url)
stub = policy_pb2_grpc.PolicyServiceStub(channel) # type: ignore[no-untyped-call]
proto_agent_id = common_pb2.AgentId(org_id=org_id, team_id=team_id, agent_id=agent_id)
subscriber = cls(stub, proto_agent_id)
subscriber._start()
return subscriber

def _start(self) -> None:
"""Open the stream + spawn the reader thread.

Separated from ``connect`` so tests can construct a subscriber with
a hand-rolled stub and call ``_start`` themselves.
"""
request = policy_pb2.OpControlSubscribeRequest(agent_id=self._agent_id)
self._call = self._stub.OpControlStream(request)
self._reader = threading.Thread(
target=self._reader_loop,
name=f"aa-op-control-{self._agent_id.agent_id}",
daemon=True,
)
self._reader.start()

def _reader_loop(self) -> None:
"""Drain the stream and dispatch each message to the matching op slot."""
try:
for message in self._call: # type: ignore[union-attr]
self._dispatch(message)
except grpc.RpcError:
# Stream closed (server shutdown, network drop, etc.) — fall through
# to mark the stream dead so await_op can detect it.
pass
finally:
self._stream_alive.clear()
# Wake any currently-blocked awaiters so they can re-check state.
with self._lock:
for state in self._ops.values():
state.event.set()

def _dispatch(self, message: policy_pb2.OpControlMessage) -> None:
"""Apply one inbound signal to the per-op state slot."""
with self._lock:
state = self._ops.setdefault(message.op_id, OpControlState())
signal = message.signal
if signal == policy_pb2.OP_CONTROL_SIGNAL_PAUSE:
state.paused = True
state.event.clear()
elif signal == policy_pb2.OP_CONTROL_SIGNAL_RESUME:
state.paused = False
state.event.set()
elif signal == policy_pb2.OP_CONTROL_SIGNAL_TERMINATE:
state.terminated = True
state.event.set()

def await_op(self, op_id: str, *, timeout: float | None = None) -> None:
"""Block until ``op_id`` is runnable, or raise on terminate.

Returns immediately when the op is not currently paused. When paused,
blocks on the per-op event up to ``timeout`` seconds. Raises
:class:`OpTerminatedError` if the op has been (or becomes) terminated.

A timeout returns normally — the caller can inspect ``is_paused`` or
retry. This matches the cooperative-pause expectation in the
architecture doc (the SDK yields, it doesn't deadline-enforce).
"""
with self._lock:
state = self._ops.setdefault(op_id, OpControlState())
if state.terminated:
raise OpTerminatedError(
f"op {op_id} was terminated by the gateway",
op_id=op_id,
)
if not state.paused:
return
event = state.event

# Drop the lock while we wait so the reader thread can update state.
event.wait(timeout=timeout)

with self._lock:
if state.terminated:
raise OpTerminatedError(
f"op {op_id} was terminated by the gateway",
op_id=op_id,
)

def is_paused(self, op_id: str) -> bool:
"""Return True iff the gateway has the op currently paused."""
with self._lock:
state = self._ops.get(op_id)
return state.paused if state else False

def is_terminated(self, op_id: str) -> bool:
"""Return True iff the gateway has terminated the op."""
with self._lock:
state = self._ops.get(op_id)
return state.terminated if state else False

def stream_alive(self) -> bool:
"""Return False once the underlying gRPC stream has closed."""
return self._stream_alive.is_set()

def close(self) -> None:
"""Cancel the stream and join the reader thread."""
if self._call is not None:
self._call.cancel()
if self._reader is not None:
self._reader.join(timeout=2.0)

def __enter__(self) -> OpControlSubscriber:
return self

def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
self.close()
Empty file.
44 changes: 44 additions & 0 deletions agent_assembly/proto/common_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions agent_assembly/proto/common_pb2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional

DESCRIPTOR: _descriptor.FileDescriptor

class Decision(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
DECISION_UNSPECIFIED: _ClassVar[Decision]
ALLOW: _ClassVar[Decision]
DENY: _ClassVar[Decision]
PENDING: _ClassVar[Decision]
REDACT: _ClassVar[Decision]

class ActionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
ACTION_UNSPECIFIED: _ClassVar[ActionType]
LLM_CALL: _ClassVar[ActionType]
TOOL_CALL: _ClassVar[ActionType]
FILE_OPERATION: _ClassVar[ActionType]
NETWORK_CALL: _ClassVar[ActionType]
PROCESS_EXEC: _ClassVar[ActionType]
AGENT_SPAWN: _ClassVar[ActionType]

class RiskTier(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
RISK_UNSPECIFIED: _ClassVar[RiskTier]
LOW: _ClassVar[RiskTier]
MEDIUM: _ClassVar[RiskTier]
HIGH: _ClassVar[RiskTier]
CRITICAL: _ClassVar[RiskTier]
DECISION_UNSPECIFIED: Decision
ALLOW: Decision
DENY: Decision
PENDING: Decision
REDACT: Decision
ACTION_UNSPECIFIED: ActionType
LLM_CALL: ActionType
TOOL_CALL: ActionType
FILE_OPERATION: ActionType
NETWORK_CALL: ActionType
PROCESS_EXEC: ActionType
AGENT_SPAWN: ActionType
RISK_UNSPECIFIED: RiskTier
LOW: RiskTier
MEDIUM: RiskTier
HIGH: RiskTier
CRITICAL: RiskTier

class AgentId(_message.Message):
__slots__ = ("org_id", "team_id", "agent_id")
ORG_ID_FIELD_NUMBER: _ClassVar[int]
TEAM_ID_FIELD_NUMBER: _ClassVar[int]
AGENT_ID_FIELD_NUMBER: _ClassVar[int]
org_id: str
team_id: str
agent_id: str
def __init__(self, org_id: _Optional[str] = ..., team_id: _Optional[str] = ..., agent_id: _Optional[str] = ...) -> None: ...

class Timestamp(_message.Message):
__slots__ = ("unix_ms",)
UNIX_MS_FIELD_NUMBER: _ClassVar[int]
unix_ms: int
def __init__(self, unix_ms: _Optional[int] = ...) -> None: ...
24 changes: 24 additions & 0 deletions agent_assembly/proto/common_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings


GRPC_GENERATED_VERSION = '1.80.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False

try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True

if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ ' but the generated code in common_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
Loading
Loading