workers-ai-provider
Version:
Workers AI Provider for the vercel AI SDK
203 lines (190 loc) • 7.01 kB
TypeScript
import { LanguageModelV2, EmbeddingModelV2, ImageModelV2 } from '@ai-sdk/provider';
type StringLike = string | {
toString(): string;
};
type AutoRAGChatSettings = {
/**
* Whether to inject a safety prompt before all conversations.
* Defaults to `false`.
*/
safePrompt?: boolean;
} & {
/**
* Passthrough settings that are provided directly to the run function.
*/
[key: string]: StringLike;
};
/**
* The names of the BaseAiTextGeneration models.
*/
type TextGenerationModels = Exclude<value2key<AiModels, BaseAiTextGeneration>, value2key<AiModels, BaseAiTextToImage>>;
type ImageGenerationModels = value2key<AiModels, BaseAiTextToImage>;
/**
* The names of the BaseAiTextToEmbeddings models.
*/
type EmbeddingModels = value2key<AiModels, BaseAiTextEmbeddings>;
type value2key<T, V> = {
[K in keyof T]: T[K] extends V ? K : never;
}[keyof T];
type AutoRAGChatConfig = {
provider: string;
binding: AutoRAG;
gateway?: GatewayOptions;
};
declare class AutoRAGChatLanguageModel implements LanguageModelV2 {
readonly specificationVersion = "v2";
readonly defaultObjectGenerationMode = "json";
readonly supportedUrls: Record<string, RegExp[]> | PromiseLike<Record<string, RegExp[]>>;
readonly modelId: TextGenerationModels;
readonly settings: AutoRAGChatSettings;
private readonly config;
constructor(modelId: TextGenerationModels, settings: AutoRAGChatSettings, config: AutoRAGChatConfig);
get provider(): string;
private getArgs;
doGenerate(options: Parameters<LanguageModelV2["doGenerate"]>[0]): Promise<Awaited<ReturnType<LanguageModelV2["doGenerate"]>>>;
doStream(options: Parameters<LanguageModelV2["doStream"]>[0]): Promise<Awaited<ReturnType<LanguageModelV2["doStream"]>>>;
}
type WorkersAIEmbeddingConfig = {
provider: string;
binding: Ai;
gateway?: GatewayOptions;
};
type WorkersAIEmbeddingSettings = {
gateway?: GatewayOptions;
maxEmbeddingsPerCall?: number;
supportsParallelCalls?: boolean;
} & {
/**
* Arbitrary provider-specific options forwarded unmodified.
*/
[key: string]: StringLike;
};
declare class WorkersAIEmbeddingModel implements EmbeddingModelV2<string> {
/**
* Semantic version of the {@link EmbeddingModelV1} specification implemented
* by this class. It never changes.
*/
readonly specificationVersion = "v2";
readonly modelId: EmbeddingModels;
private readonly config;
private readonly settings;
/**
* Provider name exposed for diagnostics and error reporting.
*/
get provider(): string;
get maxEmbeddingsPerCall(): number;
get supportsParallelCalls(): boolean;
constructor(modelId: EmbeddingModels, settings: WorkersAIEmbeddingSettings, config: WorkersAIEmbeddingConfig);
doEmbed({ values, }: Parameters<EmbeddingModelV2<string>["doEmbed"]>[0]): Promise<Awaited<ReturnType<EmbeddingModelV2<string>["doEmbed"]>>>;
}
type WorkersAIChatSettings = {
/**
* Whether to inject a safety prompt before all conversations.
* Defaults to `false`.
*/
safePrompt?: boolean;
/**
* Optionally set Cloudflare AI Gateway options.
*/
gateway?: GatewayOptions;
} & {
/**
* Passthrough settings that are provided directly to the run function.
*/
[key: string]: StringLike;
};
type WorkersAIChatConfig = {
provider: string;
binding: Ai;
gateway?: GatewayOptions;
};
declare class WorkersAIChatLanguageModel implements LanguageModelV2 {
readonly specificationVersion = "v2";
readonly defaultObjectGenerationMode = "json";
readonly supportedUrls: Record<string, RegExp[]> | PromiseLike<Record<string, RegExp[]>>;
readonly modelId: TextGenerationModels;
readonly settings: WorkersAIChatSettings;
private readonly config;
constructor(modelId: TextGenerationModels, settings: WorkersAIChatSettings, config: WorkersAIChatConfig);
get provider(): string;
private getArgs;
doGenerate(options: Parameters<LanguageModelV2["doGenerate"]>[0]): Promise<Awaited<ReturnType<LanguageModelV2["doGenerate"]>>>;
doStream(options: Parameters<LanguageModelV2["doStream"]>[0]): Promise<Awaited<ReturnType<LanguageModelV2["doStream"]>>>;
}
type WorkersAIImageConfig = {
provider: string;
binding: Ai;
gateway?: GatewayOptions;
};
type WorkersAIImageSettings = {
maxImagesPerCall?: number;
};
declare class WorkersAIImageModel implements ImageModelV2 {
readonly modelId: ImageGenerationModels;
readonly settings: WorkersAIImageSettings;
readonly config: WorkersAIImageConfig;
readonly specificationVersion = "v2";
get maxImagesPerCall(): number;
get provider(): string;
constructor(modelId: ImageGenerationModels, settings: WorkersAIImageSettings, config: WorkersAIImageConfig);
doGenerate({ prompt, n, size, aspectRatio, seed, }: Parameters<ImageModelV2["doGenerate"]>[0]): Promise<Awaited<ReturnType<ImageModelV2["doGenerate"]>>>;
}
type WorkersAISettings = ({
/**
* Provide a Cloudflare AI binding.
*/
binding: Ai;
/**
* Credentials must be absent when a binding is given.
*/
accountId?: never;
apiKey?: never;
} | {
/**
* Provide Cloudflare API credentials directly. Must be used if a binding is not specified.
*/
accountId: string;
apiKey: string;
/**
* Both binding must be absent if credentials are used directly.
*/
binding?: never;
}) & {
/**
* Optionally specify a gateway.
*/
gateway?: GatewayOptions;
};
interface WorkersAI {
(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
/**
* Creates a model for text generation.
**/
chat(modelId: TextGenerationModels, settings?: WorkersAIChatSettings): WorkersAIChatLanguageModel;
embedding(modelId: EmbeddingModels, settings?: WorkersAIEmbeddingSettings): WorkersAIEmbeddingModel;
textEmbedding(modelId: EmbeddingModels, settings?: WorkersAIEmbeddingSettings): WorkersAIEmbeddingModel;
textEmbeddingModel(modelId: EmbeddingModels, settings?: WorkersAIEmbeddingSettings): WorkersAIEmbeddingModel;
/**
* Creates a model for image generation.
**/
image(modelId: ImageGenerationModels, settings?: WorkersAIImageSettings): WorkersAIImageModel;
}
/**
* Create a Workers AI provider instance.
*/
declare function createWorkersAI(options: WorkersAISettings): WorkersAI;
type AutoRAGSettings = {
binding: AutoRAG;
};
interface AutoRAGProvider {
(options?: AutoRAGChatSettings): AutoRAGChatLanguageModel;
/**
* Creates a model for text generation.
**/
chat(settings?: AutoRAGChatSettings): AutoRAGChatLanguageModel;
}
/**
* Create a Workers AI provider instance.
*/
declare function createAutoRAG(options: AutoRAGSettings): AutoRAGProvider;
export { type AutoRAGProvider, type AutoRAGSettings, type WorkersAI, type WorkersAISettings, createAutoRAG, createWorkersAI };