@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
962 lines • 42.1 kB
JavaScript
import { Bedrock, CreateModelCustomizationJobCommand, GetModelCustomizationJobCommand, ModelCustomizationJobStatus, ModelModality, StopModelCustomizationJobCommand } from "@aws-sdk/client-bedrock";
import { BedrockRuntime } from "@aws-sdk/client-bedrock-runtime";
import { S3Client } from "@aws-sdk/client-s3";
import { AbstractDriver, Modalities, TrainingJobStatus, getMaxTokensLimitBedrock, modelModalitiesToArray, getModelCapabilities } from "@llumiverse/core";
import { transformAsyncIterator } from "@llumiverse/core/async";
import { formatNovaPrompt } from "@llumiverse/core/formatters";
import { LRUCache } from "mnemonist";
import { converseConcatMessages, converseJSONprefill, converseSystemToMessages, formatConversePrompt } from "./converse.js";
import { formatNovaImageGenerationPayload, NovaImageGenerationTaskType } from "./nova-image-payload.js";
import { forceUploadFile } from "./s3.js";
import { formatTwelvelabsPegasusPrompt } from "./twelvelabs.js";
const supportStreamingCache = new LRUCache(4096);
var BedrockModelType;
(function (BedrockModelType) {
BedrockModelType["FoundationModel"] = "foundation-model";
BedrockModelType["InferenceProfile"] = "inference-profile";
BedrockModelType["CustomModel"] = "custom-model";
BedrockModelType["Unknown"] = "unknown";
})(BedrockModelType || (BedrockModelType = {}));
;
function converseFinishReason(reason) {
//Possible values:
//end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered
if (!reason)
return undefined;
switch (reason) {
case 'end_turn': return "stop";
case 'max_tokens': return "length";
default: return reason;
}
}
//Used to get a max_token value when not specified in the model options. Claude requires it to be set.
function maxTokenFallbackClaude(option) {
const modelOptions = option.model_options;
if (modelOptions && typeof modelOptions.max_tokens === "number") {
return modelOptions.max_tokens;
}
else {
const thinking_budget = modelOptions?.thinking_budget_tokens ?? 0;
let maxSupportedTokens = getMaxTokensLimitBedrock(option.model) ?? 8192; // Should always return a number for claude, 8192 is to satisfy the TypeScript type checker;
// Fallback to the default max tokens limit for the model
if (option.model.includes('claude-3-7-sonnet') && (modelOptions?.thinking_budget_tokens ?? 0) < 48000) {
maxSupportedTokens = 64000; // Claude 3.7 can go up to 128k with a beta header, but when no max tokens is specified, we default to 64k.
}
return Math.min(16000 + thinking_budget, maxSupportedTokens); // Cap to 16k, to avoid taking up too much context window and quota.
}
}
export class BedrockDriver extends AbstractDriver {
static PROVIDER = "bedrock";
provider = BedrockDriver.PROVIDER;
_executor;
_service;
_service_region;
constructor(options) {
super(options);
if (!options.region) {
throw new Error("No region found. Set the region in the environment's endpoint URL.");
}
}
getExecutor() {
if (!this._executor) {
this._executor = new BedrockRuntime({
region: this.options.region,
credentials: this.options.credentials,
});
}
return this._executor;
}
getService(region = this.options.region) {
if (!this._service || this._service_region != region) {
this._service = new Bedrock({
region: region,
credentials: this.options.credentials,
});
this._service_region = region;
}
return this._service;
}
async formatPrompt(segments, opts) {
if (opts.model.includes("canvas")) {
return await formatNovaPrompt(segments, opts.result_schema);
}
if (opts.model.includes("twelvelabs.pegasus")) {
return await formatTwelvelabsPegasusPrompt(segments, opts);
}
return await formatConversePrompt(segments, opts);
}
getExtractedExecution(result, _prompt, options) {
let resultText = "";
let reasoning = "";
if (result.output?.message?.content) {
for (const content of result.output.message.content) {
// Get text output
if (content.text) {
resultText += content.text;
}
else if (content.reasoningContent) {
// Get reasoning content only if include_thoughts is true
const claudeOptions = options?.model_options;
if (claudeOptions?.include_thoughts) {
if (content.reasoningContent.reasoningText) {
reasoning += content.reasoningContent.reasoningText.text;
}
else if (content.reasoningContent.redactedContent) {
// Handle redacted thinking content
const redactedData = new TextDecoder().decode(content.reasoningContent.redactedContent);
reasoning += `[Redacted thinking: ${redactedData}]`;
}
}
else {
this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false");
}
}
else {
// Get content block type
const type = Object.keys(content).find(key => key !== '$unknown' && content[key] !== undefined);
this.logger.info({ type }, "[Bedrock] Unsupported content response type:");
}
}
// Add spacing if we have reasoning content
if (reasoning) {
reasoning += '\n\n';
}
}
const completionResult = {
result: reasoning + resultText ? [{ type: "text", value: reasoning + resultText }] : [],
token_usage: {
prompt: result.usage?.inputTokens,
result: result.usage?.outputTokens,
total: result.usage?.totalTokens,
},
finish_reason: converseFinishReason(result.stopReason),
};
return completionResult;
}
;
getExtractedStream(result, _prompt, options) {
let output = "";
let reasoning = "";
let stop_reason = "";
let token_usage;
// Check if we should include thoughts
const shouldIncludeThoughts = options && options.model_options?.include_thoughts;
// Handle content block start events (for reasoning blocks)
if (result.contentBlockStart) {
// Handle redacted content at block start
if (result.contentBlockStart.start && 'reasoningContent' in result.contentBlockStart.start && shouldIncludeThoughts) {
const reasoningStart = result.contentBlockStart.start;
if (reasoningStart.reasoningContent?.redactedContent) {
const redactedData = new TextDecoder().decode(reasoningStart.reasoningContent.redactedContent);
reasoning = `[Redacted thinking: ${redactedData}]`;
}
}
}
// Handle content block deltas (text and reasoning)
if (result.contentBlockDelta) {
const delta = result.contentBlockDelta.delta;
if (delta?.text) {
output = delta.text;
}
else if (delta?.reasoningContent && shouldIncludeThoughts) {
if (delta.reasoningContent.text) {
reasoning = delta.reasoningContent.text;
}
else if (delta.reasoningContent.redactedContent) {
const redactedData = new TextDecoder().decode(delta.reasoningContent.redactedContent);
reasoning = `[Redacted thinking: ${redactedData}]`;
}
else if (delta.reasoningContent.signature) {
// Handle signature updates for reasoning content - end of thinking
reasoning = "\n\n";
// Putting logging here so it only triggers once.
this.logger.info("[Bedrock] Not outputting reasoning content as include_thoughts is false");
}
}
else if (delta) {
// Get content block type
const type = Object.keys(delta).find(key => key !== '$unknown' && delta[key] !== undefined);
this.logger.info({ type }, "[Bedrock] Unsupported content response type:");
}
}
// Handle content block stop events
if (result.contentBlockStop) {
// Content block ended - could be end of reasoning or text block
// Add minimal spacing for reasoning blocks if not already present
if (reasoning && !reasoning.endsWith('\n\n') && shouldIncludeThoughts) {
reasoning += '\n\n';
}
}
if (result.messageStop) {
stop_reason = result.messageStop.stopReason ?? "";
}
if (result.metadata) {
token_usage = {
prompt: result.metadata.usage?.inputTokens,
result: result.metadata.usage?.outputTokens,
total: result.metadata.usage?.totalTokens,
};
}
const completionResult = {
result: reasoning + output ? [{ type: "text", value: reasoning + output }] : [],
token_usage: token_usage,
finish_reason: converseFinishReason(stop_reason),
};
return completionResult;
}
;
extractRegion(modelString, defaultRegion) {
// Match region in full ARN pattern
const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/);
if (arnMatch) {
return arnMatch[1];
}
// Match common AWS regions directly in string
const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/);
if (regionMatch) {
return regionMatch[0];
}
return defaultRegion;
}
async getCanStream(model, type) {
let canStream = false;
let error = null;
const region = this.extractRegion(model, this.options.region);
if (type == BedrockModelType.FoundationModel || type == BedrockModelType.Unknown) {
try {
const response = await this.getService(region).getFoundationModel({
modelIdentifier: model
});
canStream = response.modelDetails?.responseStreamingSupported ?? false;
return canStream;
}
catch (e) {
error = e;
}
}
if (type == BedrockModelType.InferenceProfile || type == BedrockModelType.Unknown) {
try {
const response = await this.getService(region).getInferenceProfile({
inferenceProfileIdentifier: model
});
canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel);
return canStream;
}
catch (e) {
error = e;
}
}
if (type == BedrockModelType.CustomModel || type == BedrockModelType.Unknown) {
try {
const response = await this.getService(region).getCustomModel({
modelIdentifier: model
});
canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel);
return canStream;
}
catch (e) {
error = e;
}
}
if (error) {
console.warn("Error on canStream check for model: " + model + " region detected: " + region, error);
}
return canStream;
}
async canStream(options) {
// // TwelveLabs Pegasus supports streaming according to the documentation
// if (options.model.includes("twelvelabs.pegasus")) {
// return true;
// }
let canStream = supportStreamingCache.get(options.model);
if (canStream == null) {
let type = BedrockModelType.Unknown;
if (options.model.includes("foundation-model")) {
type = BedrockModelType.FoundationModel;
}
else if (options.model.includes("inference-profile")) {
type = BedrockModelType.InferenceProfile;
}
else if (options.model.includes("custom-model")) {
type = BedrockModelType.CustomModel;
}
canStream = await this.getCanStream(options.model, type);
supportStreamingCache.set(options.model, canStream);
}
return canStream;
}
async requestTextCompletion(prompt, options) {
// Handle Twelvelabs Pegasus models
if (options.model.includes("twelvelabs.pegasus")) {
return this.requestTwelvelabsPegasusCompletion(prompt, options);
}
// Handle other Bedrock models that use Converse API
const conversePrompt = prompt;
let conversation = updateConversation(options.conversation, conversePrompt);
const payload = this.preparePayload(conversation, options);
const executor = this.getExecutor();
const res = await executor.converse({
...payload,
});
conversation = updateConversation(conversation, {
messages: [res.output?.message ?? { content: [{ text: "" }], role: "assistant" }],
modelId: conversePrompt.modelId,
});
let tool_use = undefined;
//Get tool requests, we check tool use regardless of finish reason, as you can hit length and still get a valid response.
tool_use = res.output?.message?.content?.reduce((tools, c) => {
if (c.toolUse) {
tools.push({
tool_name: c.toolUse.name ?? "",
tool_input: c.toolUse.input,
id: c.toolUse.toolUseId ?? "",
});
}
return tools;
}, []);
//If no tools were used, set to undefined
if (tool_use && tool_use.length == 0) {
tool_use = undefined;
}
const completion = {
...this.getExtractedExecution(res, conversePrompt, options),
original_response: options.include_original_response ? res : undefined,
conversation: conversation,
tool_use: tool_use,
};
return completion;
}
async requestTwelvelabsPegasusCompletion(prompt, options) {
const executor = this.getExecutor();
const res = await executor.invokeModel({
modelId: options.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(prompt),
});
const decoder = new TextDecoder();
const body = decoder.decode(res.body);
const result = JSON.parse(body);
// Extract the response according to TwelveLabs Pegasus format
let finishReason;
switch (result.finishReason) {
case "stop":
finishReason = "stop";
break;
case "length":
finishReason = "length";
break;
default:
finishReason = result.finishReason;
}
return {
result: result.message ? [{ type: "text", value: result.message }] : [],
finish_reason: finishReason,
original_response: options.include_original_response ? result : undefined,
};
}
async requestTwelvelabsPegasusCompletionStream(prompt, options) {
const executor = this.getExecutor();
const res = await executor.invokeModelWithResponseStream({
modelId: options.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(prompt),
});
if (!res.body) {
throw new Error("[Bedrock] Stream not found in response");
}
return transformAsyncIterator(res.body, (chunk) => {
if (chunk.chunk?.bytes) {
const decoder = new TextDecoder();
const body = decoder.decode(chunk.chunk.bytes);
try {
const result = JSON.parse(body);
// Extract streaming response according to TwelveLabs Pegasus format
let finishReason;
if (result.finishReason) {
switch (result.finishReason) {
case "stop":
finishReason = "stop";
break;
case "length":
finishReason = "length";
break;
default:
finishReason = result.finishReason;
}
}
return {
result: result.delta || result.message ? [{ type: "text", value: result.delta || result.message || "" }] : [],
finish_reason: finishReason,
};
}
catch (error) {
// If JSON parsing fails, return empty chunk
return {
result: [],
};
}
}
return {
result: [],
};
});
}
async requestTextCompletionStream(prompt, options) {
// Handle Twelvelabs Pegasus models
if (options.model.includes("twelvelabs.pegasus")) {
return this.requestTwelvelabsPegasusCompletionStream(prompt, options);
}
// Handle other Bedrock models that use Converse API
const conversePrompt = prompt;
const payload = this.preparePayload(conversePrompt, options);
const executor = this.getExecutor();
return executor.converseStream({
...payload,
}).then((res) => {
const stream = res.stream;
if (!stream) {
throw new Error("[Bedrock] Stream not found in response");
}
return transformAsyncIterator(stream, (streamSegment) => {
return this.getExtractedStream(streamSegment, conversePrompt, options);
});
}).catch((err) => {
this.logger.error({ error: err }, "[Bedrock] Failed to stream");
throw err;
});
}
preparePayload(prompt, options) {
const model_options = options.model_options ?? { _option_id: "text-fallback" };
let additionalField = {};
let supportsJSONPrefill = false;
if (options.model.includes("amazon")) {
supportsJSONPrefill = true;
//Titan models also exists but does not support any additional options
if (options.model.includes("nova")) {
additionalField = { inferenceConfig: { topK: model_options.top_k } };
}
}
else if (options.model.includes("claude")) {
const claude_options = model_options;
const thinking = claude_options.thinking_mode ?? false;
supportsJSONPrefill = !thinking;
if (options.model.includes("claude-3-7") || options.model.includes("-4-")) {
additionalField = {
...additionalField,
reasoning_config: {
type: thinking ? "enabled" : "disabled",
budget_tokens: thinking ? (claude_options.thinking_budget_tokens ?? 1024) : undefined,
}
};
if (thinking && options.model.includes("claude-3-7-sonnet") &&
((claude_options.max_tokens ?? 0) > 64000 || (claude_options.thinking_budget_tokens ?? 0) > 64000)) {
additionalField = {
...additionalField,
anthropic_beta: ["output-128k-2025-02-19"]
};
}
}
//Needs max_tokens to be set
if (!model_options.max_tokens) {
model_options.max_tokens = maxTokenFallbackClaude(options);
}
additionalField = { ...additionalField, top_k: model_options.top_k };
}
else if (options.model.includes("meta")) {
//LLaMA models support no additional options
}
else if (options.model.includes("mistral")) {
//7B instruct and 8x7B instruct
if (options.model.includes("7b")) {
additionalField = { top_k: model_options.top_k };
//Does not support system messages
if (prompt.system && prompt.system?.length != 0) {
prompt.messages?.push(converseSystemToMessages(prompt.system));
prompt.system = undefined;
prompt.messages = converseConcatMessages(prompt.messages);
}
}
else {
//Other models such as Mistral Small,Large and Large 2
//Support no additional fields.
}
}
else if (options.model.includes("ai21")) {
//Jamba models support no additional options
//Jurassic 2 models do.
if (options.model.includes("j2")) {
additionalField = {
presencePenalty: { scale: model_options.presence_penalty },
frequencyPenalty: { scale: model_options.frequency_penalty },
};
//Does not support system messages
if (prompt.system && prompt.system?.length != 0) {
prompt.messages?.push(converseSystemToMessages(prompt.system));
prompt.system = undefined;
prompt.messages = converseConcatMessages(prompt.messages);
}
}
}
else if (options.model.includes("cohere.command")) {
// If last message is "```json", remove it.
//Command R and R plus
if (options.model.includes("cohere.command-r")) {
additionalField = {
k: model_options.top_k,
frequency_penalty: model_options.frequency_penalty,
presence_penalty: model_options.presence_penalty,
};
}
else {
// Command non-R
additionalField = { k: model_options.top_k };
//Does not support system messages
if (prompt.system && prompt.system?.length != 0) {
prompt.messages?.push(converseSystemToMessages(prompt.system));
prompt.system = undefined;
prompt.messages = converseConcatMessages(prompt.messages);
}
}
}
else if (options.model.includes("palmyra")) {
const palmyraOptions = model_options;
additionalField = {
seed: palmyraOptions?.seed,
presence_penalty: palmyraOptions?.presence_penalty,
frequency_penalty: palmyraOptions?.frequency_penalty,
min_tokens: palmyraOptions?.min_tokens,
};
}
else if (options.model.includes("deepseek")) {
//DeepSeek models support no additional options
}
else if (options.model.includes("gpt-oss")) {
const gptOssOptions = model_options;
additionalField = {
reasoning_effort: gptOssOptions?.reasoning_effort,
};
}
//If last message is "```json", add corresponding ``` as a stop sequence.
if (prompt.messages && prompt.messages.length > 0) {
if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") {
const stopSeq = model_options.stop_sequence;
if (!stopSeq) {
model_options.stop_sequence = ["```"];
}
else if (!stopSeq.includes("```")) {
stopSeq.push("```");
model_options.stop_sequence = stopSeq;
}
}
}
const tool_defs = getToolDefinitions(options.tools);
// Use prefill when there is a schema and tools are not being used
if (supportsJSONPrefill && options.result_schema && !tool_defs) {
prompt.messages = converseJSONprefill(prompt.messages);
}
const request = {
messages: prompt.messages,
system: prompt.system,
modelId: options.model,
inferenceConfig: {
maxTokens: model_options.max_tokens,
temperature: model_options.temperature,
topP: model_options.top_p,
stopSequences: model_options.stop_sequence,
},
additionalModelRequestFields: {
...additionalField,
}
};
//Only add tools if they are defined and not empty
if (tool_defs?.length) {
request.toolConfig = {
tools: tool_defs,
};
}
return request;
}
async requestImageGeneration(prompt, options) {
if (options.output_modality !== Modalities.image) {
throw new Error(`Image generation requires image output_modality`);
}
if (options.model_options?._option_id !== "bedrock-nova-canvas") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
const model_options = options.model_options;
const executor = this.getExecutor();
const taskType = model_options.taskType ?? NovaImageGenerationTaskType.TEXT_IMAGE;
this.logger.info("Task type: " + taskType);
if (typeof prompt === "string") {
throw new Error("Bad prompt format");
}
const payload = await formatNovaImageGenerationPayload(taskType, prompt, options);
const res = await executor.invokeModel({
modelId: options.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(payload),
}, {
requestTimeout: 60000 * 5
});
const decoder = new TextDecoder();
const body = decoder.decode(res.body);
const bedrockResult = JSON.parse(body);
return {
error: bedrockResult.error,
result: bedrockResult.images.map((image) => ({
type: "image",
value: image
}))
};
}
async startTraining(dataset, options) {
//convert options.params to Record<string, string>
const params = {};
for (const [key, value] of Object.entries(options.params || {})) {
params[key] = String(value);
}
if (!this.options.training_bucket) {
throw new Error("Training cannot nbe used since the 'training_bucket' property was not specified in driver options");
}
const s3 = new S3Client({ region: this.options.region, credentials: this.options.credentials });
const stream = await dataset.getStream();
const upload = await forceUploadFile(s3, stream, this.options.training_bucket, dataset.name);
const service = this.getService();
const response = await service.send(new CreateModelCustomizationJobCommand({
jobName: options.name + "-job",
customModelName: options.name,
roleArn: this.options.training_role_arn || undefined,
baseModelIdentifier: options.model,
clientRequestToken: "llumiverse-" + Date.now(),
trainingDataConfig: {
s3Uri: `s3://${upload.Bucket}/${upload.Key}`,
},
outputDataConfig: undefined,
hyperParameters: params,
//TODO not supported?
//customizationType: "FINE_TUNING",
}));
const job = await service.send(new GetModelCustomizationJobCommand({
jobIdentifier: response.jobArn
}));
return jobInfo(job, response.jobArn);
}
async cancelTraining(jobId) {
const service = this.getService();
await service.send(new StopModelCustomizationJobCommand({
jobIdentifier: jobId
}));
const job = await service.send(new GetModelCustomizationJobCommand({
jobIdentifier: jobId
}));
return jobInfo(job, jobId);
}
async getTrainingJob(jobId) {
const service = this.getService();
const job = await service.send(new GetModelCustomizationJobCommand({
jobIdentifier: jobId
}));
return jobInfo(job, jobId);
}
// ===================== management API ==================
async validateConnection() {
const service = this.getService();
this.logger.debug("[Bedrock] validating connection", service.config.credentials.name);
//return true as if the client has been initialized, it means the connection is valid
return true;
}
async listTrainableModels() {
this.logger.debug("[Bedrock] listing trainable models");
return this._listModels(m => m.customizationsSupported ? m.customizationsSupported.includes("FINE_TUNING") : false);
}
async listModels() {
this.logger.debug("[Bedrock] listing models");
// exclude trainable models since they are not executable
// exclude embedding models, not to be used for typical completions.
const filter = (m) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false;
return this._listModels(filter);
}
async _listModels(foundationFilter) {
const service = this.getService();
const [foundationModelsList, customModelsList, inferenceProfilesList] = await Promise.all([
service.listFoundationModels({}).catch(() => {
this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions.");
return undefined;
}),
service.listCustomModels({}).catch(() => {
this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions.");
return undefined;
}),
service.listInferenceProfiles({}).catch(() => {
this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions.");
return undefined;
}),
]);
if (!foundationModelsList?.modelSummaries) {
throw new Error("Foundation models not found");
}
let foundationModels = foundationModelsList.modelSummaries || [];
if (foundationFilter) {
foundationModels = foundationModels.filter(foundationFilter);
}
const supportedPublishers = ["amazon", "anthropic", "cohere", "ai21",
"mistral", "meta", "deepseek", "writer",
"openai", "twelvelabs", "qwen"];
const unsupportedModelsByPublisher = {
amazon: ["titan-image-generator", "nova-reel", "nova-sonic", "rerank"],
anthropic: [],
cohere: ["rerank", "embed"],
ai21: [],
mistral: [],
meta: [],
deepseek: [],
writer: [],
openai: [],
twelvelabs: ["marengo"],
qwen: [],
};
// Helper function to check if model should be filtered out
const shouldIncludeModel = (modelId, providerName) => {
if (!modelId || !providerName)
return false;
const normalizedProvider = providerName.toLowerCase();
// Check if provider is supported
const isProviderSupported = supportedPublishers.some(provider => normalizedProvider.includes(provider));
if (!isProviderSupported)
return false;
// Check if model is in the unsupported list for its provider
for (const provider of supportedPublishers) {
if (normalizedProvider.includes(provider)) {
const unsupportedModels = unsupportedModelsByPublisher[provider] || [];
return !unsupportedModels.some(unsupported => modelId.toLowerCase().includes(unsupported));
}
}
return true;
};
foundationModels = foundationModels.filter(m => shouldIncludeModel(m.modelId, m.providerName));
const aiModels = foundationModels.map((m) => {
if (!m.modelId) {
throw new Error("modelId not found");
}
const modelCapability = getModelCapabilities(m.modelArn ?? m.modelId, this.provider);
const model = {
id: m.modelArn ?? m.modelId,
name: `${m.providerName} ${m.modelName}`,
provider: this.provider,
owner: m.providerName,
can_stream: m.responseStreamingSupported ?? false,
input_modalities: m.inputModalities ? formatAmazonModalities(m.inputModalities) : modelModalitiesToArray(modelCapability.input),
output_modalities: m.outputModalities ? formatAmazonModalities(m.outputModalities) : modelModalitiesToArray(modelCapability.input),
tool_support: modelCapability.tool_support,
};
return model;
});
//add custom models
if (customModelsList?.modelSummaries) {
customModelsList.modelSummaries.forEach((m) => {
if (!m.modelArn) {
throw new Error("Model ID not found");
}
const modelCapability = getModelCapabilities(m.modelArn, this.provider);
const model = {
id: m.modelArn,
name: m.modelName ?? m.modelArn,
provider: this.provider,
owner: "custom",
description: `Custom model from ${m.baseModelName}`,
is_custom: true,
input_modalities: modelModalitiesToArray(modelCapability.input),
output_modalities: modelModalitiesToArray(modelCapability.output),
tool_support: modelCapability.tool_support,
};
aiModels.push(model);
this.validateConnection;
});
}
//add inference profiles
if (inferenceProfilesList?.inferenceProfileSummaries) {
inferenceProfilesList.inferenceProfileSummaries.forEach((p) => {
if (!p.inferenceProfileArn) {
throw new Error("Profile ARN not found");
}
// Apply the same filtering logic to inference profiles based on their name
const profileId = p.inferenceProfileId || "";
const profileName = p.inferenceProfileName || "";
// Extract provider name from profile name or ID
let providerName = "";
for (const provider of supportedPublishers) {
if (profileName.toLowerCase().includes(provider) || profileId.toLowerCase().includes(provider)) {
providerName = provider;
break;
}
}
const modelCapability = getModelCapabilities(p.inferenceProfileArn ?? p.inferenceProfileId, this.provider);
if (providerName && shouldIncludeModel(profileId, providerName)) {
const model = {
id: p.inferenceProfileArn ?? p.inferenceProfileId,
name: p.inferenceProfileName ?? p.inferenceProfileArn,
provider: this.provider,
owner: providerName,
input_modalities: modelModalitiesToArray(modelCapability.input),
output_modalities: modelModalitiesToArray(modelCapability.output),
tool_support: modelCapability.tool_support,
};
aiModels.push(model);
}
});
}
return aiModels;
}
async generateEmbeddings({ text, image, model }) {
this.logger.info("[Bedrock] Generating embeddings with model " + model);
// Handle TwelveLabs Marengo models
if (model?.includes("twelvelabs.marengo")) {
return this.generateTwelvelabsMarengoEmbeddings({ text, image, model });
}
// Handle other Bedrock embedding models
const defaultModel = image ? "amazon.titan-embed-image-v1" : "amazon.titan-embed-text-v2:0";
const modelID = model ?? defaultModel;
const invokeBody = {
inputText: text,
inputImage: image
};
const executor = this.getExecutor();
const res = await executor.invokeModel({
modelId: modelID,
contentType: "application/json",
body: JSON.stringify(invokeBody),
});
const decoder = new TextDecoder();
const body = decoder.decode(res.body);
const result = JSON.parse(body);
if (!result.embedding) {
throw new Error("Embeddings not found");
}
return {
values: result.embedding,
model: modelID,
token_count: result.inputTextTokenCount
};
}
async generateTwelvelabsMarengoEmbeddings({ text, image, model }) {
const executor = this.getExecutor();
// Prepare the request payload for TwelveLabs Marengo
let invokeBody = {
inputType: "text"
};
if (text) {
invokeBody.inputText = text;
invokeBody.inputType = "text";
}
if (image) {
// For the embeddings interface, image is expected to be base64
invokeBody.mediaSource = {
base64String: image
};
invokeBody.inputType = "image";
}
const res = await executor.invokeModel({
modelId: model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(invokeBody),
});
const decoder = new TextDecoder();
const body = decoder.decode(res.body);
const result = JSON.parse(body);
// TwelveLabs Marengo returns embedding data
if (!result.embedding) {
throw new Error("Embeddings not found in TwelveLabs Marengo response");
}
return {
values: result.embedding,
model: model,
// TwelveLabs Marengo doesn't return token count in the same way
token_count: undefined
};
}
}
function jobInfo(job, jobId) {
const jobStatus = job.status;
let status = TrainingJobStatus.running;
let details;
if (jobStatus === ModelCustomizationJobStatus.COMPLETED) {
status = TrainingJobStatus.succeeded;
}
else if (jobStatus === ModelCustomizationJobStatus.FAILED) {
status = TrainingJobStatus.failed;
details = job.failureMessage || "error";
}
else if (jobStatus === ModelCustomizationJobStatus.STOPPED) {
status = TrainingJobStatus.cancelled;
}
else {
status = TrainingJobStatus.running;
details = jobStatus;
}
job.baseModelArn;
return {
id: jobId,
model: job.outputModelArn,
status,
details
};
}
function getToolDefinitions(tools) {
return tools ? tools.map(getToolDefinition) : undefined;
}
function getToolDefinition(tool) {
return {
toolSpec: {
name: tool.name,
description: tool.description,
inputSchema: {
json: tool.input_schema,
}
}
};
}
/**
* Update the conversation messages
* @param prompt
* @param response
* @returns
*/
function updateConversation(conversation, prompt) {
const combinedMessages = [...(conversation?.messages || []), ...(prompt.messages || [])];
const combinedSystem = prompt.system || conversation?.system;
return {
modelId: prompt?.modelId || conversation?.modelId,
messages: combinedMessages.length > 0 ? combinedMessages : [],
system: combinedSystem && combinedSystem.length > 0 ? combinedSystem : undefined,
};
}
function formatAmazonModalities(modalities) {
const standardizedModalities = [];
for (const modality of modalities) {
if (modality === ModelModality.TEXT) {
standardizedModalities.push("text");
}
else if (modality === ModelModality.IMAGE) {
standardizedModalities.push("image");
}
else if (modality === ModelModality.EMBEDDING) {
standardizedModalities.push("embedding");
}
else if (modality == "SPEECH") {
standardizedModalities.push("audio");
}
else if (modality == "VIDEO") {
standardizedModalities.push("video");
}
else {
// Handle other modalities as needed
standardizedModalities.push(modality.toString().toLowerCase());
}
}
return standardizedModalities;
}
//# sourceMappingURL=index.js.map