UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

718 lines 30.7 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.BedrockDriver = void 0; const client_bedrock_1 = require("@aws-sdk/client-bedrock"); const client_bedrock_runtime_1 = require("@aws-sdk/client-bedrock-runtime"); const client_s3_1 = require("@aws-sdk/client-s3"); const core_1 = require("@llumiverse/core"); const async_1 = require("@llumiverse/core/async"); const formatters_1 = require("@llumiverse/core/formatters"); const mnemonist_1 = require("mnemonist"); const converse_js_1 = require("./converse.js"); const nova_image_payload_js_1 = require("./nova-image-payload.js"); const s3_js_1 = require("./s3.js"); const supportStreamingCache = new mnemonist_1.LRUCache(4096); var BedrockModelType; (function (BedrockModelType) { BedrockModelType["FoundationModel"] = "foundation-model"; BedrockModelType["InferenceProfile"] = "inference-profile"; BedrockModelType["CustomModel"] = "custom-model"; BedrockModelType["Unknown"] = "unknown"; })(BedrockModelType || (BedrockModelType = {})); ; function converseFinishReason(reason) { //Possible values: //end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered if (!reason) return undefined; switch (reason) { case 'end_turn': return "stop"; case 'max_tokens': return "length"; default: return reason; } } class BedrockDriver extends core_1.AbstractDriver { static PROVIDER = "bedrock"; provider = BedrockDriver.PROVIDER; _executor; _service; _service_region; constructor(options) { super(options); if (!options.region) { throw new Error("No region found. Set the region in the environment's endpoint URL."); } } getExecutor() { if (!this._executor) { this._executor = new client_bedrock_runtime_1.BedrockRuntime({ region: this.options.region, credentials: this.options.credentials, }); } return this._executor; } getService(region = this.options.region) { if (!this._service || this._service_region != region) { this._service = new client_bedrock_1.Bedrock({ region: region, credentials: this.options.credentials, }); this._service_region = region; } return this._service; } async formatPrompt(segments, opts) { if (opts.model.includes("canvas")) { return await (0, formatters_1.formatNovaPrompt)(segments, opts.result_schema); } return await (0, converse_js_1.formatConversePrompt)(segments, opts.result_schema); } static getExtractedExecution(result, _prompt) { return { result: result.output?.message?.content?.map(c => c.text).join("\n") ?? "", token_usage: { prompt: result.usage?.inputTokens, result: result.usage?.outputTokens, total: result.usage?.totalTokens, }, finish_reason: converseFinishReason(result.stopReason), }; } ; static getExtractedStream(result, _prompt) { let output = ""; let stop_reason = ""; let token_usage; if (result.contentBlockDelta) { output = result.contentBlockDelta.delta?.text ?? ""; } if (result.messageStop) { stop_reason = result.messageStop.stopReason ?? ""; } if (result.metadata) { token_usage = { prompt: result.metadata.usage?.inputTokens, result: result.metadata.usage?.outputTokens, total: result.metadata.usage?.totalTokens, }; } return { result: output, token_usage: token_usage, finish_reason: converseFinishReason(stop_reason), }; } ; async requestTextCompletion(prompt, options) { let conversation = updateConversation(options.conversation, prompt); const payload = this.preparePayload(conversation, options); const executor = this.getExecutor(); const res = await executor.converse({ ...payload, }); conversation = updateConversation(conversation, { messages: [res.output?.message ?? { content: [{ text: "" }], role: "assistant" }], modelId: prompt.modelId, }); let tool_use = undefined; //Get tool requests if (res.stopReason == "tool_use") { tool_use = res.output?.message?.content?.reduce((tools, c) => { if (c.toolUse) { tools.push({ tool_name: c.toolUse.name ?? "", tool_input: c.toolUse.input, id: c.toolUse.toolUseId ?? "", }); } return tools; }, []); //If no tools were used, set to undefined if (tool_use && tool_use.length == 0) { tool_use = undefined; } } const completion = { ...BedrockDriver.getExtractedExecution(res, prompt), original_response: options.include_original_response ? res : undefined, conversation: conversation, tool_use: tool_use, }; return completion; } extractRegion(modelString, defaultRegion) { // Match region in full ARN pattern const arnMatch = modelString.match(/arn:aws[^:]*:bedrock:([^:]+):/); if (arnMatch) { return arnMatch[1]; } // Match common AWS regions directly in string const regionMatch = modelString.match(/(?:us|eu|ap|sa|ca|me|af)[-](east|west|central|south|north|southeast|southwest|northeast|northwest)[-][1-9]/); if (regionMatch) { return regionMatch[0]; } return defaultRegion; } async getCanStream(model, type) { let canStream = false; let error = null; const region = this.extractRegion(model, this.options.region); if (type == BedrockModelType.FoundationModel || type == BedrockModelType.Unknown) { try { const response = await this.getService(region).getFoundationModel({ modelIdentifier: model }); canStream = response.modelDetails?.responseStreamingSupported ?? false; return canStream; } catch (e) { error = e; } } if (type == BedrockModelType.InferenceProfile || type == BedrockModelType.Unknown) { try { const response = await this.getService(region).getInferenceProfile({ inferenceProfileIdentifier: model }); canStream = await this.getCanStream(response.models?.[0].modelArn ?? "", BedrockModelType.FoundationModel); return canStream; } catch (e) { error = e; } } if (type == BedrockModelType.CustomModel || type == BedrockModelType.Unknown) { try { const response = await this.getService(region).getCustomModel({ modelIdentifier: model }); canStream = await this.getCanStream(response.baseModelArn ?? "", BedrockModelType.FoundationModel); return canStream; } catch (e) { error = e; } } if (error) { console.warn("Error on canStream check for model: " + model + " region detected: " + region, error); } return canStream; } async canStream(options) { let canStream = supportStreamingCache.get(options.model); if (canStream == null) { let type = BedrockModelType.Unknown; if (options.model.includes("foundation-model")) { type = BedrockModelType.FoundationModel; } else if (options.model.includes("inference-profile")) { type = BedrockModelType.InferenceProfile; } else if (options.model.includes("custom-model")) { type = BedrockModelType.CustomModel; } canStream = await this.getCanStream(options.model, type); supportStreamingCache.set(options.model, canStream); } return canStream; } async requestTextCompletionStream(prompt, options) { const payload = this.preparePayload(prompt, options); const executor = this.getExecutor(); return executor.converseStream({ ...payload, }).then((res) => { const stream = res.stream; if (!stream) { throw new Error("[Bedrock] Stream not found in response"); } return (0, async_1.transformAsyncIterator)(stream, (stream) => { //const segment = JSON.parse(decoder.decode(stream.chunk?.bytes)); //console.log("Debug Segment for model " + options.model, JSON.stringify(segment)); return BedrockDriver.getExtractedStream(stream, prompt); }); }).catch((err) => { this.logger.error("[Bedrock] Failed to stream", err); throw err; }); } preparePayload(prompt, options) { const model_options = options.model_options ?? { _option_id: "text-fallback" }; let additionalField = {}; if (options.model.includes("amazon")) { if (options.result_schema) { prompt.messages = (0, converse_js_1.converseJSONprefill)(prompt.messages); } //Titan models also exists but does not support any additional options if (options.model.includes("nova")) { additionalField = { inferenceConfig: { topK: model_options.top_k } }; } } else if (options.model.includes("claude")) { if (options.result_schema) { prompt.messages = (0, converse_js_1.converseJSONprefill)(prompt.messages); } if (options.model.includes("claude-3-7")) { const thinking_options = options.model_options; const thinking = thinking_options.thinking_mode ?? false; additionalField = { ...additionalField, reasoning_config: { type: thinking ? "enabled" : "disabled", budget_tokens: thinking_options.thinking_budget_tokens, } }; if (thinking && (thinking_options.thinking_budget_tokens ?? 0) > 64000) { additionalField = { ...additionalField, anthorpic_beta: ["output-128k-2025-02-19"] }; } } //Needs max_tokens to be set if (!model_options.max_tokens) { model_options.max_tokens = (0, core_1.getMaxTokensLimit)(options.model, model_options); } additionalField = { ...additionalField, top_k: model_options.top_k }; } else if (options.model.includes("meta")) { //LLaMA models support no additional options } else if (options.model.includes("mistral")) { //7B instruct and 8x7B instruct if (options.model.includes("7b")) { additionalField = { top_k: model_options.top_k }; //Does not support system messages if (prompt.system && prompt.system?.length != 0) { prompt.messages?.push((0, converse_js_1.converseSystemToMessages)(prompt.system)); prompt.system = undefined; prompt.messages = (0, converse_js_1.converseConcatMessages)(prompt.messages); } if (options.result_schema) { prompt.messages = (0, converse_js_1.converseJSONprefill)(prompt.messages); } } else { //Other models such as Mistral Small,Large and Large 2 //Support no additional fields. } } else if (options.model.includes("ai21")) { //Jamba models support no additional options //Jurassic 2 models do. if (options.model.includes("j2")) { additionalField = { presencePenalty: { scale: model_options.presence_penalty }, frequencyPenalty: { scale: model_options.frequency_penalty }, }; //Does not support system messages if (prompt.system && prompt.system?.length != 0) { prompt.messages?.push((0, converse_js_1.converseSystemToMessages)(prompt.system)); prompt.system = undefined; prompt.messages = (0, converse_js_1.converseConcatMessages)(prompt.messages); } } } else if (options.model.includes("cohere.command")) { // If last message is "```json", remove it. //Command R and R plus if (options.model.includes("cohere.command-r")) { additionalField = { k: model_options.top_k, frequency_penalty: model_options.frequency_penalty, presence_penalty: model_options.presence_penalty, }; } else { // Command non-R additionalField = { k: model_options.top_k }; //Does not support system messages if (prompt.system && prompt.system?.length != 0) { prompt.messages?.push((0, converse_js_1.converseSystemToMessages)(prompt.system)); prompt.system = undefined; prompt.messages = (0, converse_js_1.converseConcatMessages)(prompt.messages); } } } else if (options.model.includes("palmyra")) { const palmyraOptions = options.model_options; additionalField = { seed: palmyraOptions?.seed, presence_penalty: palmyraOptions?.presence_penalty, frequency_penalty: palmyraOptions?.frequency_penalty, min_tokens: palmyraOptions?.min_tokens, }; } else if (options.model.includes("deepseek")) { //DeepSeek models support no additional options } //If last message is "```json", add corresponding ``` as a stop sequence. if (prompt.messages && prompt.messages.length > 0) { if (prompt.messages[prompt.messages.length - 1].content?.[0].text === "```json") { let stopSeq = model_options.stop_sequence; if (!stopSeq) { model_options.stop_sequence = ["```"]; } else if (!stopSeq.includes("```")) { stopSeq.push("```"); model_options.stop_sequence = stopSeq; } } } const tool_defs = getToolDefinitions(options.tools); const request = { messages: prompt.messages, system: prompt.system, modelId: options.model, inferenceConfig: { maxTokens: model_options.max_tokens, temperature: model_options.temperature, topP: model_options.top_p, stopSequences: model_options.stop_sequence, }, additionalModelRequestFields: { ...additionalField, } }; //Only add tools if they are defined if (tool_defs) { request.toolConfig = { tools: tool_defs, }; } return request; } async requestImageGeneration(prompt, options) { if (options.output_modality !== core_1.Modalities.image) { throw new Error(`Image generation requires image output_modality`); } if (options.model_options?._option_id !== "bedrock-nova-canvas") { this.logger.warn("Invalid model options", { options: options.model_options }); } const model_options = options.model_options; const executor = this.getExecutor(); const taskType = model_options.taskType ?? nova_image_payload_js_1.NovaImageGenerationTaskType.TEXT_IMAGE; this.logger.info("Task type: " + taskType); if (typeof prompt === "string") { throw new Error("Bad prompt format"); } const payload = await (0, nova_image_payload_js_1.formatNovaImageGenerationPayload)(taskType, prompt, options); const res = await executor.invokeModel({ modelId: options.model, contentType: "application/json", accept: "application/json", body: JSON.stringify(payload), }, { requestTimeout: 60000 * 5 }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const result = JSON.parse(body); return { error: result.error, result: { images: result.images, } }; } async startTraining(dataset, options) { //convert options.params to Record<string, string> const params = {}; for (const [key, value] of Object.entries(options.params || {})) { params[key] = String(value); } if (!this.options.training_bucket) { throw new Error("Training cannot nbe used since the 'training_bucket' property was not specified in driver options"); } const s3 = new client_s3_1.S3Client({ region: this.options.region, credentials: this.options.credentials }); const stream = await dataset.getStream(); const upload = await (0, s3_js_1.forceUploadFile)(s3, stream, this.options.training_bucket, dataset.name); const service = this.getService(); const response = await service.send(new client_bedrock_1.CreateModelCustomizationJobCommand({ jobName: options.name + "-job", customModelName: options.name, roleArn: this.options.training_role_arn || undefined, baseModelIdentifier: options.model, clientRequestToken: "llumiverse-" + Date.now(), trainingDataConfig: { s3Uri: `s3://${upload.Bucket}/${upload.Key}`, }, outputDataConfig: undefined, hyperParameters: params, //TODO not supported? //customizationType: "FINE_TUNING", })); const job = await service.send(new client_bedrock_1.GetModelCustomizationJobCommand({ jobIdentifier: response.jobArn })); return jobInfo(job, response.jobArn); } async cancelTraining(jobId) { const service = this.getService(); await service.send(new client_bedrock_1.StopModelCustomizationJobCommand({ jobIdentifier: jobId })); const job = await service.send(new client_bedrock_1.GetModelCustomizationJobCommand({ jobIdentifier: jobId })); return jobInfo(job, jobId); } async getTrainingJob(jobId) { const service = this.getService(); const job = await service.send(new client_bedrock_1.GetModelCustomizationJobCommand({ jobIdentifier: jobId })); return jobInfo(job, jobId); } // ===================== management API ================== async validateConnection() { const service = this.getService(); this.logger.debug("[Bedrock] validating connection", service.config.credentials.name); //return true as if the client has been initialized, it means the connection is valid return true; } async listTrainableModels() { this.logger.debug("[Bedrock] listing trainable models"); return this._listModels(m => m.customizationsSupported ? m.customizationsSupported.includes("FINE_TUNING") : false); } async listModels() { this.logger.debug("[Bedrock] listing models"); // exclude trainable models since they are not executable // exclude embedding models, not to be used for typical completions. const filter = (m) => (m.inferenceTypesSupported?.includes("ON_DEMAND") && !m.outputModalities?.includes("EMBEDDING")) ?? false; return this._listModels(filter); } async _listModels(foundationFilter) { const service = this.getService(); const [foundationModelsList, customModelsList, inferenceProfilesList] = await Promise.all([ service.listFoundationModels({}).catch(() => { this.logger.warn("[Bedrock] Can't list foundation models. Check if the user has the right permissions."); return undefined; }), service.listCustomModels({}).catch(() => { this.logger.warn("[Bedrock] Can't list custom models. Check if the user has the right permissions."); return undefined; }), service.listInferenceProfiles({}).catch(() => { this.logger.warn("[Bedrock] Can't list inference profiles. Check if the user has the right permissions."); return undefined; }), ]); if (!foundationModelsList?.modelSummaries) { throw new Error("Foundation models not found"); } let foundationModels = foundationModelsList.modelSummaries || []; if (foundationFilter) { foundationModels = foundationModels.filter(foundationFilter); } const supportedPublishers = ["amazon", "anthropic", "cohere", "ai21", "mistral", "meta", "deepseek", "writer"]; const unsupportedModelsByPublisher = { amazon: ["titan-image-generator", "nova-reel", "nova-sonic", "rerank"], anthropic: [], cohere: ["rerank"], ai21: [], mistral: [], meta: [], deepseek: [], writer: [], }; // Helper function to check if model should be filtered out const shouldIncludeModel = (modelId, providerName) => { if (!modelId || !providerName) return false; const normalizedProvider = providerName.toLowerCase(); // Check if provider is supported const isProviderSupported = supportedPublishers.some(provider => normalizedProvider.includes(provider)); if (!isProviderSupported) return false; // Check if model is in the unsupported list for its provider for (const provider of supportedPublishers) { if (normalizedProvider.includes(provider)) { const unsupportedModels = unsupportedModelsByPublisher[provider] || []; return !unsupportedModels.some(unsupported => modelId.toLowerCase().includes(unsupported)); } } return true; }; foundationModels = foundationModels.filter(m => shouldIncludeModel(m.modelId, m.providerName)); const aiModels = foundationModels.map((m) => { if (!m.modelId) { throw new Error("modelId not found"); } const modelCapability = (0, core_1.getModelCapabilities)(m.modelArn ?? m.modelId, this.provider); const model = { id: m.modelArn ?? m.modelId, name: `${m.providerName} ${m.modelName}`, provider: this.provider, //description: ``, owner: m.providerName, can_stream: m.responseStreamingSupported ?? false, input_modalities: m.inputModalities ? formatAmazonModalities(m.inputModalities) : (0, core_1.modelModalitiesToArray)(modelCapability.input), output_modalities: m.outputModalities ? formatAmazonModalities(m.outputModalities) : (0, core_1.modelModalitiesToArray)(modelCapability.input), tool_support: modelCapability.tool_support, }; return model; }); //add custom models if (customModelsList?.modelSummaries) { customModelsList.modelSummaries.forEach((m) => { if (!m.modelArn) { throw new Error("Model ID not found"); } const modelCapability = (0, core_1.getModelCapabilities)(m.modelArn, this.provider); const model = { id: m.modelArn, name: m.modelName ?? m.modelArn, provider: this.provider, description: `Custom model from ${m.baseModelName}`, is_custom: true, input_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.input), output_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.output), tool_support: modelCapability.tool_support, }; aiModels.push(model); this.validateConnection; }); } //add inference profiles if (inferenceProfilesList?.inferenceProfileSummaries) { inferenceProfilesList.inferenceProfileSummaries.forEach((p) => { if (!p.inferenceProfileArn) { throw new Error("Profile ARN not found"); } // Apply the same filtering logic to inference profiles based on their name const profileId = p.inferenceProfileId || ""; const profileName = p.inferenceProfileName || ""; // Extract provider name from profile name or ID let providerName = ""; for (const provider of supportedPublishers) { if (profileName.toLowerCase().includes(provider) || profileId.toLowerCase().includes(provider)) { providerName = provider; break; } } const modelCapability = (0, core_1.getModelCapabilities)(p.inferenceProfileArn ?? p.inferenceProfileId, this.provider); if (providerName && shouldIncludeModel(profileId, providerName)) { const model = { id: p.inferenceProfileArn ?? p.inferenceProfileId, name: p.inferenceProfileName ?? p.inferenceProfileArn, provider: this.provider, input_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.input), output_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.output), tool_support: modelCapability.tool_support, }; aiModels.push(model); } }); } return aiModels; } async generateEmbeddings({ text, image, model }) { this.logger.info("[Bedrock] Generating embeddings with model " + model); const defaultModel = image ? "amazon.titan-embed-image-v1" : "amazon.titan-embed-text-v2:0"; const modelID = model ?? defaultModel; const invokeBody = { inputText: text, inputImage: image }; const executor = this.getExecutor(); const res = await executor.invokeModel({ modelId: modelID, contentType: "application/json", body: JSON.stringify(invokeBody), }); const decoder = new TextDecoder(); const body = decoder.decode(res.body); const result = JSON.parse(body); if (!result.embedding) { throw new Error("Embeddings not found"); } return { values: result.embedding, model: modelID, token_count: result.inputTextTokenCount }; } } exports.BedrockDriver = BedrockDriver; function jobInfo(job, jobId) { const jobStatus = job.status; let status = core_1.TrainingJobStatus.running; let details; if (jobStatus === client_bedrock_1.ModelCustomizationJobStatus.COMPLETED) { status = core_1.TrainingJobStatus.succeeded; } else if (jobStatus === client_bedrock_1.ModelCustomizationJobStatus.FAILED) { status = core_1.TrainingJobStatus.failed; details = job.failureMessage || "error"; } else if (jobStatus === client_bedrock_1.ModelCustomizationJobStatus.STOPPED) { status = core_1.TrainingJobStatus.cancelled; } else { status = core_1.TrainingJobStatus.running; details = jobStatus; } job.baseModelArn; return { id: jobId, model: job.outputModelArn, status, details }; } function getToolDefinitions(tools) { return tools ? tools.map(getToolDefinition) : undefined; } function getToolDefinition(tool) { return { toolSpec: { name: tool.name, description: tool.description, inputSchema: { json: tool.input_schema, } } }; } /** * Update the conversation messages * @param prompt * @param response * @returns */ function updateConversation(conversation, prompt) { return { ...conversation, ...prompt, messages: [...(conversation?.messages || []), ...(prompt.messages || [])], system: prompt.system || conversation?.system, }; } function formatAmazonModalities(modalities) { const standardizedModalities = []; for (const modality of modalities) { if (modality === client_bedrock_1.ModelModality.TEXT) { standardizedModalities.push("text"); } else if (modality === client_bedrock_1.ModelModality.IMAGE) { standardizedModalities.push("image"); } else if (modality === client_bedrock_1.ModelModality.EMBEDDING) { standardizedModalities.push("embedding"); } else if (modality == "SPEECH") { standardizedModalities.push("audio"); } else if (modality == "VIDEO") { standardizedModalities.push("video"); } else { // Handle other modalities as needed standardizedModalities.push(modality.toString().toLowerCase()); } } return standardizedModalities; } //# sourceMappingURL=index.js.map