diff --git a/cspell-dict.txt b/cspell-dict.txt index 07875afb..90413c75 100644 --- a/cspell-dict.txt +++ b/cspell-dict.txt @@ -13,4 +13,7 @@ vendia withrequired typeof instanceof -codegen \ No newline at end of file +codegen +awslambda +streamify +Writables \ No newline at end of file diff --git a/src/__tests__/integrationV2Stream.test.ts b/src/__tests__/integrationV2Stream.test.ts new file mode 100644 index 00000000..0ac6c96e --- /dev/null +++ b/src/__tests__/integrationV2Stream.test.ts @@ -0,0 +1,71 @@ +import { + ApolloServer, + type ApolloServerOptions, + type BaseContext, +} from '@apollo/server'; +import { + type CreateServerForIntegrationTestsOptions, + defineIntegrationTestSuite, +} from '@apollo/server-integration-testsuite'; +import { createServer } from 'http'; +import { handlers, startServerAndCreateLambdaHandler } from '..'; +import { urlForHttpServer } from './mockServer'; +import { + createMockV2StreamServer, + installStreamMock, +} from './mockAPIGatewayV2StreamServer'; + +describe('lambdaHandlerV2Stream', () => { + beforeAll(() => { + // Must run before startServerAndCreateLambdaHandler so that + // awslambda.streamifyResponse is available at handler-creation time. + installStreamMock(); + }); + + afterAll(() => { + delete (global as any).awslambda; + }); + + describe.each([true, false])( + 'With base64 encoding set to %s', + (shouldBase64Encode) => { + defineIntegrationTestSuite( + async function ( + serverOptions: ApolloServerOptions, + testOptions?: CreateServerForIntegrationTestsOptions, + ) { + const httpServer = createServer(); + const server = new ApolloServer({ ...serverOptions }); + + const handler = startServerAndCreateLambdaHandler( + server, + handlers.createAPIGatewayProxyEventV2StreamRequestHandler(), + { ...testOptions }, + ); + + httpServer.addListener( + 'request', + createMockV2StreamServer(handler, shouldBase64Encode), + ); + + await new Promise((resolve) => { + httpServer.listen({ port: 0 }, resolve); + }); + + return { + server, + url: urlForHttpServer(httpServer), + async extraCleanup() { + await new Promise((resolve) => { + httpServer.close(() => resolve()); + }); + }, + }; + }, + { + serverIsStartedInBackground: true, + }, + ); + }, + ); +}); diff --git a/src/__tests__/middleware.test.ts b/src/__tests__/middleware.test.ts index 165af027..6ede6a8d 100644 --- a/src/__tests__/middleware.test.ts +++ b/src/__tests__/middleware.test.ts @@ -1,6 +1,6 @@ import { ApolloServer } from '@apollo/server'; import type { APIGatewayProxyEventV2 } from 'aws-lambda'; -import { handlers, startServerAndCreateLambdaHandler } from '..'; +import { handlers, middleware, startServerAndCreateLambdaHandler } from '..'; import gql from 'graphql-tag'; import { type DocumentNode, print } from 'graphql'; @@ -149,3 +149,93 @@ describe('Response mutation', () => { expect(result.cookies).toContain(cookieValue); }); }); + +describe('runMiddleware', () => { + const mockHandler: handlers.RequestHandler = { + fromEvent: jest.fn(), + toSuccessResult: jest.fn(), + toErrorResult: jest + .fn() + .mockReturnValue({ statusCode: 500, body: 'internal error' }), + }; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('returns continue with empty middleware list', async () => { + const result = await middleware.runMiddleware({}, [], mockHandler); + expect(result.status).toBe('continue'); + if (result.status === 'continue') { + expect(result.middleware).toEqual([]); + } + }); + + it('collects result-callback middleware into the continue result', async () => { + const callback = jest.fn().mockResolvedValue(undefined); + const result = await middleware.runMiddleware( + {}, + [async () => callback], + mockHandler, + ); + expect(result.status).toBe('continue'); + if (result.status === 'continue') { + expect(result.middleware).toHaveLength(1); + expect(result.middleware[0]).toBe(callback); + } + }); + + it('returns early result when middleware returns an object', async () => { + const earlyResult = { statusCode: 418 }; + const result = await middleware.runMiddleware( + {}, + [async () => earlyResult], + mockHandler, + ); + expect(result.status).toBe('result'); + if (result.status === 'result') { + expect(result.result).toBe(earlyResult); + } + }); + + it('calls toErrorResult and returns result when middleware throws', async () => { + const error = new Error('middleware exploded'); + const result = await middleware.runMiddleware( + {}, + [ + async () => { + throw error; + }, + ], + mockHandler, + ); + expect(mockHandler.toErrorResult).toHaveBeenCalledWith(error); + expect(result.status).toBe('result'); + if (result.status === 'result') { + expect(result.result).toEqual({ + statusCode: 500, + body: 'internal error', + }); + } + }); + + it('invokes collected callbacks before returning an error result on throw', async () => { + const callback = jest.fn().mockResolvedValue(undefined); + const error = new Error('late throw'); + const result = await middleware.runMiddleware( + {}, + [ + async () => callback, + async () => { + throw error; + }, + ], + mockHandler, + ); + expect(result.status).toBe('result'); + expect(callback).toHaveBeenCalledWith({ + statusCode: 500, + body: 'internal error', + }); + }); +}); diff --git a/src/__tests__/mockAPIGatewayV2StreamServer.ts b/src/__tests__/mockAPIGatewayV2StreamServer.ts new file mode 100644 index 00000000..78c1fdb5 --- /dev/null +++ b/src/__tests__/mockAPIGatewayV2StreamServer.ts @@ -0,0 +1,117 @@ +import url from 'node:url'; +import { Writable } from 'node:stream'; +import type { IncomingMessage, ServerResponse } from 'node:http'; +import type { + APIGatewayProxyEventV2, + Context as LambdaContext, + Handler, +} from 'aws-lambda'; + +// Per-invocation state keyed by event object (unique per request). +// Using WeakMaps so concurrent requests don't overwrite each other's state. +const activeWritables = new WeakMap(); +const metadataCallbacks = new WeakMap< + Writable, + (metadata: { statusCode: number; headers: Record }) => void +>(); + +// Call this once (e.g. in beforeAll) before startServerAndCreateLambdaHandler +// runs so that awslambda.streamifyResponse is available when the handler is built. +export function installStreamMock() { + (globalThis as any).awslambda = { + streamifyResponse: (fn: any) => async (event: any, context: any) => { + await fn(event, activeWritables.get(event), context); + }, + HttpResponseStream: { + from: (w: Writable, metadata: any) => { + metadataCallbacks.get(w)?.(metadata); + return w; + }, + }, + }; +} + +export function createMockV2StreamServer( + handler: Handler, + shouldBase64Encode: boolean, +) { + return (req: IncomingMessage, res: ServerResponse) => { + let body = ''; + req.on('data', (chunk) => (body += chunk)); + // this is an unawaited async function, but anything that causes it to + // reject should cause a test to fail + req.on('end', async () => { + const event = v2EventFromRequest(shouldBase64Encode)(req, body); + + // Pipe Lambda stream writes directly to the HTTP response so incremental + // delivery chunks reach the client as they are produced. + const writable = new Writable({ + write(chunk, _enc, cb) { + res.write(chunk); + cb(); + }, + }); + + // Register per-invocation state using unique objects as WeakMap keys. + activeWritables.set(event, writable); + metadataCallbacks.set(writable, (metadata) => { + res.statusCode = metadata.statusCode ?? 200; + for (const [key, value] of Object.entries(metadata.headers ?? {})) { + res.setHeader(key, String(value)); + } + }); + + await handler( + event, + { functionName: 'someFunc' } as LambdaContext, + () => { + throw new Error("we don't use callback"); + }, + ); + + activeWritables.delete(event); + res.end(); + }); + }; +} + +function v2EventFromRequest(shouldBase64Encode: boolean) { + return function (req: IncomingMessage, body: string): APIGatewayProxyEventV2 { + const urlObject = url.parse(req.url || '', false); + + type TestEventType = Partial< + Omit & { + requestContext: Partial< + Omit & { + http: Partial; + } + >; + } + >; + + const event: TestEventType = { + version: '2.0', + body: shouldBase64Encode + ? Buffer.from(body, 'utf8').toString('base64') + : body, + rawQueryString: urlObject.search?.replace(/^\?/, '') ?? '', + headers: Object.fromEntries( + Object.entries(req.headers).map(([name, value]) => { + if (Array.isArray(value)) { + return [name, value.join(',')]; + } else { + return [name, value]; + } + }), + ), + requestContext: { + http: { + method: req.method, + path: req.url, + }, + }, + isBase64Encoded: shouldBase64Encode, + }; + return event as APIGatewayProxyEventV2; + }; +} diff --git a/src/__tests__/streamHandler.test.ts b/src/__tests__/streamHandler.test.ts new file mode 100644 index 00000000..0f355c0c --- /dev/null +++ b/src/__tests__/streamHandler.test.ts @@ -0,0 +1,311 @@ +import { Writable } from 'stream'; +import { ApolloServer, HeaderMap } from '@apollo/server'; +import type { APIGatewayProxyEventV2 } from 'aws-lambda'; +import { handlers, startServerAndCreateLambdaHandler } from '..'; +import gql from 'graphql-tag'; +import { type DocumentNode, print } from 'graphql'; + +function createV2Event( + doc: DocumentNode, + options: { isBase64Encoded?: boolean; contentType?: string } = {}, +): APIGatewayProxyEventV2 { + const body = JSON.stringify({ query: print(doc) }); + return { + version: '2', + headers: { 'content-type': options.contentType ?? 'application/json' }, + isBase64Encoded: options.isBase64Encoded ?? false, + rawQueryString: '', + requestContext: { http: { method: 'POST' } } as any, + rawPath: '/', + routeKey: '/', + body: options.isBase64Encoded ? Buffer.from(body).toString('base64') : body, + }; +} + +type CapturedStream = { + getBody: () => string; + getMetadata: () => + | { + statusCode: number; + headers: Record; + cookies?: string[]; + } + | undefined; +}; + +function setupAWSLambdaMock(): CapturedStream { + let capturedMetadata: + | { statusCode: number; headers: Record } + | undefined; + const chunks: string[] = []; + + const writable = new Writable({ + write(chunk, _enc, cb) { + chunks.push(chunk.toString()); + cb(); + }, + }); + + (global as any).awslambda = { + streamifyResponse: (fn: any) => async (event: any, context: any) => { + await fn(event, writable, context); + }, + HttpResponseStream: { + from: (w: any, metadata: any) => { + capturedMetadata = metadata; + return w; + }, + }, + }; + + return { + getBody: () => chunks.join(''), + getMetadata: () => capturedMetadata, + }; +} + +const typeDefs = `#graphql + type Query { + hello: String + } +`; + +const resolvers = { + Query: { + hello: () => 'world', + }, +}; + +describe('isStreamRequestHandler', () => { + it('returns true for a stream handler', () => { + const h = handlers.createAPIGatewayProxyEventV2StreamRequestHandler(); + expect(handlers.isStreamRequestHandler(h)).toBe(true); + }); + + it('returns false for a regular handler', () => { + const h = handlers.createAPIGatewayProxyEventV2RequestHandler(); + expect(handlers.isStreamRequestHandler(h)).toBe(false); + }); +}); + +describe('createStreamRequestHandler', () => { + const streamHandler = + handlers.createAPIGatewayProxyEventV2StreamRequestHandler(); + + describe('fromEvent', () => { + it('parses HTTP method, headers, query string, and JSON body', () => { + const event = createV2Event(gql` + query { + hello + } + `); + const req = streamHandler.fromEvent(event); + expect(req.method).toBe('POST'); + expect(req.headers.get('content-type')).toBe('application/json'); + expect(req.search).toBe(''); + expect((req.body as any).query).toBeDefined(); + }); + + it('decodes base64-encoded JSON body', () => { + const event = createV2Event( + gql` + query { + hello + } + `, + { isBase64Encoded: true }, + ); + const req = streamHandler.fromEvent(event); + expect((req.body as any).query).toBeDefined(); + }); + + it('parses text/plain body as a string', () => { + const event: APIGatewayProxyEventV2 = { + ...createV2Event( + gql` + query { + hello + } + `, + { contentType: 'text/plain' }, + ), + body: '{ hello }', + }; + const req = streamHandler.fromEvent(event); + expect(req.body).toBe('{ hello }'); + }); + + it('returns empty string for an unknown content type', () => { + const event: APIGatewayProxyEventV2 = { + ...createV2Event( + gql` + query { + hello + } + `, + { contentType: 'application/xml' }, + ), + body: '', + }; + const req = streamHandler.fromEvent(event); + expect(req.body).toBe(''); + }); + + it('returns empty string when body is absent', () => { + const event: APIGatewayProxyEventV2 = { + ...createV2Event(gql` + query { + hello + } + `), + body: undefined, + }; + const req = streamHandler.fromEvent(event); + expect(req.body).toBe(''); + }); + }); + + describe('buildHTTPMetadata', () => { + it('maps response status and headers to HttpMetadata', async () => { + const responseHeaders = new HeaderMap(); + responseHeaders.set('content-type', 'application/json'); + const metadata = await streamHandler.buildHTTPMetadata({ + status: 201, + headers: responseHeaders, + body: { kind: 'complete', string: '' }, + } as any); + expect(metadata.statusCode).toBe(201); + expect(metadata.headers['content-type']).toBe('application/json'); + }); + + it('defaults status to 200 when status is undefined', async () => { + const metadata = await streamHandler.buildHTTPMetadata({ + status: undefined, + headers: new HeaderMap(), + body: { kind: 'complete', string: '' }, + } as any); + expect(metadata.statusCode).toBe(200); + }); + }); + + describe('toErrorResult', () => { + it('returns 400 metadata and the error message as body', async () => { + const result = await streamHandler.toErrorResult( + new Error('something failed'), + ); + expect(result.metadata.statusCode).toBe(400); + expect(result.body).toBe('something failed'); + }); + }); +}); + +describe('Stream Lambda Handler', () => { + let captured: CapturedStream; + + beforeEach(() => { + captured = setupAWSLambdaMock(); + }); + + afterEach(() => { + delete (global as any).awslambda; + }); + + it('processes a basic GraphQL query and streams the response', async () => { + const server = new ApolloServer({ typeDefs, resolvers }); + const lambdaHandler = startServerAndCreateLambdaHandler( + server, + handlers.createAPIGatewayProxyEventV2StreamRequestHandler(), + ); + + await lambdaHandler( + createV2Event(gql` + query { + hello + } + `), + {} as any, + () => {}, + ); + + expect(captured.getMetadata()?.statusCode).toBe(200); + expect(JSON.parse(captured.getBody())).toEqual({ + data: { hello: 'world' }, + }); + }); + + it('short-circuits and ends the stream when middleware returns metadata', async () => { + const server = new ApolloServer({ typeDefs, resolvers }); + const lambdaHandler = startServerAndCreateLambdaHandler( + server, + handlers.createAPIGatewayProxyEventV2StreamRequestHandler(), + { + middleware: [async () => ({ statusCode: 401, headers: {} })], + }, + ); + + await lambdaHandler( + createV2Event(gql` + query { + hello + } + `), + {} as any, + () => {}, + ); + + expect(captured.getMetadata()?.statusCode).toBe(401); + expect(captured.getBody()).toBe(''); + }); + + it('writes a 400 error response when event parsing fails', async () => { + const server = new ApolloServer({ typeDefs, resolvers }); + const lambdaHandler = startServerAndCreateLambdaHandler( + server, + handlers.createAPIGatewayProxyEventV2StreamRequestHandler(), + ); + + const badEvent: APIGatewayProxyEventV2 = { + version: '2', + headers: { 'content-type': 'application/json' }, + isBase64Encoded: false, + rawQueryString: '', + requestContext: { http: { method: 'POST' } } as any, + rawPath: '/', + routeKey: '/', + body: 'not-valid-json', + }; + + await lambdaHandler(badEvent, {} as any, () => {}); + + expect(captured.getMetadata()?.statusCode).toBe(400); + expect(captured.getBody()).toContain('JSON'); + }); + + it('writes a 400 error response when middleware throws', async () => { + const server = new ApolloServer({ typeDefs, resolvers }); + const lambdaHandler = startServerAndCreateLambdaHandler( + server, + handlers.createAPIGatewayProxyEventV2StreamRequestHandler(), + { + middleware: [ + async () => { + throw new Error('middleware exploded'); + }, + ], + }, + ); + + await lambdaHandler( + createV2Event(gql` + query { + hello + } + `), + {} as any, + () => {}, + ); + + expect(captured.getMetadata()?.statusCode).toBe(400); + expect(captured.getBody()).toBe('middleware exploded'); + }); +}); diff --git a/src/awslambda.ts b/src/awslambda.ts new file mode 100644 index 00000000..23ae6aae --- /dev/null +++ b/src/awslambda.ts @@ -0,0 +1,30 @@ +import type { Writable } from 'stream'; +import type { Handler, Context } from 'aws-lambda'; + +const anyGlobal: any = global; + +export namespace awslambda { + export type HttpMetadata = { + statusCode: number; + headers: Record; + cookies?: string[]; + }; + + export namespace HttpResponseStream { + export function from(writable: Writable, metadata: HttpMetadata): Writable { + return anyGlobal.awslambda.HttpResponseStream.from(writable, metadata); + } + } + + export type StreamHandler = ( + event: Event, + responseStream: Writable, + context: Context, + ) => void; + + export function streamifyResponse( + handler: StreamHandler, + ): Handler { + return anyGlobal.awslambda.streamifyResponse(handler); + } +} diff --git a/src/lambdaHandler.ts b/src/lambdaHandler.ts index 997373e6..ea3e2b40 100644 --- a/src/lambdaHandler.ts +++ b/src/lambdaHandler.ts @@ -5,42 +5,49 @@ import type { } from '@apollo/server'; import type { WithRequired } from '@apollo/utils.withrequired'; import type { Context, Handler } from 'aws-lambda'; -import type { LambdaResponse, MiddlewareFn } from './middleware'; -import type { - RequestHandler, - RequestHandlerEvent, - RequestHandlerResult, +import { + runMiddleware, + type LambdaResponse, + type MiddlewareFn, +} from './middleware'; +import { + isStreamRequestHandler, + type RequestHandler, + type RequestHandlerEvent, + type RequestHandlerResult, + type StreamRequestHandler, } from './request-handlers/_create'; +import { awslambda } from './awslambda'; +import type { Writable } from 'stream'; export interface LambdaContextFunctionArgument< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, > { - event: RH extends RequestHandler ? EventType : never; + event: RequestHandlerEvent; context: Context; } export interface LambdaHandlerOptions< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, TContext extends BaseContext, > { middleware?: Array>; context?: ContextFunction<[LambdaContextFunctionArgument], TContext>; } -export type LambdaHandler> = Handler< - RequestHandlerEvent, - RequestHandlerResult ->; +export type LambdaHandler< + RH extends RequestHandler | StreamRequestHandler, +> = Handler, RequestHandlerResult>; export function startServerAndCreateLambdaHandler< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, >( server: ApolloServer, handler: RH, options?: LambdaHandlerOptions, ): LambdaHandler; export function startServerAndCreateLambdaHandler< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, TContext extends BaseContext, >( server: ApolloServer, @@ -48,7 +55,7 @@ export function startServerAndCreateLambdaHandler< options: WithRequired, 'context'>, ): LambdaHandler; export function startServerAndCreateLambdaHandler< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, TContext extends BaseContext, >( server: ApolloServer, @@ -70,24 +77,95 @@ export function startServerAndCreateLambdaHandler< TContext > = options?.context ?? defaultContext; + if (isStreamRequestHandler(handler)) { + return awslambda.streamifyResponse>( + async (event, responseStream, context) => { + let resultMiddlewareFns: Array< + LambdaResponse> + > = []; + let httpResponseStream: Writable | undefined; + try { + const middlewareResult = await runMiddleware( + event, + options?.middleware ?? [], + handler, + ); + if (middlewareResult.status === 'result') { + httpResponseStream = awslambda.HttpResponseStream.from( + responseStream, + middlewareResult.result, + ); + httpResponseStream.end(); + return; + } + resultMiddlewareFns = middlewareResult.middleware; + + const httpGraphQLRequest = handler.fromEvent(event); + + const response = await server.executeHTTPGraphQLRequest({ + httpGraphQLRequest, + context: () => { + return contextFunction({ + event, + context, + }); + }, + }); + + const metadata = await handler.buildHTTPMetadata(response); + + httpResponseStream = awslambda.HttpResponseStream.from( + responseStream, + metadata, + ); + + if (response.body.kind === 'complete') { + httpResponseStream.write(response.body.string); + httpResponseStream.end(); + return; + } + + for await (const chunk of response.body.asyncIterator) { + httpResponseStream.write(chunk); + } + httpResponseStream.end(); + } catch (e) { + const { metadata, body } = await handler.toErrorResult(e); + + if (httpResponseStream) { + httpResponseStream.write(body); + httpResponseStream.end(); + return; + } + + for (const resultMiddlewareFn of resultMiddlewareFns) { + await resultMiddlewareFn(metadata as any); + } + + httpResponseStream = awslambda.HttpResponseStream.from( + responseStream, + metadata, + ); + httpResponseStream.write(body); + httpResponseStream.end(); + } + }, + ); + } + return async function (event, context) { - const resultMiddlewareFns: Array>> = + let resultMiddlewareFns: Array>> = []; try { - for (const middlewareFn of options?.middleware ?? []) { - const middlewareReturnValue = await middlewareFn(event); - // If the middleware returns an object, we assume it's a LambdaResponse - if ( - typeof middlewareReturnValue === 'object' && - middlewareReturnValue !== null - ) { - return middlewareReturnValue; - } - // If the middleware returns a function, we assume it's a result callback - if (middlewareReturnValue) { - resultMiddlewareFns.push(middlewareReturnValue); - } + const middlewareResult = await runMiddleware( + event, + options?.middleware ?? [], + handler, + ); + if (middlewareResult.status === 'result') { + return middlewareResult.result; } + resultMiddlewareFns = middlewareResult.middleware; const httpGraphQLRequest = handler.fromEvent(event); diff --git a/src/middleware.ts b/src/middleware.ts index 7ba042b7..49f33584 100644 --- a/src/middleware.ts +++ b/src/middleware.ts @@ -1,4 +1,10 @@ -import type { RequestHandler } from './request-handlers/_create'; +import { + isStreamRequestHandler, + type RequestHandler, + type RequestHandlerEvent, + type RequestHandlerResult, + type StreamRequestHandler, +} from './request-handlers/_create'; export type LambdaResponse = (result: ResultType) => Promise; @@ -6,7 +12,64 @@ export type LambdaRequest = ( event: EventType, ) => Promise | ResultType | void>; -export type MiddlewareFn> = - RH extends RequestHandler - ? LambdaRequest - : never; +export type MiddlewareFn< + RH extends RequestHandler | StreamRequestHandler, +> = LambdaRequest, RequestHandlerResult>; + +export async function runMiddleware< + RH extends RequestHandler | StreamRequestHandler, +>( + event: RequestHandlerEvent, + middleware: Array>, + handler: RH, +): Promise< + | { + status: 'result'; + result: RequestHandlerResult; + } + | { + status: 'continue'; + middleware: Array>>; + } +> { + const resultMiddlewareFns: Array>> = + []; + try { + for (const middlewareFn of middleware) { + const middlewareReturnValue = await middlewareFn(event); + // If the middleware returns an object, we assume it's an early result + if ( + typeof middlewareReturnValue === 'object' && + middlewareReturnValue !== null + ) { + return { + status: 'result', + result: middlewareReturnValue, + }; + } + // If the middleware returns a function, we assume it's a result callback + if (middlewareReturnValue) { + resultMiddlewareFns.push(middlewareReturnValue); + } + } + return { + status: 'continue', + middleware: resultMiddlewareFns, + }; + } catch (e) { + if (isStreamRequestHandler(handler)) { + throw e; + } + + const result = handler.toErrorResult(e); + + for (const resultMiddlewareFn of resultMiddlewareFns) { + await resultMiddlewareFn(result); + } + + return { + status: 'result', + result, + }; + } +} diff --git a/src/request-handlers/APIGatewayProxyEventV2StreamRequestHandler.ts b/src/request-handlers/APIGatewayProxyEventV2StreamRequestHandler.ts new file mode 100644 index 00000000..1b2ba83f --- /dev/null +++ b/src/request-handlers/APIGatewayProxyEventV2StreamRequestHandler.ts @@ -0,0 +1,38 @@ +import type { APIGatewayProxyEventV2 } from 'aws-lambda'; +import { createStreamRequestHandler } from './_create'; +import { HeaderMap } from '@apollo/server'; + +export const createAPIGatewayProxyEventV2StreamRequestHandler = < + Event extends APIGatewayProxyEventV2 = APIGatewayProxyEventV2, +>() => { + return createStreamRequestHandler({ + parseHttpMethod(event) { + return event.requestContext.http.method; + }, + parseHeaders(event) { + const headerMap = new HeaderMap(); + for (const [key, value] of Object.entries(event.headers ?? {})) { + headerMap.set(key, value ?? ''); + } + return headerMap; + }, + parseBody(event, headers) { + if (event.body) { + const contentType = headers.get('content-type'); + const parsedBody = event.isBase64Encoded + ? Buffer.from(event.body, 'base64').toString('utf8') + : event.body; + if (contentType?.startsWith('application/json')) { + return JSON.parse(parsedBody); + } + if (contentType?.startsWith('text/plain')) { + return parsedBody; + } + } + return ''; + }, + parseQueryParams(event) { + return event.rawQueryString; + }, + }); +}; diff --git a/src/request-handlers/_create.ts b/src/request-handlers/_create.ts index 945b97e3..db5d22ec 100644 --- a/src/request-handlers/_create.ts +++ b/src/request-handlers/_create.ts @@ -3,6 +3,7 @@ import type { HTTPGraphQLRequest, HTTPGraphQLResponse, } from '@apollo/server'; +import type { awslambda } from '../awslambda'; export interface RequestHandler { fromEvent: (event: EventType) => HTTPGraphQLRequest; @@ -10,11 +11,35 @@ export interface RequestHandler { toErrorResult: (error: unknown) => ResultType; } -export type RequestHandlerEvent> = - T extends RequestHandler ? EventType : never; +export interface StreamRequestHandler { + type: 'stream'; + fromEvent: (event: EventType) => HTTPGraphQLRequest; + buildHTTPMetadata: ( + response: HTTPGraphQLResponse, + ) => Promise; + toErrorResult: (error: unknown) => Promise<{ + metadata: awslambda.HttpMetadata; + body: string; + }>; +} -export type RequestHandlerResult> = - T extends RequestHandler ? ResultType : never; +export type RequestHandlerEvent< + T extends RequestHandler | StreamRequestHandler, +> = + T extends StreamRequestHandler + ? EventType + : T extends RequestHandler + ? EventType + : never; + +export type RequestHandlerResult< + T extends RequestHandler | StreamRequestHandler, +> = + T extends StreamRequestHandler + ? awslambda.HttpMetadata + : T extends RequestHandler + ? ResultType + : never; export type EventParser = | { @@ -51,3 +76,51 @@ export function createRequestHandler( toErrorResult: resultGenerator.error, }; } + +export function createStreamRequestHandler( + eventParser: EventParser, +): StreamRequestHandler { + return { + type: 'stream', + fromEvent(event) { + if (typeof eventParser === 'function') { + return eventParser(event); + } + const headers = eventParser.parseHeaders(event); + return { + method: eventParser.parseHttpMethod(event), + headers, + search: eventParser.parseQueryParams(event), + body: eventParser.parseBody(event, headers), + }; + }, + buildHTTPMetadata: async (response) => { + const { headers, status, body } = response; + + return { + statusCode: status ?? 200, + headers: { + ...Object.fromEntries(headers), + ...(body.kind === 'complete' + ? { 'content-length': Buffer.byteLength(body.string).toString() } + : {}), + }, + }; + }, + toErrorResult: async (error) => { + return { + metadata: { + statusCode: 400, + headers: {}, + }, + body: (error as Error).message, + }; + }, + }; +} + +export function isStreamRequestHandler( + handler: RequestHandler | StreamRequestHandler, +): handler is StreamRequestHandler { + return 'type' in handler && handler.type === 'stream'; +} diff --git a/src/request-handlers/_index.ts b/src/request-handlers/_index.ts index 9372f10f..901fb1f1 100644 --- a/src/request-handlers/_index.ts +++ b/src/request-handlers/_index.ts @@ -1,4 +1,5 @@ export { createALBEventRequestHandler } from './ALBEventRequestHandler'; export { createAPIGatewayProxyEventRequestHandler } from './APIGatewayProxyEventRequestHandler'; export { createAPIGatewayProxyEventV2RequestHandler } from './APIGatewayProxyEventV2RequestHandler'; +export { createAPIGatewayProxyEventV2StreamRequestHandler } from './APIGatewayProxyEventV2StreamRequestHandler'; export * from './_create';