UNPKG

@tanstack/ai

Version:

Type-safe TypeScript AI SDK for streaming chat, tool calling, agents, structured outputs, and multimodal generation.

262 lines (247 loc) 8.45 kB
import type { Modality } from './types' // =========================== // Extended Model Definition // =========================== /** * Definition for a custom model to add to an adapter. * * @template TName - The model name as a literal string type * @template TInput - Array of supported input modalities * @template TOptions - Provider options type for this model * * @example * ```typescript * const customModels = [ * createModel('my-custom-model', ['text', 'image']), * ] as const * ``` */ export interface ExtendedModelDef< TName extends string = string, TInput extends ReadonlyArray<Modality> = ReadonlyArray<Modality>, TOptions = unknown, TFeatures extends ReadonlyArray<string> = ReadonlyArray<string>, TTools extends ReadonlyArray<string> = ReadonlyArray<string>, > { /** The model name identifier */ name: TName /** Supported input modalities for this model */ input: TInput /** Type brand for provider options - use `{} as YourOptionsType` */ modelOptions: TOptions /** Optional declared features (e.g. 'reasoning', 'structured_outputs') */ features?: TFeatures /** Optional declared provider tools (e.g. 'web_search') */ tools?: TTools } /** Capability bag accepted by the object form of `createModel`. */ export interface ModelCapabilities< TInput extends ReadonlyArray<Modality> = ReadonlyArray<Modality>, TFeatures extends ReadonlyArray<string> = ReadonlyArray<string>, TTools extends ReadonlyArray<string> = ReadonlyArray<string>, TOptions = unknown, > { input?: TInput features?: TFeatures tools?: TTools modelOptions?: TOptions } /** * Creates a custom model definition for use with `extendAdapter`. * * This is a helper function that provides proper type inference without * requiring manual `as const` casts on individual properties. * * @template TName - The model name (inferred from argument) * @template TInput - The input modalities array (inferred from argument) * * @param name - The model name identifier (literal string) * @param input - Array of supported input modalities * @returns A properly typed model definition for use with `extendAdapter` * * @example * ```typescript * import { extendAdapter, createModel } from '@tanstack/ai' * import { openaiText } from '@tanstack/ai-openai' * * // Define custom models with full type inference * const customModels = [ * createModel('my-fine-tuned-gpt4', ['text', 'image']), * createModel('local-llama', ['text']), * ] as const * * const myOpenai = extendAdapter(openaiText, customModels) * ``` * * @example * ```typescript * // Capabilities object form - declare features and provider tools * const reasoner = createModel('reasoner', { * input: ['text'], * features: ['reasoning', 'structured_outputs'], * tools: ['web_search'], * }) * ``` */ // Overload 1 — legacy positional input array (unchanged behavior) export function createModel< const TName extends string, const TInput extends ReadonlyArray<Modality>, >(name: TName, input: TInput): ExtendedModelDef<TName, TInput> // Overload 2 — capabilities object export function createModel< const TName extends string, const TCaps extends ModelCapabilities, >( name: TName, capabilities: TCaps, ): ExtendedModelDef< TName, TCaps['input'] extends ReadonlyArray<Modality> ? TCaps['input'] : ReadonlyArray<Modality>, TCaps['modelOptions'], TCaps['features'] extends ReadonlyArray<string> ? TCaps['features'] : ReadonlyArray<string>, TCaps['tools'] extends ReadonlyArray<string> ? TCaps['tools'] : ReadonlyArray<string> > // Implementation export function createModel( name: string, second: ReadonlyArray<Modality> | ModelCapabilities, ): ExtendedModelDef { if (Array.isArray(second)) { return { name, input: second, modelOptions: {} } } const caps = second as ModelCapabilities return { name, input: caps.input ?? (['text'] as ReadonlyArray<Modality>), modelOptions: caps.modelOptions ?? {}, features: caps.features, tools: caps.tools, } } // =========================== // Type Extraction Utilities // =========================== /** * Extract the model name union from an array of model definitions. */ type ExtractCustomModelNames<TDefs extends ReadonlyArray<ExtendedModelDef>> = TDefs[number]['name'] // =========================== // Factory Type Inference // =========================== /** * The widest factory shape `extendAdapter` accepts: any function taking a * model as its first parameter. Parameters are contravariant, so `never` * params and an `unknown` return accept every factory without resorting * to `any`. */ type AnyAdapterFactory = (model: never, ...args: Array<never>) => unknown /** * Infer the model parameter type from an adapter factory function. * For generic functions like `<T extends Union>(model: T)`, this gets `T` which * TypeScript treats as the constraint union when used in parameter position. */ type InferFactoryModels<TFactory> = TFactory extends ( model: infer TModel, ...args: Array<never> ) => unknown ? TModel extends string ? TModel : string : string /** * Infer the adapter return type from a factory function. */ type InferAdapterReturn<TFactory> = TFactory extends ( ...args: Array<never> ) => infer TReturn ? TReturn : never /** * Extracts all parameter types after the model parameter from a factory, * preserving labels and optionality (e.g. `[apiKey: string, config?: C]`). * Note: overloaded factories resolve against their last overload (a * `Parameters` limitation). */ type InferRestArgs<TFactory extends AnyAdapterFactory> = Parameters<TFactory> extends [unknown?, ...infer TRest] ? TRest : [] /** * The factory signature produced by `extendAdapter`: accepts both original * and custom model names while preserving all remaining parameters and the * return type of the original factory. */ type ExtendedFactory< TFactory extends AnyAdapterFactory, TDefs extends ReadonlyArray<ExtendedModelDef>, > = ( model: InferFactoryModels<TFactory> | ExtractCustomModelNames<TDefs>, ...args: InferRestArgs<TFactory> ) => InferAdapterReturn<TFactory> // =========================== // extendAdapter Function // =========================== /** * Extends an existing adapter factory with additional custom models. * * The extended adapter accepts both original models (with full original type inference) * and custom models (with types from your definitions). * * At runtime, this simply passes through to the original factory - no validation is performed. * The original factory's signature is fully preserved, including any config parameters. * * @param factory - The original adapter factory function (e.g., `openaiText`, `anthropicText`) * @param models - Array of custom model definitions with `name` and `input` * @returns A new factory function that accepts both original and custom models * * @example * ```typescript * import { extendAdapter, createModel } from '@tanstack/ai' * import { openaiText } from '@tanstack/ai-openai' * * // Define custom models * const customModels = [ * createModel('my-fine-tuned-gpt4', ['text', 'image']), * createModel('local-llama', ['text']), * ] as const * * // Create extended adapter * const myOpenai = extendAdapter(openaiText, customModels) * * // Use with original models - full type inference preserved * const gpt4 = myOpenai('gpt-4o') * * // Use with custom models * const custom = myOpenai('my-fine-tuned-gpt4') * * // Type error: 'invalid-model' is not a valid model * // myOpenai('invalid-model') * * // Works with chat() * chat({ * adapter: myOpenai('my-fine-tuned-gpt4'), * messages: [...] * }) * ``` */ export function extendAdapter< TFactory extends AnyAdapterFactory, const TDefs extends ReadonlyArray<ExtendedModelDef>, >(factory: TFactory, _customModels: TDefs): ExtendedFactory<TFactory, TDefs> // The implementation signature stays at the honest `AnyAdapterFactory` width; // the overload above performs the deliberate model-union widening. export function extendAdapter( factory: AnyAdapterFactory, _customModels: ReadonlyArray<ExtendedModelDef>, ): AnyAdapterFactory { // At runtime, we simply pass through to the original factory. // The _customModels parameter is only used for type inference. // No runtime validation - users are trusted to pass valid model names. return factory }