@agenite/bedrock
Version:
AWS Bedrock provider for Agenite
442 lines (437 loc) • 13.1 kB
JavaScript
import { BedrockRuntimeClient, ConverseStreamCommand, ConverseCommand } from '@aws-sdk/client-bedrock-runtime';
import { BaseLLMProvider, convertStringToMessages } from '@agenite/llm';
// src/provider.ts
// src/utils.ts
var mapStopReason = (bedrockStopReason) => {
if (!bedrockStopReason) return void 0;
const stopReasonMap = {
max_tokens: "maxTokens",
stop_sequence: "stopSequence",
end_turn: "endTurn",
tool_use: "toolUse",
guardrail_intervened: "endTurn",
content_filtered: "endTurn"
};
return stopReasonMap[bedrockStopReason];
};
var mapContent = (bedrockContent) => {
return bedrockContent.map((block) => {
if (block.text) {
if (/^\s*$/.test(block.text)) {
return null;
}
return {
type: "text",
text: block.text
};
}
if (block.toolUse) {
const toolUseId = block.toolUse.toolUseId;
if (!toolUseId) {
throw new Error("Tool use ID is required");
}
return {
type: "toolUse",
toolName: block.toolUse.name,
input: block.toolUse.input || {},
id: toolUseId,
name: block.toolUse.name || "unknown"
};
}
if (block.image) {
const format = block.image.format || "webp";
const validFormat = ["jpeg", "png", "gif", "webp"].includes(format) ? format : "webp";
return {
type: "image",
source: {
type: "base64",
data: block.image.source?.$unknown?.[1] || "",
media_type: `image/${validFormat}`
}
};
}
if (block.reasoningContent) {
return {
type: "thinking",
thinking: block.reasoningContent.reasoningText?.text || "",
signature: block.reasoningContent.reasoningText?.signature || ""
};
}
throw new Error(
`Unsupported content block type: ${JSON.stringify(block, null, 2)}`
);
}).filter((block) => block !== null);
};
var convertToMessageFormat = (messages) => {
return messages?.filter(
(message) => ["user", "assistant"].includes(message.role)
).map((message) => ({
role: message.role,
content: message.content.map((block) => {
if (typeof block === "string") {
return {
text: block,
$unknown: void 0
};
}
switch (block.type) {
case "text":
return {
text: block.text,
$unknown: void 0
};
case "toolUse":
return {
toolUse: {
toolUseId: block.id,
name: block.name,
input: block.input
}
};
case "toolResult":
return {
toolResult: {
toolUseId: block.toolUseId,
content: [
{
text: JSON.stringify(block.content)
}
],
status: block.isError ? "error" : "success"
}
};
case "image":
return {
image: {
source: {
$unknown: ["source", block.source]
},
format: block.source.type === "url" ? "url" : block.source.media_type.split("/")[1] || "webp"
},
$unknown: void 0
};
case "thinking":
return {
reasoningContent: {
reasoningText: {
text: block.thinking,
signature: block.signature
}
}
};
case "document":
return {
document: {
format: block.source?.type === "url" ? "pdf" : "txt",
name: String(block.name || block.title),
source: {
$unknown: ["url", block.source]
}
}
};
default:
throw new Error(
`Unsupported content block type: ${JSON.stringify(block, null, 2)}`
);
}
})
}));
};
// src/tool-adapter.ts
var BedrockToolAdapter = class {
convertToProviderTool(tool) {
return {
toolSpec: {
name: tool.name,
description: tool.description || tool.name,
inputSchema: {
json: {
type: "object",
properties: tool.inputSchema.properties,
required: tool.inputSchema.required
}
}
}
};
}
};
// src/provider.ts
var DEFAULT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0";
var DEFAULT_MAX_TOKENS = 4096;
var BedrockProvider = class extends BaseLLMProvider {
client;
config;
toolAdapter;
name = "Bedrock";
version = "1.0";
constructor(config) {
super();
this.config = config;
this.client = new BedrockRuntimeClient({
region: config.region,
credentials: config.credentials,
...config.bedrockClientConfig
});
this.toolAdapter = new BedrockToolAdapter();
}
createRequestBody(input, options) {
const messageArray = convertStringToMessages(input);
const transformedMessages = convertToMessageFormat(messageArray);
const providerTools = options?.tools?.map(
(tool) => this.toolAdapter.convertToProviderTool(tool)
);
const temperature = this.config.enableThinking || this.config.enableReasoning ? 1 : options?.temperature ?? this.config.temperature ?? 0.7;
return {
modelId: this.config.model || DEFAULT_MODEL,
system: options?.systemPrompt ? [{ text: options.systemPrompt }] : void 0,
messages: transformedMessages,
inferenceConfig: {
maxTokens: options?.maxTokens ?? DEFAULT_MAX_TOKENS,
temperature,
stopSequences: options?.stopSequences
},
toolConfig: providerTools?.length ? {
tools: providerTools,
toolChoice: { auto: {} }
} : void 0,
additionalModelRequestFields: this.config.enableThinking || this.config.enableReasoning ? {
reasoning_config: {
type: "enabled",
budget_tokens: this.config.reasoningBudgetTokens || 1024
}
} : void 0
};
}
handleError(error) {
console.error("Bedrock generation failed:", error);
throw error instanceof Error ? new Error(`Bedrock generation failed: ${error.message}`) : new Error("Bedrock generation failed with unknown error");
}
createGenerateResponse(content, stopReason, inputTokens, outputTokens) {
return {
content,
stopReason: mapStopReason(stopReason),
tokenUsage: {
model: this.config.model || DEFAULT_MODEL,
inputTokens,
outputTokens,
// TODO: introduce cost LLM
inputCost: 0,
outputCost: 0
}
};
}
handleStreamEvent(event, state) {
if ("metadata" in event && event.metadata) {
state.inputTokens = event.metadata.usage?.inputTokens || state.inputTokens;
state.outputTokens = event.metadata.usage?.outputTokens || state.outputTokens;
}
if ("contentBlockStart" in event && event.contentBlockStart) {
this.handleContentBlockStart(
event.contentBlockStart,
state.contentBlocks
);
return null;
}
if ("contentBlockDelta" in event && event.contentBlockDelta) {
return this.handleContentBlockDelta(event.contentBlockDelta, state);
}
return null;
}
handleContentBlockStart({ contentBlockIndex = 0, start }, contentBlocks) {
if (start?.toolUse) {
contentBlocks[contentBlockIndex] = {
...contentBlocks[contentBlockIndex],
toolUse: {
...start.toolUse,
...contentBlocks[contentBlockIndex]?.toolUse
}
};
}
}
handleContentBlockDelta({ delta, contentBlockIndex = 0 }, state) {
if (!delta) return null;
if (delta.reasoningContent) {
return this.handleReasoningDelta(
delta.reasoningContent,
contentBlockIndex,
state
);
}
if (delta.text) {
return this.handleTextDelta(delta.text, contentBlockIndex, state);
} else if (delta.toolUse) {
this.handleToolUseDelta(
delta.toolUse,
contentBlockIndex,
state.contentBlocks
);
}
return null;
}
handleTextDelta(text, contentBlockIndex, state) {
state.buffer += text || "";
state.contentBlocks[contentBlockIndex] = {
...state.contentBlocks[contentBlockIndex],
text: (state.contentBlocks[contentBlockIndex]?.text || "") + text || ""
};
if (state.buffer.length > 10) {
const result = {
type: "text",
text: state.buffer
};
state.buffer = "";
return result;
}
return null;
}
handleReasoningDelta(reasoningContent, contentBlockIndex, state) {
state.buffer += reasoningContent.text || "";
state.contentBlocks[contentBlockIndex] = {
...state.contentBlocks[contentBlockIndex],
reasoningContent: {
...state.contentBlocks[contentBlockIndex]?.reasoningContent,
reasoningText: {
text: (state.contentBlocks[contentBlockIndex]?.reasoningContent?.reasoningText?.text || "") + (reasoningContent.text || ""),
signature: reasoningContent.signature
}
}
};
if (state.buffer.length > 10) {
const result = {
type: "thinking",
thinking: state.buffer
};
state.buffer = "";
return result;
}
return null;
}
handleToolUseDelta(toolUse, contentBlockIndex, contentBlocks) {
contentBlocks[contentBlockIndex] = {
toolUse: {
...contentBlocks[contentBlockIndex]?.toolUse,
input: (contentBlocks[contentBlockIndex]?.toolUse?.input || "") + (toolUse.input || "")
}
};
}
handleContentBlockStop(event, state) {
if (event.contentBlockStop && event.contentBlockStop.contentBlockIndex !== void 0) {
const blockIndex = event.contentBlockStop.contentBlockIndex;
const block = state.contentBlocks[blockIndex];
if (block?.toolUse?.input) {
block.toolUse.input = JSON.parse(String(block.toolUse.input));
const toolUseResult = {
type: "toolUse",
toolUse: mapContent([block])[0],
isEnd: true
};
return toolUseResult;
}
const finalBuffer = state.buffer;
state.buffer = "";
if (block?.text) {
return {
type: "text",
text: finalBuffer,
isEnd: true
};
}
if (block?.reasoningContent) {
return {
type: "thinking",
thinking: finalBuffer,
isEnd: true
};
}
}
return null;
}
handleTextBlockStart(event, state) {
const contentBlockIndex = event.contentBlockDelta?.contentBlockIndex;
const hasExistingBlock = contentBlockIndex !== void 0 && state.contentBlocks[contentBlockIndex] !== void 0;
if (hasExistingBlock) {
return null;
}
if (event.contentBlockDelta?.delta?.text) {
return {
type: "text",
text: "",
isStart: true
};
}
if (event.contentBlockDelta?.delta?.reasoningContent) {
return {
type: "thinking",
thinking: "",
isStart: true
};
}
}
async *stream(input, options) {
try {
const requestBody = this.createRequestBody(input, options);
const response = await this.client.send(
new ConverseStreamCommand({
...requestBody,
...this.config.converseCommandConfig
})
);
if (!response.stream) {
throw new Error("No stream found in response");
}
const state = {
buffer: "",
inputTokens: 0,
outputTokens: 0,
contentBlocks: []
};
let finalStopReason;
for await (const event of response.stream) {
if ("messageStop" in event) {
finalStopReason = event.messageStop?.stopReason;
continue;
}
const startBlock = this.handleTextBlockStart(event, state);
if (startBlock) {
yield startBlock;
}
const stopResult = this.handleContentBlockStop(event, state);
if (stopResult) {
yield stopResult;
}
const result = this.handleStreamEvent(event, state);
if (result) yield result;
}
return this.createGenerateResponse(
mapContent(state.contentBlocks),
finalStopReason,
state.inputTokens,
state.outputTokens
);
} catch (error) {
this.handleError(error);
}
}
async generate(input, options) {
try {
const requestBody = this.createRequestBody(input, options);
const response = await this.client.send(
new ConverseCommand({
...requestBody,
...this.config.converseCommandConfig
})
);
return this.createGenerateResponse(
mapContent(response.output?.message?.content || []),
response.stopReason,
response.usage?.inputTokens || 0,
response.usage?.outputTokens || 0
);
} catch (error) {
this.handleError(error);
}
}
};
export { BedrockProvider };
//# sourceMappingURL=index.js.map
//# sourceMappingURL=index.js.map