Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 91 additions & 5 deletions src/vite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@ import {
type CaptunTunnel,
createCaptunTunnel,
type CreateCaptunTunnelOptions,
type WebSocketConnectResult,
type WebSocketHandle,
webSocketHandleFromSocket,
isWebSocketUpgradeRequest,
pipeWebSocketToHandle,
type TunnelReady,
} from "./index.js";
import { captunHealthResponse, isCaptunHealthRequest } from "./tunnel-health.js";

/**
* Options for the {@link captun} Vite plugin.
*
* Every `createCaptunTunnel` option except `fetch` (which the plugin wires to
* the Vite server) is passed through verbatim — see
* Every `createCaptunTunnel` option except `fetch` and `connectWebSocket`
* (which the plugin wires to the Vite server) is passed through verbatim — see
* {@link CreateCaptunTunnelOptions} for `gateway`, `name`, and `token`.
*/
export type CaptunVitePluginOptions = Omit<CreateCaptunTunnelOptions, "fetch"> & {
export type CaptunVitePluginOptions = Omit<
CreateCaptunTunnelOptions,
"fetch" | "connectWebSocket"
> & {
/**
* Called once the tunnel is connected, with the public `url` and the
* reusable connect `token` when the gateway provides or accepts one.
Expand Down Expand Up @@ -51,8 +59,8 @@ export type CaptunVitePluginOptions = Omit<CreateCaptunTunnelOptions, "fetch"> &
* tunnel on demand, make the plugin conditional in your Vite config:
* `plugins: [process.env.TUNNEL ? captun() : undefined]`.
*
* WebSockets are not forwarded, so Vite HMR only works on the local URL —
* the tunnel is for plain HTTP: webhooks, previews, and e2e tests.
* HTTP and WebSockets are both forwarded, so Vite HMR can connect through the
* tunnel URL.
*/
export default function captun(options: CaptunVitePluginOptions = {}): Plugin {
const { onTunnel, onError, ...tunnelOptions } = options;
Expand All @@ -69,6 +77,7 @@ export default function captun(options: CaptunVitePluginOptions = {}): Plugin {
tunnel = await createCaptunTunnel({
...tunnelOptions,
fetch: forwardToLocalServer(localOrigin(protocol, address)),
connectWebSocket: connectLocalWebSocket(localOrigin(protocol, address)),
});
} catch (error) {
if (onError) onError(error);
Expand Down Expand Up @@ -130,6 +139,12 @@ function localOrigin(protocol: "http" | "https", address: AddressInfo) {

function forwardToLocalServer(origin: string) {
return (request: Request): Response | Promise<Response> => {
if (isWebSocketUpgradeRequest(request)) {
return new Response("Use connectWebSocket for WebSocket tunnel requests\n", {
status: 400,
});
}

// The reserved health path is answered by every Tunnel Client itself.
if (isCaptunHealthRequest(request)) return captunHealthResponse();
const url = new URL(request.url);
Expand Down Expand Up @@ -161,3 +176,74 @@ function forwardToLocalServer(origin: string) {
return fetch(new URL(url.pathname + url.search, origin), init);
};
}

function connectLocalWebSocket(origin: string) {
return async (request: Request, remote: WebSocketHandle): Promise<WebSocketConnectResult> => {
const url = new URL(request.url);
const targetUrl = new URL(url.pathname + url.search, origin);
targetUrl.protocol = targetUrl.protocol === "https:" ? "wss:" : "ws:";
const targetSocket = new WebSocket(targetUrl, {
protocols: request.headers
.get("sec-websocket-protocol")
?.split(",")
.map((protocol) => protocol.trim())
.filter(Boolean),
headers: forwardedHandshakeHeaders(request.headers),
// Node's WebSocket (undici) accepts { protocols, headers }; the DOM type doesn't.
} as unknown as string[]);

try {
await waitForWebSocketOpen(targetSocket);
if (targetSocket.readyState !== WebSocket.OPEN)
throw new Error("WebSocket closed after open");
} catch {
targetSocket.close();
return {
accepted: false,
response: new Response(
`Request reached the captun Vite plugin, but ${targetUrl.origin} did not accept the WebSocket\n`,
{ status: 502 },
),
};
}

pipeWebSocketToHandle(targetSocket, remote);
return {
accepted: true,
protocol: targetSocket.protocol || undefined,
socket: webSocketHandleFromSocket(targetSocket),
};
};
}

function forwardedHandshakeHeaders(headers: Headers) {
const skip = new Set(["connection", "host", "keep-alive", "te", "trailer", "upgrade"]);
const forwarded: Record<string, string> = {};
for (const [name, value] of headers) {
if (skip.has(name) || name.startsWith("sec-websocket-") || name.startsWith("proxy-")) continue;
forwarded[name] = value;
}
return forwarded;
}

async function waitForWebSocketOpen(socket: WebSocket) {
if (socket.readyState === WebSocket.OPEN) return;
if (socket.readyState !== WebSocket.CONNECTING) throw new Error("WebSocket closed before open");

const listeners = new AbortController();
await new Promise<void>((resolveOpen, rejectOpen) => {
const settle = (callback: () => void) => {
listeners.abort();
callback();
};
socket.addEventListener("open", () => settle(resolveOpen), { signal: listeners.signal });
socket.addEventListener("error", () => settle(() => rejectOpen(new Error("WebSocket error"))), {
signal: listeners.signal,
});
socket.addEventListener(
"close",
() => settle(() => rejectOpen(new Error("WebSocket closed before open"))),
{ signal: listeners.signal },
);
});
}
57 changes: 57 additions & 0 deletions test/vite-plugin.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,32 @@ test.concurrent("serves the dev server through a tunnel", async ({ task }) => {
expect(await response.text()).toContain("captun vite fixture");
});

test.concurrent("forwards Vite HMR WebSockets through a tunnel", async ({ task }) => {
await using worker = await createCaptunWorkerFixture({});
await using root = await createAppFixture();
const ready = Promise.withResolvers<TunnelReady>();
await using server = await devServer({
root: root.path,
plugins: [
captun({ gateway: worker.origin, name: tunnelName(task.name), onTunnel: ready.resolve }),
],
});

const tunnel = await ready.promise;
const socket = new WebSocket(
`${tunnel.url}/?token=${server.config.webSocketToken}`.replace(/^http/, "ws"),
"vite-hmr",
);
try {
await waitForWebSocket(socket);
await expect(nextWebSocketMessage(socket).then(webSocketMessageJson)).resolves.toMatchObject({
type: "connected",
});
} finally {
socket.close();
}
});

test.concurrent("forwards request bodies", async ({ task }) => {
await using worker = await createCaptunWorkerFixture({});
await using root = await createAppFixture();
Expand Down Expand Up @@ -288,3 +314,34 @@ function tunnelName(testName: string) {
const hash = createHash("sha256").update(seed).digest("hex").slice(0, 12);
return `${prefix}-${hash}`;
}

function waitForWebSocket(socket: WebSocket) {
if (socket.readyState === WebSocket.OPEN) return Promise.resolve();
return new Promise<void>((resolveOpen, rejectOpen) => {
socket.addEventListener("open", () => resolveOpen(), { once: true });
socket.addEventListener("error", () => rejectOpen(new Error("WebSocket error")), {
once: true,
});
socket.addEventListener("close", () => rejectOpen(new Error("WebSocket closed")), {
once: true,
});
});
}

function nextWebSocketMessage(socket: WebSocket) {
return new Promise<unknown>((resolveMessage, rejectMessage) => {
socket.addEventListener("message", (event) => resolveMessage(event.data), { once: true });
socket.addEventListener("error", () => rejectMessage(new Error("WebSocket error")), {
once: true,
});
socket.addEventListener("close", () => rejectMessage(new Error("WebSocket closed")), {
once: true,
});
});
}

function webSocketMessageJson(data: unknown) {
if (typeof data !== "string")
throw new Error(`Expected text WebSocket message, got ${typeof data}`);
return JSON.parse(data) as unknown;
}
Loading