UNPKG

@trpc/tanstack-react-query

Version:

TanStack React Query Integration for tRPC

315 lines (280 loc) 8.72 kB
import type { SkipToken } from '@tanstack/react-query'; import { hashKey, skipToken } from '@tanstack/react-query'; import type { TRPCClientErrorLike, TRPCUntypedClient } from '@trpc/client'; import type { TRPCConnectionState } from '@trpc/client/unstable-internals'; import type { Unsubscribable } from '@trpc/server/observable'; import type { inferAsyncIterableYield } from '@trpc/server/unstable-core-do-not-import'; import * as React from 'react'; import type { DefaultFeatureFlags, FeatureFlags, ResolverDef, TRPCQueryKey, TRPCQueryOptionsResult, } from './types'; import { createTRPCOptionsResult, readQueryKey } from './utils'; interface BaseTRPCSubscriptionOptionsIn<TOutput, TError> { enabled?: boolean; onStarted?: () => void; onData?: (data: inferAsyncIterableYield<TOutput>) => void; onError?: (err: TError) => void; onConnectionStateChange?: (state: TRPCConnectionState<TError>) => void; } interface UnusedSkipTokenTRPCSubscriptionOptionsIn<TOutput, TError> { onStarted?: () => void; onData?: (data: inferAsyncIterableYield<TOutput>) => void; onError?: (err: TError) => void; onConnectionStateChange?: (state: TRPCConnectionState<TError>) => void; } interface TRPCSubscriptionOptionsOut< TOutput, TError, TFeatureFlags extends FeatureFlags, > extends UnusedSkipTokenTRPCSubscriptionOptionsIn<TOutput, TError>, TRPCQueryOptionsResult { enabled: boolean; queryKey: TRPCQueryKey<TFeatureFlags['keyPrefix']>; subscribe: ( innerOpts: UnusedSkipTokenTRPCSubscriptionOptionsIn<TOutput, TError>, ) => Unsubscribable; } export interface TRPCSubscriptionOptions< TDef extends ResolverDef, TFeatureFlags extends FeatureFlags = DefaultFeatureFlags, > { ( input: TDef['input'], opts?: UnusedSkipTokenTRPCSubscriptionOptionsIn< inferAsyncIterableYield<TDef['output']>, TRPCClientErrorLike<TDef> >, ): TRPCSubscriptionOptionsOut< inferAsyncIterableYield<TDef['output']>, TRPCClientErrorLike<TDef>, TFeatureFlags >; ( input: TDef['input'] | SkipToken, opts?: BaseTRPCSubscriptionOptionsIn< inferAsyncIterableYield<TDef['output']>, TRPCClientErrorLike<TDef> >, ): TRPCSubscriptionOptionsOut< inferAsyncIterableYield<TDef['output']>, TRPCClientErrorLike<TDef>, TFeatureFlags >; } export type TRPCSubscriptionStatus = | 'idle' | 'connecting' | 'pending' | 'error'; export interface TRPCSubscriptionBaseResult<TOutput, TError> { status: TRPCSubscriptionStatus; data: undefined | TOutput; error: null | TError; /** * Reset the subscription */ reset: () => void; } export interface TRPCSubscriptionIdleResult<TOutput> extends TRPCSubscriptionBaseResult<TOutput, null> { status: 'idle'; data: undefined; error: null; } export interface TRPCSubscriptionConnectingResult<TOutput, TError> extends TRPCSubscriptionBaseResult<TOutput, TError> { status: 'connecting'; data: undefined | TOutput; error: TError | null; } export interface TRPCSubscriptionPendingResult<TOutput> extends TRPCSubscriptionBaseResult<TOutput, undefined> { status: 'pending'; data: TOutput | undefined; error: null; } export interface TRPCSubscriptionErrorResult<TOutput, TError> extends TRPCSubscriptionBaseResult<TOutput, TError> { status: 'error'; data: TOutput | undefined; error: TError; } export type TRPCSubscriptionResult<TOutput, TError> = | TRPCSubscriptionIdleResult<TOutput> | TRPCSubscriptionConnectingResult<TOutput, TError> | TRPCSubscriptionErrorResult<TOutput, TError> | TRPCSubscriptionPendingResult<TOutput>; type AnyTRPCSubscriptionOptionsIn = | BaseTRPCSubscriptionOptionsIn<unknown, unknown> | UnusedSkipTokenTRPCSubscriptionOptionsIn<unknown, unknown>; type AnyTRPCSubscriptionOptionsOut<TFeatureFlags extends FeatureFlags> = TRPCSubscriptionOptionsOut<unknown, unknown, TFeatureFlags>; /** * @internal */ export const trpcSubscriptionOptions = < TFeatureFlags extends FeatureFlags, >(args: { subscribe: typeof TRPCUntypedClient.prototype.subscription; path: string[]; queryKey: TRPCQueryKey<TFeatureFlags['keyPrefix']>; opts?: AnyTRPCSubscriptionOptionsIn; }): AnyTRPCSubscriptionOptionsOut<TFeatureFlags> => { const { subscribe, path, queryKey, opts = {} } = args; const input = readQueryKey(queryKey)?.args?.input; const enabled = 'enabled' in opts ? !!opts.enabled : input !== skipToken; const _subscribe: ReturnType< TRPCSubscriptionOptions<any, TFeatureFlags> >['subscribe'] = (innerOpts) => { return subscribe(path.join('.'), input ?? undefined, innerOpts); }; return { ...opts, enabled, subscribe: _subscribe, queryKey, trpc: createTRPCOptionsResult({ path }), }; }; export function useSubscription<TOutput, TError>( opts: TRPCSubscriptionOptionsOut<TOutput, TError, any>, ): TRPCSubscriptionResult<TOutput, TError> { type $Result = TRPCSubscriptionResult<TOutput, TError>; const optsRef = React.useRef(opts); optsRef.current = opts; const trackedProps = React.useRef(new Set<keyof $Result>([])); const addTrackedProp = React.useCallback((key: keyof $Result) => { trackedProps.current.add(key); }, []); type Unsubscribe = () => void; const currentSubscriptionRef = React.useRef<Unsubscribe>(() => { // noop }); const reset = React.useCallback((): void => { // unsubscribe from the previous subscription currentSubscriptionRef.current?.(); updateState(getInitialState); if (!opts.enabled) { return; } const subscription = opts.subscribe({ onStarted: () => { optsRef.current.onStarted?.(); updateState((prev) => ({ ...(prev as any), status: 'pending', error: null, })); }, onData: (data) => { optsRef.current.onData?.(data); updateState((prev) => ({ ...(prev as any), status: 'pending', data, error: null, })); }, onError: (error) => { optsRef.current.onError?.(error); updateState((prev) => ({ ...(prev as any), status: 'error', error, })); }, onConnectionStateChange: (result) => { optsRef.current.onConnectionStateChange?.(result); updateState((prev) => { switch (result.state) { case 'connecting': return { ...prev, status: 'connecting', error: result.error, }; case 'pending': // handled in onStarted return prev; case 'idle': return { ...prev, status: 'idle', data: undefined, error: null, }; } }); }, }); currentSubscriptionRef.current = () => { subscription.unsubscribe(); }; // eslint-disable-next-line react-hooks/react-compiler // eslint-disable-next-line react-hooks/exhaustive-deps }, [hashKey(opts.queryKey), opts.enabled]); const getInitialState = React.useCallback((): $Result => { return opts.enabled ? { data: undefined, error: null, status: 'connecting', reset, } : { data: undefined, error: null, status: 'idle', reset, }; }, [opts.enabled, reset]); const resultRef = React.useRef<$Result>(getInitialState()); const [state, setState] = React.useState<$Result>( trackResult(resultRef, addTrackedProp), ); state.reset = reset; const updateState = React.useCallback( (callback: (prevState: $Result) => $Result) => { const prev = resultRef.current; const next = (resultRef.current = callback(prev)); let shouldUpdate = false; for (const key of trackedProps.current) { if (prev[key] !== next[key]) { shouldUpdate = true; break; } } if (shouldUpdate) { setState(trackResult(resultRef, addTrackedProp)); } }, [addTrackedProp], ); React.useEffect(() => { if (!opts.enabled) { return; } reset(); return () => { currentSubscriptionRef.current?.(); }; }, [reset, opts.enabled]); return state; } function trackResult<T extends object>( result: React.RefObject<T>, onTrackResult: (key: keyof T) => void, ): T { const trackedResult = new Proxy(result.current, { get(_target, prop) { onTrackResult(prop as keyof T); // Bypass target, so that we always get the latest value return result.current[prop as keyof T]; }, }); return trackedResult; }