From 7b2752a62cd770d411aa9abb30b2082efc312dba Mon Sep 17 00:00:00 2001 From: Asher Date: Thu, 5 Nov 2020 12:58:37 -0600 Subject: [PATCH] Move websocket routes into a separate app This is mostly so we don't have to do any wacky patching but it also makes it so we don't have to keep checking if the request is a web socket request every time we add middleware. --- src/node/app.ts | 9 ++-- src/node/entry.ts | 4 +- src/node/http.ts | 110 -------------------------------------- src/node/proxy.ts | 5 +- src/node/routes/index.ts | 52 ++++++++++++++---- src/node/routes/proxy.ts | 5 +- src/node/routes/vscode.ts | 37 +++++++------ src/node/wsRouter.ts | 57 ++++++++++++++++++++ 8 files changed, 134 insertions(+), 145 deletions(-) create mode 100644 src/node/wsRouter.ts diff --git a/src/node/app.ts b/src/node/app.ts index 171a7c4d..448ec966 100644 --- a/src/node/app.ts +++ b/src/node/app.ts @@ -4,12 +4,12 @@ import { promises as fs } from "fs" import http from "http" import * as httpolyglot from "httpolyglot" import { DefaultedArgs } from "./cli" -import { handleUpgrade } from "./http" +import { handleUpgrade } from "./wsRouter" /** * Create an Express app and an HTTP/S server to serve it. */ -export const createApp = async (args: DefaultedArgs): Promise<[Express, http.Server]> => { +export const createApp = async (args: DefaultedArgs): Promise<[Express, Express, http.Server]> => { const app = express() const server = args.cert @@ -39,9 +39,10 @@ export const createApp = async (args: DefaultedArgs): Promise<[Express, http.Ser } }) - handleUpgrade(app, server) + const wsApp = express() + handleUpgrade(wsApp, server) - return [app, server] + return [app, wsApp, server] } /** diff --git a/src/node/entry.ts b/src/node/entry.ts index c192158b..431aa8f4 100644 --- a/src/node/entry.ts +++ b/src/node/entry.ts @@ -102,9 +102,9 @@ const main = async (args: DefaultedArgs): Promise => { throw new Error("Please pass in a password via the config file or $PASSWORD") } - const [app, server] = await createApp(args) + const [app, wsApp, server] = await createApp(args) const serverAddress = ensureAddress(server) - await register(app, server, args) + await register(app, wsApp, server, args) logger.info(`Using config file ${humanPath(args.config)}`) logger.info(`HTTP server listening on ${serverAddress} ${args.link ? "(randomized by --link)" : ""}`) diff --git a/src/node/http.ts b/src/node/http.ts index 71b938a2..f259d103 100644 --- a/src/node/http.ts +++ b/src/node/http.ts @@ -1,8 +1,6 @@ import { field, logger } from "@coder/logger" import * as express from "express" import * as expressCore from "express-serve-static-core" -import * as http from "http" -import * as net from "net" import qs from "qs" import safeCompare from "safe-compare" import { HttpCode, HttpError } from "../common/http" @@ -135,111 +133,3 @@ export const getCookieDomain = (host: string, proxyDomains: string[]): string | logger.debug("got cookie doman", field("host", host)) return host || undefined } - -declare module "express" { - function Router(options?: express.RouterOptions): express.Router & WithWebsocketMethod - - type WebSocketRequestHandler = ( - req: express.Request & WithWebSocket, - res: express.Response, - next: express.NextFunction, - ) => void | Promise - - type WebSocketMethod = (route: expressCore.PathParams, ...handlers: WebSocketRequestHandler[]) => T - - interface WithWebSocket { - ws: net.Socket - head: Buffer - } - - interface WithWebsocketMethod { - ws: WebSocketMethod - } -} - -interface WebsocketRequest extends express.Request, express.WithWebSocket { - _ws_handled: boolean -} - -function isWebSocketRequest(req: express.Request): req is WebsocketRequest { - return !!(req as WebsocketRequest).ws -} - -export const handleUpgrade = (app: express.Express, server: http.Server): void => { - server.on("upgrade", (req, socket, head) => { - socket.on("error", () => socket.destroy()) - - req.ws = socket - req.head = head - req._ws_handled = false - - const res = new http.ServerResponse(req) - res.writeHead = function writeHead(statusCode: number) { - if (statusCode > 200) { - socket.destroy(new Error(`${statusCode}`)) - } - return res - } - - // Send the request off to be handled by Express. - ;(app as any).handle(req, res, () => { - if (!req._ws_handled) { - socket.destroy(new Error("Not found")) - } - }) - }) -} - -/** - * Patch Express routers to handle web sockets. - * - * Not using express-ws since the ws-wrapped sockets don't work with the proxy. - */ -function patchRouter(): void { - // This works because Router is also the prototype assigned to the routers it - // returns. - - // Store this since the original method will be overridden. - const originalGet = (express.Router as any).prototype.get - - // Inject the `ws` method. - ;(express.Router as any).prototype.ws = function ws( - route: expressCore.PathParams, - ...handlers: express.WebSocketRequestHandler[] - ) { - originalGet.apply(this, [ - route, - ...handlers.map((handler) => { - const wrapped: express.Handler = (req, res, next) => { - if (isWebSocketRequest(req)) { - req._ws_handled = true - return handler(req, res, next) - } - next() - } - return wrapped - }), - ]) - return this - } - // Overwrite `get` so we can distinguish between websocket and non-websocket - // routes. - ;(express.Router as any).prototype.get = function get(route: expressCore.PathParams, ...handlers: express.Handler[]) { - originalGet.apply(this, [ - route, - ...handlers.map((handler) => { - const wrapped: express.Handler = (req, res, next) => { - if (!isWebSocketRequest(req)) { - return handler(req, res, next) - } - next() - } - return wrapped - }), - ]) - return this - } -} - -// This needs to happen before anything creates a router. -patchRouter() diff --git a/src/node/proxy.ts b/src/node/proxy.ts index bfc6af5b..4343d334 100644 --- a/src/node/proxy.ts +++ b/src/node/proxy.ts @@ -2,6 +2,7 @@ import { Request, Router } from "express" import proxyServer from "http-proxy" import { HttpCode, HttpError } from "../common/http" import { authenticated, ensureAuthenticated } from "./http" +import { Router as WsRouter } from "./wsRouter" export const proxy = proxyServer.createProxyServer({}) proxy.on("error", (error, _, res) => { @@ -82,7 +83,9 @@ router.all("*", (req, res, next) => { }) }) -router.ws("*", (req, _, next) => { +export const wsRouter = WsRouter() + +wsRouter.ws("*", (req, _, next) => { const port = maybeProxy(req) if (!port) { return next() diff --git a/src/node/routes/index.ts b/src/node/routes/index.ts index 910f5b69..8e5d3c18 100644 --- a/src/node/routes/index.ts +++ b/src/node/routes/index.ts @@ -1,7 +1,7 @@ import { logger } from "@coder/logger" import bodyParser from "body-parser" import cookieParser from "cookie-parser" -import { ErrorRequestHandler, Express } from "express" +import * as express from "express" import { promises as fs } from "fs" import http from "http" import * as path from "path" @@ -15,6 +15,7 @@ import { replaceTemplates } from "../http" import { loadPlugins } from "../plugin" import * as domainProxy from "../proxy" import { getMediaMime, paths } from "../util" +import { WebsocketRequest } from "../wsRouter" import * as health from "./health" import * as login from "./login" import * as proxy from "./proxy" @@ -36,7 +37,12 @@ declare global { /** * Register all routes and middleware. */ -export const register = async (app: Express, server: http.Server, args: DefaultedArgs): Promise => { +export const register = async ( + app: express.Express, + wsApp: express.Express, + server: http.Server, + args: DefaultedArgs, +): Promise => { const heart = new Heart(path.join(paths.data, "heartbeat"), async () => { return new Promise((resolve, reject) => { server.getConnections((error, count) => { @@ -50,14 +56,28 @@ export const register = async (app: Express, server: http.Server, args: Defaulte }) app.disable("x-powered-by") + wsApp.disable("x-powered-by") app.use(cookieParser()) + wsApp.use(cookieParser()) + app.use(bodyParser.json()) app.use(bodyParser.urlencoded({ extended: true })) - app.use(async (req, res, next) => { + const common: express.RequestHandler = (req, _, next) => { heart.beat() + // Add common variables routes can use. + req.args = args + req.heart = heart + + next() + } + + app.use(common) + wsApp.use(common) + + app.use(async (req, res, next) => { // If we're handling TLS ensure all requests are redirected to HTTPS. // TODO: This does *NOT* work if you have a base path since to specify the // protocol we need to specify the whole path. @@ -72,23 +92,28 @@ export const register = async (app: Express, server: http.Server, args: Defaulte return res.send(await fs.readFile(resourcePath)) } - // Add common variables routes can use. - req.args = args - req.heart = heart - - return next() + next() }) app.use("/", domainProxy.router) + wsApp.use("/", domainProxy.wsRouter.router) + app.use("/", vscode.router) + wsApp.use("/", vscode.wsRouter.router) + app.use("/vscode", vscode.router) + wsApp.use("/vscode", vscode.wsRouter.router) + app.use("/healthz", health.router) + if (args.auth === AuthType.Password) { app.use("/login", login.router) } + app.use("/proxy", proxy.router) + wsApp.use("/proxy", proxy.wsRouter.router) + app.use("/static", _static.router) app.use("/update", update.router) - app.use("/vscode", vscode.router) await loadPlugins(app, args) @@ -96,7 +121,7 @@ export const register = async (app: Express, server: http.Server, args: Defaulte throw new HttpError("Not Found", HttpCode.NotFound) }) - const errorHandler: ErrorRequestHandler = async (err, req, res, next) => { + const errorHandler: express.ErrorRequestHandler = async (err, req, res, next) => { const resourcePath = path.resolve(rootPath, "src/browser/pages/error.html") res.set("Content-Type", getMediaMime(resourcePath)) try { @@ -117,4 +142,11 @@ export const register = async (app: Express, server: http.Server, args: Defaulte } app.use(errorHandler) + + const wsErrorHandler: express.ErrorRequestHandler = async (err, req) => { + logger.error(`${err.message} ${err.stack}`) + ;(req as WebsocketRequest).ws.destroy(err) + } + + wsApp.use(wsErrorHandler) } diff --git a/src/node/routes/proxy.ts b/src/node/routes/proxy.ts index 59db92d9..ff6f4067 100644 --- a/src/node/routes/proxy.ts +++ b/src/node/routes/proxy.ts @@ -3,6 +3,7 @@ import qs from "qs" import { HttpCode, HttpError } from "../../common/http" import { authenticated, redirect } from "../http" import { proxy } from "../proxy" +import { Router as WsRouter } from "../wsRouter" export const router = Router() @@ -35,7 +36,9 @@ router.all("/(:port)(/*)?", (req, res) => { }) }) -router.ws("/(:port)(/*)?", (req) => { +export const wsRouter = WsRouter() + +wsRouter.ws("/(:port)(/*)?", (req) => { proxy.ws(req, req.ws, req.head, { ignorePath: true, target: getProxyTarget(req, true), diff --git a/src/node/routes/vscode.ts b/src/node/routes/vscode.ts index c936571c..db2dc207 100644 --- a/src/node/routes/vscode.ts +++ b/src/node/routes/vscode.ts @@ -6,6 +6,7 @@ import { commit, rootPath, version } from "../constants" import { authenticated, ensureAuthenticated, redirect, replaceTemplates } from "../http" import { getMediaMime, pathToFsPath } from "../util" import { VscodeProvider } from "../vscode" +import { Router as WsRouter } from "../wsRouter" export const router = Router() @@ -53,23 +54,6 @@ router.get("/", async (req, res) => { ) }) -router.ws("/", ensureAuthenticated, async (req) => { - const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - const reply = crypto - .createHash("sha1") - .update(req.headers["sec-websocket-key"] + magic) - .digest("base64") - req.ws.write( - [ - "HTTP/1.1 101 Switching Protocols", - "Upgrade: websocket", - "Connection: Upgrade", - `Sec-WebSocket-Accept: ${reply}`, - ].join("\r\n") + "\r\n\r\n", - ) - await vscode.sendWebsocket(req.ws, req.query) -}) - /** * TODO: Might currently be unused. */ @@ -103,3 +87,22 @@ router.get("/webview/*", ensureAuthenticated, async (req, res) => { await fs.readFile(path.join(vscode.vsRootPath, "out/vs/workbench/contrib/webview/browser/pre", req.params[0])), ) }) + +export const wsRouter = WsRouter() + +wsRouter.ws("/", ensureAuthenticated, async (req) => { + const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + const reply = crypto + .createHash("sha1") + .update(req.headers["sec-websocket-key"] + magic) + .digest("base64") + req.ws.write( + [ + "HTTP/1.1 101 Switching Protocols", + "Upgrade: websocket", + "Connection: Upgrade", + `Sec-WebSocket-Accept: ${reply}`, + ].join("\r\n") + "\r\n\r\n", + ) + await vscode.sendWebsocket(req.ws, req.query) +}) diff --git a/src/node/wsRouter.ts b/src/node/wsRouter.ts new file mode 100644 index 00000000..1a057f0f --- /dev/null +++ b/src/node/wsRouter.ts @@ -0,0 +1,57 @@ +import * as express from "express" +import * as expressCore from "express-serve-static-core" +import * as http from "http" +import * as net from "net" + +export const handleUpgrade = (app: express.Express, server: http.Server): void => { + server.on("upgrade", (req, socket, head) => { + socket.on("error", () => socket.destroy()) + + req.ws = socket + req.head = head + req._ws_handled = false + + // Send the request off to be handled by Express. + ;(app as any).handle(req, new http.ServerResponse(req), () => { + if (!req._ws_handled) { + socket.destroy(new Error("Not found")) + } + }) + }) +} + +export interface WebsocketRequest extends express.Request { + ws: net.Socket + head: Buffer +} + +interface InternalWebsocketRequest extends WebsocketRequest { + _ws_handled: boolean +} + +export type WebSocketHandler = ( + req: WebsocketRequest, + res: express.Response, + next: express.NextFunction, +) => void | Promise + +export class WebsocketRouter { + public readonly router = express.Router() + + public ws(route: expressCore.PathParams, ...handlers: WebSocketHandler[]): void { + this.router.get( + route, + ...handlers.map((handler) => { + const wrapped: express.Handler = (req, res, next) => { + ;(req as InternalWebsocketRequest)._ws_handled = true + return handler(req as WebsocketRequest, res, next) + } + return wrapped + }), + ) + } +} + +export function Router(): WebsocketRouter { + return new WebsocketRouter() +}