@tanstack/ai
Version:
Type-safe TypeScript AI SDK for streaming chat, tool calling, agents, structured outputs, and multimodal generation.
262 lines (247 loc) • 8.45 kB
text/typescript
import type { Modality } from './types'
// ===========================
// Extended Model Definition
// ===========================
/**
* Definition for a custom model to add to an adapter.
*
* @template TName - The model name as a literal string type
* @template TInput - Array of supported input modalities
* @template TOptions - Provider options type for this model
*
* @example
* ```typescript
* const customModels = [
* createModel('my-custom-model', ['text', 'image']),
* ] as const
* ```
*/
export interface ExtendedModelDef<
TName extends string = string,
TInput extends ReadonlyArray<Modality> = ReadonlyArray<Modality>,
TOptions = unknown,
TFeatures extends ReadonlyArray<string> = ReadonlyArray<string>,
TTools extends ReadonlyArray<string> = ReadonlyArray<string>,
> {
/** The model name identifier */
name: TName
/** Supported input modalities for this model */
input: TInput
/** Type brand for provider options - use `{} as YourOptionsType` */
modelOptions: TOptions
/** Optional declared features (e.g. 'reasoning', 'structured_outputs') */
features?: TFeatures
/** Optional declared provider tools (e.g. 'web_search') */
tools?: TTools
}
/** Capability bag accepted by the object form of `createModel`. */
export interface ModelCapabilities<
TInput extends ReadonlyArray<Modality> = ReadonlyArray<Modality>,
TFeatures extends ReadonlyArray<string> = ReadonlyArray<string>,
TTools extends ReadonlyArray<string> = ReadonlyArray<string>,
TOptions = unknown,
> {
input?: TInput
features?: TFeatures
tools?: TTools
modelOptions?: TOptions
}
/**
* Creates a custom model definition for use with `extendAdapter`.
*
* This is a helper function that provides proper type inference without
* requiring manual `as const` casts on individual properties.
*
* @template TName - The model name (inferred from argument)
* @template TInput - The input modalities array (inferred from argument)
*
* @param name - The model name identifier (literal string)
* @param input - Array of supported input modalities
* @returns A properly typed model definition for use with `extendAdapter`
*
* @example
* ```typescript
* import { extendAdapter, createModel } from '@tanstack/ai'
* import { openaiText } from '@tanstack/ai-openai'
*
* // Define custom models with full type inference
* const customModels = [
* createModel('my-fine-tuned-gpt4', ['text', 'image']),
* createModel('local-llama', ['text']),
* ] as const
*
* const myOpenai = extendAdapter(openaiText, customModels)
* ```
*
* @example
* ```typescript
* // Capabilities object form - declare features and provider tools
* const reasoner = createModel('reasoner', {
* input: ['text'],
* features: ['reasoning', 'structured_outputs'],
* tools: ['web_search'],
* })
* ```
*/
// Overload 1 — legacy positional input array (unchanged behavior)
export function createModel<
const TName extends string,
const TInput extends ReadonlyArray<Modality>,
>(name: TName, input: TInput): ExtendedModelDef<TName, TInput>
// Overload 2 — capabilities object
export function createModel<
const TName extends string,
const TCaps extends ModelCapabilities,
>(
name: TName,
capabilities: TCaps,
): ExtendedModelDef<
TName,
TCaps['input'] extends ReadonlyArray<Modality>
? TCaps['input']
: ReadonlyArray<Modality>,
TCaps['modelOptions'],
TCaps['features'] extends ReadonlyArray<string>
? TCaps['features']
: ReadonlyArray<string>,
TCaps['tools'] extends ReadonlyArray<string>
? TCaps['tools']
: ReadonlyArray<string>
>
// Implementation
export function createModel(
name: string,
second: ReadonlyArray<Modality> | ModelCapabilities,
): ExtendedModelDef {
if (Array.isArray(second)) {
return { name, input: second, modelOptions: {} }
}
const caps = second as ModelCapabilities
return {
name,
input: caps.input ?? (['text'] as ReadonlyArray<Modality>),
modelOptions: caps.modelOptions ?? {},
features: caps.features,
tools: caps.tools,
}
}
// ===========================
// Type Extraction Utilities
// ===========================
/**
* Extract the model name union from an array of model definitions.
*/
type ExtractCustomModelNames<TDefs extends ReadonlyArray<ExtendedModelDef>> =
TDefs[number]['name']
// ===========================
// Factory Type Inference
// ===========================
/**
* The widest factory shape `extendAdapter` accepts: any function taking a
* model as its first parameter. Parameters are contravariant, so `never`
* params and an `unknown` return accept every factory without resorting
* to `any`.
*/
type AnyAdapterFactory = (model: never, ...args: Array<never>) => unknown
/**
* Infer the model parameter type from an adapter factory function.
* For generic functions like `<T extends Union>(model: T)`, this gets `T` which
* TypeScript treats as the constraint union when used in parameter position.
*/
type InferFactoryModels<TFactory> = TFactory extends (
model: infer TModel,
...args: Array<never>
) => unknown
? TModel extends string
? TModel
: string
: string
/**
* Infer the adapter return type from a factory function.
*/
type InferAdapterReturn<TFactory> = TFactory extends (
...args: Array<never>
) => infer TReturn
? TReturn
: never
/**
* Extracts all parameter types after the model parameter from a factory,
* preserving labels and optionality (e.g. `[apiKey: string, config?: C]`).
* Note: overloaded factories resolve against their last overload (a
* `Parameters` limitation).
*/
type InferRestArgs<TFactory extends AnyAdapterFactory> =
Parameters<TFactory> extends [unknown?, ...infer TRest] ? TRest : []
/**
* The factory signature produced by `extendAdapter`: accepts both original
* and custom model names while preserving all remaining parameters and the
* return type of the original factory.
*/
type ExtendedFactory<
TFactory extends AnyAdapterFactory,
TDefs extends ReadonlyArray<ExtendedModelDef>,
> = (
model: InferFactoryModels<TFactory> | ExtractCustomModelNames<TDefs>,
...args: InferRestArgs<TFactory>
) => InferAdapterReturn<TFactory>
// ===========================
// extendAdapter Function
// ===========================
/**
* Extends an existing adapter factory with additional custom models.
*
* The extended adapter accepts both original models (with full original type inference)
* and custom models (with types from your definitions).
*
* At runtime, this simply passes through to the original factory - no validation is performed.
* The original factory's signature is fully preserved, including any config parameters.
*
* @param factory - The original adapter factory function (e.g., `openaiText`, `anthropicText`)
* @param models - Array of custom model definitions with `name` and `input`
* @returns A new factory function that accepts both original and custom models
*
* @example
* ```typescript
* import { extendAdapter, createModel } from '@tanstack/ai'
* import { openaiText } from '@tanstack/ai-openai'
*
* // Define custom models
* const customModels = [
* createModel('my-fine-tuned-gpt4', ['text', 'image']),
* createModel('local-llama', ['text']),
* ] as const
*
* // Create extended adapter
* const myOpenai = extendAdapter(openaiText, customModels)
*
* // Use with original models - full type inference preserved
* const gpt4 = myOpenai('gpt-4o')
*
* // Use with custom models
* const custom = myOpenai('my-fine-tuned-gpt4')
*
* // Type error: 'invalid-model' is not a valid model
* // myOpenai('invalid-model')
*
* // Works with chat()
* chat({
* adapter: myOpenai('my-fine-tuned-gpt4'),
* messages: [...]
* })
* ```
*/
export function extendAdapter<
TFactory extends AnyAdapterFactory,
const TDefs extends ReadonlyArray<ExtendedModelDef>,
>(factory: TFactory, _customModels: TDefs): ExtendedFactory<TFactory, TDefs>
// The implementation signature stays at the honest `AnyAdapterFactory` width;
// the overload above performs the deliberate model-union widening.
export function extendAdapter(
factory: AnyAdapterFactory,
_customModels: ReadonlyArray<ExtendedModelDef>,
): AnyAdapterFactory {
// At runtime, we simply pass through to the original factory.
// The _customModels parameter is only used for type inference.
// No runtime validation - users are trusted to pass valid model names.
return factory
}