diff --git a/src/common/util.ts b/src/common/util.ts index d9195b62..4c9366b4 100644 --- a/src/common/util.ts +++ b/src/common/util.ts @@ -33,6 +33,13 @@ export const normalize = (url: string, keepTrailing = false): string => { return url.replace(/\/\/+/g, "/").replace(/\/+$/, keepTrailing ? "/" : "") } +/** + * Remove leading and trailing slashes. + */ +export const trimSlashes = (url: string): string => { + return url.replace(/^\/+|\/+$/g, "") +} + /** * Get options embedded in the HTML or query params. */ @@ -75,3 +82,17 @@ export const getOptions = (): T => { return options } + +/** + * Wrap the value in an array if it's not already an array. If the value is + * undefined return an empty array. + */ +export const arrayify = (value?: T | T[]): T[] => { + if (Array.isArray(value)) { + return value + } + if (typeof value === "undefined") { + return [] + } + return [value] +} diff --git a/src/node/app/vscode.ts b/src/node/app/vscode.ts index 681f8336..fb5ed308 100644 --- a/src/node/app/vscode.ts +++ b/src/node/app/vscode.ts @@ -14,7 +14,7 @@ import { WorkbenchOptions, } from "../../../lib/vscode/src/vs/server/ipc" import { HttpCode, HttpError } from "../../common/http" -import { generateUuid } from "../../common/util" +import { arrayify, generateUuid } from "../../common/util" import { Args } from "../cli" import { HttpProvider, HttpProviderOptions, HttpResponse, Route } from "../http" import { settings } from "../settings" @@ -224,8 +224,7 @@ export class VscodeHttpProvider extends HttpProvider { } for (let i = 0; i < startPaths.length; ++i) { const startPath = startPaths[i] - const url = - startPath && (typeof startPath.url === "string" ? [startPath.url] : startPath.url || []).find((p) => !!p) + const url = arrayify(startPath && startPath.url).find((p) => !!p) if (startPath && url) { return { url, diff --git a/src/node/http.ts b/src/node/http.ts index 313ecbeb..216ff5a2 100644 --- a/src/node/http.ts +++ b/src/node/http.ts @@ -12,7 +12,7 @@ import { Readable } from "stream" import * as tls from "tls" import * as url from "url" import { HttpCode, HttpError } from "../common/http" -import { normalize, Options, plural, split } from "../common/util" +import { arrayify, normalize, Options, plural, split, trimSlashes } from "../common/util" import { SocketProxyProvider } from "./socket" import { getMediaMime, paths } from "./util" @@ -287,7 +287,7 @@ export abstract class HttpProvider { * Helper to error on invalid methods (default GET). */ protected ensureMethod(request: http.IncomingMessage, method?: string | string[]): void { - const check = Array.isArray(method) ? method : [method || "GET"] + const check = arrayify(method || "GET") if (!request.method || !check.includes(request.method)) { throw new HttpError(`Unsupported method ${request.method}`, HttpCode.BadRequest) } @@ -559,7 +559,7 @@ export class HttpServer { }, ...args, ) - const endpoints = (typeof endpoint === "string" ? [endpoint] : endpoint).map((e) => e.replace(/^\/+|\/+$/g, "")) + const endpoints = arrayify(endpoint).map(trimSlashes) endpoints.forEach((endpoint) => { if (/\//.test(endpoint)) { throw new Error(`Only top-level endpoints are supported (got ${endpoint})`)