diff --git a/go/adk/pkg/a2a/executor.go b/go/adk/pkg/a2a/executor.go index aa5957c42..b8a49ea12 100644 --- a/go/adk/pkg/a2a/executor.go +++ b/go/adk/pkg/a2a/executor.go @@ -11,6 +11,7 @@ import ( "github.com/a2aproject/a2a-go/a2asrv" "github.com/a2aproject/a2a-go/a2asrv/eventqueue" "github.com/go-logr/logr" + "github.com/kagent-dev/kagent/go/adk/pkg/auth" "github.com/kagent-dev/kagent/go/adk/pkg/models" "github.com/kagent-dev/kagent/go/adk/pkg/session" "github.com/kagent-dev/kagent/go/adk/pkg/skills" @@ -117,6 +118,7 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont sessionID := reqCtx.ContextID ctx = withBearerToken(ctx) + ctx = auth.WithUserID(ctx, userID) e.logger.Info("Execute", "taskID", reqCtx.TaskID, diff --git a/go/adk/pkg/auth/token.go b/go/adk/pkg/auth/token.go index e4a2092d5..bd857bcb4 100644 --- a/go/adk/pkg/auth/token.go +++ b/go/adk/pkg/auth/token.go @@ -8,6 +8,21 @@ import ( "time" ) +type contextKey int + +const userIDKey contextKey = iota + +// WithUserID returns a copy of ctx that carries the user ID for injection into +// outgoing HTTP requests by TokenRoundTripper. +func WithUserID(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, userIDKey, userID) +} + +func userIDFromContext(ctx context.Context) string { + id, _ := ctx.Value(userIDKey).(string) + return id +} + const kagentTokenPath = "/var/run/secrets/tokens/kagent-token" // KAgentTokenService reads a k8s token from a file and reloads it periodically @@ -61,6 +76,9 @@ func (s *KAgentTokenService) AddHeaders(req *http.Request) { if token := s.GetToken(); token != "" { req.Header.Set("Authorization", "Bearer "+token) } + if userID := userIDFromContext(req.Context()); userID != "" { + req.Header.Set("X-User-Id", userID) + } } // readToken reads the token from the file diff --git a/go/core/internal/httpserver/auth/proxy_authn.go b/go/core/internal/httpserver/auth/proxy_authn.go index d3bcd98fa..6ed7154e5 100644 --- a/go/core/internal/httpserver/auth/proxy_authn.go +++ b/go/core/internal/httpserver/auth/proxy_authn.go @@ -27,67 +27,74 @@ func NewProxyAuthenticator(userIDClaim string) *ProxyAuthenticator { func (a *ProxyAuthenticator) Authenticate(ctx context.Context, reqHeaders http.Header, query url.Values) (auth.Session, error) { authHeader := reqHeaders.Get("Authorization") - - // Always read agent identity from X-Agent-Name header (used by agents calling back) agentID := reqHeaders.Get("X-Agent-Name") - // If we have a Bearer token, parse JWT - if tokenString, ok := strings.CutPrefix(authHeader, "Bearer "); ok { - // Parse JWT without validation (oauth2-proxy or k8s service account already validated) - rawClaims, err := parseJWTPayload(tokenString) - if err != nil { - return nil, ErrUnauthenticated - } + tokenString, ok := strings.CutPrefix(authHeader, "Bearer ") + if !ok { + return nil, ErrUnauthenticated + } + + // Parse JWT without validation (oauth2-proxy or k8s service account already validated) + rawClaims, err := parseJWTPayload(tokenString) + if err != nil { + return nil, ErrUnauthenticated + } - userID, _ := rawClaims[a.userIDClaim].(string) - if userID == "" && a.userIDClaim != "sub" { + if agentID != "" { + // Agent call: the Bearer SA token authenticates the pod; the caller's + // identity should be supplied explicitly via X-User-Id / user_id. + // Fall back to the SA sub claim for direct calls to agent pods that + // do not yet propagate the caller identity. + userID := userIDFromRequest(reqHeaders, query) + if userID == "" { userID, _ = rawClaims["sub"].(string) } if userID == "" { return nil, ErrUnauthenticated } - return &SimpleSession{ P: auth.Principal{ - User: auth.User{ID: userID}, - Agent: auth.Agent{ID: agentID}, - Claims: rawClaims, + User: auth.User{ID: userID}, + Agent: auth.Agent{ID: agentID}, }, authHeader: authHeader, }, nil } - // Fall back to service account auth for internal agent-to-controller calls. - // Requires X-Agent-Name to identify the calling agent. - if agentID == "" { - return nil, ErrUnauthenticated - } - - // Agents authenticate via user_id query param or X-User-Id header - userID := query.Get("user_id") - if userID == "" { - userID = reqHeaders.Get("X-User-Id") + // Direct user call: identity comes from the OIDC JWT claims. + userID, _ := rawClaims[a.userIDClaim].(string) + if userID == "" && a.userIDClaim != "sub" { + userID, _ = rawClaims["sub"].(string) } if userID == "" { return nil, ErrUnauthenticated } - return &SimpleSession{ P: auth.Principal{ - User: auth.User{ - ID: userID, - }, - Agent: auth.Agent{ - ID: agentID, - }, + User: auth.User{ID: userID}, + Claims: rawClaims, }, authHeader: authHeader, }, nil } +// userIDFromRequest returns the user identity from the user_id query param or +// X-User-Id header, preferring the query param. +func userIDFromRequest(headers http.Header, query url.Values) string { + if v := query.Get("user_id"); v != "" { + return v + } + return headers.Get("X-User-Id") +} + func (a *ProxyAuthenticator) UpstreamAuth(r *http.Request, session auth.Session, upstreamPrincipal auth.Principal) error { - if simpleSession, ok := session.(*SimpleSession); ok && simpleSession.authHeader != "" { - r.Header.Set("Authorization", simpleSession.authHeader) + if simpleSession, ok := session.(*SimpleSession); ok { + if simpleSession.authHeader != "" { + r.Header.Set("Authorization", simpleSession.authHeader) + } + if userID := simpleSession.P.User.ID; userID != "" { + r.Header.Set("X-User-Id", userID) + } } return nil } diff --git a/go/core/internal/httpserver/auth/proxy_authn_test.go b/go/core/internal/httpserver/auth/proxy_authn_test.go index abd18dd74..b6af3260d 100644 --- a/go/core/internal/httpserver/auth/proxy_authn_test.go +++ b/go/core/internal/httpserver/auth/proxy_authn_test.go @@ -159,112 +159,63 @@ func TestProxyAuthenticator_Authenticate(t *testing.T) { } } -func TestProxyAuthenticator_JWTWithAgentHeader(t *testing.T) { +func TestProxyAuthenticator_AgentCalls(t *testing.T) { tests := []struct { name string - claims map[string]any - agentName string + headers map[string]string + queryParams map[string]string wantUserID string wantAgentID string + wantErr bool }{ { - name: "extracts agent identity from header when JWT is present", - claims: map[string]any{ - "sub": "system:serviceaccount:kagent:kebab-agent", - "iss": "https://kubernetes.default.svc.cluster.local", - "aud": []any{"kagent"}, + name: "agent with SA Bearer token and X-User-Id header uses header identity", + headers: map[string]string{ + "Authorization": "Bearer " + createTestJWT(map[string]any{"sub": "system:serviceaccount:kagent:test-agent"}), + "X-Agent-Name": "kagent/test-agent", + "X-User-Id": "user@example.com", }, - agentName: "kagent__NS__kebab_agent", - wantUserID: "system:serviceaccount:kagent:kebab-agent", - wantAgentID: "kagent__NS__kebab_agent", + wantUserID: "user@example.com", + wantAgentID: "kagent/test-agent", }, { - name: "works with OIDC JWT and agent header", - claims: map[string]any{ - "sub": "user123", - "email": "user@example.com", + name: "agent with SA Bearer token and user_id query param uses query identity", + headers: map[string]string{ + "Authorization": "Bearer " + createTestJWT(map[string]any{"sub": "system:serviceaccount:kagent:test-agent"}), + "X-Agent-Name": "kagent/test-agent", }, - agentName: "kagent__NS__my_agent", - wantUserID: "user123", - wantAgentID: "kagent__NS__my_agent", - }, - { - name: "handles JWT without agent header", - claims: map[string]any{ - "sub": "user123", + queryParams: map[string]string{ + "user_id": "user@example.com", }, - agentName: "", - wantUserID: "user123", - wantAgentID: "", + wantUserID: "user@example.com", + wantAgentID: "kagent/test-agent", }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - auth := authimpl.NewProxyAuthenticator("") - - headers := http.Header{} - token := createTestJWT(tt.claims) - headers.Set("Authorization", "Bearer "+token) - if tt.agentName != "" { - headers.Set("X-Agent-Name", tt.agentName) - } - - session, err := auth.Authenticate(context.Background(), headers, url.Values{}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - principal := session.Principal() - if principal.User.ID != tt.wantUserID { - t.Errorf("User.ID = %q, want %q", principal.User.ID, tt.wantUserID) - } - if principal.Agent.ID != tt.wantAgentID { - t.Errorf("Agent.ID = %q, want %q", principal.Agent.ID, tt.wantAgentID) - } - }) - } -} - -func TestProxyAuthenticator_ServiceAccountFallback(t *testing.T) { - tests := []struct { - name string - headers map[string]string - queryParams map[string]string - wantUserID string - wantAgentID string - wantErr bool - }{ { - name: "authenticates via user_id query param with agent name", - queryParams: map[string]string{ - "user_id": "system:serviceaccount:kagent:kebab-agent", - }, + name: "agent with no X-User-Id falls back to SA sub claim", headers: map[string]string{ - "X-Agent-Name": "kagent/kebab-agent", + "Authorization": "Bearer " + createTestJWT(map[string]any{"sub": "system:serviceaccount:kagent:test-agent"}), + "X-Agent-Name": "kagent/test-agent", }, - wantUserID: "system:serviceaccount:kagent:kebab-agent", - wantAgentID: "kagent/kebab-agent", - wantErr: false, + wantUserID: "system:serviceaccount:kagent:test-agent", + wantAgentID: "kagent/test-agent", }, + // Error cases. { - name: "authenticates via X-User-Id header with agent name", + name: "agent without Bearer token is rejected", headers: map[string]string{ - "X-User-Id": "system:serviceaccount:kagent:test-agent", "X-Agent-Name": "kagent/test-agent", + "X-User-Id": "user@example.com", }, - wantUserID: "system:serviceaccount:kagent:test-agent", - wantAgentID: "kagent/test-agent", - wantErr: false, + wantErr: true, }, { - name: "returns error when no auth method available", + name: "no token and no X-Agent-Name is rejected", wantErr: true, }, { - name: "returns error when no X-Agent-Name header for fallback", + name: "user_id without X-Agent-Name is rejected", queryParams: map[string]string{ - "user_id": "system:serviceaccount:kagent:kebab-agent", + "user_id": "user@example.com", }, wantErr: true, }, @@ -339,4 +290,9 @@ func TestProxyAuthenticator_UpstreamAuth(t *testing.T) { if got := req.Header.Get("Authorization"); got != authHeader { t.Errorf("Authorization header = %q, want %q", got, authHeader) } + + // Verify X-User-Id is forwarded so downstream A2A runtimes receive the real user identity + if got := req.Header.Get("X-User-Id"); got != "user123" { + t.Errorf("X-User-Id header = %q, want %q", got, "user123") + } } diff --git a/python/packages/kagent-adk/src/kagent/adk/_session_service.py b/python/packages/kagent-adk/src/kagent/adk/_session_service.py index da08895a5..d344d89e1 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_session_service.py +++ b/python/packages/kagent-adk/src/kagent/adk/_session_service.py @@ -49,7 +49,6 @@ async def create_session( response = await self.client.post( "/api/sessions", json=request_data, - headers={"X-User-ID": user_id}, ) response.raise_for_status() @@ -88,10 +87,7 @@ async def get_session( url += "&limit=-1" # Make API call to get session - response: httpx.Response = await self.client.get( - url, - headers={"X-User-ID": user_id}, - ) + response: httpx.Response = await self.client.get(url) if response.status_code == 404: return None response.raise_for_status() @@ -131,7 +127,7 @@ async def get_session( @override async def list_sessions(self, *, app_name: str, user_id: str) -> ListSessionsResponse: # Make API call to list sessions - response = await self.client.get(f"/api/sessions?user_id={user_id}", headers={"X-User-ID": user_id}) + response = await self.client.get(f"/api/sessions?user_id={user_id}") response.raise_for_status() data = response.json() @@ -151,10 +147,7 @@ def list_sessions_sync(self, *, app_name: str, user_id: str) -> ListSessionsResp @override async def delete_session(self, *, app_name: str, user_id: str, session_id: str) -> None: # Make API call to delete session - response = await self.client.delete( - f"/api/sessions/{session_id}?user_id={user_id}", - headers={"X-User-ID": user_id}, - ) + response = await self.client.delete(f"/api/sessions/{session_id}?user_id={user_id}") response.raise_for_status() @override @@ -172,7 +165,6 @@ async def append_event(self, session: Session, event: Event) -> Event: response = await self.client.post( f"/api/sessions/{session.id}/events?user_id={session.user_id}", json=event_data, - headers={"X-User-ID": session.user_id}, ) response.raise_for_status() diff --git a/python/packages/kagent-adk/src/kagent/adk/_token.py b/python/packages/kagent-adk/src/kagent/adk/_token.py index 34bae4988..07fe74adb 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_token.py +++ b/python/packages/kagent-adk/src/kagent/adk/_token.py @@ -4,6 +4,7 @@ from typing import Any, Optional import httpx +from kagent.core.a2a import get_request_user_id KAGENT_TOKEN_PATH = "/var/run/secrets/tokens/kagent-token" logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ def event_hooks(self): """Returns a dictionary of event hooks for the application to use when creating the httpx.AsyncClient. """ - return {"request": [self._add_bearer_token]} + return {"request": [self._add_headers]} async def _update_token_loop(self) -> None: self.token = await self._read_kagent_token() @@ -61,12 +62,13 @@ async def _refresh_token(self): async with self.update_lock: self.token = token - async def _add_bearer_token(self, request: httpx.Request): - # Your function to generate headers dynamically + async def _add_headers(self, request: httpx.Request): token = await self._get_token() headers = {"X-Agent-Name": self.app_name} if token: headers["Authorization"] = f"Bearer {token}" + if user_id := get_request_user_id(): + headers["X-User-Id"] = user_id request.headers.update(headers) diff --git a/python/packages/kagent-core/src/kagent/core/a2a/__init__.py b/python/packages/kagent-core/src/kagent/core/a2a/__init__.py index 70da62764..3de48d635 100644 --- a/python/packages/kagent-core/src/kagent/core/a2a/__init__.py +++ b/python/packages/kagent-core/src/kagent/core/a2a/__init__.py @@ -1,4 +1,5 @@ from ._config import get_a2a_max_content_length +from ._context import get_request_user_id, set_request_user_id from ._consts import ( A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT, @@ -33,6 +34,8 @@ __all__ = [ "get_a2a_max_content_length", + "get_request_user_id", + "set_request_user_id", "KAgentRequestContextBuilder", "KAgentTaskStore", "get_kagent_metadata_key", diff --git a/python/packages/kagent-core/src/kagent/core/a2a/_context.py b/python/packages/kagent-core/src/kagent/core/a2a/_context.py new file mode 100644 index 000000000..2081c9b8f --- /dev/null +++ b/python/packages/kagent-core/src/kagent/core/a2a/_context.py @@ -0,0 +1,17 @@ +from contextvars import ContextVar + +_current_user_id: ContextVar[str | None] = ContextVar("kagent_user_id", default=None) + + +def set_request_user_id(user_id: str | None) -> None: + """Store the caller's user ID for the current async context. + + Must be called before any outgoing HTTP requests to the kagent controller + so that the token service event hook can inject X-User-Id. + """ + _current_user_id.set(user_id) + + +def get_request_user_id() -> str | None: + """Return the caller's user ID for the current async context.""" + return _current_user_id.get() diff --git a/python/packages/kagent-core/src/kagent/core/a2a/_requests.py b/python/packages/kagent-core/src/kagent/core/a2a/_requests.py index 35a4e2670..13b36ffa9 100644 --- a/python/packages/kagent-core/src/kagent/core/a2a/_requests.py +++ b/python/packages/kagent-core/src/kagent/core/a2a/_requests.py @@ -6,6 +6,8 @@ from a2a.server.tasks import TaskStore from a2a.types import MessageSendParams, Task +from ._context import set_request_user_id + # --- Configure Logging --- logger = logging.getLogger(__name__) @@ -47,6 +49,7 @@ async def build( user_id = headers.get("x-user-id", None) if user_id: context.user = KAgentUser(user_id=user_id) + set_request_user_id(user_id) # Propagate x-kagent-source so downstream code (e.g. session # creation) can tag this session as agent-originated. source = headers.get("x-kagent-source", None)