@trpc/server
Version:
708 lines (648 loc) • 19.4 kB
text/typescript
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import {
isObservable,
observableToAsyncIterable,
} from '../../observable/observable';
import { getErrorShape } from '../error/getErrorShape';
import { getTRPCErrorFromUnknown, TRPCError } from '../error/TRPCError';
import type { ProcedureType } from '../procedure';
import {
type AnyRouter,
type inferRouterContext,
type inferRouterError,
} from '../router';
import type { TRPCResponse } from '../rpc';
import { isPromise, jsonlStreamProducer } from '../stream/jsonl';
import { sseHeaders, sseStreamProducer } from '../stream/sse';
import { transformTRPCResponse } from '../transformer';
import { isAsyncIterable, isObject, run } from '../utils';
import { getRequestInfo } from './contentType';
import { getHTTPStatusCode } from './getHTTPStatusCode';
import type {
HTTPBaseHandlerOptions,
ResolveHTTPRequestOptionsContextFn,
TRPCRequestInfo,
} from './types';
function errorToAsyncIterable(err: TRPCError): AsyncIterable<never> {
return run(async function* () {
throw err;
});
}
type HTTPMethods =
| 'GET'
| 'POST'
| 'HEAD'
| 'OPTIONS'
| 'PUT'
| 'DELETE'
| 'PATCH';
const TYPE_ACCEPTED_METHOD_MAP: Record<ProcedureType, HTTPMethods[]> = {
mutation: ['POST'],
query: ['GET'],
subscription: ['GET'],
};
const TYPE_ACCEPTED_METHOD_MAP_WITH_METHOD_OVERRIDE: Record<
ProcedureType,
HTTPMethods[]
> = {
// never allow GET to do a mutation
mutation: ['POST'],
query: ['GET', 'POST'],
subscription: ['GET', 'POST'],
};
interface ResolveHTTPRequestOptions<TRouter extends AnyRouter>
extends HTTPBaseHandlerOptions<TRouter, Request> {
createContext: ResolveHTTPRequestOptionsContextFn<TRouter>;
req: Request;
path: string;
/**
* If the request had an issue before reaching the handler
*/
error: TRPCError | null;
}
function initResponse<TRouter extends AnyRouter, TRequest>(initOpts: {
ctx: inferRouterContext<TRouter> | undefined;
info: TRPCRequestInfo | undefined;
responseMeta?: HTTPBaseHandlerOptions<TRouter, TRequest>['responseMeta'];
untransformedJSON:
| TRPCResponse<unknown, inferRouterError<TRouter>>
| TRPCResponse<unknown, inferRouterError<TRouter>>[]
| null;
errors: TRPCError[];
headers: Headers;
}) {
const {
ctx,
info,
responseMeta,
untransformedJSON,
errors = [],
headers,
} = initOpts;
let status = untransformedJSON ? getHTTPStatusCode(untransformedJSON) : 200;
const eagerGeneration = !untransformedJSON;
const data = eagerGeneration
? []
: Array.isArray(untransformedJSON)
? untransformedJSON
: [untransformedJSON];
const meta =
responseMeta?.({
ctx,
info,
paths: info?.calls.map((call) => call.path),
data,
errors,
eagerGeneration,
type:
info?.calls.find((call) => call.procedure?._def.type)?.procedure?._def
.type ?? 'unknown',
}) ?? {};
if (meta.headers) {
if (meta.headers instanceof Headers) {
for (const [key, value] of meta.headers.entries()) {
headers.append(key, value);
}
} else {
/**
* @deprecated, delete in v12
*/
for (const [key, value] of Object.entries(meta.headers)) {
if (Array.isArray(value)) {
for (const v of value) {
headers.append(key, v);
}
} else if (typeof value === 'string') {
headers.set(key, value);
}
}
}
}
if (meta.status) {
status = meta.status;
}
return {
status,
};
}
function caughtErrorToData<TRouter extends AnyRouter>(
cause: unknown,
errorOpts: {
opts: Pick<
ResolveHTTPRequestOptions<TRouter>,
'onError' | 'req' | 'router'
>;
ctx: inferRouterContext<TRouter> | undefined;
type: ProcedureType | 'unknown';
path?: string;
input?: unknown;
},
) {
const { router, req, onError } = errorOpts.opts;
const error = getTRPCErrorFromUnknown(cause);
onError?.({
error,
path: errorOpts.path,
input: errorOpts.input,
ctx: errorOpts.ctx,
type: errorOpts.type,
req,
});
const untransformedJSON = {
error: getErrorShape({
config: router._def._config,
error,
type: errorOpts.type,
path: errorOpts.path,
input: errorOpts.input,
ctx: errorOpts.ctx,
}),
};
const transformedJSON = transformTRPCResponse(
router._def._config,
untransformedJSON,
);
const body = JSON.stringify(transformedJSON);
return {
error,
untransformedJSON,
body,
};
}
/**
* Check if a value is a stream-like object
* - if it's an async iterable
* - if it's an object with async iterables or promises
*/
function isDataStream(v: unknown) {
if (!isObject(v)) {
return false;
}
if (isAsyncIterable(v)) {
return true;
}
return (
Object.values(v).some(isPromise) || Object.values(v).some(isAsyncIterable)
);
}
type ResultTuple<T> = [undefined, T] | [TRPCError, undefined];
export async function resolveResponse<TRouter extends AnyRouter>(
opts: ResolveHTTPRequestOptions<TRouter>,
): Promise<Response> {
const { router, req } = opts;
const headers = new Headers([['vary', 'trpc-accept']]);
const config = router._def._config;
const url = new URL(req.url);
if (req.method === 'HEAD') {
// can be used for lambda warmup
return new Response(null, {
status: 204,
});
}
const allowBatching = opts.allowBatching ?? opts.batching?.enabled ?? true;
const allowMethodOverride =
(opts.allowMethodOverride ?? false) && req.method === 'POST';
type $Context = inferRouterContext<TRouter>;
const infoTuple: ResultTuple<TRPCRequestInfo> = await run(async () => {
try {
return [
undefined,
await getRequestInfo({
req,
path: decodeURIComponent(opts.path),
router,
searchParams: url.searchParams,
headers: opts.req.headers,
url,
}),
];
} catch (cause) {
return [getTRPCErrorFromUnknown(cause), undefined];
}
});
interface ContextManager {
valueOrUndefined: () => $Context | undefined;
value: () => $Context;
create: (info: TRPCRequestInfo) => Promise<void>;
}
const ctxManager: ContextManager = run(() => {
let result: ResultTuple<$Context> | undefined = undefined;
return {
valueOrUndefined: () => {
if (!result) {
return undefined;
}
return result[1];
},
value: () => {
const [err, ctx] = result!;
if (err) {
throw err;
}
return ctx;
},
create: async (info) => {
if (result) {
throw new Error(
'This should only be called once - report a bug in tRPC',
);
}
try {
const ctx = await opts.createContext({
info,
});
result = [undefined, ctx];
} catch (cause) {
result = [getTRPCErrorFromUnknown(cause), undefined];
}
},
};
});
const methodMapper = allowMethodOverride
? TYPE_ACCEPTED_METHOD_MAP_WITH_METHOD_OVERRIDE
: TYPE_ACCEPTED_METHOD_MAP;
/**
* @deprecated
*/
const isStreamCall = req.headers.get('trpc-accept') === 'application/jsonl';
const experimentalSSE = config.sse?.enabled ?? true;
try {
const [infoError, info] = infoTuple;
if (infoError) {
throw infoError;
}
if (info.isBatchCall && !allowBatching) {
throw new TRPCError({
code: 'BAD_REQUEST',
message: `Batching is not enabled on the server`,
});
}
/* istanbul ignore if -- @preserve */
if (isStreamCall && !info.isBatchCall) {
throw new TRPCError({
message: `Streaming requests must be batched (you can do a batch of 1)`,
code: 'BAD_REQUEST',
});
}
await ctxManager.create(info);
interface RPCResultOk {
data: unknown;
}
type RPCResult = ResultTuple<RPCResultOk>;
const rpcCalls = info.calls.map(async (call): Promise<RPCResult> => {
const proc = call.procedure;
try {
if (opts.error) {
throw opts.error;
}
if (!proc) {
throw new TRPCError({
code: 'NOT_FOUND',
message: `No procedure found on path "${call.path}"`,
});
}
if (!methodMapper[proc._def.type].includes(req.method as HTTPMethods)) {
throw new TRPCError({
code: 'METHOD_NOT_SUPPORTED',
message: `Unsupported ${req.method}-request to ${proc._def.type} procedure at path "${call.path}"`,
});
}
if (proc._def.type === 'subscription') {
/* istanbul ignore if -- @preserve */
if (info.isBatchCall) {
throw new TRPCError({
code: 'BAD_REQUEST',
message: `Cannot batch subscription calls`,
});
}
}
const data: unknown = await proc({
path: call.path,
getRawInput: call.getRawInput,
ctx: ctxManager.value(),
type: proc._def.type,
signal: opts.req.signal,
});
return [undefined, { data }];
} catch (cause) {
const error = getTRPCErrorFromUnknown(cause);
const input = call.result();
opts.onError?.({
error,
path: call.path,
input,
ctx: ctxManager.valueOrUndefined(),
type: call.procedure?._def.type ?? 'unknown',
req: opts.req,
});
return [error, undefined];
}
});
// ----------- response handlers -----------
if (!info.isBatchCall) {
const [call] = info.calls;
const [error, result] = await rpcCalls[0]!;
switch (info.type) {
case 'unknown':
case 'mutation':
case 'query': {
// httpLink
headers.set('content-type', 'application/json');
if (isDataStream(result?.data)) {
throw new TRPCError({
code: 'UNSUPPORTED_MEDIA_TYPE',
message:
'Cannot use stream-like response in non-streaming request - use httpBatchStreamLink',
});
}
const res: TRPCResponse<unknown, inferRouterError<TRouter>> = error
? {
error: getErrorShape({
config,
ctx: ctxManager.valueOrUndefined(),
error,
input: call!.result(),
path: call!.path,
type: info.type,
}),
}
: { result: { data: result.data } };
const headResponse = initResponse({
ctx: ctxManager.valueOrUndefined(),
info,
responseMeta: opts.responseMeta,
errors: error ? [error] : [],
headers,
untransformedJSON: [res],
});
return new Response(
JSON.stringify(transformTRPCResponse(config, res)),
{
status: headResponse.status,
headers,
},
);
}
case 'subscription': {
// httpSubscriptionLink
const iterable: AsyncIterable<unknown> = run(() => {
if (error) {
return errorToAsyncIterable(error);
}
if (!experimentalSSE) {
return errorToAsyncIterable(
new TRPCError({
code: 'METHOD_NOT_SUPPORTED',
message: 'Missing experimental flag "sseSubscriptions"',
}),
);
}
if (!isObservable(result.data) && !isAsyncIterable(result.data)) {
return errorToAsyncIterable(
new TRPCError({
message: `Subscription ${
call!.path
} did not return an observable or a AsyncGenerator`,
code: 'INTERNAL_SERVER_ERROR',
}),
);
}
const dataAsIterable = isObservable(result.data)
? observableToAsyncIterable(result.data, opts.req.signal)
: result.data;
return dataAsIterable;
});
const stream = sseStreamProducer({
...config.sse,
data: iterable,
serialize: (v) => config.transformer.output.serialize(v),
formatError(errorOpts) {
const error = getTRPCErrorFromUnknown(errorOpts.error);
const input = call?.result();
const path = call?.path;
const type = call?.procedure?._def.type ?? 'unknown';
opts.onError?.({
error,
path,
input,
ctx: ctxManager.valueOrUndefined(),
req: opts.req,
type,
});
const shape = getErrorShape({
config,
ctx: ctxManager.valueOrUndefined(),
error,
input,
path,
type,
});
return shape;
},
});
for (const [key, value] of Object.entries(sseHeaders)) {
headers.set(key, value);
}
const headResponse = initResponse({
ctx: ctxManager.valueOrUndefined(),
info,
responseMeta: opts.responseMeta,
errors: [],
headers,
untransformedJSON: null,
});
return new Response(stream, {
headers,
status: headResponse.status,
});
}
}
}
// batch response handlers
if (info.accept === 'application/jsonl') {
// httpBatchStreamLink
headers.set('content-type', 'application/json');
headers.set('transfer-encoding', 'chunked');
const headResponse = initResponse({
ctx: ctxManager.valueOrUndefined(),
info,
responseMeta: opts.responseMeta,
errors: [],
headers,
untransformedJSON: null,
});
const stream = jsonlStreamProducer({
...config.jsonl,
/**
* Example structure for `maxDepth: 4`:
* {
* // 1
* 0: {
* // 2
* result: {
* // 3
* data: // 4
* }
* }
* }
*/
maxDepth: Infinity,
data: rpcCalls.map(async (res) => {
const [error, result] = await res;
const call = info.calls[0];
if (error) {
return {
error: getErrorShape({
config,
ctx: ctxManager.valueOrUndefined(),
error,
input: call!.result(),
path: call!.path,
type: call!.procedure?._def.type ?? 'unknown',
}),
};
}
/**
* Not very pretty, but we need to wrap nested data in promises
* Our stream producer will only resolve top-level async values or async values that are directly nested in another async value
*/
const iterable = isObservable(result.data)
? observableToAsyncIterable(result.data, opts.req.signal)
: Promise.resolve(result.data);
return {
result: Promise.resolve({
data: iterable,
}),
};
}),
serialize: config.transformer.output.serialize,
onError: (cause) => {
opts.onError?.({
error: getTRPCErrorFromUnknown(cause),
path: undefined,
input: undefined,
ctx: ctxManager.valueOrUndefined(),
req: opts.req,
type: info?.type ?? 'unknown',
});
},
formatError(errorOpts) {
const call = info?.calls[errorOpts.path[0] as any];
const error = getTRPCErrorFromUnknown(errorOpts.error);
const input = call?.result();
const path = call?.path;
const type = call?.procedure?._def.type ?? 'unknown';
// no need to call `onError` here as it will be propagated through the stream itself
const shape = getErrorShape({
config,
ctx: ctxManager.valueOrUndefined(),
error,
input,
path,
type,
});
return shape;
},
});
return new Response(stream, {
headers,
status: headResponse.status,
});
}
// httpBatchLink
/**
* Non-streaming response:
* - await all responses in parallel, blocking on the slowest one
* - create headers with known response body
* - return a complete HTTPResponse
*/
headers.set('content-type', 'application/json');
const results: RPCResult[] = (await Promise.all(rpcCalls)).map(
(res): RPCResult => {
const [error, result] = res;
if (error) {
return res;
}
if (isDataStream(result.data)) {
return [
new TRPCError({
code: 'UNSUPPORTED_MEDIA_TYPE',
message:
'Cannot use stream-like response in non-streaming request - use httpBatchStreamLink',
}),
undefined,
];
}
return res;
},
);
const resultAsRPCResponse = results.map(
(
[error, result],
index,
): TRPCResponse<unknown, inferRouterError<TRouter>> => {
const call = info.calls[index]!;
if (error) {
return {
error: getErrorShape({
config,
ctx: ctxManager.valueOrUndefined(),
error,
input: call.result(),
path: call.path,
type: call.procedure?._def.type ?? 'unknown',
}),
};
}
return {
result: { data: result.data },
};
},
);
const errors = results
.map(([error]) => error)
.filter(Boolean) as TRPCError[];
const headResponse = initResponse({
ctx: ctxManager.valueOrUndefined(),
info,
responseMeta: opts.responseMeta,
untransformedJSON: resultAsRPCResponse,
errors,
headers,
});
return new Response(
JSON.stringify(transformTRPCResponse(config, resultAsRPCResponse)),
{
status: headResponse.status,
headers,
},
);
} catch (cause) {
const [_infoError, info] = infoTuple;
const ctx = ctxManager.valueOrUndefined();
// we get here if
// - batching is called when it's not enabled
// - `createContext()` throws
// - `router._def._config.transformer.output.serialize()` throws
// - post body is too large
// - input deserialization fails
// - `errorFormatter` return value is malformed
const { error, untransformedJSON, body } = caughtErrorToData(cause, {
opts,
ctx: ctxManager.valueOrUndefined(),
type: info?.type ?? 'unknown',
});
const headResponse = initResponse({
ctx,
info,
responseMeta: opts.responseMeta,
untransformedJSON,
errors: [error],
headers,
});
return new Response(body, {
status: headResponse.status,
headers,
});
}
}