UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

274 lines 12.2 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.VertexAIDriver = void 0; exports.trimModelName = trimModelName; const core_1 = require("@llumiverse/core"); const api_fetch_client_1 = require("@vertesia/api-fetch-client"); const google_auth_library_1 = require("google-auth-library"); const embeddings_text_js_1 = require("./embeddings/embeddings-text.js"); const models_js_1 = require("./models.js"); const embeddings_image_js_1 = require("./embeddings/embeddings-image.js"); const aiplatform_1 = require("@google-cloud/aiplatform"); const vertex_sdk_1 = require("@anthropic-ai/vertex-sdk"); const imagen_js_1 = require("./models/imagen.js"); const genai_1 = require("@google/genai"); function trimModelName(model) { const i = model.lastIndexOf("@"); return i > -1 ? model.substring(0, i) : model; } class VertexAIDriver extends core_1.AbstractDriver { static PROVIDER = "vertexai"; provider = VertexAIDriver.PROVIDER; aiplatform; anthropicClient; fetchClient; googleGenAI; llamaClient; modelGarden; authClient; constructor(options) { super(options); this.aiplatform = undefined; this.anthropicClient = undefined; this.fetchClient = undefined; this.googleGenAI = undefined; this.modelGarden = undefined; this.llamaClient = undefined; this.authClient = options.googleAuthOptions?.authClient ?? new google_auth_library_1.GoogleAuth(options.googleAuthOptions); } getGoogleGenAIClient() { //Lazy initialisation if (!this.googleGenAI) { this.googleGenAI = new genai_1.GoogleGenAI({ project: this.options.project, location: this.options.region, vertexai: true, googleAuthOptions: { authClient: this.authClient, } }); } return this.googleGenAI; } getFetchClient() { //Lazy initialisation if (!this.fetchClient) { this.fetchClient = createFetchClient({ region: this.options.region, project: this.options.project, }).withAuthCallback(async () => { const accessTokenResponse = await this.authClient.getAccessToken(); const token = typeof accessTokenResponse === 'string' ? accessTokenResponse : accessTokenResponse?.token; return `Bearer ${token}`; }); } return this.fetchClient; } getLLamaClient(region = "us-central1") { //Lazy initialisation if (!this.llamaClient || this.llamaClient["region"] !== region) { this.llamaClient = createFetchClient({ region: region, project: this.options.project, apiVersion: "v1beta1", }).withAuthCallback(async () => { const accessTokenResponse = await this.authClient.getAccessToken(); const token = typeof accessTokenResponse === 'string' ? accessTokenResponse : accessTokenResponse?.token; return `Bearer ${token}`; }); // Store the region for potential client reuse this.llamaClient["region"] = region; } return this.llamaClient; } getAnthropicClient() { //Lazy initialisation if (!this.anthropicClient) { this.anthropicClient = new vertex_sdk_1.AnthropicVertex({ region: "us-east5", projectId: process.env.GOOGLE_PROJECT_ID, }); } return this.anthropicClient; } getAIPlatformClient() { //Lazy initialisation if (!this.aiplatform) { this.aiplatform = new aiplatform_1.v1beta1.ModelServiceClient({ projectId: this.options.project, apiEndpoint: `${this.options.region}-${API_BASE_PATH}`, authClient: this.authClient, }); } return this.aiplatform; } getModelGardenClient() { //Lazy initialisation if (!this.modelGarden) { this.modelGarden = new aiplatform_1.v1beta1.ModelGardenServiceClient({ projectId: this.options.project, apiEndpoint: `${this.options.region}-${API_BASE_PATH}`, authClient: this.authClient, }); } return this.modelGarden; } validateResult(result, options) { // Optionally preprocess the result before validation const modelDef = (0, models_js_1.getModelDefinition)(options.model); if (typeof modelDef.preValidationProcessing === "function") { const processed = modelDef.preValidationProcessing(result, options); result = processed.result; options = processed.options; } super.validateResult(result, options); } canStream(options) { if (options.output_modality == core_1.Modalities.image) { return Promise.resolve(false); } return Promise.resolve((0, models_js_1.getModelDefinition)(options.model).model.can_stream === true); } createPrompt(segments, options) { if (options.model.includes("imagen")) { return new imagen_js_1.ImagenModelDefinition(options.model).createPrompt(this, segments, options); } return (0, models_js_1.getModelDefinition)(options.model).createPrompt(this, segments, options); } async requestTextCompletion(prompt, options) { return (0, models_js_1.getModelDefinition)(options.model).requestTextCompletion(this, prompt, options); } async requestTextCompletionStream(prompt, options) { return (0, models_js_1.getModelDefinition)(options.model).requestTextCompletionStream(this, prompt, options); } async requestImageGeneration(_prompt, _options) { const splits = _options.model.split("/"); const modelName = trimModelName(splits[splits.length - 1]); return new imagen_js_1.ImagenModelDefinition(modelName).requestImageGeneration(this, _prompt, _options); } async listModels(_params) { // Get clients const modelGarden = this.getModelGardenClient(); const aiplatform = this.getAIPlatformClient(); let models = []; //Project specific deployed models const [response] = await aiplatform.listModels({ parent: `projects/${this.options.project}/locations/${this.options.region}`, }); models = models.concat(response.map((model) => ({ id: model.name?.split("/").pop() ?? "", name: model.displayName ?? "", provider: "vertexai" }))); //Model Garden Publisher models - Pretrained models const publishers = ["google", "anthropic", "meta"]; // Meta "maas" models are LLama Models-As-A-Service. Non-maas models are not pre-deployed. const supportedModels = { google: ["gemini", "imagen"], anthropic: ["claude"], meta: ["maas"] }; // Additional models not in the listings, but we want to include // TODO: Remove once the models are available in the listing API, or no longer needed const additionalModels = { google: ["imagen-3.0-fast-generate-001"], anthropic: [], meta: [ "llama-4-maverick-17b-128e-instruct-maas", "llama-4-scout-17b-16e-instruct-maas", "llama-3.3-70b-instruct-maas", "llama-3.2-90b-vision-instruct-maas", "llama-3.1-405b-instruct-maas", "llama-3.1-70b-instruct-maas", "llama-3.1-8b-instruct-maas", ], }; //Used to exclude retired models that are still in the listing API but not available for use. //Or models we do not support yet const unsupportedModelsByPublisher = { google: ["gemini-pro", "gemini-ultra"], anthropic: [], meta: [], }; for (const publisher of publishers) { let [response] = await modelGarden.listPublisherModels({ parent: `publishers/${publisher}`, orderBy: "name", listAllVersions: true, }); // Filter out the 100+ long list coming from Google models if (publisher === "google") { response = response.filter((model) => { return (model.supportedActions?.openGenerationAiStudio || undefined) !== undefined; }); } const modelFamily = supportedModels[publisher]; const retiredModels = unsupportedModelsByPublisher[publisher]; models = models.concat(response.filter((model) => { const modelName = model.name ?? ""; // Exclude retired models if (retiredModels.some(retiredModel => modelName.includes(retiredModel))) { return false; } // Check if the model belongs to the supported model families if (modelFamily.some(family => modelName.includes(family))) { return true; } return false; }).map(model => { const modelCapability = (0, core_1.getModelCapabilities)(model.name ?? '', "vertexai"); return { id: model.name ?? '', name: model.name?.split('/').pop() ?? '', provider: 'vertexai', owner: publisher, input_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.input), output_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.output), tool_support: modelCapability.tool_support, }; })); // Add additional models that are not in the listing for (const additionalModel of additionalModels[publisher]) { const publisherModelName = `publishers/${publisher}/models/${additionalModel}`; const modelCapability = (0, core_1.getModelCapabilities)(additionalModel, "vertexai"); models.push({ id: publisherModelName, name: additionalModel, provider: 'vertexai', owner: publisher, input_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.input), output_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.output), tool_support: modelCapability.tool_support, }); } } //Remove duplicates const uniqueModels = Array.from(new Set(models.map(a => a.id))) .map(id => { return models.find(a => a.id === id) ?? {}; }).sort((a, b) => a.id.localeCompare(b.id)); return uniqueModels; } validateConnection() { throw new Error("Method not implemented."); } async generateEmbeddings(options) { if (options.image || options.model?.includes("multimodal")) { if (options.text && options.image) { throw new Error("Text and Image simultaneous embedding not implemented. Submit separately"); } return (0, embeddings_image_js_1.getEmbeddingsForImages)(this, options); } const text_options = { content: options.text ?? "", model: options.model, }; return (0, embeddings_text_js_1.getEmbeddingsForText)(this, text_options); } } exports.VertexAIDriver = VertexAIDriver; //'us-central1-aiplatform.googleapis.com', const API_BASE_PATH = "aiplatform.googleapis.com"; function createFetchClient({ region, project, apiEndpoint, apiVersion = "v1", }) { const vertexBaseEndpoint = apiEndpoint ?? `${region}-${API_BASE_PATH}`; return new api_fetch_client_1.FetchClient(`https://${vertexBaseEndpoint}/${apiVersion}/projects/${project}/locations/${region}`).withHeaders({ "Content-Type": "application/json", }); } //# sourceMappingURL=index.js.map