diff --git a/.betterer.results b/.betterer.results index 79aa3963..169b85f2 100644 --- a/.betterer.results +++ b/.betterer.results @@ -9,8 +9,8 @@ exports[`TypeScript Strict Mode`] = { [327, 15, 4, "tsc: Expected 2 arguments, but got 1.", "2087764327"], [350, 15, 4, "tsc: Expected 2 arguments, but got 1.", "2087764327"] ], - "src/server/transports/jsonrpc/jsonrpc_transport_handler.ts:1878858438": [ - [89, 12, 10, "tsc: Variable \'rpcRequest\' is used before being assigned.", "3927050741"] + "src/server/transports/jsonrpc/jsonrpc_transport_handler.ts:3366537262": [ + [90, 12, 10, "tsc: Variable \'rpcRequest\' is used before being assigned.", "3927050741"] ] }` }; diff --git a/src/client/factory.ts b/src/client/factory.ts index b97ccb4e..e49d07f9 100644 --- a/src/client/factory.ts +++ b/src/client/factory.ts @@ -4,6 +4,7 @@ import { AgentCardResolver } from './card-resolver.js'; import { Client, ClientConfig } from './multitransport-client.js'; import { JsonRpcTransportFactory } from './transports/json_rpc_transport.js'; import { RestTransportFactory } from './transports/rest_transport.js'; +import { TenantTransportDecorator } from './transports/tenant_transport_decorator.js'; import { TransportFactory } from './transports/transport.js'; export interface ClientFactoryOptions { @@ -95,29 +96,40 @@ export class ClientFactory { /** * Creates a new client from the provided agent card. + * + * When the selected `AgentInterface` declares a non-empty `tenant` value + * (per spec Section 4.4.6), the transport is automatically wrapped with a + * {@link TenantTransportDecorator} so the default tenant is applied to every + * request without requiring callers to set it manually. */ async createFromAgentCard(agentCard: AgentCard): Promise { const interfaces = agentCard.supportedInterfaces ?? []; - const urlsPerAgentTransports = new CaseInsensitiveMap(); - for (const i of interfaces) { - const existing = urlsPerAgentTransports.get(i.protocolBinding); - if (!existing || i.protocolVersion === '1.0') { - urlsPerAgentTransports.set(i.protocolBinding, i.url); + + const bestInterfacePerProtocol = new CaseInsensitiveMap<(typeof interfaces)[number]>(); + for (const agentInterface of interfaces) { + const existing = bestInterfacePerProtocol.get(agentInterface.protocolBinding); + if (!existing || agentInterface.protocolVersion === '1.0') { + bestInterfacePerProtocol.set(agentInterface.protocolBinding, agentInterface); } } + const transportsByPreference = [ ...(this.options.preferredTransports ?? []), ...interfaces.map((i) => i.protocolBinding), ]; - for (const transport of transportsByPreference) { - const url = urlsPerAgentTransports.get(transport); - const factory = this.transportsByName.get(transport); - if (factory && url) { - return new Client( - await factory.create(url, agentCard), - agentCard, - this.options.clientConfig - ); + for (const transportName of transportsByPreference) { + const selectedInterface = bestInterfacePerProtocol.get(transportName); + const factory = this.transportsByName.get(transportName); + if (factory && selectedInterface) { + let transport = await factory.create(selectedInterface.url, agentCard); + + // If the agent interface declares a default tenant, wrap the transport + // so the tenant is automatically applied to all requests. + if (selectedInterface.tenant) { + transport = new TenantTransportDecorator(transport, selectedInterface.tenant); + } + + return new Client(transport, agentCard, this.options.clientConfig); } } throw new Error( diff --git a/src/client/index.ts b/src/client/index.ts index ea1ea5f2..91c931be 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -10,6 +10,7 @@ export { } from './card-resolver.js'; export { Client, type ClientConfig, type RequestOptions } from './multitransport-client.js'; export type { Transport, TransportFactory } from './transports/transport.js'; +export { TenantTransportDecorator } from './transports/tenant_transport_decorator.js'; export { ClientFactory, ClientFactoryOptions } from './factory.js'; export { JsonRpcTransportFactory } from './transports/json_rpc_transport.js'; export { RestTransportFactory } from './transports/rest_transport.js'; diff --git a/src/client/multitransport-client.ts b/src/client/multitransport-client.ts index 82f1eda3..51fc7ff6 100644 --- a/src/client/multitransport-client.ts +++ b/src/client/multitransport-client.ts @@ -77,13 +77,17 @@ export class Client { /** * If the current agent card supports the extended feature, it will try to fetch the extended agent card from the server, * Otherwise it will return the current agent card value. + * + * When a default tenant is configured (via `TenantTransportDecorator`, wired + * automatically by `ClientFactory` from `AgentInterface.tenant`), the tenant + * is applied to the request transparently. */ async getAgentCard(options?: RequestOptions): Promise { if (this.agentCard.capabilities?.extendedAgentCard) { this.agentCard = await this.executeWithInterceptors( { method: 'getAgentCard' }, options, - (_, options) => this.transport.getExtendedAgentCard(options) + (_, options) => this.transport.getExtendedAgentCard({ tenant: '' }, options) ); } return this.agentCard; diff --git a/src/client/transports/grpc/grpc_transport.ts b/src/client/transports/grpc/grpc_transport.ts index 3b13f88d..9aed2c57 100644 --- a/src/client/transports/grpc/grpc_transport.ts +++ b/src/client/transports/grpc/grpc_transport.ts @@ -69,10 +69,13 @@ export class GrpcTransport implements Transport { return PROTOCOL_NAME; } - async getExtendedAgentCard(options?: RequestOptions): Promise { + async getExtendedAgentCard( + params: GetExtendedAgentCardRequest, + options?: RequestOptions + ): Promise { const rpcResponse = await this._sendGrpcRequest( 'getExtendedAgentCard', - { tenant: '' }, + params, options, this.grpcClient.getExtendedAgentCard.bind(this.grpcClient) ); diff --git a/src/client/transports/json_rpc_transport.ts b/src/client/transports/json_rpc_transport.ts index fdf70cd2..6d8af770 100644 --- a/src/client/transports/json_rpc_transport.ts +++ b/src/client/transports/json_rpc_transport.ts @@ -17,6 +17,7 @@ import { Transport, TransportFactory } from './transport.js'; import { CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, MessageFns, SendMessageRequest, SubscribeToTaskRequest, @@ -51,12 +52,15 @@ export class JsonRpcTransport implements Transport { return PROTOCOL_NAME; } - async getExtendedAgentCard(options?: RequestOptions): Promise { - const rpcResponse = await this._sendRpcRequest( + async getExtendedAgentCard( + params: GetExtendedAgentCardRequest, + options?: RequestOptions + ): Promise { + const rpcResponse = await this._sendRpcRequest( 'GetExtendedAgentCard', - undefined, + params, options, - undefined + GetExtendedAgentCardRequest ); return AgentCard.fromJSON(rpcResponse.result); } diff --git a/src/client/transports/rest_transport.ts b/src/client/transports/rest_transport.ts index 5f0724e4..4eef8c92 100644 --- a/src/client/transports/rest_transport.ts +++ b/src/client/transports/rest_transport.ts @@ -20,6 +20,7 @@ import { AgentCard, CancelTaskRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, @@ -60,14 +61,22 @@ export class RestTransport implements Transport { this.customFetchImpl = options.fetchImpl; } + private _buildPath(path: string, tenant?: string): string { + return tenant ? '/' + encodeURIComponent(tenant) + path : path; + } + get protocolName(): string { return PROTOCOL_NAME; } - async getExtendedAgentCard(options?: RequestOptions): Promise { + async getExtendedAgentCard( + params: GetExtendedAgentCardRequest, + options?: RequestOptions + ): Promise { + const path = this._buildPath('/extendedAgentCard', params.tenant); const response = await this._sendRequest( 'GET', - '/extendedAgentCard', + path, undefined, options, undefined, @@ -81,9 +90,10 @@ export class RestTransport implements Transport { options?: RequestOptions ): Promise { const requestBody = params; + const path = this._buildPath('/message:send', params.tenant); const response = await this._sendRequest( 'POST', - '/message:send', + path, requestBody, options, SendMessageRequest, @@ -97,24 +107,22 @@ export class RestTransport implements Transport { options?: RequestOptions ): AsyncGenerator { const requestBody = SendMessageRequest.toJSON(params); - yield* this._sendStreamingRequest('/message:stream', requestBody, options); + const path = this._buildPath('/message:stream', params.tenant); + yield* this._sendStreamingRequest(path, requestBody, options); } async createTaskPushNotificationConfig( params: TaskPushNotificationConfig, options?: RequestOptions ): Promise { - const response = await this._sendRequest< - TaskPushNotificationConfig, - TaskPushNotificationConfig - >( - 'POST', + const path = this._buildPath( `/tasks/${encodeURIComponent(params.taskId)}/pushNotificationConfigs`, - params, - options, + params.tenant + ); + const response = await this._sendRequest< TaskPushNotificationConfig, TaskPushNotificationConfig - ); + >('POST', path, params, options, TaskPushNotificationConfig, TaskPushNotificationConfig); return response; } @@ -122,9 +130,15 @@ export class RestTransport implements Transport { params: GetTaskPushNotificationConfigRequest, options?: RequestOptions ): Promise { - const response = await this._sendRequest( + const path = this._buildPath( + `/tasks/${encodeURIComponent(params.taskId)}/pushNotificationConfigs/${encodeURIComponent( + params.id + )}`, + params.tenant + ); + const response = await this._sendRequest( 'GET', - `/tasks/${params.taskId}/pushNotificationConfigs/${params.id}`, + path, undefined, options, undefined, @@ -137,9 +151,13 @@ export class RestTransport implements Transport { params: ListTaskPushNotificationConfigsRequest, options?: RequestOptions ): Promise { - const response = await this._sendRequest( + const path = this._buildPath( + `/tasks/${encodeURIComponent(params.taskId)}/pushNotificationConfigs`, + params.tenant + ); + const response = await this._sendRequest( 'GET', - `/tasks/${params.taskId}/pushNotificationConfigs`, + path, undefined, options, undefined, @@ -152,24 +170,26 @@ export class RestTransport implements Transport { params: DeleteTaskPushNotificationConfigRequest, options?: RequestOptions ): Promise { - await this._sendRequest( - 'DELETE', - `/tasks/${params.taskId}/pushNotificationConfigs/${params.id}`, - undefined, - options, - undefined, - undefined + const path = this._buildPath( + `/tasks/${encodeURIComponent(params.taskId)}/pushNotificationConfigs/${encodeURIComponent( + params.id + )}`, + params.tenant ); + await this._sendRequest('DELETE', path, undefined, options, undefined, undefined); } async getTask(params: GetTaskRequest, options?: RequestOptions): Promise { const queryParams = new URLSearchParams(); if (params.historyLength !== undefined) { - queryParams.set('historyLength', String(params.historyLength)); + queryParams.set('historyLength', params.historyLength.toString()); } const queryString = queryParams.toString(); - const path = `/tasks/${params.id}${queryString ? `?${queryString}` : ''}`; - const response = await this._sendRequest( + const path = this._buildPath( + `/tasks/${encodeURIComponent(params.id)}${queryString ? `?${queryString}` : ''}`, + params.tenant + ); + const response = await this._sendRequest( 'GET', path, undefined, @@ -181,9 +201,10 @@ export class RestTransport implements Transport { } async cancelTask(params: CancelTaskRequest, options?: RequestOptions): Promise { - const response = await this._sendRequest( + const path = this._buildPath(`/tasks/${encodeURIComponent(params.id)}:cancel`, params.tenant); + const response = await this._sendRequest( 'POST', - `/tasks/${params.id}:cancel`, + path, undefined, options, undefined, @@ -194,7 +215,6 @@ export class RestTransport implements Transport { async listTasks(params: ListTasksRequest, options?: RequestOptions): Promise { const queryParams = new URLSearchParams(); - if (params.tenant) queryParams.set('tenant', params.tenant); if (params.contextId) queryParams.set('contextId', params.contextId); if (params.status !== undefined && params.status !== TaskState.TASK_STATE_UNSPECIFIED) { queryParams.set('status', taskStateToJSON(params.status)); @@ -209,9 +229,9 @@ export class RestTransport implements Transport { queryParams.set('includeArtifacts', String(params.includeArtifacts)); const queryString = queryParams.toString(); - const path = `/tasks${queryString ? `?${queryString}` : ''}`; + const path = this._buildPath(`/tasks${queryString ? `?${queryString}` : ''}`, params.tenant); - const response = await this._sendRequest( + const response = await this._sendRequest( 'GET', path, undefined, @@ -226,7 +246,11 @@ export class RestTransport implements Transport { params: SubscribeToTaskRequest, options?: RequestOptions ): AsyncGenerator { - yield* this._sendStreamingRequest(`/tasks/${params.id}:subscribe`, undefined, options); + const path = this._buildPath( + `/tasks/${encodeURIComponent(params.id)}:subscribe`, + params.tenant + ); + yield* this._sendStreamingRequest(path, undefined, options); } private _fetch(...args: Parameters): ReturnType { diff --git a/src/client/transports/tenant_transport_decorator.ts b/src/client/transports/tenant_transport_decorator.ts new file mode 100644 index 00000000..620a48a1 --- /dev/null +++ b/src/client/transports/tenant_transport_decorator.ts @@ -0,0 +1,141 @@ +import { + AgentCard, + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, + SendMessageRequest, + StreamResponse, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, +} from '../../index.js'; +import { RequestOptions } from '../multitransport-client.js'; +import { Transport } from './transport.js'; +import { SendMessageResult } from '../../index.js'; + +/** + * A transport decorator that attaches a default tenant to all requests. + * + * When an `AgentInterface` declares a `tenant` value (per spec Section 4.4.6), + * this decorator ensures every outbound request carries that tenant unless the + * caller has already specified one. This mirrors the behavior of the Python SDK's + * `TenantTransportDecorator`. + * + * The factory wires this decorator automatically when `AgentInterface.tenant` is + * non-empty, so callers do not need to manually set tenant on every request. + */ +export class TenantTransportDecorator implements Transport { + constructor( + private readonly base: Transport, + private readonly defaultTenant: string + ) {} + + get protocolName(): string { + return this.base.protocolName; + } + + /** + * Returns the request tenant if non-empty, otherwise falls back to the default. + */ + private _resolveTenant(tenant: string | undefined): string { + return tenant || this.defaultTenant; + } + + async getExtendedAgentCard( + params: GetExtendedAgentCardRequest, + options?: RequestOptions + ): Promise { + return this.base.getExtendedAgentCard( + { ...params, tenant: this._resolveTenant(params.tenant) }, + options + ); + } + + async sendMessage( + params: SendMessageRequest, + options?: RequestOptions + ): Promise { + return this.base.sendMessage( + { ...params, tenant: this._resolveTenant(params.tenant) }, + options + ); + } + + async *sendMessageStream( + params: SendMessageRequest, + options?: RequestOptions + ): AsyncGenerator { + yield* this.base.sendMessageStream( + { ...params, tenant: this._resolveTenant(params.tenant) }, + options + ); + } + + async getTask(params: GetTaskRequest, options?: RequestOptions): Promise { + return this.base.getTask({ ...params, tenant: this._resolveTenant(params.tenant) }, options); + } + + async cancelTask(params: CancelTaskRequest, options?: RequestOptions): Promise { + return this.base.cancelTask({ ...params, tenant: this._resolveTenant(params.tenant) }, options); + } + + async listTasks(params: ListTasksRequest, options?: RequestOptions): Promise { + return this.base.listTasks({ ...params, tenant: this._resolveTenant(params.tenant) }, options); + } + + async createTaskPushNotificationConfig( + params: TaskPushNotificationConfig, + options?: RequestOptions + ): Promise { + return this.base.createTaskPushNotificationConfig( + { ...params, tenant: this._resolveTenant(params.tenant) }, + options + ); + } + + async getTaskPushNotificationConfig( + params: GetTaskPushNotificationConfigRequest, + options?: RequestOptions + ): Promise { + return this.base.getTaskPushNotificationConfig( + { ...params, tenant: this._resolveTenant(params.tenant) }, + options + ); + } + + async listTaskPushNotificationConfig( + params: ListTaskPushNotificationConfigsRequest, + options?: RequestOptions + ): Promise { + return this.base.listTaskPushNotificationConfig( + { ...params, tenant: this._resolveTenant(params.tenant) }, + options + ); + } + + async deleteTaskPushNotificationConfig( + params: DeleteTaskPushNotificationConfigRequest, + options?: RequestOptions + ): Promise { + return this.base.deleteTaskPushNotificationConfig( + { ...params, tenant: this._resolveTenant(params.tenant) }, + options + ); + } + + async *resubscribeTask( + params: SubscribeToTaskRequest, + options?: RequestOptions + ): AsyncGenerator { + yield* this.base.resubscribeTask( + { ...params, tenant: this._resolveTenant(params.tenant) }, + options + ); + } +} diff --git a/src/client/transports/transport.ts b/src/client/transports/transport.ts index 8685e1c8..2273b701 100644 --- a/src/client/transports/transport.ts +++ b/src/client/transports/transport.ts @@ -9,6 +9,7 @@ import { ListTaskPushNotificationConfigsResponse, DeleteTaskPushNotificationConfigRequest, GetTaskRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, SubscribeToTaskRequest, SendMessageResult, @@ -20,7 +21,10 @@ import { RequestOptions } from '../multitransport-client.js'; export interface Transport { get protocolName(): string; - getExtendedAgentCard(options?: RequestOptions): Promise; + getExtendedAgentCard( + params: GetExtendedAgentCardRequest, + options?: RequestOptions + ): Promise; sendMessage(params: SendMessageRequest, options?: RequestOptions): Promise; diff --git a/src/server/agent_execution/agent_executor.ts b/src/server/agent_execution/agent_executor.ts index 8477bdc3..4c7f2b68 100644 --- a/src/server/agent_execution/agent_executor.ts +++ b/src/server/agent_execution/agent_executor.ts @@ -4,6 +4,11 @@ import { RequestContext } from './request_context.js'; export interface AgentExecutor { /** * Executes the agent logic based on the request context and publishes events. + * + * In multi-tenant deployments, the tenant identifier is available via + * `requestContext.context.tenant`. Implementations MAY use this to scope + * agent behavior or data access by tenant. + * * @param requestContext The context of the current request. * @param eventBus The bus to publish execution events to. */ diff --git a/src/server/context.ts b/src/server/context.ts index 37c17c66..fce82392 100644 --- a/src/server/context.ts +++ b/src/server/context.ts @@ -4,11 +4,17 @@ import { User } from './authentication/user.js'; export class ServerCallContext { private readonly _requestedExtensions?: Extensions; private readonly _user?: User; + private readonly _tenant?: string; private _activatedExtensions?: Extensions; - constructor(requestedExtensions?: Extensions, user?: User) { + constructor(requestedExtensions?: Extensions, user?: User, tenant?: string) { this._requestedExtensions = requestedExtensions; this._user = user; + this._tenant = tenant; + } + + get tenant(): string | undefined { + return this._tenant; } get user(): User | undefined { diff --git a/src/server/express/rest_handler.ts b/src/server/express/rest_handler.ts index ebb56c3c..7ea5227c 100644 --- a/src/server/express/rest_handler.ts +++ b/src/server/express/rest_handler.ts @@ -108,16 +108,19 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { /** * Builds a ServerCallContext from the Express request. - * Extracts protocol extensions from headers and builds user from request. + * Extracts protocol extensions from headers, builds user from request, + * and extracts tenant from the URL path parameter if present. * * @param req - Express request object - * @returns ServerCallContext with requested extensions and authenticated user + * @returns ServerCallContext with requested extensions, authenticated user, and tenant */ const buildContext = async (req: Request): Promise => { const user = await options.userBuilder(req); + const tenant = (req.params.tenant as string) || undefined; return new ServerCallContext( Extensions.parseServiceParameter(req.header(HTTP_EXTENSION_HEADER)), - user + user, + tenant ); }; @@ -260,6 +263,71 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { // Route Handlers // ============================================================================ + /** + * Middleware that resolves tenant from the URL path parameter and normalizes + * it into the request so downstream handlers don't need to deal with tenant + * resolution at all. + * + * For tenant-prefixed routes (`/:tenant/...`), the path tenant is the + * canonical source (per spec: "provided as a path parameter"). If the + * request body or query string also carries a tenant that differs, a warning + * is logged and the path tenant wins. + * + * The resolved tenant is written to: + * - `req.body.tenant` for POST / PUT / DELETE requests that may carry a JSON body + * - `req.query.tenant` for GET requests that use query parameters + * + * Non-tenant-prefixed routes pass through unchanged. + */ + const tenantMiddleware = (req: Request, _res: Response, next: () => void): void => { + const pathTenant = req.params.tenant as string | undefined; + if (!pathTenant) { + next(); + return; + } + + // Detect conflict with body tenant (POST / PUT / DELETE with JSON body) + const bodyTenant = req.body?.tenant as string | undefined; + if (bodyTenant && bodyTenant !== pathTenant) { + console.warn( + `Tenant mismatch: URL path tenant "${pathTenant}" differs from request body ` + + `tenant "${bodyTenant}". Using path tenant as the canonical value.` + ); + } + + // Detect conflict with query tenant (GET) + const queryTenant = req.query?.tenant as string | undefined; + if (queryTenant && queryTenant !== pathTenant) { + console.warn( + `Tenant mismatch: URL path tenant "${pathTenant}" differs from query param ` + + `tenant "${queryTenant}". Using path tenant as the canonical value.` + ); + } + + // Normalize: write path tenant into both body and query so handlers can + // read it from whichever source they naturally consume. + if (req.body) { + req.body.tenant = pathTenant; + } + (req.query as Record).tenant = pathTenant; + + next(); + }; + + /** + * Helper to register routes with and without optional tenant prefix. + * Tenant-prefixed routes get `tenantMiddleware` applied automatically, + * so individual handlers never need to resolve tenant themselves. + */ + const registerRoute = ( + method: 'get' | 'post' | 'delete' | 'put', + path: string, + handler: AsyncRouteHandler + ) => { + router[method](path, asyncHandler(handler)); + router[method](`/:tenant${path}`, tenantMiddleware, asyncHandler(handler)); + }; + /** * GET /extendedAgentCard * @@ -268,14 +336,14 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 200 OK with agent card * @returns 500 Internal Server Error on failure */ - router.get( - '/extendedAgentCard', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const result = await restTransportHandler.getAuthenticatedExtendedAgentCard(context); - sendResponse(res, HTTP_STATUS.OK, context, result, AgentCard); - }) - ); + registerRoute('get', '/extendedAgentCard', async (req, res) => { + const context = await buildContext(req); + const result = await restTransportHandler.getAuthenticatedExtendedAgentCard( + { tenant: (req.query.tenant as string) || '' }, + context + ); + sendResponse(res, HTTP_STATUS.OK, context, result, AgentCard); + }); /** * POST /message:send @@ -288,22 +356,19 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 201 Created with RestMessage or RestTask * @returns 400 Bad Request if message is invalid */ - router.post( - '/message\\:send', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const params = SendMessageRequest.fromJSON(req.body); - const result = await restTransportHandler.sendMessage(params, context); - const protoResult = ToProto.messageSendResult(result); - sendResponse( - res, - HTTP_STATUS.CREATED, - context, - protoResult, - SendMessageResponse - ); - }) - ); + registerRoute('post', '/message\\:send', async (req, res) => { + const context = await buildContext(req); + const params = SendMessageRequest.fromJSON(req.body); + const result = await restTransportHandler.sendMessage(params, context); + const protoResult = ToProto.messageSendResult(result); + sendResponse( + res, + HTTP_STATUS.CREATED, + context, + protoResult, + SendMessageResponse + ); + }); /** * POST /message:stream @@ -317,15 +382,12 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 400 Bad Request if message is invalid * @returns 501 Not Implemented if streaming not supported */ - router.post( - '/message\\:stream', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const params = SendMessageRequest.fromJSON(req.body); - const stream = await restTransportHandler.sendMessageStream(params, context); - await sendStreamResponse(res, stream, context); - }) - ); + registerRoute('post', '/message\\:stream', async (req, res) => { + const context = await buildContext(req); + const params = SendMessageRequest.fromJSON(req.body); + const stream = await restTransportHandler.sendMessageStream(params, context); + await sendStreamResponse(res, stream, context); + }); /** * GET /tasks/:taskId @@ -338,18 +400,16 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 400 Bad Request if historyLength is invalid * @returns 404 Not Found if task doesn't exist */ - router.get( - '/tasks/:taskId', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const result = await restTransportHandler.getTask( - req.params.taskId, - context, - req.query.historyLength - ); - sendResponse(res, HTTP_STATUS.OK, context, result, Task); - }) - ); + registerRoute('get', '/tasks/:taskId', async (req, res) => { + const context = await buildContext(req); + const result = await restTransportHandler.getTask( + req.params.taskId, + context, + req.query.historyLength, + (req.query.tenant as string) || '' + ); + sendResponse(res, HTTP_STATUS.OK, context, result, Task); + }); /** * POST /tasks/:taskId:cancel @@ -362,14 +422,15 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 404 Not Found if task doesn't exist * @returns 409 Conflict if task cannot be canceled */ - router.post( - '/tasks/:taskId\\:cancel', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const result = await restTransportHandler.cancelTask(req.params.taskId, context); - sendResponse(res, HTTP_STATUS.ACCEPTED, context, result, Task); - }) - ); + registerRoute('post', '/tasks/:taskId\\:cancel', async (req, res) => { + const context = await buildContext(req); + const result = await restTransportHandler.cancelTask( + req.params.taskId, + context, + (req.query.tenant as string) || '' + ); + sendResponse(res, HTTP_STATUS.ACCEPTED, context, result, Task); + }); /** * GET /tasks @@ -379,14 +440,11 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 200 OK with ListTasksResponse * @returns 400 Bad Request if filter or pageSize is invalid */ - router.get( - '/tasks', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const result = await restTransportHandler.listTasks(req.query, context); - sendResponse(res, HTTP_STATUS.OK, context, result, ListTasksResponse); - }) - ); + registerRoute('get', '/tasks', async (req, res) => { + const context = await buildContext(req); + const result = await restTransportHandler.listTasks(req.query, context); + sendResponse(res, HTTP_STATUS.OK, context, result, ListTasksResponse); + }); /** * POST /tasks/:taskId:subscribe @@ -399,14 +457,15 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 404 Not Found if task doesn't exist * @returns 501 Not Implemented if streaming not supported */ - router.post( - '/tasks/:taskId\\:subscribe', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const stream = await restTransportHandler.resubscribe(req.params.taskId, context); - await sendStreamResponse(res, stream, context); - }) - ); + registerRoute('post', '/tasks/:taskId\\:subscribe', async (req, res) => { + const context = await buildContext(req); + const stream = await restTransportHandler.resubscribe( + req.params.taskId, + context, + (req.query.tenant as string) || '' + ); + await sendStreamResponse(res, stream, context); + }); /** * POST /tasks/:taskId/pushNotificationConfigs @@ -419,21 +478,18 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 201 Created with TaskPushNotificationConfig * @returns 501 Not Implemented if push notifications not supported */ - router.post( - '/tasks/:taskId/pushNotificationConfigs', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const params = TaskPushNotificationConfig.fromJSON(req.body); - const result = await restTransportHandler.createTaskPushNotificationConfig(params, context); - sendResponse( - res, - HTTP_STATUS.CREATED, - context, - result, - TaskPushNotificationConfig - ); - }) - ); + registerRoute('post', '/tasks/:taskId/pushNotificationConfigs', async (req, res) => { + const context = await buildContext(req); + const params = TaskPushNotificationConfig.fromJSON(req.body); + const result = await restTransportHandler.createTaskPushNotificationConfig(params, context); + sendResponse( + res, + HTTP_STATUS.CREATED, + context, + result, + TaskPushNotificationConfig + ); + }); /** * GET /tasks/:taskId/pushNotificationConfigs @@ -444,23 +500,21 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 200 OK with array of TaskPushNotificationConfig * @returns 404 Not Found if task doesn't exist */ - router.get( - '/tasks/:taskId/pushNotificationConfigs', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const result = await restTransportHandler.listTaskPushNotificationConfigs( - req.params.taskId, - context - ); - sendResponse( - res, - HTTP_STATUS.OK, - context, - result, - ListTaskPushNotificationConfigsResponse - ); - }) - ); + registerRoute('get', '/tasks/:taskId/pushNotificationConfigs', async (req, res) => { + const context = await buildContext(req); + const result = await restTransportHandler.listTaskPushNotificationConfigs( + req.params.taskId, + context, + (req.query.tenant as string) || '' + ); + sendResponse( + res, + HTTP_STATUS.OK, + context, + result, + ListTaskPushNotificationConfigsResponse + ); + }); /** * GET /tasks/:taskId/pushNotificationConfigs/:configId @@ -472,24 +526,22 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 200 OK with TaskPushNotificationConfig * @returns 404 Not Found if task or config doesn't exist */ - router.get( - '/tasks/:taskId/pushNotificationConfigs/:configId', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - const result = await restTransportHandler.getTaskPushNotificationConfig( - req.params.taskId, - req.params.configId, - context - ); - sendResponse( - res, - HTTP_STATUS.OK, - context, - result, - TaskPushNotificationConfig - ); - }) - ); + registerRoute('get', '/tasks/:taskId/pushNotificationConfigs/:configId', async (req, res) => { + const context = await buildContext(req); + const result = await restTransportHandler.getTaskPushNotificationConfig( + req.params.taskId, + req.params.configId, + context, + (req.query.tenant as string) || '' + ); + sendResponse( + res, + HTTP_STATUS.OK, + context, + result, + TaskPushNotificationConfig + ); + }); /** * DELETE /tasks/:taskId/pushNotificationConfigs/:configId @@ -501,18 +553,16 @@ export function restHandler(options: RestHandlerOptions): RequestHandler { * @returns 204 No Content on success * @returns 404 Not Found if task or config doesn't exist */ - router.delete( - '/tasks/:taskId/pushNotificationConfigs/:configId', - asyncHandler(async (req, res) => { - const context = await buildContext(req); - await restTransportHandler.deleteTaskPushNotificationConfig( - req.params.taskId, - req.params.configId, - context - ); - sendResponse(res, HTTP_STATUS.NO_CONTENT, context); - }) - ); + registerRoute('delete', '/tasks/:taskId/pushNotificationConfigs/:configId', async (req, res) => { + const context = await buildContext(req); + await restTransportHandler.deleteTaskPushNotificationConfig( + req.params.taskId, + req.params.configId, + context, + (req.query.tenant as string) || '' + ); + sendResponse(res, HTTP_STATUS.NO_CONTENT, context); + }); return router; } diff --git a/src/server/grpc/grpc_service.ts b/src/server/grpc/grpc_service.ts index aa285b10..1839001e 100644 --- a/src/server/grpc/grpc_service.ts +++ b/src/server/grpc/grpc_service.ts @@ -64,7 +64,8 @@ export function grpcService(options: GrpcServiceOptions): A2AServiceServer { const requestHandler = options.requestHandler; /** - * Helper to wrap Unary calls with common logic (context, metadata, error handling) + * Helper to wrap Unary calls with common logic (context, metadata, error handling). + * Extracts tenant from the request if available and enriches the context. */ const wrapUnaryWithConverter = async ( call: grpc.ServerUnaryCall, @@ -73,7 +74,7 @@ export function grpcService(options: GrpcServiceOptions): A2AServiceServer { converter: (res: TResult) => TRes ) => { try { - const context = await buildContext(call, options.userBuilder); + const context = await _buildContext(call, options.userBuilder); const result = await handler(call.request, context); call.sendMetadata(buildMetadata(context)); callback(null, converter(result)); @@ -91,14 +92,15 @@ export function grpcService(options: GrpcServiceOptions): A2AServiceServer { }; /** - * Helper to wrap Streaming calls with common logic (context, metadata, error handling) + * Helper to wrap Streaming calls with common logic (context, metadata, error handling). + * Extracts tenant from the request if available and enriches the context. */ const wrapStreaming = async ( call: grpc.ServerWritableStream, handler: (req: TReq, ctx: ServerCallContext) => AsyncGenerator ) => { try { - const context = await buildContext(call, options.userBuilder); + const context = await _buildContext(call, options.userBuilder); const stream = await handler(call.request, context); call.sendMetadata(buildMetadata(context)); for await (const responsePart of stream) { @@ -202,8 +204,8 @@ export function grpcService(options: GrpcServiceOptions): A2AServiceServer { call: grpc.ServerUnaryCall, callback: grpc.sendUnaryData ): Promise { - return wrapUnary(call, callback, (_params, context) => - requestHandler.getAuthenticatedExtendedAgentCard(context) + return wrapUnary(call, callback, (params, context) => + requestHandler.getAuthenticatedExtendedAgentCard(params, context) ); }, listTasks( @@ -241,15 +243,15 @@ const mapToError = (error: unknown): Partial => { }; }; -const buildContext = async ( +const _buildContext = async ( call: grpc.ServerUnaryCall | grpc.ServerWritableStream, userBuilder: UserBuilder ): Promise => { const user = await userBuilder(call); const extensionHeaders = call.metadata.get(HTTP_EXTENSION_HEADER); const extensionString = extensionHeaders.map((v) => v.toString()).join(','); - - return new ServerCallContext(Extensions.parseServiceParameter(extensionString), user); + const tenant = (call.request as Record)?.tenant as string | undefined; + return new ServerCallContext(Extensions.parseServiceParameter(extensionString), user, tenant); }; const buildMetadata = (context: ServerCallContext): grpc.Metadata => { diff --git a/src/server/push_notification/push_notification_store.ts b/src/server/push_notification/push_notification_store.ts index f96b7bdd..7b88ced3 100644 --- a/src/server/push_notification/push_notification_store.ts +++ b/src/server/push_notification/push_notification_store.ts @@ -1,6 +1,12 @@ import { TaskPushNotificationConfig } from '../../index.js'; import { ServerCallContext } from '../context.js'; +/** + * Interface for push notification configuration storage. + * + * Implementations SHOULD use `context.tenant` (when present) to scope data access, + * ensuring push notification configs from one tenant are not accessible to another. + */ export interface PushNotificationStore { save( taskId: string, @@ -11,15 +17,44 @@ export interface PushNotificationStore { delete(taskId: string, context: ServerCallContext, configId?: string): Promise; } +/** + * In-memory push notification config store with tenant-scoped data isolation. + * A nested Map structure (tenant -> taskId -> configs[]) is used so that tenant + * scoping is structural, imposing no restrictions on task ID format. + */ export class InMemoryPushNotificationStore implements PushNotificationStore { - private store: Map = new Map(); + // Outer map: tenant key ('' for global/no-tenant) -> inner map of taskId -> configs + private store: Map> = new Map(); + + private _tenantKey(context: ServerCallContext): string { + return context.tenant ?? ''; + } + + private _getTenantBucket( + context: ServerCallContext + ): Map | undefined { + return this.store.get(this._tenantKey(context)); + } + + private _getOrCreateTenantBucket( + context: ServerCallContext + ): Map { + const key = this._tenantKey(context); + let bucket = this.store.get(key); + if (!bucket) { + bucket = new Map(); + this.store.set(key, bucket); + } + return bucket; + } async save( taskId: string, - _context: ServerCallContext, + context: ServerCallContext, pushNotificationConfig: TaskPushNotificationConfig ): Promise { - const configs = this.store.get(taskId) || []; + const bucket = this._getOrCreateTenantBucket(context); + const configs = bucket.get(taskId) || []; // Set ID if it's not already set if (!pushNotificationConfig.id) { @@ -34,21 +69,26 @@ export class InMemoryPushNotificationStore implements PushNotificationStore { // Add the new/updated config configs.push(pushNotificationConfig); - this.store.set(taskId, configs); + bucket.set(taskId, configs); } - async load(taskId: string, _context: ServerCallContext): Promise { - const configs = this.store.get(taskId); + async load(taskId: string, context: ServerCallContext): Promise { + const configs = this._getTenantBucket(context)?.get(taskId); return configs || []; } - async delete(taskId: string, _context: ServerCallContext, configId?: string): Promise { + async delete(taskId: string, context: ServerCallContext, configId?: string): Promise { // If no configId is provided, use taskId as the configId (backward compatibility) if (configId === undefined) { configId = taskId; } - const configs = this.store.get(taskId); + const bucket = this._getTenantBucket(context); + if (!bucket) { + return; + } + + const configs = bucket.get(taskId); if (!configs) { return; } @@ -59,9 +99,7 @@ export class InMemoryPushNotificationStore implements PushNotificationStore { } if (configs.length === 0) { - this.store.delete(taskId); - } else { - this.store.set(taskId, configs); + bucket.delete(taskId); } } } diff --git a/src/server/request_handler/a2a_request_handler.ts b/src/server/request_handler/a2a_request_handler.ts index 76569771..968ee1b2 100644 --- a/src/server/request_handler/a2a_request_handler.ts +++ b/src/server/request_handler/a2a_request_handler.ts @@ -6,6 +6,7 @@ import { ListTaskPushNotificationConfigsRequest, GetTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, CancelTaskRequest, GetTaskRequest, SubscribeToTaskRequest, @@ -20,7 +21,10 @@ import { ServerCallContext } from '../context.js'; export interface A2ARequestHandler { getAgentCard(): Promise; - getAuthenticatedExtendedAgentCard(context: ServerCallContext): Promise; + getAuthenticatedExtendedAgentCard( + params: GetExtendedAgentCardRequest, + context: ServerCallContext + ): Promise; sendMessage(params: SendMessageRequest, context: ServerCallContext): Promise; diff --git a/src/server/request_handler/default_request_handler.ts b/src/server/request_handler/default_request_handler.ts index 83808ee1..3340a25b 100644 --- a/src/server/request_handler/default_request_handler.ts +++ b/src/server/request_handler/default_request_handler.ts @@ -21,6 +21,7 @@ import { SendMessageRequest, GetTaskRequest, CancelTaskRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, ListTaskPushNotificationConfigsRequest, DeleteTaskPushNotificationConfigRequest, @@ -55,6 +56,26 @@ import { ServerCallContext } from '../context.js'; import { DEFAULT_PAGE_SIZE } from '../../constants.js'; import { TERMINAL_STATE_LIST } from '../utils.js'; +/** + * Default implementation of the A2A request handler. + * + * ## Multi-Tenancy + * + * This handler supports multi-tenant deployments through the `tenant` field present + * on all request objects (per A2A spec Sections 3.1.x and 4.4.6). The tenant value + * flows through the system as follows: + * + * 1. **Transport layer** extracts tenant from the protocol-specific source: + * - REST: URL path prefix (`/:tenant/...`) + * - JSON-RPC: `params.tenant` in the request body + * - gRPC: `tenant` field in the request message + * + * 2. **`ServerCallContext.tenant`** carries the tenant to all downstream components, + * including `TaskStore`, `PushNotificationStore`, and `AgentExecutor`. + * + * 3. **`InMemoryTaskStore`** and **`InMemoryPushNotificationStore`** use `context.tenant` + * to scope data with composite keys (`{tenant}:{id}`), providing tenant isolation. + */ export class DefaultRequestHandler implements A2ARequestHandler { private readonly agentCard: AgentCard; private readonly taskStore: TaskStore; @@ -92,7 +113,10 @@ export class DefaultRequestHandler implements A2ARequestHandler { return this.agentCard; } - async getAuthenticatedExtendedAgentCard(context: ServerCallContext): Promise { + async getAuthenticatedExtendedAgentCard( + _params: GetExtendedAgentCardRequest, + context: ServerCallContext + ): Promise { if (!this.agentCard.capabilities?.extendedAgentCard) { throw new UnsupportedOperationError('Agent does not support authenticated extended card.'); } diff --git a/src/server/store.ts b/src/server/store.ts index f3f6511b..4518e27b 100644 --- a/src/server/store.ts +++ b/src/server/store.ts @@ -6,13 +6,18 @@ import { RequestMalformedError } from '../errors.js'; /** * Simplified interface for task storage providers. * Stores and retrieves the task. + * + * Implementations SHOULD use `context.tenant` (when present) to scope data access. + * Per spec Section 13.1, servers MUST ensure appropriate scope limitation based on + * the authenticated caller's authorization boundaries, which includes tenant isolation + * in multi-tenant deployments. */ export interface TaskStore { /** * Saves a task. * Overwrites existing data if the task ID exists. * @param task The task to save. - * @param context The context of the current call. + * @param context The context of the current call. Use `context.tenant` for tenant-scoped storage. * @returns A promise resolving when the save operation is complete. */ save(task: Task, context: ServerCallContext): Promise; @@ -20,7 +25,7 @@ export interface TaskStore { /** * Loads a task by task ID. * @param taskId The ID of the task to load. - * @param context The context of the current call. + * @param context The context of the current call. Use `context.tenant` for tenant-scoped lookups. * @returns A promise resolving to an object containing the Task, or undefined if not found. */ load(taskId: string, context: ServerCallContext): Promise; @@ -28,7 +33,7 @@ export interface TaskStore { /** * Lists tasks with filtering and pagination. * @param params Filtering and pagination parameters. - * @param context The context of the current call. + * @param context The context of the current call. Use `context.tenant` for tenant-scoped listing. */ list(params: ListTasksRequest, context: ServerCallContext): Promise; } @@ -37,25 +42,44 @@ export interface TaskStore { // InMemoryTaskStore // ======================== // -// Methods in InMemoryTaskStore accept ServerCallContext but do not use it. -// This is intentional to match the TaskStore interface. +// InMemoryTaskStore provides tenant-scoped data isolation using `context.tenant`. +// A nested Map structure (tenant -> taskId -> Task) is used so that tenant scoping +// is structural rather than key-convention based, imposing no restrictions on task ID format. -// Use Task directly for storage export class InMemoryTaskStore implements TaskStore { - private store: Map = new Map(); + // Outer map: tenant key ('' for global/no-tenant) -> inner map of taskId -> Task + private store: Map> = new Map(); - async load(taskId: string, _context: ServerCallContext): Promise { - const entry = this.store.get(taskId); + private _tenantKey(context: ServerCallContext): string { + return context.tenant ?? ''; + } + + private _getTenantBucket(context: ServerCallContext): Map | undefined { + return this.store.get(this._tenantKey(context)); + } + + private _getOrCreateTenantBucket(context: ServerCallContext): Map { + const key = this._tenantKey(context); + let bucket = this.store.get(key); + if (!bucket) { + bucket = new Map(); + this.store.set(key, bucket); + } + return bucket; + } + + async load(taskId: string, context: ServerCallContext): Promise { + const entry = this._getTenantBucket(context)?.get(taskId); // Return copies to prevent external mutation return entry ? { ...entry } : undefined; } - async save(task: Task, _context: ServerCallContext): Promise { + async save(task: Task, context: ServerCallContext): Promise { // Store copies to prevent internal mutation if caller reuses objects - this.store.set(task.id, { ...task }); + this._getOrCreateTenantBucket(context).set(task.id, { ...task }); } - async list(params: ListTasksRequest, _context: ServerCallContext): Promise { + async list(params: ListTasksRequest, context: ServerCallContext): Promise { const { contextId, status, @@ -66,7 +90,8 @@ export class InMemoryTaskStore implements TaskStore { includeArtifacts = false, } = params; - let tasks = Array.from(this.store.values()); + const bucket = this._getTenantBucket(context); + let tasks = bucket ? Array.from(bucket.values()) : []; // Filter by contextId if (contextId) { diff --git a/src/server/transports/jsonrpc/jsonrpc_transport_handler.ts b/src/server/transports/jsonrpc/jsonrpc_transport_handler.ts index 51649a30..e9134c2e 100644 --- a/src/server/transports/jsonrpc/jsonrpc_transport_handler.ts +++ b/src/server/transports/jsonrpc/jsonrpc_transport_handler.ts @@ -5,6 +5,7 @@ import { SendMessageRequest, SubscribeToTaskRequest, GetTaskRequest, + GetExtendedAgentCardRequest, CancelTaskRequest, TaskPushNotificationConfig, GetTaskPushNotificationConfigRequest, @@ -98,6 +99,15 @@ export class JsonRpcTransportHandler { throw new RequestMalformedError(`Invalid method parameters.`); } + // For JSON-RPC, tenant is inside the params body. Extract it and enrich the + // context so downstream components (stores, executors) can scope by tenant. + const paramsTenant = (rpcRequest.params as Record | undefined)?.tenant as + | string + | undefined; + if (paramsTenant && !context.tenant) { + context = new ServerCallContext(context.requestedExtensions, context.user, paramsTenant); + } + if (method === 'SendStreamingMessage' || method === 'SubscribeToTask') { const params = rpcRequest.params; const agentCard = await this.requestHandler.getAgentCard(); @@ -208,7 +218,10 @@ export class JsonRpcTransportHandler { break; case 'GetExtendedAgentCard': result = AgentCard.toJSON( - await this.requestHandler.getAuthenticatedExtendedAgentCard(context) + await this.requestHandler.getAuthenticatedExtendedAgentCard( + GetExtendedAgentCardRequest.fromJSON(rpcRequest.params), + context + ) ); break; default: diff --git a/src/server/transports/rest/rest_transport_handler.ts b/src/server/transports/rest/rest_transport_handler.ts index c3cb7446..113ab3cb 100644 --- a/src/server/transports/rest/rest_transport_handler.ts +++ b/src/server/transports/rest/rest_transport_handler.ts @@ -16,6 +16,7 @@ import { StreamResponse, GetTaskRequest, CancelTaskRequest, + GetExtendedAgentCardRequest, ListTasksRequest, ListTasksResponse, TaskState, @@ -119,8 +120,11 @@ export class RestTransportHandler { /** * Gets the authenticated extended agent card. */ - async getAuthenticatedExtendedAgentCard(context: ServerCallContext): Promise { - return this.requestHandler.getAuthenticatedExtendedAgentCard(context); + async getAuthenticatedExtendedAgentCard( + params: GetExtendedAgentCardRequest, + context: ServerCallContext + ): Promise { + return this.requestHandler.getAuthenticatedExtendedAgentCard(params, context); } /** @@ -166,9 +170,10 @@ export class RestTransportHandler { async getTask( taskId: string, context: ServerCallContext, - historyLength?: unknown + historyLength?: unknown, + tenant?: string ): Promise { - const params: GetTaskRequest = { id: taskId, historyLength: 0, tenant: '' }; + const params: GetTaskRequest = { id: taskId, historyLength: 0, tenant: tenant || '' }; if (historyLength !== undefined) { params.historyLength = this.parseHistoryLength(historyLength); } @@ -178,8 +183,8 @@ export class RestTransportHandler { /** * Cancels a task. */ - async cancelTask(taskId: string, context: ServerCallContext): Promise { - const params: CancelTaskRequest = { id: taskId, tenant: '', metadata: {} }; + async cancelTask(taskId: string, context: ServerCallContext, tenant?: string): Promise { + const params: CancelTaskRequest = { id: taskId, tenant: tenant || '', metadata: {} }; return this.requestHandler.cancelTask(params, context); } @@ -212,10 +217,11 @@ export class RestTransportHandler { */ async resubscribe( taskId: string, - context: ServerCallContext + context: ServerCallContext, + tenant?: string ): Promise> { await this.requireCapability('streaming'); - return this.requestHandler.resubscribe({ id: taskId, tenant: '' }, context); + return this.requestHandler.resubscribe({ id: taskId, tenant: tenant || '' }, context); } /** @@ -238,10 +244,11 @@ export class RestTransportHandler { */ async listTaskPushNotificationConfigs( taskId: string, - context: ServerCallContext + context: ServerCallContext, + tenant?: string ): Promise { const result = await this.requestHandler.listTaskPushNotificationConfigs( - { taskId, pageSize: 0, pageToken: '', tenant: '' }, + { taskId, pageSize: 0, pageToken: '', tenant: tenant || '' }, context ); return result; @@ -253,10 +260,11 @@ export class RestTransportHandler { async getTaskPushNotificationConfig( taskId: string, configId: string, - context: ServerCallContext + context: ServerCallContext, + tenant?: string ): Promise { const config = await this.requestHandler.getTaskPushNotificationConfig( - { taskId, id: configId, tenant: '' }, + { taskId, id: configId, tenant: tenant || '' }, context ); return config; @@ -268,10 +276,11 @@ export class RestTransportHandler { async deleteTaskPushNotificationConfig( taskId: string, configId: string, - context: ServerCallContext + context: ServerCallContext, + tenant?: string ): Promise { await this.requestHandler.deleteTaskPushNotificationConfig( - { taskId, id: configId, tenant: '' }, + { taskId, id: configId, tenant: tenant || '' }, context ); } diff --git a/test/client/factory.spec.ts b/test/client/factory.spec.ts index 0677372e..5f671f79 100644 --- a/test/client/factory.spec.ts +++ b/test/client/factory.spec.ts @@ -1,6 +1,7 @@ import { describe, it, beforeEach, expect, vi, Mock } from 'vitest'; import { ClientFactory, ClientFactoryOptions } from '../../src/client/factory.js'; import { Transport } from '../../src/client/transports/transport.js'; +import { TenantTransportDecorator } from '../../src/client/transports/tenant_transport_decorator.js'; import { AgentCard } from '../../src/index.js'; import { Client } from '../../src/client/multitransport-client.js'; import { CallInterceptor } from '../../src/client/interceptors.js'; @@ -299,6 +300,40 @@ describe('ClientFactory', () => { 'a2a/my-agent-card.json' ); }); + + it('should wrap transport with TenantTransportDecorator when interface has tenant', async () => { + agentCard.supportedInterfaces = [ + { + url: 'http://transport1.com', + protocolBinding: 'Transport1', + tenant: 'my-tenant', + protocolVersion: '1.0.0', + }, + ]; + const factory = new ClientFactory({ transports: [mockTransportFactory1] }); + + const client = await factory.createFromAgentCard(agentCard); + + expect(client).to.be.instanceOf(Client); + expect(client.transport).to.be.instanceOf(TenantTransportDecorator); + }); + + it('should NOT wrap transport with TenantTransportDecorator when interface has no tenant', async () => { + agentCard.supportedInterfaces = [ + { + url: 'http://transport1.com', + protocolBinding: 'Transport1', + tenant: '', + protocolVersion: '1.0.0', + }, + ]; + const factory = new ClientFactory({ transports: [mockTransportFactory1] }); + + const client = await factory.createFromAgentCard(agentCard); + + expect(client).to.be.instanceOf(Client); + expect(client.transport).not.to.be.instanceOf(TenantTransportDecorator); + }); }); describe('ClientFactoryOptions.createFrom', () => { diff --git a/test/client/multitransport-client.spec.ts b/test/client/multitransport-client.spec.ts index d2e7c7bd..f131ddbc 100644 --- a/test/client/multitransport-client.spec.ts +++ b/test/client/multitransport-client.spec.ts @@ -77,8 +77,10 @@ describe('Client', () => { }; client = new Client(transport, agentCardWithExtendedSupport); - let caughtOptions; - transport.getExtendedAgentCard.mockImplementation(async (options) => { + let caughtParams: unknown; + let caughtOptions: unknown; + transport.getExtendedAgentCard.mockImplementation(async (params, options) => { + caughtParams = params; caughtOptions = options; return extendedAgentCard; }); @@ -90,6 +92,7 @@ describe('Client', () => { expect(transport.getExtendedAgentCard).toHaveBeenCalledTimes(1); expect(result).to.equal(extendedAgentCard); + expect(caughtParams).to.deep.equal({ tenant: '' }); expect(caughtOptions).to.equal(expectedOptions); }); diff --git a/test/client/transports/grpc_transport.spec.ts b/test/client/transports/grpc_transport.spec.ts index 10883d66..b1310585 100644 --- a/test/client/transports/grpc_transport.spec.ts +++ b/test/client/transports/grpc_transport.spec.ts @@ -89,7 +89,7 @@ describe('GrpcTransport', () => { const mockCard = createMockAgentCard(); mockUnarySuccess(mockGrpcClient.getExtendedAgentCard as Mock, mockCard); - const result = await transport.getExtendedAgentCard(); + const result = await transport.getExtendedAgentCard({ tenant: '' }); expect(result).toEqual(mockCard); expect(mockGrpcClient.getExtendedAgentCard).toHaveBeenCalled(); diff --git a/test/client/transports/rest_transport.spec.ts b/test/client/transports/rest_transport.spec.ts index 046df52b..1ed37dbc 100644 --- a/test/client/transports/rest_transport.spec.ts +++ b/test/client/transports/rest_transport.spec.ts @@ -101,6 +101,19 @@ describe('RestTransport', () => { ); }); + it('should send message with tenant prefix successfully', async () => { + const messageParams = createMessageParams(); + messageParams.tenant = 'tenant1'; + const mockResponse = createMockProtoMessage(); + + mockFetch.mockResolvedValue(createRestResponse(mockResponse)); + + await transport.sendMessage(messageParams); + + const [url] = mockFetch.mock.calls[0]; + expect(url).to.equal(`${endpoint}/tenant1/message:send`); + }); + it('should correctly add the extension headers', async () => { const messageParams = createMessageParams(); const expectedExtensions = 'extension1,extension2'; @@ -141,6 +154,18 @@ describe('RestTransport', () => { expect(options?.method).to.equal('GET'); }); + it('should get task with tenant prefix successfully', async () => { + const taskId = 'task-123'; + const mockTask = createMockProtoTask(taskId); + + mockFetch.mockResolvedValue(createRestResponse(mockTask)); + + await transport.getTask({ id: taskId, tenant: 'tenant1', historyLength: 0 }); + + const [url] = mockFetch.mock.calls[0]; + expect(url).to.equal(`${endpoint}/tenant1/tasks/${taskId}?historyLength=0`); + }); + it('should pass historyLength as query parameter', async () => { const taskId = 'task-123'; const historyLength = 10; @@ -193,6 +218,44 @@ describe('RestTransport', () => { }); }); + describe('listTasks', () => { + it('should list tasks successfully', async () => { + const mockResponse = { tasks: [] as any[], nextPageToken: '', pageSize: 0, totalSize: 0 }; + mockFetch.mockResolvedValue(createRestResponse(mockResponse)); + + const result = await transport.listTasks({ + tenant: '', + contextId: '', + status: TaskState.TASK_STATE_UNSPECIFIED, + pageToken: '', + statusTimestampAfter: '', + }); + + expect(result).to.deep.equal(mockResponse); + expect(mockFetch).toHaveBeenCalledTimes(1); + + const [url, options] = mockFetch.mock.calls[0]; + expect(url).to.equal(`${endpoint}/tasks`); + expect(options?.method).to.equal('GET'); + }); + + it('should list tasks with tenant prefix successfully', async () => { + const mockResponse = { tasks: [] as any[], nextPageToken: '', pageSize: 0, totalSize: 0 }; + mockFetch.mockResolvedValue(createRestResponse(mockResponse)); + + await transport.listTasks({ + tenant: 'tenant1', + contextId: '', + status: TaskState.TASK_STATE_UNSPECIFIED, + pageToken: '', + statusTimestampAfter: '', + }); + + const [url] = mockFetch.mock.calls[0]; + expect(url).to.equal(`${endpoint}/tenant1/tasks`); + }); + }); + describe('getExtendedAgentCard', () => { it('should get extended agent card successfully', async () => { const mockCard: AgentCard = { @@ -227,7 +290,7 @@ describe('RestTransport', () => { mockFetch.mockResolvedValue(createRestResponse(AgentCard.toJSON(mockCard))); - const result = await transport.getExtendedAgentCard(); + const result = await transport.getExtendedAgentCard({ tenant: '' }); expect(result).toEqual(expect.objectContaining(mockCard)); expect(mockFetch).toHaveBeenCalledTimes(1); @@ -236,6 +299,45 @@ describe('RestTransport', () => { expect(url).to.equal(`${endpoint}/extendedAgentCard`); expect(options?.method).to.equal('GET'); }); + + it('should get extended agent card with tenant prefix', async () => { + const mockCard: AgentCard = { + name: 'Test Agent', + description: 'A test agent for testing', + capabilities: { + streaming: true, + pushNotifications: true, + extensions: [], + }, + skills: [], + defaultInputModes: ['text'], + defaultOutputModes: ['text'], + supportedInterfaces: [ + { + url: endpoint, + protocolBinding: 'HTTP+JSON', + tenant: 'my-tenant', + protocolVersion: '1.0.0', + }, + ], + version: '1.0.0', + provider: { + url: '', + organization: '', + }, + securityRequirements: [], + securitySchemes: {}, + documentationUrl: '', + signatures: [], + }; + + mockFetch.mockResolvedValue(createRestResponse(AgentCard.toJSON(mockCard))); + + await transport.getExtendedAgentCard({ tenant: 'my-tenant' }); + + const [url] = mockFetch.mock.calls[0]; + expect(url).to.equal(`${endpoint}/my-tenant/extendedAgentCard`); + }); }); describe('Push Notification Config', () => { diff --git a/test/client/transports/tenant_transport_decorator.spec.ts b/test/client/transports/tenant_transport_decorator.spec.ts new file mode 100644 index 00000000..b56463ba --- /dev/null +++ b/test/client/transports/tenant_transport_decorator.spec.ts @@ -0,0 +1,316 @@ +import { describe, it, beforeEach, expect, vi, Mock } from 'vitest'; +import { TenantTransportDecorator } from '../../../src/client/transports/tenant_transport_decorator.js'; +import { Transport } from '../../../src/client/transports/transport.js'; +import { SendMessageRequest } from '../../../src/types/pb/a2a.js'; + +/** Drains an async generator to completion. */ +async function drain(gen: AsyncGenerator): Promise { + while (!(await gen.next()).done) { + // consume all values + } +} + +describe('TenantTransportDecorator', () => { + const DEFAULT_TENANT = 'default-tenant'; + let mockTransport: Record, Mock> & { + protocolName: string; + }; + let decorator: TenantTransportDecorator; + + beforeEach(() => { + mockTransport = { + getExtendedAgentCard: vi.fn().mockResolvedValue({}), + sendMessage: vi.fn().mockResolvedValue({}), + sendMessageStream: vi.fn().mockReturnValue((async function* () {})()), + createTaskPushNotificationConfig: vi.fn().mockResolvedValue({}), + getTaskPushNotificationConfig: vi.fn().mockResolvedValue({}), + listTaskPushNotificationConfig: vi.fn().mockResolvedValue({ configs: [] }), + deleteTaskPushNotificationConfig: vi.fn().mockResolvedValue(undefined), + getTask: vi.fn().mockResolvedValue({}), + cancelTask: vi.fn().mockResolvedValue({}), + listTasks: vi.fn().mockResolvedValue({ tasks: [] }), + resubscribeTask: vi.fn().mockReturnValue((async function* () {})()), + protocolName: 'MockTransport', + }; + decorator = new TenantTransportDecorator(mockTransport, DEFAULT_TENANT); + }); + + it('should expose the base transport protocol name', () => { + expect(decorator.protocolName).to.equal('MockTransport'); + }); + + describe('default tenant application', () => { + it('should apply default tenant to sendMessage when tenant is empty', async () => { + await decorator.sendMessage({ + tenant: '', + message: undefined, + configuration: undefined, + metadata: {}, + }); + + const passedParams = mockTransport.sendMessage.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should preserve caller-specified tenant on sendMessage', async () => { + await decorator.sendMessage({ + tenant: 'custom-tenant', + message: undefined, + configuration: undefined, + metadata: {}, + }); + + const passedParams = mockTransport.sendMessage.mock.calls[0][0]; + expect(passedParams.tenant).to.equal('custom-tenant'); + }); + + it('should apply default tenant to getExtendedAgentCard when tenant is empty', async () => { + await decorator.getExtendedAgentCard({ tenant: '' }); + + const passedParams = mockTransport.getExtendedAgentCard.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should apply default tenant to getTask when tenant is empty', async () => { + await decorator.getTask({ id: 'task-1', tenant: '', historyLength: 0 }); + + const passedParams = mockTransport.getTask.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should preserve caller-specified tenant on getTask', async () => { + await decorator.getTask({ id: 'task-1', tenant: 'override', historyLength: 0 }); + + const passedParams = mockTransport.getTask.mock.calls[0][0]; + expect(passedParams.tenant).to.equal('override'); + }); + + it('should apply default tenant to cancelTask when tenant is empty', async () => { + await decorator.cancelTask({ id: 'task-1', tenant: '', metadata: {} }); + + const passedParams = mockTransport.cancelTask.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should apply default tenant to listTasks when tenant is empty', async () => { + await decorator.listTasks({ + tenant: '', + contextId: '', + status: undefined, + pageToken: '', + statusTimestampAfter: '', + }); + + const passedParams = mockTransport.listTasks.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should apply default tenant to createTaskPushNotificationConfig when tenant is empty', async () => { + await decorator.createTaskPushNotificationConfig({ + tenant: '', + id: 'config-1', + taskId: 'task-1', + url: 'https://example.com', + token: '', + authentication: undefined, + }); + + const passedParams = mockTransport.createTaskPushNotificationConfig.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should apply default tenant to getTaskPushNotificationConfig when tenant is empty', async () => { + await decorator.getTaskPushNotificationConfig({ id: 'cfg-1', taskId: 'task-1', tenant: '' }); + + const passedParams = mockTransport.getTaskPushNotificationConfig.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should apply default tenant to listTaskPushNotificationConfig when tenant is empty', async () => { + await decorator.listTaskPushNotificationConfig({ + taskId: 'task-1', + tenant: '', + pageSize: 0, + pageToken: '', + }); + + const passedParams = mockTransport.listTaskPushNotificationConfig.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should apply default tenant to deleteTaskPushNotificationConfig when tenant is empty', async () => { + await decorator.deleteTaskPushNotificationConfig({ + id: 'cfg-1', + taskId: 'task-1', + tenant: '', + }); + + const passedParams = mockTransport.deleteTaskPushNotificationConfig.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should apply default tenant to sendMessageStream when tenant is empty', async () => { + await drain( + decorator.sendMessageStream({ + tenant: '', + message: undefined, + configuration: undefined, + metadata: {}, + }) + ); + + const passedParams = mockTransport.sendMessageStream.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should apply default tenant to resubscribeTask when tenant is empty', async () => { + await drain(decorator.resubscribeTask({ id: 'task-1', tenant: '' })); + + const passedParams = mockTransport.resubscribeTask.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + }); + + describe('caller-specified tenant preservation', () => { + const CALLER_TENANT = 'caller-tenant'; + + it('should preserve caller-specified tenant on getExtendedAgentCard', async () => { + await decorator.getExtendedAgentCard({ tenant: CALLER_TENANT }); + + const passedParams = mockTransport.getExtendedAgentCard.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(CALLER_TENANT); + }); + + it('should preserve caller-specified tenant on cancelTask', async () => { + await decorator.cancelTask({ id: 'task-1', tenant: CALLER_TENANT, metadata: {} }); + + const passedParams = mockTransport.cancelTask.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(CALLER_TENANT); + }); + + it('should preserve caller-specified tenant on listTasks', async () => { + await decorator.listTasks({ + tenant: CALLER_TENANT, + contextId: '', + status: undefined, + pageToken: '', + statusTimestampAfter: '', + }); + + const passedParams = mockTransport.listTasks.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(CALLER_TENANT); + }); + + it('should preserve caller-specified tenant on createTaskPushNotificationConfig', async () => { + await decorator.createTaskPushNotificationConfig({ + tenant: CALLER_TENANT, + id: 'config-1', + taskId: 'task-1', + url: 'https://example.com', + token: '', + authentication: undefined, + }); + + const passedParams = mockTransport.createTaskPushNotificationConfig.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(CALLER_TENANT); + }); + + it('should preserve caller-specified tenant on getTaskPushNotificationConfig', async () => { + await decorator.getTaskPushNotificationConfig({ + id: 'cfg-1', + taskId: 'task-1', + tenant: CALLER_TENANT, + }); + + const passedParams = mockTransport.getTaskPushNotificationConfig.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(CALLER_TENANT); + }); + + it('should preserve caller-specified tenant on listTaskPushNotificationConfig', async () => { + await decorator.listTaskPushNotificationConfig({ + taskId: 'task-1', + tenant: CALLER_TENANT, + pageSize: 0, + pageToken: '', + }); + + const passedParams = mockTransport.listTaskPushNotificationConfig.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(CALLER_TENANT); + }); + + it('should preserve caller-specified tenant on deleteTaskPushNotificationConfig', async () => { + await decorator.deleteTaskPushNotificationConfig({ + id: 'cfg-1', + taskId: 'task-1', + tenant: CALLER_TENANT, + }); + + const passedParams = mockTransport.deleteTaskPushNotificationConfig.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(CALLER_TENANT); + }); + + it('should preserve caller-specified tenant on sendMessageStream', async () => { + await drain( + decorator.sendMessageStream({ + tenant: CALLER_TENANT, + message: undefined, + configuration: undefined, + metadata: {}, + }) + ); + + const passedParams = mockTransport.sendMessageStream.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(CALLER_TENANT); + }); + + it('should preserve caller-specified tenant on resubscribeTask', async () => { + await drain( + decorator.resubscribeTask({ + id: 'task-1', + tenant: CALLER_TENANT, + }) + ); + + const passedParams = mockTransport.resubscribeTask.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(CALLER_TENANT); + }); + }); + + describe('other fields passthrough', () => { + it('should not mutate the original params object', async () => { + const params: SendMessageRequest = { + tenant: '', + message: undefined, + configuration: undefined, + metadata: { key: 'value' }, + }; + const original = { ...params }; + await decorator.sendMessage(params); + + // Original object should be unchanged + expect(params).to.deep.equal(original); + // But the base transport received the resolved tenant + const passedParams = mockTransport.sendMessage.mock.calls[0][0]; + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + + it('should forward all non-tenant fields unchanged', async () => { + await decorator.getTask({ id: 'my-task', tenant: '', historyLength: 42 }); + + const passedParams = mockTransport.getTask.mock.calls[0][0]; + expect(passedParams.id).to.equal('my-task'); + expect(passedParams.historyLength).to.equal(42); + expect(passedParams.tenant).to.equal(DEFAULT_TENANT); + }); + }); + + it('should pass through RequestOptions to the base transport', async () => { + const options = { signal: new AbortController().signal }; + await decorator.sendMessage( + { tenant: '', message: undefined, configuration: undefined, metadata: {} }, + options + ); + + expect(mockTransport.sendMessage.mock.calls[0][1]).to.equal(options); + }); +}); diff --git a/test/e2e.spec.ts b/test/e2e.spec.ts index 8b323c1b..64a53421 100644 --- a/test/e2e.spec.ts +++ b/test/e2e.spec.ts @@ -10,7 +10,7 @@ import { RequestContext, } from '../src/server/index.js'; import { AgentEvent } from '../src/server/events/execution_event_bus.js'; -import { AgentCard, Message, Role, TaskState, StreamResponse } from '../src/index.js'; +import { AgentCard, Message, Role, Task, TaskState, StreamResponse } from '../src/index.js'; import { agentCardHandler } from '../src/server/express/agent_card_handler.js'; import { jsonRpcHandler } from '../src/server/express/json_rpc_handler.js'; import { restHandler } from '../src/server/express/rest_handler.js'; @@ -292,6 +292,211 @@ describe('Client E2E tests', () => { }); }); +describe('Multi-tenancy E2E tests', () => { + // Only REST supports tenant-prefixed URL routing. JSON-RPC uses body params, + // and gRPC uses request message fields (both tested via the transport handler unit tests). + describe('[REST] tenant-scoped routing', () => { + let app: Express; + let server: Server; + let agentExecutor: TestAgentExecutor; + let agentCard: AgentCard; + let clientFactory: ClientFactory; + + beforeEach(async () => { + agentExecutor = new TestAgentExecutor(); + agentCard = { + name: 'Test Agent', + description: 'A multi-tenant test agent', + version: '1.0.0', + supportedInterfaces: [ + { + url: 'localhost', + protocolBinding: 'HTTP+JSON', + tenant: 'test-tenant', + protocolVersion: '1.0.0', + }, + ], + capabilities: { + streaming: true, + pushNotifications: true, + extensions: [], + }, + defaultInputModes: ['text/plain'], + defaultOutputModes: ['text/plain'], + skills: [], + provider: { url: '', organization: '' }, + documentationUrl: '', + securityRequirements: [], + securitySchemes: {}, + signatures: [], + }; + const requestHandler = new DefaultRequestHandler( + agentCard, + new InMemoryTaskStore(), + agentExecutor + ); + + app = express(); + app.use( + '/a2a/rest', + restHandler({ requestHandler, userBuilder: UserBuilder.noAuthentication }) + ); + + server = app.listen(); + const address = server.address() as AddressInfo; + agentCard.supportedInterfaces![0].url = `http://localhost:${address.port}/a2a/rest`; + clientFactory = new ClientFactory(); + }); + + afterEach(() => { + server.close(); + }); + + it('should send a message via tenant-prefixed route and retrieve the task', async () => { + const tenant = 'test-tenant'; + agentExecutor.events = [ + AgentEvent.task({ + id: '1', + contextId: '2', + status: { + state: TaskState.TASK_STATE_SUBMITTED, + timestamp: undefined, + message: undefined, + }, + artifacts: [], + history: [], + metadata: {}, + }), + AgentEvent.statusUpdate({ + taskId: '1', + contextId: '2', + status: { + state: TaskState.TASK_STATE_COMPLETED, + timestamp: undefined, + message: undefined, + }, + metadata: {}, + }), + ]; + const client = await clientFactory.createFromAgentCard(agentCard); + + const result = await client.sendMessage({ + tenant, + message: createTestMessage('msg-1', 'Hello from tenant'), + configuration: undefined, + metadata: {}, + }); + + // Result should be a Task (not a Message) since we published task events + expect('id' in result).to.equal(true); + const task = result as Task; + expect(task.status?.state).to.equal(TaskState.TASK_STATE_COMPLETED); + + // Should be able to retrieve the task via the same tenant + const retrieved = await client.getTask({ + id: task.id, + tenant, + historyLength: 10, + }); + expect(retrieved.id).to.equal(task.id); + }); + + it('should isolate tasks between tenants', async () => { + const requestHandler = new DefaultRequestHandler( + agentCard, + new InMemoryTaskStore(), + agentExecutor + ); + + // Create a separate server with a fresh store + const isolationApp = express(); + isolationApp.use( + '/a2a/rest', + restHandler({ requestHandler, userBuilder: UserBuilder.noAuthentication }) + ); + const isolationServer = isolationApp.listen(); + const address = isolationServer.address() as AddressInfo; + + try { + const baseUrl = `http://localhost:${address.port}/a2a/rest`; + + // Send message as tenant-A + agentExecutor.events = [ + AgentEvent.task({ + id: 'task-a', + contextId: 'ctx-a', + status: { + state: TaskState.TASK_STATE_SUBMITTED, + timestamp: undefined, + message: undefined, + }, + artifacts: [], + history: [], + metadata: {}, + }), + AgentEvent.statusUpdate({ + taskId: 'task-a', + contextId: 'ctx-a', + status: { + state: TaskState.TASK_STATE_COMPLETED, + timestamp: undefined, + message: undefined, + }, + metadata: {}, + }), + ]; + + const tenantACard = { + ...agentCard, + supportedInterfaces: [ + { + url: baseUrl, + protocolBinding: 'HTTP+JSON', + tenant: 'tenant-A', + protocolVersion: '1.0.0', + }, + ], + }; + const clientA = await clientFactory.createFromAgentCard(tenantACard); + const resultA = await clientA.sendMessage({ + tenant: 'tenant-A', + message: createTestMessage('msg-a', 'Hello from A'), + configuration: undefined, + metadata: {}, + }); + expect('id' in resultA).to.equal(true); + + // Try to get tenant-A's task as tenant-B -- should fail + const tenantBCard = { + ...agentCard, + supportedInterfaces: [ + { + url: baseUrl, + protocolBinding: 'HTTP+JSON', + tenant: 'tenant-B', + protocolVersion: '1.0.0', + }, + ], + }; + const clientB = await clientFactory.createFromAgentCard(tenantBCard); + try { + await clientB.getTask({ + id: (resultA as any).id, + tenant: 'tenant-B', + historyLength: 0, + }); + // Should not reach here + expect.fail('Expected TaskNotFoundError'); + } catch (error: unknown) { + expect((error as Error).name).to.equal('TaskNotFoundError'); + } + } finally { + isolationServer.close(); + } + }); + }); +}); + const removeUndefinedFields = (obj: any) => JSON.parse(JSON.stringify(obj)); function createTestMessage(id: string, text: string): Message { return { diff --git a/test/server/default_request_handler.spec.ts b/test/server/default_request_handler.spec.ts index bbe18c81..bfc459e2 100644 --- a/test/server/default_request_handler.spec.ts +++ b/test/server/default_request_handler.spec.ts @@ -3020,7 +3020,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { it('getAuthenticatedExtendedAgentCard should fail if the agent card does not support extended agent card', async () => { let caughtError; try { - await handler.getAuthenticatedExtendedAgentCard(serverCallContext); + await handler.getAuthenticatedExtendedAgentCard({ tenant: '' }, serverCallContext); } catch (error: any) { caughtError = error; } finally { @@ -3040,7 +3040,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { ); let caughtError; try { - await handler.getAuthenticatedExtendedAgentCard(serverCallContext); + await handler.getAuthenticatedExtendedAgentCard({ tenant: '' }, serverCallContext); } catch (error: any) { caughtError = error; } finally { @@ -3061,7 +3061,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { ); const context = new ServerCallContext(undefined, new A2AUser(true)); - const agentCard = await handler.getAuthenticatedExtendedAgentCard(context); + const agentCard = await handler.getAuthenticatedExtendedAgentCard({ tenant: '' }, context); assert.deepEqual(agentCard, extendedAgentCard); }); @@ -3077,7 +3077,7 @@ describe('DefaultRequestHandler as A2ARequestHandler', () => { ); const context = new ServerCallContext(undefined, new A2AUser(false)); - const agentCard = await handler.getAuthenticatedExtendedAgentCard(context); + const agentCard = await handler.getAuthenticatedExtendedAgentCard({ tenant: '' }, context); assert(agentCard.capabilities.extensions.length === 1); assert.deepEqual(agentCard.capabilities.extensions[0], { uri: 'requested-extension-uri', diff --git a/test/server/express/rest_handler.spec.ts b/test/server/express/rest_handler.spec.ts index c0f8d2b8..cac20379 100644 --- a/test/server/express/rest_handler.spec.ts +++ b/test/server/express/rest_handler.spec.ts @@ -128,6 +128,13 @@ describe('restHandler', () => { assert.deepEqual(response.body.name, testAgentCard.name); }); + it('should return the agent card with 200 OK when tenant is provided', async () => { + const response = await request(app).get('/tenant1/extendedAgentCard').expect(200); + + expect(mockRequestHandler.getAuthenticatedExtendedAgentCard as Mock).toHaveBeenCalledTimes(1); + assert.deepEqual(response.body.name, testAgentCard.name); + }); + it('should return 400 if getAuthenticatedExtendedAgentCard fails', async () => { (mockRequestHandler.getAuthenticatedExtendedAgentCard as Mock).mockRejectedValue( new RequestMalformedError('Card fetch failed') @@ -164,6 +171,23 @@ describe('restHandler', () => { assert.isUndefined(response.body.kind); }); + it('should accept message with tenant prefix and pass tenant to handler', async () => { + const message = ProtoMessage.toJSON(testMessage); + (mockRequestHandler.sendMessage as Mock).mockResolvedValue(testTask); + + await request(app).post('/tenant1/message:send').send({ message }).expect(201); + + expect(mockRequestHandler.sendMessage).toHaveBeenCalledWith( + expect.objectContaining({ + tenant: 'tenant1', + message: expect.objectContaining({ + messageId: 'msg-1', + }), + }), + expect.anything() + ); + }); + it('should return 400 when message is invalid', async () => { (mockRequestHandler.sendMessage as Mock).mockRejectedValue( new RequestMalformedError('Message is required') diff --git a/test/server/grpc/grpc_handler.spec.ts b/test/server/grpc/grpc_handler.spec.ts index 11b1dac4..27b898fa 100644 --- a/test/server/grpc/grpc_handler.spec.ts +++ b/test/server/grpc/grpc_handler.spec.ts @@ -108,7 +108,7 @@ describe('grpcHandler', () => { describe('getExtendedAgentCard', () => { it('should return agent card via gRPC callback', async () => { - const call = createMockUnaryCall({}); + const call = createMockUnaryCall({ tenant: '' }); const callback = vi.fn(); await handler.getExtendedAgentCard(call, callback); @@ -123,7 +123,7 @@ describe('grpcHandler', () => { (mockRequestHandler.getAuthenticatedExtendedAgentCard as Mock).mockRejectedValue( new TaskNotFoundError('Not Found') ); - const call = createMockUnaryCall({}); + const call = createMockUnaryCall({ tenant: '' }); const callback = vi.fn(); await handler.getExtendedAgentCard(call, callback); @@ -132,6 +132,17 @@ describe('grpcHandler', () => { assert.equal(err.code, grpc.status.NOT_FOUND); assert.equal(err.details, 'Not Found'); }); + + it('should pass tenant from request to request handler', async () => { + const call = createMockUnaryCall({ tenant: 'test-tenant' }); + const callback = vi.fn(); + await handler.getExtendedAgentCard(call, callback); + + expect(mockRequestHandler.getAuthenticatedExtendedAgentCard).toHaveBeenCalledWith( + expect.objectContaining({ tenant: 'test-tenant' }), + expect.anything() + ); + }); }); describe('sendMessage', () => { diff --git a/test/server/jsonrpc_transport_handler.spec.ts b/test/server/jsonrpc_transport_handler.spec.ts index c260394e..c0d1e150 100644 --- a/test/server/jsonrpc_transport_handler.spec.ts +++ b/test/server/jsonrpc_transport_handler.spec.ts @@ -204,6 +204,23 @@ describe('JsonRpcTransportHandler', () => { }); }); + describe('Method handling', () => { + it('should pass tenant from params to getAuthenticatedExtendedAgentCard', async () => { + const request = { + jsonrpc: '2.0', + method: 'GetExtendedAgentCard', + id: 1, + params: { tenant: 'test-tenant' }, + }; + await transportHandler.handle(request, defaultContext); + + expect(mockRequestHandler.getAuthenticatedExtendedAgentCard).toHaveBeenCalledWith( + expect.objectContaining({ tenant: 'test-tenant' }), + expect.anything() + ); + }); + }); + describe('Error mapping', () => { it('should map RequestMalformedError to code and message', async () => { const mappedError = JsonRpcTransportHandler.mapToJSONRPCError( diff --git a/test/server/rest_transport_handler.spec.ts b/test/server/rest_transport_handler.spec.ts index a05a07ba..20bfa580 100644 --- a/test/server/rest_transport_handler.spec.ts +++ b/test/server/rest_transport_handler.spec.ts @@ -137,7 +137,10 @@ describe('RestTransportHandler', () => { describe('getAuthenticatedExtendedAgentCard', () => { it('should return extended agent card from request handler', async () => { - const card = await transportHandler.getAuthenticatedExtendedAgentCard(mockContext); + const card = await transportHandler.getAuthenticatedExtendedAgentCard( + { tenant: '' }, + mockContext + ); expect(card).to.deep.equal(testAgentCard); expect(mockRequestHandler.getAuthenticatedExtendedAgentCard as Mock).toHaveBeenCalledTimes(1); diff --git a/test/server/tenant_isolation.spec.ts b/test/server/tenant_isolation.spec.ts new file mode 100644 index 00000000..e4451e44 --- /dev/null +++ b/test/server/tenant_isolation.spec.ts @@ -0,0 +1,233 @@ +import { describe, it, expect, beforeEach } from 'vitest'; +import { InMemoryTaskStore } from '../../src/server/store.js'; +import { InMemoryPushNotificationStore } from '../../src/server/push_notification/push_notification_store.js'; +import { ServerCallContext } from '../../src/server/context.js'; +import { Task, TaskState, TaskPushNotificationConfig } from '../../src/index.js'; + +function createContext(tenant?: string): ServerCallContext { + return new ServerCallContext(undefined, undefined, tenant); +} + +function createTask(id: string, contextId: string = 'ctx-1'): Task { + return { + id, + contextId, + status: { + state: TaskState.TASK_STATE_COMPLETED, + timestamp: new Date().toISOString(), + message: undefined, + }, + artifacts: [], + history: [], + metadata: {}, + }; +} + +describe('InMemoryTaskStore tenant isolation', () => { + let store: InMemoryTaskStore; + + beforeEach(() => { + store = new InMemoryTaskStore(); + }); + + it('should save and load a task without tenant (global scope)', async () => { + const ctx = createContext(); + const task = createTask('task-1'); + await store.save(task, ctx); + + const loaded = await store.load('task-1', ctx); + expect(loaded).toBeDefined(); + expect(loaded!.id).to.equal('task-1'); + }); + + it('should save and load a task with tenant', async () => { + const ctx = createContext('tenant-A'); + const task = createTask('task-1'); + await store.save(task, ctx); + + const loaded = await store.load('task-1', ctx); + expect(loaded).toBeDefined(); + expect(loaded!.id).to.equal('task-1'); + }); + + it('should isolate tasks between tenants', async () => { + const ctxA = createContext('tenant-A'); + const ctxB = createContext('tenant-B'); + + await store.save(createTask('task-1'), ctxA); + + // Tenant A can load the task + const loadedA = await store.load('task-1', ctxA); + expect(loadedA).toBeDefined(); + + // Tenant B cannot load the same task + const loadedB = await store.load('task-1', ctxB); + expect(loadedB).toBeUndefined(); + }); + + it('should allow same task ID in different tenants', async () => { + const ctxA = createContext('tenant-A'); + const ctxB = createContext('tenant-B'); + + const taskA = createTask('task-1', 'ctx-A'); + const taskB = createTask('task-1', 'ctx-B'); + + await store.save(taskA, ctxA); + await store.save(taskB, ctxB); + + const loadedA = await store.load('task-1', ctxA); + const loadedB = await store.load('task-1', ctxB); + + expect(loadedA!.contextId).to.equal('ctx-A'); + expect(loadedB!.contextId).to.equal('ctx-B'); + }); + + it('should list only tasks belonging to the tenant', async () => { + const ctxA = createContext('tenant-A'); + const ctxB = createContext('tenant-B'); + + await store.save(createTask('task-a1'), ctxA); + await store.save(createTask('task-a2'), ctxA); + await store.save(createTask('task-b1'), ctxB); + + const listA = await store.list( + { + tenant: 'tenant-A', + contextId: '', + status: undefined, + pageSize: 10, + pageToken: '', + statusTimestampAfter: '', + }, + ctxA + ); + + expect(listA.tasks).toHaveLength(2); + expect(listA.tasks.map((t) => t.id).sort()).toEqual(['task-a1', 'task-a2']); + + const listB = await store.list( + { + tenant: 'tenant-B', + contextId: '', + status: undefined, + pageSize: 10, + pageToken: '', + statusTimestampAfter: '', + }, + ctxB + ); + + expect(listB.tasks).toHaveLength(1); + expect(listB.tasks[0].id).to.equal('task-b1'); + }); + + it('should isolate tenant-scoped tasks from global scope', async () => { + const ctxGlobal = createContext(); + const ctxTenant = createContext('tenant-A'); + + await store.save(createTask('global-task'), ctxGlobal); + await store.save(createTask('tenant-task'), ctxTenant); + + // Global context should not see tenant tasks + const globalList = await store.list( + { + tenant: '', + contextId: '', + status: undefined, + pageSize: 10, + pageToken: '', + statusTimestampAfter: '', + }, + ctxGlobal + ); + expect(globalList.tasks).toHaveLength(1); + expect(globalList.tasks[0].id).to.equal('global-task'); + + // Tenant context should not see global tasks + const tenantList = await store.list( + { + tenant: 'tenant-A', + contextId: '', + status: undefined, + pageSize: 10, + pageToken: '', + statusTimestampAfter: '', + }, + ctxTenant + ); + expect(tenantList.tasks).toHaveLength(1); + expect(tenantList.tasks[0].id).to.equal('tenant-task'); + }); +}); + +describe('InMemoryPushNotificationStore tenant isolation', () => { + let store: InMemoryPushNotificationStore; + + const createConfig = ( + id: string, + taskId: string, + tenant: string = '' + ): TaskPushNotificationConfig => ({ + tenant, + id, + taskId, + url: `https://notify.example.com/${id}`, + token: 'secret', + authentication: undefined, + }); + + beforeEach(() => { + store = new InMemoryPushNotificationStore(); + }); + + it('should isolate configs between tenants', async () => { + const ctxA = createContext('tenant-A'); + const ctxB = createContext('tenant-B'); + + await store.save('task-1', ctxA, createConfig('config-1', 'task-1', 'tenant-A')); + + // Tenant A can load the config + const loadedA = await store.load('task-1', ctxA); + expect(loadedA).toHaveLength(1); + expect(loadedA[0].id).to.equal('config-1'); + + // Tenant B cannot load tenant A's configs + const loadedB = await store.load('task-1', ctxB); + expect(loadedB).toHaveLength(0); + }); + + it('should allow same task ID configs in different tenants', async () => { + const ctxA = createContext('tenant-A'); + const ctxB = createContext('tenant-B'); + + await store.save('task-1', ctxA, createConfig('config-a', 'task-1', 'tenant-A')); + await store.save('task-1', ctxB, createConfig('config-b', 'task-1', 'tenant-B')); + + const loadedA = await store.load('task-1', ctxA); + const loadedB = await store.load('task-1', ctxB); + + expect(loadedA).toHaveLength(1); + expect(loadedA[0].id).to.equal('config-a'); + expect(loadedB).toHaveLength(1); + expect(loadedB[0].id).to.equal('config-b'); + }); + + it('should delete configs only within the tenant scope', async () => { + const ctxA = createContext('tenant-A'); + const ctxB = createContext('tenant-B'); + + await store.save('task-1', ctxA, createConfig('config-1', 'task-1', 'tenant-A')); + await store.save('task-1', ctxB, createConfig('config-1', 'task-1', 'tenant-B')); + + // Delete from tenant A + await store.delete('task-1', ctxA, 'config-1'); + + // Tenant A config is gone + const loadedA = await store.load('task-1', ctxA); + expect(loadedA).toHaveLength(0); + + // Tenant B config still exists + const loadedB = await store.load('task-1', ctxB); + expect(loadedB).toHaveLength(1); + }); +});