UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

783 lines (704 loc) 32.2 kB
import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"; import { PredictionServiceClient, v1beta1 } from "@google-cloud/aiplatform"; import { type Content, GoogleGenAI, type Model } from "@google/genai"; import { type AIModel, AbstractDriver, type Completion, type CompletionChunkObject, type CompletionResult, type DriverOptions, type EmbeddingsOptions, type EmbeddingsResult, type ExecutionOptions, type LlumiverseError, type LlumiverseErrorContext, type ModelSearchPayload, type PromptSegment, getConversationMeta, getModelCapabilities, incrementConversationTurn, modelModalitiesToArray, stripBase64ImagesFromConversation, stripHeartbeatsFromConversation, truncateLargeTextInConversation, } from "@llumiverse/core"; import { FetchClient } from "@vertesia/api-fetch-client"; import { type AuthClient, GoogleAuth, type GoogleAuthOptions } from "google-auth-library"; import { getEmbeddingsForImages } from "./embeddings/embeddings-image.js"; import { type TextEmbeddingsOptions, getEmbeddingsForText } from "./embeddings/embeddings-text.js"; import { getModelDefinition } from "./models.js"; import { ANTHROPIC_REGIONS, NON_GLOBAL_ANTHROPIC_MODELS } from "./models/claude.js"; import { ImagenModelDefinition, type ImagenPrompt } from "./models/imagen.js"; export interface VertexAIDriverOptions extends DriverOptions { project: string; region: string; googleAuthOptions?: GoogleAuthOptions; } export interface GenerateContentPrompt { contents: Content[]; system?: Content; } //General Prompt type for VertexAI export type VertexAIPrompt = ImagenPrompt | GenerateContentPrompt; export function trimModelName(model: string) { const i = model.lastIndexOf("@"); return i > -1 ? model.substring(0, i) : model; } export class VertexAIDriver extends AbstractDriver<VertexAIDriverOptions, VertexAIPrompt> { static PROVIDER = "vertexai"; provider = VertexAIDriver.PROVIDER; aiplatform: v1beta1.ModelServiceClient | undefined; anthropicClient: AnthropicVertex | undefined; fetchClient: FetchClient | undefined; googleGenAI: GoogleGenAI | undefined; googleGenAIRegion: string | undefined; googleGenAIFlex: boolean | undefined; llamaClient: FetchClient & { region?: string } | undefined; modelGarden: v1beta1.ModelGardenServiceClient | undefined; imagenClient: PredictionServiceClient | undefined; googleAuth: GoogleAuth<any>; private authClientPromise: Promise<AuthClient> | undefined; constructor(options: VertexAIDriverOptions) { super(options); this.aiplatform = undefined; this.anthropicClient = undefined; this.fetchClient = undefined this.googleGenAI = undefined; this.googleGenAIRegion = undefined; this.googleGenAIFlex = undefined; this.modelGarden = undefined; this.llamaClient = undefined; this.imagenClient = undefined; this.googleAuth = new GoogleAuth(options.googleAuthOptions) as GoogleAuth<any>; this.authClientPromise = undefined; } private async getAuthClient(): Promise<AuthClient> { if (!this.authClientPromise) { this.authClientPromise = this.googleAuth.getClient(); } return this.authClientPromise; } public getGoogleGenAIClient(region: string = this.options.region, flex: boolean = false): GoogleGenAI { if (this.googleGenAI && this.googleGenAIRegion === region && this.googleGenAIFlex === flex) { // Return existing client if region and flex settings match return this.googleGenAI; } this.googleGenAI = this.buildGoogleGenAIClient(region, flex); this.googleGenAIRegion = region; this.googleGenAIFlex = flex; return this.googleGenAI; } private buildGoogleGenAIClient(region: string, flex: boolean): GoogleGenAI { return new GoogleGenAI({ project: this.options.project, location: region, vertexai: true, googleAuthOptions: this.options.googleAuthOptions || { scopes: ["https://www.googleapis.com/auth/cloud-platform"], }, ...(flex ? { httpOptions: { headers: { "X-Vertex-AI-LLM-Request-Type": "shared", "X-Vertex-AI-LLM-Shared-Request-Type": "flex", } } } : {}), }); } public getFetchClient(): FetchClient { //Lazy initialization if (!this.fetchClient) { this.fetchClient = createFetchClient({ region: this.options.region, project: this.options.project, }).withAuthCallback(async () => { const token = await this.googleAuth.getAccessToken(); return `Bearer ${token}`; }); } return this.fetchClient; } public getLLamaClient(region: string = "us-central1"): FetchClient { //Lazy initialization if (!this.llamaClient || this.llamaClient["region"] !== region) { this.llamaClient = createFetchClient({ region: region, project: this.options.project, apiVersion: "v1beta1", }).withAuthCallback(async () => { const token = await this.googleAuth.getAccessToken(); return `Bearer ${token}`; }); // Store the region for potential client reuse this.llamaClient["region"] = region; } return this.llamaClient; } public async getAnthropicClient(region: string = this.options.region): Promise<AnthropicVertex> { // Extract region prefix and map if it exists in ANTHROPIC_REGIONS, otherwise use as-is const getRegionPrefix = (r: string) => r.split('-')[0]; const regionPrefix = getRegionPrefix(region); const mappedRegion = ANTHROPIC_REGIONS[regionPrefix] || region; const defaultRegionPrefix = getRegionPrefix(this.options.region); const defaultMappedRegion = ANTHROPIC_REGIONS[defaultRegionPrefix] || this.options.region; // Get auth client to avoid version mismatch with GoogleAuth generic types const authClient = await this.getAuthClient(); // If mapped region is different from default mapped region, create one-off client if (mappedRegion !== defaultMappedRegion) { return new AnthropicVertex({ timeout: 20 * 60 * 10000, // Set to 20 minutes, 10 minute default, setting this disables long request error: https://github.com/anthropics/anthropic-sdk-typescript?#long-requests region: mappedRegion, projectId: this.options.project, authClient: authClient, }); } //Lazy initialization for default region if (!this.anthropicClient) { this.anthropicClient = new AnthropicVertex({ timeout: 20 * 60 * 10000, // Set to 20 minutes, 10 minute default, setting this disables long request error: https://github.com/anthropics/anthropic-sdk-typescript?#long-requests region: mappedRegion, projectId: this.options.project, authClient: authClient, }); } return this.anthropicClient; } public async getAIPlatformClient(): Promise<v1beta1.ModelServiceClient> { //Lazy initialization if (!this.aiplatform) { const authClient = await this.getAuthClient(); this.aiplatform = new v1beta1.ModelServiceClient({ projectId: this.options.project, apiEndpoint: `${this.options.region}-${API_BASE_PATH}`, authClient, }); } return this.aiplatform; } public async getModelGardenClient(): Promise<v1beta1.ModelGardenServiceClient> { //Lazy initialization if (!this.modelGarden) { const authClient = await this.getAuthClient(); this.modelGarden = new v1beta1.ModelGardenServiceClient({ projectId: this.options.project, apiEndpoint: `${this.options.region}-${API_BASE_PATH}`, authClient, }); } return this.modelGarden; } public async getImagenClient(): Promise<PredictionServiceClient> { //Lazy initialization if (!this.imagenClient) { // TODO: make location configurable, fixed to us-central1 for now const authClient = await this.getAuthClient(); this.imagenClient = new PredictionServiceClient({ projectId: this.options.project, apiEndpoint: `us-central1-${API_BASE_PATH}`, authClient, }); } return this.imagenClient; } validateResult(result: Completion, options: ExecutionOptions) { // Optionally preprocess the result before validation const modelDef = getModelDefinition(options.model); if (typeof modelDef.preValidationProcessing === "function") { const processed = modelDef.preValidationProcessing(result, options); result = processed.result; options = processed.options; } super.validateResult(result, options); } protected canStream(options: ExecutionOptions): Promise<boolean> { if (this.isImageModel(options.model)) { return Promise.resolve(false); } return Promise.resolve(getModelDefinition(options.model).model.can_stream === true); } protected isImageModel(model: string): boolean { return model.includes("imagen"); } public createPrompt(segments: PromptSegment[], options: ExecutionOptions): Promise<VertexAIPrompt> { if (this.isImageModel(options.model)) { return new ImagenModelDefinition(options.model).createPrompt(this, segments, options); } return getModelDefinition(options.model).createPrompt(this, segments, options); } async requestTextCompletion(prompt: VertexAIPrompt, options: ExecutionOptions): Promise<Completion> { return getModelDefinition(options.model).requestTextCompletion(this, prompt, options); } async requestTextCompletionStream( prompt: VertexAIPrompt, options: ExecutionOptions, ): Promise<AsyncIterable<CompletionChunkObject>> { return getModelDefinition(options.model).requestTextCompletionStream(this, prompt, options); } /** * Build conversation context after streaming completion. * Reconstructs the assistant message from accumulated results and applies stripping. * Handles both Gemini (Content[]) and Claude (ClaudePrompt) formats. */ buildStreamingConversation( prompt: VertexAIPrompt, result: unknown[], toolUse: unknown[] | undefined, options: ExecutionOptions ): Content[] | unknown | undefined { // Handle Claude-style prompts (has 'messages' array) if ('messages' in prompt && Array.isArray((prompt as any).messages)) { return this.buildClaudeStreamingConversation(prompt as any, result, toolUse, options); } // Only handle Gemini-style prompts with contents array if (!('contents' in prompt) || !Array.isArray(prompt.contents)) { return undefined; } const completionResults = result as CompletionResult[]; // Convert accumulated results to text content for assistant message const textContent = completionResults .map(r => { switch (r.type) { case 'text': return r.value; case 'json': return typeof r.value === 'string' ? r.value : JSON.stringify(r.value); case 'image': // Skip images in conversation - they're in the result return ''; default: return String((r as any).value || ''); } }) .join(''); // Build parts array for assistant message const parts: any[] = []; if (textContent) { parts.push({ text: textContent }); } // Add function calls if present (Gemini format) if (toolUse && toolUse.length > 0) { for (const tool of toolUse as any[]) { const functionCallPart: any = { functionCall: { name: tool.tool_name, args: tool.tool_input, } }; // Include thought_signature for Gemini thinking models (2.5+/3.0+) // This must be preserved in the conversation for subsequent API calls if (tool.thought_signature) { functionCallPart.thoughtSignature = tool.thought_signature; } parts.push(functionCallPart); } } // prompt.contents already includes the conversation history // (merged in requestTextCompletionStream via updateConversation), // so we use it directly — do NOT prepend options.conversation again. let conversation: Content[] = [ ...prompt.contents, ]; // Only add assistant message if there's actual content // (Empty text parts can cause API errors) if (parts.length > 0) { conversation.push({ role: 'model', parts: parts }); } // Increment turn counter conversation = incrementConversationTurn(conversation) as Content[]; // Apply stripping based on options const currentTurn = getConversationMeta(conversation).turnNumber; const stripOptions = { keepForTurns: options.stripImagesAfterTurns ?? Infinity, currentTurn, textMaxTokens: options.stripTextMaxTokens }; let processedConversation = stripBase64ImagesFromConversation(conversation, stripOptions); processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions); processedConversation = stripHeartbeatsFromConversation(processedConversation, { keepForTurns: options.stripHeartbeatsAfterTurns ?? 1, currentTurn, }); // Preserve system instruction in conversation for Gemini multi-turn support. // The Gemini API takes system as a separate parameter (not in contents), // so we must store it in the conversation wrapper to survive serialization. const geminiPrompt = prompt as GenerateContentPrompt; if (geminiPrompt.system) { if (typeof processedConversation === 'object' && processedConversation !== null) { processedConversation = { ...processedConversation as object, _llumiverse_system: geminiPrompt.system }; } } return processedConversation; } /** * Build conversation for Claude streaming. * Creates assistant message with tool_use blocks in Claude's ContentBlock format. */ private buildClaudeStreamingConversation( prompt: { messages: unknown[]; system?: unknown[] }, result: unknown[], toolUse: unknown[] | undefined, options: ExecutionOptions ): unknown { const completionResults = result as CompletionResult[]; // Convert accumulated results to text content const textContent = completionResults .map(r => { switch (r.type) { case 'text': return r.value; case 'json': return typeof r.value === 'string' ? r.value : JSON.stringify(r.value); case 'image': return ''; default: return String((r as any).value || ''); } }) .join(''); // Build Claude-style ContentBlock array for assistant message const content: unknown[] = []; // Add text block if there's text content if (textContent) { content.push({ type: 'text', text: textContent }); } // Add tool_use blocks in Claude format if (toolUse && toolUse.length > 0) { for (const tool of toolUse as any[]) { content.push({ type: 'tool_use', id: tool.id, name: tool.tool_name, input: tool.tool_input ?? {} }); } } // Claude's requestTextCompletionStream does NOT mutate prompt.messages // to include history, so we must prepend options.conversation here. const existingMessages = (options.conversation as any)?.messages ?? []; const existingSystem = (options.conversation as any)?.system ?? prompt.system; // Build the new messages array const newMessages = [ ...existingMessages, ...prompt.messages, ]; // Only add assistant message if there's actual content // (Claude API rejects empty text content blocks) if (content.length > 0) { newMessages.push({ role: 'assistant', content: content }); } // Build the new conversation in ClaudePrompt format const conversation = { messages: newMessages, system: existingSystem }; // Increment turn counter const withTurn = incrementConversationTurn(conversation); // Apply stripping based on options const currentTurn = getConversationMeta(withTurn).turnNumber; const stripOptions = { keepForTurns: options.stripImagesAfterTurns ?? Infinity, currentTurn, textMaxTokens: options.stripTextMaxTokens }; let processedConversation = stripBase64ImagesFromConversation(withTurn, stripOptions); processedConversation = truncateLargeTextInConversation(processedConversation, stripOptions); processedConversation = stripHeartbeatsFromConversation(processedConversation, { keepForTurns: options.stripHeartbeatsAfterTurns ?? 1, currentTurn, }); return processedConversation; } async requestImageGeneration( _prompt: ImagenPrompt, _options: ExecutionOptions, ): Promise<Completion> { const splits = _options.model.split("/"); const modelName = trimModelName(splits[splits.length - 1]); return new ImagenModelDefinition(modelName).requestImageGeneration(this, _prompt, _options); } async getGenAIModelsArray(client: GoogleGenAI): Promise<Model[]> { const models: Model[] = []; const pager = await client.models.list(); for await (const item of pager) { models.push(item); } return models; } async listModels(_params?: ModelSearchPayload): Promise<AIModel<string>[]> { // Get clients const modelGarden = await this.getModelGardenClient(); const aiplatform = await this.getAIPlatformClient(); const globalGenAiClient = this.getGoogleGenAIClient("global"); let models: AIModel<string>[] = []; //Model Garden Publisher models - Pretrained models const publishers = ["google", "anthropic", "meta"]; // Meta "maas" models are LLama Models-As-A-Service. Non-maas models are not pre-deployed. const supportedModels = { google: ["gemini", "imagen"], anthropic: ["claude"], meta: ["maas"] }; // Additional models not in the listings, but we want to include // TODO: Remove once the models are available in the listing API, or no longer needed const additionalModels = { google: [ "imagen-3.0-fast-generate-001", ], anthropic: [], meta: [ "llama-4-maverick-17b-128e-instruct-maas", "llama-4-scout-17b-16e-instruct-maas", "llama-3.3-70b-instruct-maas", "llama-3.2-90b-vision-instruct-maas", "llama-3.1-405b-instruct-maas", "llama-3.1-70b-instruct-maas", "llama-3.1-8b-instruct-maas", ], } //Used to exclude retired models that are still in the listing API but not available for use. //Or models we do not support yet const unsupportedModelsByPublisher = { google: ["gemini-pro", "gemini-ultra", "imagen-product-recontext-preview", "embedding", "gemini-live-2.5-flash-preview-native-audio", "computer-use-preview"], anthropic: [], meta: [], }; // Start all network requests in parallel const aiplatformPromise = aiplatform.listModels({ parent: `projects/${this.options.project}/locations/${this.options.region}`, }); const publisherPromises = publishers.map(async (publisher) => { const [response] = await modelGarden.listPublisherModels({ parent: `publishers/${publisher}`, orderBy: "name", listAllVersions: true, }); return { publisher, response }; }); const globalGooglePromise = this.getGenAIModelsArray(globalGenAiClient); // Await all network requests const [aiplatformResult, globalGoogleResult, ...publisherResults] = await Promise.all([ aiplatformPromise, globalGooglePromise, ...publisherPromises, ]); // Process aiplatform models, project specific models const [response] = aiplatformResult; models = models.concat( response.map((model) => ({ id: model.name?.split("/").pop() ?? "", name: model.displayName ?? "", provider: "vertexai" })) ); // Process global google models from GenAI models = models.concat( globalGoogleResult.map((model) => { const modelCapability = getModelCapabilities(model.name ?? '', "vertexai"); return { id: "locations/global/" + model.name, name: "Global " + model.name?.split('/').pop(), provider: "vertexai", owner: "google", input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, }; }) ); // Process publisher models for (const result of publisherResults) { const { publisher, response } = result; const modelFamily = supportedModels[publisher as keyof typeof supportedModels]; const retiredModels = unsupportedModelsByPublisher[publisher as keyof typeof unsupportedModelsByPublisher]; models = models.concat(response.filter((model) => { const modelName = model.name ?? ""; // Exclude retired models if (retiredModels.some(retiredModel => modelName.includes(retiredModel))) { return false; } // Check if the model belongs to the supported model families if (modelFamily.some(family => modelName.includes(family))) { return true; } return false; }).map(model => { const modelCapability = getModelCapabilities(model.name ?? '', "vertexai"); return { id: model.name ?? '', name: model.name?.split('/').pop() ?? '', provider: 'vertexai', owner: publisher, input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, } satisfies AIModel<string>; })); // Create global google gemini models for Gemini 2.5 and later, if missing from GenAI listing if (publisher === 'google') { const globalGeminiModels = response.filter((model) => { const modelName = model.name ?? ""; if (retiredModels.some(retiredModel => modelName.includes(retiredModel))) { return false; } if (modelFamily.some(family => modelName.includes(family))) { const versionMatch = modelName.match(/gemini-(\d+(?:\.\d+)?)/); if (versionMatch) { const version = parseFloat(versionMatch[1]); if (version >= 2.5) { // Check if already present const shortName = modelName.split('/').pop(); const globalName = "Global " + shortName; if (models.some(m => m.name === globalName)) { return false; } return true; } } return false; } return false; }).map(model => { const modelCapability = getModelCapabilities(model.name ?? '', "vertexai"); return { id: "locations/global/" + model.name, name: "Global " + model.name?.split('/').pop(), provider: 'vertexai', owner: publisher, input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, } satisfies AIModel<string>; }); models = models.concat(globalGeminiModels); } // Create global anthropic models for those not in NON_GLOBAL_ANTHROPIC_MODELS if (publisher === 'anthropic') { const globalAnthropicModels = response.filter((model) => { const modelName = model.name ?? ""; if (retiredModels.some(retiredModel => modelName.includes(retiredModel))) { return false; } if (modelFamily.some(family => modelName.includes(family))) { if (modelName.includes("claude-3-7")) { return true; } return !NON_GLOBAL_ANTHROPIC_MODELS.some(nonGlobalModel => modelName.includes(nonGlobalModel)); } return false; }).map(model => { const modelCapability = getModelCapabilities(model.name ?? '', "vertexai"); return { id: "locations/global/" + model.name, name: "Global " + model.name?.split('/').pop(), provider: 'vertexai', owner: publisher, input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, } satisfies AIModel<string>; }); models = models.concat(globalAnthropicModels); } // Add additional models that are not in the listing for (const additionalModel of additionalModels[publisher as keyof typeof additionalModels]) { const publisherModelName = `publishers/${publisher}/models/${additionalModel}`; const modelCapability = getModelCapabilities(additionalModel, "vertexai"); models.push({ id: publisherModelName, name: additionalModel, provider: 'vertexai', owner: publisher, input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, } satisfies AIModel<string>); } } //Remove duplicates const uniqueModels = Array.from(new Set(models.map(a => a.id))) .map(id => { return models.find(a => a.id === id) ?? {} as AIModel<string>; }).sort((a, b) => a.id.localeCompare(b.id)); return uniqueModels; } validateConnection(): Promise<boolean> { throw new Error("Method not implemented."); } async generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult> { if (options.image || options.model?.includes("multimodal")) { if (options.text && options.image) { throw new Error("Text and Image simultaneous embedding not implemented. Submit separately"); } return getEmbeddingsForImages(this, options); } const text_options: TextEmbeddingsOptions = { content: options.text ?? "", model: options.model, }; return getEmbeddingsForText(this, text_options); } /** * Cleanup Google Cloud clients when the driver is evicted from the cache. */ destroy(): void { this.aiplatform?.close(); this.modelGarden?.close(); this.imagenClient?.close(); } /** * Format VertexAI errors by routing to model-specific error handlers. * Each model definition (Gemini, Claude, Llama) can provide custom error parsing * based on their specific SDK error structures. * * @param error - The error from the VertexAI/model SDK * @param context - Context about where the error occurred * @returns A standardized LlumiverseError */ public formatLlumiverseError( error: unknown, context: LlumiverseErrorContext ): LlumiverseError { // Get the model definition for this request const modelDef = getModelDefinition(context.model); // If the model definition provides custom error handling, use it if (modelDef.formatLlumiverseError) { try { return modelDef.formatLlumiverseError(this, error, context); } catch (formattingError) { // If model-specific handler throws, fall through to default handling // This allows model handlers to explicitly opt out for certain errors } } // Fall back to default AbstractDriver error handling return super.formatLlumiverseError(error, context); } } //'us-central1-aiplatform.googleapis.com', const API_BASE_PATH = "aiplatform.googleapis.com"; function createFetchClient({ region, project, apiEndpoint, apiVersion = "v1", }: { region: string; project: string; apiEndpoint?: string; apiVersion?: string; }) { const vertexBaseEndpoint = apiEndpoint ?? `${region}-${API_BASE_PATH}`; return new FetchClient( `https://${vertexBaseEndpoint}/${apiVersion}/projects/${project}/locations/${region}`, ).withHeaders({ "Content-Type": "application/json", }); }