Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cspell-dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ vendia
withrequired
typeof
instanceof
codegen
codegen
awslambda
streamify
Writables
71 changes: 71 additions & 0 deletions src/__tests__/integrationV2Stream.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import {
ApolloServer,
type ApolloServerOptions,
type BaseContext,
} from '@apollo/server';
import {
type CreateServerForIntegrationTestsOptions,
defineIntegrationTestSuite,
} from '@apollo/server-integration-testsuite';
import { createServer } from 'http';
import { handlers, startServerAndCreateLambdaHandler } from '..';
import { urlForHttpServer } from './mockServer';
import {
createMockV2StreamServer,
installStreamMock,
} from './mockAPIGatewayV2StreamServer';

describe('lambdaHandlerV2Stream', () => {
beforeAll(() => {
// Must run before startServerAndCreateLambdaHandler so that
// awslambda.streamifyResponse is available at handler-creation time.
installStreamMock();
});

afterAll(() => {
delete (global as any).awslambda;
});

describe.each([true, false])(
'With base64 encoding set to %s',
(shouldBase64Encode) => {
defineIntegrationTestSuite(
async function (
serverOptions: ApolloServerOptions<BaseContext>,
testOptions?: CreateServerForIntegrationTestsOptions,
) {
const httpServer = createServer();
const server = new ApolloServer({ ...serverOptions });

const handler = startServerAndCreateLambdaHandler(
server,
handlers.createAPIGatewayProxyEventV2StreamRequestHandler(),
{ ...testOptions },
);

httpServer.addListener(
'request',
createMockV2StreamServer(handler, shouldBase64Encode),
);

await new Promise<void>((resolve) => {
httpServer.listen({ port: 0 }, resolve);
});

return {
server,
url: urlForHttpServer(httpServer),
async extraCleanup() {
await new Promise<void>((resolve) => {
httpServer.close(() => resolve());
});
},
};
},
{
serverIsStartedInBackground: true,
},
);
},
);
});
92 changes: 91 additions & 1 deletion src/__tests__/middleware.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ApolloServer } from '@apollo/server';
import type { APIGatewayProxyEventV2 } from 'aws-lambda';
import { handlers, startServerAndCreateLambdaHandler } from '..';
import { handlers, middleware, startServerAndCreateLambdaHandler } from '..';
import gql from 'graphql-tag';
import { type DocumentNode, print } from 'graphql';

Expand Down Expand Up @@ -149,3 +149,93 @@ describe('Response mutation', () => {
expect(result.cookies).toContain(cookieValue);
});
});

describe('runMiddleware', () => {
const mockHandler: handlers.RequestHandler<any, any> = {
fromEvent: jest.fn(),
toSuccessResult: jest.fn(),
toErrorResult: jest
.fn()
.mockReturnValue({ statusCode: 500, body: 'internal error' }),
};

beforeEach(() => {
jest.clearAllMocks();
});

it('returns continue with empty middleware list', async () => {
const result = await middleware.runMiddleware({}, [], mockHandler);
expect(result.status).toBe('continue');
if (result.status === 'continue') {
expect(result.middleware).toEqual([]);
}
});

it('collects result-callback middleware into the continue result', async () => {
const callback = jest.fn().mockResolvedValue(undefined);
const result = await middleware.runMiddleware(
{},
[async () => callback],
mockHandler,
);
expect(result.status).toBe('continue');
if (result.status === 'continue') {
expect(result.middleware).toHaveLength(1);
expect(result.middleware[0]).toBe(callback);
}
});

it('returns early result when middleware returns an object', async () => {
const earlyResult = { statusCode: 418 };
const result = await middleware.runMiddleware(
{},
[async () => earlyResult],
mockHandler,
);
expect(result.status).toBe('result');
if (result.status === 'result') {
expect(result.result).toBe(earlyResult);
}
});

it('calls toErrorResult and returns result when middleware throws', async () => {
const error = new Error('middleware exploded');
const result = await middleware.runMiddleware(
{},
[
async () => {
throw error;
},
],
mockHandler,
);
expect(mockHandler.toErrorResult).toHaveBeenCalledWith(error);
expect(result.status).toBe('result');
if (result.status === 'result') {
expect(result.result).toEqual({
statusCode: 500,
body: 'internal error',
});
}
});

it('invokes collected callbacks before returning an error result on throw', async () => {
const callback = jest.fn().mockResolvedValue(undefined);
const error = new Error('late throw');
const result = await middleware.runMiddleware(
{},
[
async () => callback,
async () => {
throw error;
},
],
mockHandler,
);
expect(result.status).toBe('result');
expect(callback).toHaveBeenCalledWith({
statusCode: 500,
body: 'internal error',
});
});
});
117 changes: 117 additions & 0 deletions src/__tests__/mockAPIGatewayV2StreamServer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import url from 'node:url';
import { Writable } from 'node:stream';
import type { IncomingMessage, ServerResponse } from 'node:http';
import type {
APIGatewayProxyEventV2,
Context as LambdaContext,
Handler,
} from 'aws-lambda';

// Per-invocation state keyed by event object (unique per request).
// Using WeakMaps so concurrent requests don't overwrite each other's state.
const activeWritables = new WeakMap<object, Writable>();
const metadataCallbacks = new WeakMap<
Writable,
(metadata: { statusCode: number; headers: Record<string, string> }) => void
>();

// Call this once (e.g. in beforeAll) before startServerAndCreateLambdaHandler
// runs so that awslambda.streamifyResponse is available when the handler is built.
export function installStreamMock() {
(globalThis as any).awslambda = {
streamifyResponse: (fn: any) => async (event: any, context: any) => {
await fn(event, activeWritables.get(event), context);
},
HttpResponseStream: {
from: (w: Writable, metadata: any) => {
metadataCallbacks.get(w)?.(metadata);
return w;
},
},
};
}

export function createMockV2StreamServer(
handler: Handler<APIGatewayProxyEventV2, any>,
shouldBase64Encode: boolean,
) {
return (req: IncomingMessage, res: ServerResponse) => {
let body = '';
req.on('data', (chunk) => (body += chunk));
// this is an unawaited async function, but anything that causes it to
// reject should cause a test to fail
req.on('end', async () => {
const event = v2EventFromRequest(shouldBase64Encode)(req, body);

// Pipe Lambda stream writes directly to the HTTP response so incremental
// delivery chunks reach the client as they are produced.
const writable = new Writable({
write(chunk, _enc, cb) {
res.write(chunk);
cb();
},
});

// Register per-invocation state using unique objects as WeakMap keys.
activeWritables.set(event, writable);
metadataCallbacks.set(writable, (metadata) => {
res.statusCode = metadata.statusCode ?? 200;
for (const [key, value] of Object.entries(metadata.headers ?? {})) {
res.setHeader(key, String(value));
}
});

await handler(
event,
{ functionName: 'someFunc' } as LambdaContext,
() => {
throw new Error("we don't use callback");
},
);

activeWritables.delete(event);
res.end();
});
};
}

function v2EventFromRequest(shouldBase64Encode: boolean) {
return function (req: IncomingMessage, body: string): APIGatewayProxyEventV2 {
const urlObject = url.parse(req.url || '', false);

type TestEventType = Partial<
Omit<APIGatewayProxyEventV2, 'requestContext'> & {
requestContext: Partial<
Omit<APIGatewayProxyEventV2['requestContext'], 'http'> & {
http: Partial<APIGatewayProxyEventV2['requestContext']['http']>;
}
>;
}
>;

const event: TestEventType = {
version: '2.0',
body: shouldBase64Encode
? Buffer.from(body, 'utf8').toString('base64')
: body,
rawQueryString: urlObject.search?.replace(/^\?/, '') ?? '',
headers: Object.fromEntries(
Object.entries(req.headers).map(([name, value]) => {
if (Array.isArray(value)) {
return [name, value.join(',')];
} else {
return [name, value];
}
}),
),
requestContext: {
http: {
method: req.method,
path: req.url,
},
},
isBase64Encoded: shouldBase64Encode,
};
return event as APIGatewayProxyEventV2;
};
}
Loading