UNPKG

@agenite/bedrock

Version:
442 lines (437 loc) 13.1 kB
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand } from '@aws-sdk/client-bedrock-runtime'; import { BaseLLMProvider, convertStringToMessages } from '@agenite/llm'; // src/provider.ts // src/utils.ts var mapStopReason = (bedrockStopReason) => { if (!bedrockStopReason) return void 0; const stopReasonMap = { max_tokens: "maxTokens", stop_sequence: "stopSequence", end_turn: "endTurn", tool_use: "toolUse", guardrail_intervened: "endTurn", content_filtered: "endTurn" }; return stopReasonMap[bedrockStopReason]; }; var mapContent = (bedrockContent) => { return bedrockContent.map((block) => { if (block.text) { if (/^\s*$/.test(block.text)) { return null; } return { type: "text", text: block.text }; } if (block.toolUse) { const toolUseId = block.toolUse.toolUseId; if (!toolUseId) { throw new Error("Tool use ID is required"); } return { type: "toolUse", toolName: block.toolUse.name, input: block.toolUse.input || {}, id: toolUseId, name: block.toolUse.name || "unknown" }; } if (block.image) { const format = block.image.format || "webp"; const validFormat = ["jpeg", "png", "gif", "webp"].includes(format) ? format : "webp"; return { type: "image", source: { type: "base64", data: block.image.source?.$unknown?.[1] || "", media_type: `image/${validFormat}` } }; } if (block.reasoningContent) { return { type: "thinking", thinking: block.reasoningContent.reasoningText?.text || "", signature: block.reasoningContent.reasoningText?.signature || "" }; } throw new Error( `Unsupported content block type: ${JSON.stringify(block, null, 2)}` ); }).filter((block) => block !== null); }; var convertToMessageFormat = (messages) => { return messages?.filter( (message) => ["user", "assistant"].includes(message.role) ).map((message) => ({ role: message.role, content: message.content.map((block) => { if (typeof block === "string") { return { text: block, $unknown: void 0 }; } switch (block.type) { case "text": return { text: block.text, $unknown: void 0 }; case "toolUse": return { toolUse: { toolUseId: block.id, name: block.name, input: block.input } }; case "toolResult": return { toolResult: { toolUseId: block.toolUseId, content: [ { text: JSON.stringify(block.content) } ], status: block.isError ? "error" : "success" } }; case "image": return { image: { source: { $unknown: ["source", block.source] }, format: block.source.type === "url" ? "url" : block.source.media_type.split("/")[1] || "webp" }, $unknown: void 0 }; case "thinking": return { reasoningContent: { reasoningText: { text: block.thinking, signature: block.signature } } }; case "document": return { document: { format: block.source?.type === "url" ? "pdf" : "txt", name: String(block.name || block.title), source: { $unknown: ["url", block.source] } } }; default: throw new Error( `Unsupported content block type: ${JSON.stringify(block, null, 2)}` ); } }) })); }; // src/tool-adapter.ts var BedrockToolAdapter = class { convertToProviderTool(tool) { return { toolSpec: { name: tool.name, description: tool.description || tool.name, inputSchema: { json: { type: "object", properties: tool.inputSchema.properties, required: tool.inputSchema.required } } } }; } }; // src/provider.ts var DEFAULT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"; var DEFAULT_MAX_TOKENS = 4096; var BedrockProvider = class extends BaseLLMProvider { client; config; toolAdapter; name = "Bedrock"; version = "1.0"; constructor(config) { super(); this.config = config; this.client = new BedrockRuntimeClient({ region: config.region, credentials: config.credentials, ...config.bedrockClientConfig }); this.toolAdapter = new BedrockToolAdapter(); } createRequestBody(input, options) { const messageArray = convertStringToMessages(input); const transformedMessages = convertToMessageFormat(messageArray); const providerTools = options?.tools?.map( (tool) => this.toolAdapter.convertToProviderTool(tool) ); const temperature = this.config.enableThinking || this.config.enableReasoning ? 1 : options?.temperature ?? this.config.temperature ?? 0.7; return { modelId: this.config.model || DEFAULT_MODEL, system: options?.systemPrompt ? [{ text: options.systemPrompt }] : void 0, messages: transformedMessages, inferenceConfig: { maxTokens: options?.maxTokens ?? DEFAULT_MAX_TOKENS, temperature, stopSequences: options?.stopSequences }, toolConfig: providerTools?.length ? { tools: providerTools, toolChoice: { auto: {} } } : void 0, additionalModelRequestFields: this.config.enableThinking || this.config.enableReasoning ? { reasoning_config: { type: "enabled", budget_tokens: this.config.reasoningBudgetTokens || 1024 } } : void 0 }; } handleError(error) { console.error("Bedrock generation failed:", error); throw error instanceof Error ? new Error(`Bedrock generation failed: ${error.message}`) : new Error("Bedrock generation failed with unknown error"); } createGenerateResponse(content, stopReason, inputTokens, outputTokens) { return { content, stopReason: mapStopReason(stopReason), tokenUsage: { model: this.config.model || DEFAULT_MODEL, inputTokens, outputTokens, // TODO: introduce cost LLM inputCost: 0, outputCost: 0 } }; } handleStreamEvent(event, state) { if ("metadata" in event && event.metadata) { state.inputTokens = event.metadata.usage?.inputTokens || state.inputTokens; state.outputTokens = event.metadata.usage?.outputTokens || state.outputTokens; } if ("contentBlockStart" in event && event.contentBlockStart) { this.handleContentBlockStart( event.contentBlockStart, state.contentBlocks ); return null; } if ("contentBlockDelta" in event && event.contentBlockDelta) { return this.handleContentBlockDelta(event.contentBlockDelta, state); } return null; } handleContentBlockStart({ contentBlockIndex = 0, start }, contentBlocks) { if (start?.toolUse) { contentBlocks[contentBlockIndex] = { ...contentBlocks[contentBlockIndex], toolUse: { ...start.toolUse, ...contentBlocks[contentBlockIndex]?.toolUse } }; } } handleContentBlockDelta({ delta, contentBlockIndex = 0 }, state) { if (!delta) return null; if (delta.reasoningContent) { return this.handleReasoningDelta( delta.reasoningContent, contentBlockIndex, state ); } if (delta.text) { return this.handleTextDelta(delta.text, contentBlockIndex, state); } else if (delta.toolUse) { this.handleToolUseDelta( delta.toolUse, contentBlockIndex, state.contentBlocks ); } return null; } handleTextDelta(text, contentBlockIndex, state) { state.buffer += text || ""; state.contentBlocks[contentBlockIndex] = { ...state.contentBlocks[contentBlockIndex], text: (state.contentBlocks[contentBlockIndex]?.text || "") + text || "" }; if (state.buffer.length > 10) { const result = { type: "text", text: state.buffer }; state.buffer = ""; return result; } return null; } handleReasoningDelta(reasoningContent, contentBlockIndex, state) { state.buffer += reasoningContent.text || ""; state.contentBlocks[contentBlockIndex] = { ...state.contentBlocks[contentBlockIndex], reasoningContent: { ...state.contentBlocks[contentBlockIndex]?.reasoningContent, reasoningText: { text: (state.contentBlocks[contentBlockIndex]?.reasoningContent?.reasoningText?.text || "") + (reasoningContent.text || ""), signature: reasoningContent.signature } } }; if (state.buffer.length > 10) { const result = { type: "thinking", thinking: state.buffer }; state.buffer = ""; return result; } return null; } handleToolUseDelta(toolUse, contentBlockIndex, contentBlocks) { contentBlocks[contentBlockIndex] = { toolUse: { ...contentBlocks[contentBlockIndex]?.toolUse, input: (contentBlocks[contentBlockIndex]?.toolUse?.input || "") + (toolUse.input || "") } }; } handleContentBlockStop(event, state) { if (event.contentBlockStop && event.contentBlockStop.contentBlockIndex !== void 0) { const blockIndex = event.contentBlockStop.contentBlockIndex; const block = state.contentBlocks[blockIndex]; if (block?.toolUse?.input) { block.toolUse.input = JSON.parse(String(block.toolUse.input)); const toolUseResult = { type: "toolUse", toolUse: mapContent([block])[0], isEnd: true }; return toolUseResult; } const finalBuffer = state.buffer; state.buffer = ""; if (block?.text) { return { type: "text", text: finalBuffer, isEnd: true }; } if (block?.reasoningContent) { return { type: "thinking", thinking: finalBuffer, isEnd: true }; } } return null; } handleTextBlockStart(event, state) { const contentBlockIndex = event.contentBlockDelta?.contentBlockIndex; const hasExistingBlock = contentBlockIndex !== void 0 && state.contentBlocks[contentBlockIndex] !== void 0; if (hasExistingBlock) { return null; } if (event.contentBlockDelta?.delta?.text) { return { type: "text", text: "", isStart: true }; } if (event.contentBlockDelta?.delta?.reasoningContent) { return { type: "thinking", thinking: "", isStart: true }; } } async *stream(input, options) { try { const requestBody = this.createRequestBody(input, options); const response = await this.client.send( new ConverseStreamCommand({ ...requestBody, ...this.config.converseCommandConfig }) ); if (!response.stream) { throw new Error("No stream found in response"); } const state = { buffer: "", inputTokens: 0, outputTokens: 0, contentBlocks: [] }; let finalStopReason; for await (const event of response.stream) { if ("messageStop" in event) { finalStopReason = event.messageStop?.stopReason; continue; } const startBlock = this.handleTextBlockStart(event, state); if (startBlock) { yield startBlock; } const stopResult = this.handleContentBlockStop(event, state); if (stopResult) { yield stopResult; } const result = this.handleStreamEvent(event, state); if (result) yield result; } return this.createGenerateResponse( mapContent(state.contentBlocks), finalStopReason, state.inputTokens, state.outputTokens ); } catch (error) { this.handleError(error); } } async generate(input, options) { try { const requestBody = this.createRequestBody(input, options); const response = await this.client.send( new ConverseCommand({ ...requestBody, ...this.config.converseCommandConfig }) ); return this.createGenerateResponse( mapContent(response.output?.message?.content || []), response.stopReason, response.usage?.inputTokens || 0, response.usage?.outputTokens || 0 ); } catch (error) { this.handleError(error); } } }; export { BedrockProvider }; //# sourceMappingURL=index.js.map //# sourceMappingURL=index.js.map