UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

962 lines 42.1 kB
import { Bedrock, CreateModelCustomizationJobCommand, GetModelCustomizationJobCommand, ModelCustomizationJobStatus, ModelModality, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock"; import { BedrockRuntime } from "@aws-sdk/client-bedrock-runtime"; import { S3Client } from "@aws-sdk/client-s3"; import { AbstractDriver, Modalities, TrainingJobStatus, getMaxTokensLimitBedrock, modelModalitiesToArray, getModelCapabilities } from "@llumiverse/core"; import { transformAsyncIterator } from "@llumiverse/core/async"; import { formatNovaPrompt } from "@llumiverse/core/formatters"; import { LRUCache } from "mnemonist"; import { converseConcatMessages, converseJSONprefill, converseSystemToMessages, formatConversePrompt } from "./converse.js"; import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js"; import { forceUploadFile } from "./s3.js"; import { formatTwelvelabsPegasusPrompt } from "./twelvelabs.js"; const supportStreamingCache = new LRUCache(4096); var BedrockModelType; (function (BedrockModelType) { BedrockModelType["FoundationModel"] = "foundation-model"; BedrockModelType["InferenceProfile"] = "inference-profile"; BedrockModelType["CustomModel"] = "custom-model"; BedrockModelType["Unknown"] = "unknown"; })(BedrockModelType || (BedrockModelType = {})); ; function converseFinishReason(reason) { //Possible values: //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered if (!reason) return undefined; switch (reason) { case 'end_turn': return "stop"; case 'max_tokens': return "length"; default: return reason; } } //Used to get a max_token value when not specified in the model options. Claude requires it to be set. function maxTokenFallbackClaude(option) { const modelOptions = option.model_options; if (modelOptions && typeof modelOptions.max_tokens === "number") { return modelOptions.max_tokens; } else { const thinking_budget = modelOptions?.thinking_budget_tokens ?? 0; let maxSupportedTokens = getMaxTokensLimitBedrock(option.model) ?? 8192; // Should always return a number for claude, 8192 is to satisfy the TypeScript type checker; // Fallback to the default max tokens limit for the model if (option.model.includes('claude-3-7-sonnet') && (modelOptions?.thinking_budget_tokens ?? 0) < 48000) { maxSupportedTokens = 64000; // Claude 3.7 can go up to 128k with a beta header, but when no max tokens is specified, we default to 64k. } return Math.min(16000 + thinking_budget, maxSupportedTokens); // Cap to 16k, to avoid taking up too much context window and quota. } } export class BedrockDriver extends AbstractDriver { static PROVIDER = "bedrock"; provider = BedrockDriver.PROVIDER; _executor; _service; _service_region; constructor(options) { super(options); if (!options.region) { throw new Error("No region found. Set the region in the environment's endpoint URL."); } } getExecutor() { if (!this._executor) { this._executor = new BedrockRuntime({ region: this.options.region, credentials: this.options.credentials, }); } return this._executor; } getService(region = this.options.region) { if (!this._service || this._service_region != region) { this._service = new Bedrock({ region: region, credentials: this.options.credentials, }); this._service_region = region; } return this._service; } async formatPrompt(segments, opts) { if (opts.model.includes("canvas")) { return await formatNovaPrompt(segments, opts.result_schema); } if (opts.model.includes("twelvelabs.pegasus")) { return await formatTwelvelabsPegasusPrompt(segments, opts); } return await formatConversePrompt(segments, opts); } getExtractedExecution(result, _prompt, options) { let resultText = ""; let reasoning = ""; if (result.output?.message?.content) { for (const content of result.output.message.content) { // Get text output if (content.text) { resultText += content.text; } else if (content.reasoningContent) { // Get reasoning content only if include_thoughts is true const claudeOptions = options?.model_options; if (claudeOptions?.include_thoughts) { if (content.reasoningContent.reasoningText) { reasoning += content.reasoningContent.reasoningText.text; } else if (content.reasoningContent.redactedContent) { // Handle redacted thinking content const redactedData = new TextDecoder().decode(content.reasoningContent.redactedContent); reasoning += `[Redacted thinking: ${redactedData}]`; } } else { this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false"); } } else { // Get content block type const type = Object.keys(content).find(key => key !== '$unknown' && content[key] !== undefined); this.logger.info({ type }, "[Bedrock] Unsupported content response type:"); } } // Add spacing if we have reasoning content if (reasoning) { reasoning += '\n\n'; } } const completionResult = { result: reasoning + resultText ? [{ type: "text", value: reasoning + resultText }] : [], token_usage: { prompt: result.usage?.inputTokens, result: result.usage?.outputTokens, total: result.usage?.totalTokens, }, finish_reason: converseFinishReason(result.stopReason), }; return completionResult; } ; getExtractedStream(result, _prompt, options) { let output = ""; let reasoning = ""; let stop_reason = ""; let token_usage; // Check if we should include thoughts const shouldIncludeThoughts = options && options.model_options?.include_thoughts; // Handle content block start events (for reasoning blocks) if (result.contentBlockStart) { // Handle redacted content at block start if (result.contentBlockStart.start && 'reasoningContent' in result.contentBlockStart.start && shouldIncludeThoughts) { const reasoningStart = result.contentBlockStart.start; if (reasoningStart.reasoningContent?.redactedContent) { const redactedData = new TextDecoder().decode(reasoningStart.reasoningContent.redactedContent); reasoning = `[Redacted thinking: ${redactedData}]`; } } } // Handle content block deltas (text and reasoning) if (result.contentBlockDelta) { const delta = result.contentBlockDelta.delta; if (delta?.text) { output = delta.text; } else if (delta?.reasoningContent && shouldIncludeThoughts) { if (delta.reasoningContent.text) { reasoning = delta.reasoningContent.text; } else if (delta.reasoningContent.redactedContent) { const redactedData = new TextDecoder().decode(delta.reasoningContent.redactedContent); reasoning = `[Redacted thinking: ${redactedData}]`; } else if (delta.reasoningContent.signature) { // Handle signature updates for reasoning content - end of thinking reasoning = "\n\n"; // Putting logging here so it only triggers once. this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false"); } } else if (delta) { // Get content block type const type = Object.keys(delta).find(key => key !== '$unknown' && delta[key] !== undefined); this.logger.info({ type }, "[Bedrock] Unsupported content response type:"); } } // Handle content block stop events if (result.contentBlockStop) { // Content block ended - could be end of reasoning or text block // Add minimal spacing for reasoning blocks if not already present if (reasoning && !reasoning.endsWith('\n\n') && shouldIncludeThoughts) { reasoning += '\n\n'; } } if (result.messageStop) { stop_reason = result.messageStop.stopReason ?? ""; } if (result.metadata) { token_usage = { prompt: result.metadata.usage?.inputTokens, result: result.metadata.usage?.outputTokens, total: result.metadata.usage?.totalTokens, }; } const completionResult = { result: reasoning + output ? [{ type: "text", value: reasoning + output }] : [], token_usage: token_usage, finish_reason: converseFinishReason(stop_reason), }; return completionResult; } ; extractRegion(modelString, defaultRegion) { // Match region in full ARN pattern const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/); if (arnMatch) { return arnMatch[1]; } // Match common AWS regions directly in string const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/); if (regionMatch) { return regionMatch[0]; } return defaultRegion; } async getCanStream(model, type) { let canStream = false; let error = null; const region = this.extractRegion(model, this.options.region); if (type == BedrockModelType.FoundationModel || type == BedrockModelType.Unknown) { try { const response = await this.getService(region).getFoundationModel({ modelIdentifier: model }); canStream = response.modelDetails?.responseStreamingSupported ?? false; return canStream; } catch (e) { error = e; } } if (type == BedrockModelType.InferenceProfile || type == BedrockModelType.Unknown) { try { const response = await this.getService(region).getInferenceProfile({ inferenceProfileIdentifier: model }); canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel); return canStream; } catch (e) { error = e; } } if (type == BedrockModelType.CustomModel || type == BedrockModelType.Unknown) { try { const response = await this.getService(region).getCustomModel({ modelIdentifier: model }); canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel); return canStream; } catch (e) { error = e; } } if (error) { console.warn("Error on canStream check for model: " + model + " region detected: " + region, error); } return canStream; } async canStream(options) { // // TwelveLabs Pegasus supports streaming according to the documentation // if (options.model.includes("twelvelabs.pegasus")) { // return true; // } let canStream = supportStreamingCache.get(options.model); if (canStream == null) { let type = BedrockModelType.Unknown; if (options.model.includes("foundation-model")) { type = BedrockModelType.FoundationModel; } else if (options.model.includes("inference-profile")) { type = BedrockModelType.InferenceProfile; } else if (options.model.includes("custom-model")) { type = BedrockModelType.CustomModel; } canStream = await this.getCanStream(options.model, type); supportStreamingCache.set(options.model, canStream); } return canStream; } async requestTextCompletion(prompt, options) { // Handle Twelvelabs Pegasus models if (options.model.includes("twelvelabs.pegasus")) { return this.requestTwelvelabsPegasusCompletion(prompt, options); } // Handle other Bedrock models that use Converse API const conversePrompt = prompt; let conversation = updateConversation(options.conversation, conversePrompt); const payload = this.preparePayload(conversation, options); const executor = this.getExecutor(); const res = await executor.converse({ ...payload, }); conversation = updateConversation(conversation, { messages: [res.output?.message ?? { content: [{ text: "" }], role: "assistant" }], modelId: conversePrompt.modelId, }); let tool_use = undefined; //Get tool requests, we check tool use regardless of finish reason, as you can hit length and still get a valid response. tool_use = res.output?.message?.content?.reduce((tools, c) => { if (c.toolUse) { tools.push({ tool_name: c.toolUse.name ?? "", tool_input: c.toolUse.input, id: c.toolUse.toolUseId ?? "", }); } return tools; }, []); //If no tools were used, set to undefined if (tool_use && tool_use.length == 0) { tool_use = undefined; } const completion = { ...this.getExtractedExecution(res, conversePrompt, options), original_response: options.include_original_response ? res : undefined, conversation: conversation, tool_use: tool_use, }; return completion; } async requestTwelvelabsPegasusCompletion(prompt, options) { const executor = this.getExecutor(); const res = await executor.invokeModel({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(prompt), }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const result = JSON.parse(body); // Extract the response according to TwelveLabs Pegasus format let finishReason; switch (result.finishReason) { case "stop": finishReason = "stop"; break; case "length": finishReason = "length"; break; default: finishReason = result.finishReason; } return { result: result.message ? [{ type: "text", value: result.message }] : [], finish_reason: finishReason, original_response: options.include_original_response ? result : undefined, }; } async requestTwelvelabsPegasusCompletionStream(prompt, options) { const executor = this.getExecutor(); const res = await executor.invokeModelWithResponseStream({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(prompt), }); if (!res.body) { throw new Error("[Bedrock] Stream not found in response"); } return transformAsyncIterator(res.body, (chunk) => { if (chunk.chunk?.bytes) { const decoder = new TextDecoder(); const body = decoder.decode(chunk.chunk.bytes); try { const result = JSON.parse(body); // Extract streaming response according to TwelveLabs Pegasus format let finishReason; if (result.finishReason) { switch (result.finishReason) { case "stop": finishReason = "stop"; break; case "length": finishReason = "length"; break; default: finishReason = result.finishReason; } } return { result: result.delta || result.message ? [{ type: "text", value: result.delta || result.message || "" }] : [], finish_reason: finishReason, }; } catch (error) { // If JSON parsing fails, return empty chunk return { result: [], }; } } return { result: [], }; }); } async requestTextCompletionStream(prompt, options) { // Handle Twelvelabs Pegasus models if (options.model.includes("twelvelabs.pegasus")) { return this.requestTwelvelabsPegasusCompletionStream(prompt, options); } // Handle other Bedrock models that use Converse API const conversePrompt = prompt; const payload = this.preparePayload(conversePrompt, options); const executor = this.getExecutor(); return executor.converseStream({ ...payload, }).then((res) => { const stream = res.stream; if (!stream) { throw new Error("[Bedrock] Stream not found in response"); } return transformAsyncIterator(stream, (streamSegment) => { return this.getExtractedStream(streamSegment, conversePrompt, options); }); }).catch((err) => { this.logger.error({ error: err }, "[Bedrock] Failed to stream"); throw err; }); } preparePayload(prompt, options) { const model_options = options.model_options ?? { _option_id: "text-fallback" }; let additionalField = {}; let supportsJSONPrefill = false; if (options.model.includes("amazon")) { supportsJSONPrefill = true; //Titan models also exists but does not support any additional options if (options.model.includes("nova")) { additionalField = { inferenceConfig: { topK: model_options.top_k } }; } } else if (options.model.includes("claude")) { const claude_options = model_options; const thinking = claude_options.thinking_mode ?? false; supportsJSONPrefill = !thinking; if (options.model.includes("claude-3-7") || options.model.includes("-4-")) { additionalField = { ...additionalField, reasoning_config: { type: thinking ? "enabled" : "disabled", budget_tokens: thinking ? (claude_options.thinking_budget_tokens ?? 1024) : undefined, } }; if (thinking && options.model.includes("claude-3-7-sonnet") && ((claude_options.max_tokens ?? 0) > 64000 || (claude_options.thinking_budget_tokens ?? 0) > 64000)) { additionalField = { ...additionalField, anthropic_beta: ["output-128k-2025-02-19"] }; } } //Needs max_tokens to be set if (!model_options.max_tokens) { model_options.max_tokens = maxTokenFallbackClaude(options); } additionalField = { ...additionalField, top_k: model_options.top_k }; } else if (options.model.includes("meta")) { //LLaMA models support no additional options } else if (options.model.includes("mistral")) { //7B instruct and 8x7B instruct if (options.model.includes("7b")) { additionalField = { top_k: model_options.top_k }; //Does not support system messages if (prompt.system && prompt.system?.length != 0) { prompt.messages?.push(converseSystemToMessages(prompt.system)); prompt.system = undefined; prompt.messages = converseConcatMessages(prompt.messages); } } else { //Other models such as Mistral Small,Large and Large 2 //Support no additional fields. } } else if (options.model.includes("ai21")) { //Jamba models support no additional options //Jurassic 2 models do. if (options.model.includes("j2")) { additionalField = { presencePenalty: { scale: model_options.presence_penalty }, frequencyPenalty: { scale: model_options.frequency_penalty }, }; //Does not support system messages if (prompt.system && prompt.system?.length != 0) { prompt.messages?.push(converseSystemToMessages(prompt.system)); prompt.system = undefined; prompt.messages = converseConcatMessages(prompt.messages); } } } else if (options.model.includes("cohere.command")) { // If last message is "```json", remove it. //Command R and R plus if (options.model.includes("cohere.command-r")) { additionalField = { k: model_options.top_k, frequency_penalty: model_options.frequency_penalty, presence_penalty: model_options.presence_penalty, }; } else { // Command non-R additionalField = { k: model_options.top_k }; //Does not support system messages if (prompt.system && prompt.system?.length != 0) { prompt.messages?.push(converseSystemToMessages(prompt.system)); prompt.system = undefined; prompt.messages = converseConcatMessages(prompt.messages); } } } else if (options.model.includes("palmyra")) { const palmyraOptions = model_options; additionalField = { seed: palmyraOptions?.seed, presence_penalty: palmyraOptions?.presence_penalty, frequency_penalty: palmyraOptions?.frequency_penalty, min_tokens: palmyraOptions?.min_tokens, }; } else if (options.model.includes("deepseek")) { //DeepSeek models support no additional options } else if (options.model.includes("gpt-oss")) { const gptOssOptions = model_options; additionalField = { reasoning_effort: gptOssOptions?.reasoning_effort, }; } //If last message is "```json", add corresponding ``` as a stop sequence. if (prompt.messages && prompt.messages.length > 0) { if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") { const stopSeq = model_options.stop_sequence; if (!stopSeq) { model_options.stop_sequence = ["```"]; } else if (!stopSeq.includes("```")) { stopSeq.push("```"); model_options.stop_sequence = stopSeq; } } } const tool_defs = getToolDefinitions(options.tools); // Use prefill when there is a schema and tools are not being used if (supportsJSONPrefill && options.result_schema && !tool_defs) { prompt.messages = converseJSONprefill(prompt.messages); } const request = { messages: prompt.messages, system: prompt.system, modelId: options.model, inferenceConfig: { maxTokens: model_options.max_tokens, temperature: model_options.temperature, topP: model_options.top_p, stopSequences: model_options.stop_sequence, }, additionalModelRequestFields: { ...additionalField, } }; //Only add tools if they are defined and not empty if (tool_defs?.length) { request.toolConfig = { tools: tool_defs, }; } return request; } async requestImageGeneration(prompt, options) { if (options.output_modality !== Modalities.image) { throw new Error(`Image generation requires image output_modality`); } if (options.model_options?._option_id !== "bedrock-nova-canvas") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } const model_options = options.model_options; const executor = this.getExecutor(); const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE; this.logger.info("Task type: " + taskType); if (typeof prompt === "string") { throw new Error("Bad prompt format"); } const payload = await formatNovaImageGenerationPayload(taskType, prompt, options); const res = await executor.invokeModel({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(payload), }, { requestTimeout: 60000 * 5 }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const bedrockResult = JSON.parse(body); return { error: bedrockResult.error, result: bedrockResult.images.map((image) => ({ type: "image", value: image })) }; } async startTraining(dataset, options) { //convert options.params to Record<string, string> const params = {}; for (const [key, value] of Object.entries(options.params || {})) { params[key] = String(value); } if (!this.options.training_bucket) { throw new Error("Training cannot nbe used since the 'training_bucket' property was not specified in driver options"); } const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials }); const stream = await dataset.getStream(); const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name); const service = this.getService(); const response = await service.send(new CreateModelCustomizationJobCommand({ jobName: options.name + "-job", customModelName: options.name, roleArn: this.options.training_role_arn || undefined, baseModelIdentifier: options.model, clientRequestToken: "llumiverse-" + Date.now(), trainingDataConfig: { s3Uri: `s3://${upload.Bucket}/${upload.Key}`, }, outputDataConfig: undefined, hyperParameters: params, //TODO not supported? //customizationType: "FINE_TUNING", })); const job = await service.send(new GetModelCustomizationJobCommand({ jobIdentifier: response.jobArn })); return jobInfo(job, response.jobArn); } async cancelTraining(jobId) { const service = this.getService(); await service.send(new StopModelCustomizationJobCommand({ jobIdentifier: jobId })); const job = await service.send(new GetModelCustomizationJobCommand({ jobIdentifier: jobId })); return jobInfo(job, jobId); } async getTrainingJob(jobId) { const service = this.getService(); const job = await service.send(new GetModelCustomizationJobCommand({ jobIdentifier: jobId })); return jobInfo(job, jobId); } // ===================== management API ================== async validateConnection() { const service = this.getService(); this.logger.debug("[Bedrock] validating connection", service.config.credentials.name); //return true as if the client has been initialized, it means the connection is valid return true; } async listTrainableModels() { this.logger.debug("[Bedrock] listing trainable models"); return this._listModels(m => m.customizationsSupported ? m.customizationsSupported.includes("FINE_TUNING") : false); } async listModels() { this.logger.debug("[Bedrock] listing models"); // exclude trainable models since they are not executable // exclude embedding models, not to be used for typical completions. const filter = (m) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false; return this._listModels(filter); } async _listModels(foundationFilter) { const service = this.getService(); const [foundationModelsList, customModelsList, inferenceProfilesList] = await Promise.all([ service.listFoundationModels({}).catch(() => { this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions."); return undefined; }), service.listCustomModels({}).catch(() => { this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions."); return undefined; }), service.listInferenceProfiles({}).catch(() => { this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions."); return undefined; }), ]); if (!foundationModelsList?.modelSummaries) { throw new Error("Foundation models not found"); } let foundationModels = foundationModelsList.modelSummaries || []; if (foundationFilter) { foundationModels = foundationModels.filter(foundationFilter); } const supportedPublishers = ["amazon", "anthropic", "cohere", "ai21", "mistral", "meta", "deepseek", "writer", "openai", "twelvelabs", "qwen"]; const unsupportedModelsByPublisher = { amazon: ["titan-image-generator", "nova-reel", "nova-sonic", "rerank"], anthropic: [], cohere: ["rerank", "embed"], ai21: [], mistral: [], meta: [], deepseek: [], writer: [], openai: [], twelvelabs: ["marengo"], qwen: [], }; // Helper function to check if model should be filtered out const shouldIncludeModel = (modelId, providerName) => { if (!modelId || !providerName) return false; const normalizedProvider = providerName.toLowerCase(); // Check if provider is supported const isProviderSupported = supportedPublishers.some(provider => normalizedProvider.includes(provider)); if (!isProviderSupported) return false; // Check if model is in the unsupported list for its provider for (const provider of supportedPublishers) { if (normalizedProvider.includes(provider)) { const unsupportedModels = unsupportedModelsByPublisher[provider] || []; return !unsupportedModels.some(unsupported => modelId.toLowerCase().includes(unsupported)); } } return true; }; foundationModels = foundationModels.filter(m => shouldIncludeModel(m.modelId, m.providerName)); const aiModels = foundationModels.map((m) => { if (!m.modelId) { throw new Error("modelId not found"); } const modelCapability = getModelCapabilities(m.modelArn ?? m.modelId, this.provider); const model = { id: m.modelArn ?? m.modelId, name: `${m.providerName} ${m.modelName}`, provider: this.provider, owner: m.providerName, can_stream: m.responseStreamingSupported ?? false, input_modalities: m.inputModalities ? formatAmazonModalities(m.inputModalities) : modelModalitiesToArray(modelCapability.input), output_modalities: m.outputModalities ? formatAmazonModalities(m.outputModalities) : modelModalitiesToArray(modelCapability.input), tool_support: modelCapability.tool_support, }; return model; }); //add custom models if (customModelsList?.modelSummaries) { customModelsList.modelSummaries.forEach((m) => { if (!m.modelArn) { throw new Error("Model ID not found"); } const modelCapability = getModelCapabilities(m.modelArn, this.provider); const model = { id: m.modelArn, name: m.modelName ?? m.modelArn, provider: this.provider, owner: "custom", description: `Custom model from ${m.baseModelName}`, is_custom: true, input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, }; aiModels.push(model); this.validateConnection; }); } //add inference profiles if (inferenceProfilesList?.inferenceProfileSummaries) { inferenceProfilesList.inferenceProfileSummaries.forEach((p) => { if (!p.inferenceProfileArn) { throw new Error("Profile ARN not found"); } // Apply the same filtering logic to inference profiles based on their name const profileId = p.inferenceProfileId || ""; const profileName = p.inferenceProfileName || ""; // Extract provider name from profile name or ID let providerName = ""; for (const provider of supportedPublishers) { if (profileName.toLowerCase().includes(provider) || profileId.toLowerCase().includes(provider)) { providerName = provider; break; } } const modelCapability = getModelCapabilities(p.inferenceProfileArn ?? p.inferenceProfileId, this.provider); if (providerName && shouldIncludeModel(profileId, providerName)) { const model = { id: p.inferenceProfileArn ?? p.inferenceProfileId, name: p.inferenceProfileName ?? p.inferenceProfileArn, provider: this.provider, owner: providerName, input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, }; aiModels.push(model); } }); } return aiModels; } async generateEmbeddings({ text, image, model }) { this.logger.info("[Bedrock] Generating embeddings with model " + model); // Handle TwelveLabs Marengo models if (model?.includes("twelvelabs.marengo")) { return this.generateTwelvelabsMarengoEmbeddings({ text, image, model }); } // Handle other Bedrock embedding models const defaultModel = image ? "amazon.titan-embed-image-v1" : "amazon.titan-embed-text-v2:0"; const modelID = model ?? defaultModel; const invokeBody = { inputText: text, inputImage: image }; const executor = this.getExecutor(); const res = await executor.invokeModel({ modelId: modelID, contentType: "application/json", body: JSON.stringify(invokeBody), }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const result = JSON.parse(body); if (!result.embedding) { throw new Error("Embeddings not found"); } return { values: result.embedding, model: modelID, token_count: result.inputTextTokenCount }; } async generateTwelvelabsMarengoEmbeddings({ text, image, model }) { const executor = this.getExecutor(); // Prepare the request payload for TwelveLabs Marengo let invokeBody = { inputType: "text" }; if (text) { invokeBody.inputText = text; invokeBody.inputType = "text"; } if (image) { // For the embeddings interface, image is expected to be base64 invokeBody.mediaSource = { base64String: image }; invokeBody.inputType = "image"; } const res = await executor.invokeModel({ modelId: model, contentType: "application/json", accept: "application/json", body: JSON.stringify(invokeBody), }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const result = JSON.parse(body); // TwelveLabs Marengo returns embedding data if (!result.embedding) { throw new Error("Embeddings not found in TwelveLabs Marengo response"); } return { values: result.embedding, model: model, // TwelveLabs Marengo doesn't return token count in the same way token_count: undefined }; } } function jobInfo(job, jobId) { const jobStatus = job.status; let status = TrainingJobStatus.running; let details; if (jobStatus === ModelCustomizationJobStatus.COMPLETED) { status = TrainingJobStatus.succeeded; } else if (jobStatus === ModelCustomizationJobStatus.FAILED) { status = TrainingJobStatus.failed; details = job.failureMessage || "error"; } else if (jobStatus === ModelCustomizationJobStatus.STOPPED) { status = TrainingJobStatus.cancelled; } else { status = TrainingJobStatus.running; details = jobStatus; } job.baseModelArn; return { id: jobId, model: job.outputModelArn, status, details }; } function getToolDefinitions(tools) { return tools ? tools.map(getToolDefinition) : undefined; } function getToolDefinition(tool) { return { toolSpec: { name: tool.name, description: tool.description, inputSchema: { json: tool.input_schema, } } }; } /** * Update the conversation messages * @param prompt * @param response * @returns */ function updateConversation(conversation, prompt) { const combinedMessages = [...(conversation?.messages || []), ...(prompt.messages || [])]; const combinedSystem = prompt.system || conversation?.system; return { modelId: prompt?.modelId || conversation?.modelId, messages: combinedMessages.length > 0 ? combinedMessages : [], system: combinedSystem && combinedSystem.length > 0 ? combinedSystem : undefined, }; } function formatAmazonModalities(modalities) { const standardizedModalities = []; for (const modality of modalities) { if (modality === ModelModality.TEXT) { standardizedModalities.push("text"); } else if (modality === ModelModality.IMAGE) { standardizedModalities.push("image"); } else if (modality === ModelModality.EMBEDDING) { standardizedModalities.push("embedding"); } else if (modality == "SPEECH") { standardizedModalities.push("audio"); } else if (modality == "VIDEO") { standardizedModalities.push("video"); } else { // Handle other modalities as needed standardizedModalities.push(modality.toString().toLowerCase()); } } return standardizedModalities; } //# sourceMappingURL=index.js.map