diff --git a/src/vite.ts b/src/vite.ts index c407647..f6f6762 100644 --- a/src/vite.ts +++ b/src/vite.ts @@ -6,6 +6,11 @@ import { type CaptunTunnel, createCaptunTunnel, type CreateCaptunTunnelOptions, + type WebSocketConnectResult, + type WebSocketHandle, + webSocketHandleFromSocket, + isWebSocketUpgradeRequest, + pipeWebSocketToHandle, type TunnelReady, } from "./index.js"; import { captunHealthResponse, isCaptunHealthRequest } from "./tunnel-health.js"; @@ -13,11 +18,14 @@ import { captunHealthResponse, isCaptunHealthRequest } from "./tunnel-health.js" /** * Options for the {@link captun} Vite plugin. * - * Every `createCaptunTunnel` option except `fetch` (which the plugin wires to - * the Vite server) is passed through verbatim — see + * Every `createCaptunTunnel` option except `fetch` and `connectWebSocket` + * (which the plugin wires to the Vite server) is passed through verbatim — see * {@link CreateCaptunTunnelOptions} for `gateway`, `name`, and `token`. */ -export type CaptunVitePluginOptions = Omit & { +export type CaptunVitePluginOptions = Omit< + CreateCaptunTunnelOptions, + "fetch" | "connectWebSocket" +> & { /** * Called once the tunnel is connected, with the public `url` and the * reusable connect `token` when the gateway provides or accepts one. @@ -51,8 +59,8 @@ export type CaptunVitePluginOptions = Omit & * tunnel on demand, make the plugin conditional in your Vite config: * `plugins: [process.env.TUNNEL ? captun() : undefined]`. * - * WebSockets are not forwarded, so Vite HMR only works on the local URL — - * the tunnel is for plain HTTP: webhooks, previews, and e2e tests. + * HTTP and WebSockets are both forwarded, so Vite HMR can connect through the + * tunnel URL. */ export default function captun(options: CaptunVitePluginOptions = {}): Plugin { const { onTunnel, onError, ...tunnelOptions } = options; @@ -69,6 +77,7 @@ export default function captun(options: CaptunVitePluginOptions = {}): Plugin { tunnel = await createCaptunTunnel({ ...tunnelOptions, fetch: forwardToLocalServer(localOrigin(protocol, address)), + connectWebSocket: connectLocalWebSocket(localOrigin(protocol, address)), }); } catch (error) { if (onError) onError(error); @@ -130,6 +139,12 @@ function localOrigin(protocol: "http" | "https", address: AddressInfo) { function forwardToLocalServer(origin: string) { return (request: Request): Response | Promise => { + if (isWebSocketUpgradeRequest(request)) { + return new Response("Use connectWebSocket for WebSocket tunnel requests\n", { + status: 400, + }); + } + // The reserved health path is answered by every Tunnel Client itself. if (isCaptunHealthRequest(request)) return captunHealthResponse(); const url = new URL(request.url); @@ -161,3 +176,74 @@ function forwardToLocalServer(origin: string) { return fetch(new URL(url.pathname + url.search, origin), init); }; } + +function connectLocalWebSocket(origin: string) { + return async (request: Request, remote: WebSocketHandle): Promise => { + const url = new URL(request.url); + const targetUrl = new URL(url.pathname + url.search, origin); + targetUrl.protocol = targetUrl.protocol === "https:" ? "wss:" : "ws:"; + const targetSocket = new WebSocket(targetUrl, { + protocols: request.headers + .get("sec-websocket-protocol") + ?.split(",") + .map((protocol) => protocol.trim()) + .filter(Boolean), + headers: forwardedHandshakeHeaders(request.headers), + // Node's WebSocket (undici) accepts { protocols, headers }; the DOM type doesn't. + } as unknown as string[]); + + try { + await waitForWebSocketOpen(targetSocket); + if (targetSocket.readyState !== WebSocket.OPEN) + throw new Error("WebSocket closed after open"); + } catch { + targetSocket.close(); + return { + accepted: false, + response: new Response( + `Request reached the captun Vite plugin, but ${targetUrl.origin} did not accept the WebSocket\n`, + { status: 502 }, + ), + }; + } + + pipeWebSocketToHandle(targetSocket, remote); + return { + accepted: true, + protocol: targetSocket.protocol || undefined, + socket: webSocketHandleFromSocket(targetSocket), + }; + }; +} + +function forwardedHandshakeHeaders(headers: Headers) { + const skip = new Set(["connection", "host", "keep-alive", "te", "trailer", "upgrade"]); + const forwarded: Record = {}; + for (const [name, value] of headers) { + if (skip.has(name) || name.startsWith("sec-websocket-") || name.startsWith("proxy-")) continue; + forwarded[name] = value; + } + return forwarded; +} + +async function waitForWebSocketOpen(socket: WebSocket) { + if (socket.readyState === WebSocket.OPEN) return; + if (socket.readyState !== WebSocket.CONNECTING) throw new Error("WebSocket closed before open"); + + const listeners = new AbortController(); + await new Promise((resolveOpen, rejectOpen) => { + const settle = (callback: () => void) => { + listeners.abort(); + callback(); + }; + socket.addEventListener("open", () => settle(resolveOpen), { signal: listeners.signal }); + socket.addEventListener("error", () => settle(() => rejectOpen(new Error("WebSocket error"))), { + signal: listeners.signal, + }); + socket.addEventListener( + "close", + () => settle(() => rejectOpen(new Error("WebSocket closed before open"))), + { signal: listeners.signal }, + ); + }); +} diff --git a/test/vite-plugin.test.ts b/test/vite-plugin.test.ts index 4fe60aa..046b8bd 100644 --- a/test/vite-plugin.test.ts +++ b/test/vite-plugin.test.ts @@ -31,6 +31,32 @@ test.concurrent("serves the dev server through a tunnel", async ({ task }) => { expect(await response.text()).toContain("captun vite fixture"); }); +test.concurrent("forwards Vite HMR WebSockets through a tunnel", async ({ task }) => { + await using worker = await createCaptunWorkerFixture({}); + await using root = await createAppFixture(); + const ready = Promise.withResolvers(); + await using server = await devServer({ + root: root.path, + plugins: [ + captun({ gateway: worker.origin, name: tunnelName(task.name), onTunnel: ready.resolve }), + ], + }); + + const tunnel = await ready.promise; + const socket = new WebSocket( + `${tunnel.url}/?token=${server.config.webSocketToken}`.replace(/^http/, "ws"), + "vite-hmr", + ); + try { + await waitForWebSocket(socket); + await expect(nextWebSocketMessage(socket).then(webSocketMessageJson)).resolves.toMatchObject({ + type: "connected", + }); + } finally { + socket.close(); + } +}); + test.concurrent("forwards request bodies", async ({ task }) => { await using worker = await createCaptunWorkerFixture({}); await using root = await createAppFixture(); @@ -288,3 +314,34 @@ function tunnelName(testName: string) { const hash = createHash("sha256").update(seed).digest("hex").slice(0, 12); return `${prefix}-${hash}`; } + +function waitForWebSocket(socket: WebSocket) { + if (socket.readyState === WebSocket.OPEN) return Promise.resolve(); + return new Promise((resolveOpen, rejectOpen) => { + socket.addEventListener("open", () => resolveOpen(), { once: true }); + socket.addEventListener("error", () => rejectOpen(new Error("WebSocket error")), { + once: true, + }); + socket.addEventListener("close", () => rejectOpen(new Error("WebSocket closed")), { + once: true, + }); + }); +} + +function nextWebSocketMessage(socket: WebSocket) { + return new Promise((resolveMessage, rejectMessage) => { + socket.addEventListener("message", (event) => resolveMessage(event.data), { once: true }); + socket.addEventListener("error", () => rejectMessage(new Error("WebSocket error")), { + once: true, + }); + socket.addEventListener("close", () => rejectMessage(new Error("WebSocket closed")), { + once: true, + }); + }); +} + +function webSocketMessageJson(data: unknown) { + if (typeof data !== "string") + throw new Error(`Expected text WebSocket message, got ${typeof data}`); + return JSON.parse(data) as unknown; +}