UNPKG

better-auth

Version:

The most comprehensive authentication framework for TypeScript.

293 lines (291 loc) • 11.7 kB
import { originCheck } from "../../api/middlewares/origin-check.mjs"; import { symmetricDecrypt, symmetricEncrypt } from "../../crypto/index.mjs"; import { parseSetCookieHeader } from "../../cookies/cookie-utils.mjs"; import "../../cookies/index.mjs"; import "../../api/index.mjs"; import { getOrigin } from "../../utils/url.mjs"; import { parseJSON } from "../../client/parser.mjs"; import { checkSkipProxy, resolveCurrentURL } from "./utils.mjs"; import * as z from "zod"; import { createAuthEndpoint, createAuthMiddleware } from "@better-auth/core/api"; //#region src/plugins/oauth-proxy/index.ts const oAuthProxyQuerySchema = z.object({ callbackURL: z.string().meta({ description: "The URL to redirect to after the proxy" }), cookies: z.string().meta({ description: "The cookies to set after the proxy" }) }); const oAuthProxy = (opts) => { const maxAge = opts?.maxAge ?? 60; return { id: "oauth-proxy", options: opts, endpoints: { oAuthProxy: createAuthEndpoint("/oauth-proxy-callback", { method: "GET", operationId: "oauthProxyCallback", query: oAuthProxyQuerySchema, use: [originCheck((ctx) => ctx.query.callbackURL)], metadata: { openapi: { operationId: "oauthProxyCallback", description: "OAuth Proxy Callback", parameters: [{ in: "query", name: "callbackURL", required: true, description: "The URL to redirect to after the proxy" }, { in: "query", name: "cookies", required: true, description: "The cookies to set after the proxy" }], responses: { 302: { description: "Redirect", headers: { Location: { description: "The URL to redirect to", schema: { type: "string" } } } } } } } }, async (ctx) => { let decryptedPayload = null; try { decryptedPayload = await symmetricDecrypt({ key: ctx.context.secret, data: ctx.query.cookies }); } catch (e) { ctx.context.logger.error("Failed to decrypt OAuth proxy cookies:", e); } if (!decryptedPayload) { const errorURL = ctx.context.options.onAPIError?.errorURL || `${ctx.context.options.baseURL}/api/auth/error`; throw ctx.redirect(`${errorURL}?error=OAuthProxy - Invalid cookies or secret`); } let payload; try { payload = parseJSON(decryptedPayload); } catch (e) { ctx.context.logger.error("Failed to parse OAuth proxy payload:", e); const errorURL = ctx.context.options.onAPIError?.errorURL || `${ctx.context.options.baseURL}/api/auth/error`; throw ctx.redirect(`${errorURL}?error=OAuthProxy - Invalid payload format`); } if (!payload.cookies || typeof payload.cookies !== "string" || typeof payload.timestamp !== "number") { ctx.context.logger.error("OAuth proxy payload missing required fields"); const errorURL = ctx.context.options.onAPIError?.errorURL || `${ctx.context.options.baseURL}/api/auth/error`; throw ctx.redirect(`${errorURL}?error=OAuthProxy - Invalid payload structure`); } const age = (Date.now() - payload.timestamp) / 1e3; if (age > maxAge || age < -10) { ctx.context.logger.error(`OAuth proxy payload expired or invalid (age: ${age}s, maxAge: ${maxAge}s)`); const errorURL = ctx.context.options.onAPIError?.errorURL || `${ctx.context.options.baseURL}/api/auth/error`; throw ctx.redirect(`${errorURL}?error=OAuthProxy - Payload expired or invalid`); } const decryptedCookies = payload.cookies; const isSecureContext = resolveCurrentURL(ctx, opts).protocol === "https:"; const parsedCookies = parseSetCookieHeader(decryptedCookies); const processedCookies = Array.from(parsedCookies.entries()).map(([name, attrs]) => { const options = {}; if (attrs.path) options.path = attrs.path; if (attrs.expires) options.expires = attrs.expires; if (attrs.samesite) options.sameSite = attrs.samesite; if (attrs.httponly) options.httpOnly = true; if (attrs["max-age"] !== void 0) options.maxAge = attrs["max-age"]; if (isSecureContext) options.secure = true; return { name, options, value: decodeURIComponent(attrs.value) }; }); for (const cookie of processedCookies) ctx.setCookie(cookie.name, cookie.value, cookie.options); throw ctx.redirect(ctx.query.callbackURL); }) }, hooks: { before: [ { matcher(context) { return !!(context.path?.startsWith("/sign-in/social") || context.path?.startsWith("/sign-in/oauth2")); }, handler: createAuthMiddleware(async (ctx) => { if (checkSkipProxy(ctx, opts)) return; const currentURL = resolveCurrentURL(ctx, opts); const originalCallbackURL = ctx.body?.callbackURL || ctx.context.baseURL; const newCallbackURL = `${currentURL.origin}${ctx.context.options.basePath || "/api/auth"}/oauth-proxy-callback?callbackURL=${encodeURIComponent(originalCallbackURL)}`; if (!ctx.body) return; ctx.body.callbackURL = newCallbackURL; }) }, { matcher(context) { return !!(context.path?.startsWith("/callback") || context.path?.startsWith("/oauth2/callback")); }, handler: createAuthMiddleware(async (ctx) => { const state = ctx.query?.state || ctx.body?.state; if (!state || typeof state !== "string") return; let statePackage; try { statePackage = parseJSON(await symmetricDecrypt({ key: ctx.context.secret, data: state })); } catch { return; } if (!statePackage.isOAuthProxy || !statePackage.state || !statePackage.stateCookie) return; let stateCookieValue; try { stateCookieValue = await symmetricDecrypt({ key: ctx.context.secret, data: statePackage.stateCookie }); parseJSON(stateCookieValue); } catch (e) { ctx.context.logger.error("Failed to decrypt OAuth proxy state cookie:", e); return; } ctx.context._oauthProxySnapshot = { storeStateStrategy: ctx.context.oauthConfig.storeStateStrategy, skipStateCookieCheck: ctx.context.oauthConfig.skipStateCookieCheck, internalAdapter: ctx.context.internalAdapter }; const originalAdapter = ctx.context.internalAdapter; const capturedStatePackage = statePackage; ctx.context.oauthConfig.storeStateStrategy = "database"; ctx.context.internalAdapter = { ...ctx.context.internalAdapter, findVerificationValue: async (identifier) => { if (identifier === capturedStatePackage.state) return { id: `oauth-proxy-${capturedStatePackage.state}`, identifier: capturedStatePackage.state, value: stateCookieValue, createdAt: /* @__PURE__ */ new Date(), updatedAt: /* @__PURE__ */ new Date(), expiresAt: new Date(Date.now() + 600 * 1e3) }; return originalAdapter.findVerificationValue(identifier); } }; if (ctx.query?.state) ctx.query.state = statePackage.state; if (ctx.body?.state) ctx.body.state = statePackage.state; ctx.context.oauthConfig.skipStateCookieCheck = true; }) }, { matcher() { return true; }, handler: createAuthMiddleware(async (ctx) => { if (ctx.path !== "/callback/:id") return; if (ctx.context.oauthConfig.storeStateStrategy === "cookie") return; if (ctx.context._oauthProxySnapshot) return; const state = ctx.query?.state || ctx.body?.state; if (!state) return; const data = await ctx.context.internalAdapter.findVerificationValue(state); if (!data) return; let parsedState; try { parsedState = parseJSON(data.value); } catch { parsedState = void 0; } if (!parsedState?.callbackURL?.includes("/oauth-proxy-callback")) return; ctx.context._oauthProxySnapshot = { storeStateStrategy: ctx.context.oauthConfig.storeStateStrategy, skipStateCookieCheck: ctx.context.oauthConfig.skipStateCookieCheck, internalAdapter: ctx.context.internalAdapter }; ctx.context.oauthConfig.skipStateCookieCheck = true; }) } ], after: [ { matcher(context) { return !!(context.path?.startsWith("/sign-in/social") || context.path?.startsWith("/sign-in/oauth2")); }, handler: createAuthMiddleware(async (ctx) => { if (checkSkipProxy(ctx, opts)) return; if (ctx.context.oauthConfig.storeStateStrategy !== "cookie") return; const signInResponse = ctx.context.returned; if (!signInResponse || typeof signInResponse !== "object" || !("url" in signInResponse)) return; const { url: providerURL } = signInResponse; if (typeof providerURL !== "string") return; const oauthURL = new URL(providerURL); const originalState = oauthURL.searchParams.get("state"); if (!originalState) return; const setCookieHeader = ctx.context.responseHeaders?.get("set-cookie"); if (!setCookieHeader) return; const stateCookie = ctx.context.createAuthCookie("oauth_state"); const stateCookieAttrs = parseSetCookieHeader(setCookieHeader).get(stateCookie.name); if (!stateCookieAttrs?.value) return; const stateCookieValue = stateCookieAttrs.value; try { const statePackage = { state: originalState, stateCookie: stateCookieValue, isOAuthProxy: true }; const encryptedPackage = await symmetricEncrypt({ key: ctx.context.secret, data: JSON.stringify(statePackage) }); oauthURL.searchParams.set("state", encryptedPackage); ctx.context.returned = { ...signInResponse, url: oauthURL.toString() }; } catch (e) { ctx.context.logger.error("Failed to encrypt OAuth proxy state package:", e); } }) }, { matcher(context) { return !!(context.path?.startsWith("/callback") || context.path?.startsWith("/oauth2/callback")); }, handler: createAuthMiddleware(async (ctx) => { const headers = ctx.context.responseHeaders; const location = headers?.get("location"); if (!location?.includes("/oauth-proxy-callback?callbackURL") || !location.startsWith("http")) return; const productionOrigin = getOrigin(opts?.productionURL || ctx.context.options.baseURL || ctx.context.baseURL); const locationURL = new URL(location); if (locationURL.origin === productionOrigin) { const newLocation = locationURL.searchParams.get("callbackURL"); if (!newLocation) return; ctx.setHeader("location", newLocation); return; } const setCookies = headers?.get("set-cookie"); if (!setCookies) return; const payload = { cookies: setCookies, timestamp: Date.now() }; const encryptedCookies = await symmetricEncrypt({ key: ctx.context.secret, data: JSON.stringify(payload) }); const locationWithCookies = `${location}&cookies=${encodeURIComponent(encryptedCookies)}`; ctx.setHeader("location", locationWithCookies); }) }, { matcher(context) { return !!(context.path?.startsWith("/callback") || context.path?.startsWith("/oauth2/callback")); }, handler: createAuthMiddleware(async (ctx) => { const contextWithSnapshot = ctx.context; const snapshot = contextWithSnapshot._oauthProxySnapshot; if (snapshot) { ctx.context.oauthConfig.storeStateStrategy = snapshot.storeStateStrategy; ctx.context.oauthConfig.skipStateCookieCheck = snapshot.skipStateCookieCheck; ctx.context.internalAdapter = snapshot.internalAdapter; contextWithSnapshot._oauthProxySnapshot = void 0; } }) } ] } }; }; //#endregion export { oAuthProxy }; //# sourceMappingURL=index.mjs.map