@ai-sdk/amazon-bedrock
Version:
The **[Amazon Bedrock provider](https://ai-sdk.dev/providers/ai-sdk-providers/amazon-bedrock)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the Amazon Bedrock [converse API](https://docs.aws.amazon.com/bedrock/latest/APIR
232 lines (212 loc) • 7.09 kB
text/typescript
import {
TooManyEmbeddingValuesForCallError,
type EmbeddingModelV4,
} from '@ai-sdk/provider';
import {
combineHeaders,
createJsonErrorResponseHandler,
createJsonResponseHandler,
parseProviderOptions,
postJsonToApi,
resolve,
serializeModelOptions,
WORKFLOW_SERIALIZE,
WORKFLOW_DESERIALIZE,
type FetchFunction,
type Resolvable,
} from '@ai-sdk/provider-utils';
import {
amazonBedrockEmbeddingModelOptionsSchema,
type AmazonBedrockEmbeddingModelId,
} from './amazon-bedrock-embedding-model-options';
import { AmazonBedrockErrorSchema } from './amazon-bedrock-error';
import { z } from 'zod/v4';
type AmazonBedrockEmbeddingConfig = {
baseUrl: () => string;
headers?: Resolvable<Record<string, string | undefined>>;
fetch?: FetchFunction;
};
type DoEmbedResponse = Awaited<ReturnType<EmbeddingModelV4['doEmbed']>>;
export class AmazonBedrockEmbeddingModel implements EmbeddingModelV4 {
readonly specificationVersion = 'v4';
readonly provider = 'amazon-bedrock';
readonly supportsParallelCalls = true;
get maxEmbeddingsPerCall() {
return isCohereEmbeddingModel(this.modelId) ? 96 : 1;
}
static [WORKFLOW_SERIALIZE](model: AmazonBedrockEmbeddingModel) {
return serializeModelOptions({
modelId: model.modelId,
config: model.config,
});
}
static [WORKFLOW_DESERIALIZE](options: {
modelId: string;
config: AmazonBedrockEmbeddingConfig;
}) {
return new AmazonBedrockEmbeddingModel(options.modelId, options.config);
}
constructor(
readonly modelId: AmazonBedrockEmbeddingModelId,
private readonly config: AmazonBedrockEmbeddingConfig,
) {}
private getUrl(modelId: string): string {
const encodedModelId = encodeURIComponent(modelId);
return `${this.config.baseUrl()}/model/${encodedModelId}/invoke`;
}
async doEmbed({
values,
headers,
abortSignal,
providerOptions,
}: Parameters<EmbeddingModelV4['doEmbed']>[0]): Promise<DoEmbedResponse> {
if (values.length > this.maxEmbeddingsPerCall) {
throw new TooManyEmbeddingValuesForCallError({
provider: this.provider,
modelId: this.modelId,
maxEmbeddingsPerCall: this.maxEmbeddingsPerCall,
values,
});
}
// Parse provider options. Prefer `amazonBedrock`; fall back to legacy
// `bedrock` key for backward compatibility.
const amazonBedrockOptions =
(await parseProviderOptions({
provider: 'amazonBedrock',
providerOptions,
schema: amazonBedrockEmbeddingModelOptionsSchema,
})) ??
(await parseProviderOptions({
provider: 'bedrock',
providerOptions,
schema: amazonBedrockEmbeddingModelOptionsSchema,
})) ??
{};
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html
//
// Note: Different embedding model families expect different request/response
// payloads (e.g. Titan vs Cohere vs Nova). We keep the public interface stable and
// adapt here based on the modelId.
const isNovaModel = isNovaEmbeddingModel(this.modelId);
const isCohereModel = isCohereEmbeddingModel(this.modelId);
const args = isNovaModel
? {
taskType: 'SINGLE_EMBEDDING',
singleEmbeddingParams: {
embeddingPurpose:
amazonBedrockOptions.embeddingPurpose ?? 'GENERIC_INDEX',
embeddingDimension: amazonBedrockOptions.embeddingDimension ?? 1024,
text: {
truncationMode: amazonBedrockOptions.truncate ?? 'END',
value: values[0],
},
},
}
: isCohereModel
? {
// Cohere embedding models on Bedrock require `input_type`.
// Without it, the service attempts other schema branches and rejects the request.
input_type: amazonBedrockOptions.inputType ?? 'search_query',
texts: values,
truncate: amazonBedrockOptions.truncate,
output_dimension: amazonBedrockOptions.outputDimension,
}
: {
inputText: values[0],
dimensions: amazonBedrockOptions.dimensions,
normalize: amazonBedrockOptions.normalize,
};
const url = this.getUrl(this.modelId);
const { value: response, responseHeaders } = await postJsonToApi({
url,
headers: await resolve(
combineHeaders(
this.config.headers ? await resolve(this.config.headers) : undefined,
headers,
),
),
body: args,
failedResponseHandler: createJsonErrorResponseHandler({
errorSchema: AmazonBedrockErrorSchema,
errorToMessage: error => `${error.type}: ${error.message}`,
}),
successfulResponseHandler: createJsonResponseHandler(
AmazonBedrockEmbeddingResponseSchema,
),
fetch: this.config.fetch,
abortSignal,
});
// Extract embeddings based on response format
let embeddings: number[][];
if ('embedding' in response) {
// Titan response
embeddings = [response.embedding];
} else if (Array.isArray(response.embeddings)) {
const firstEmbedding = response.embeddings[0];
if (
typeof firstEmbedding === 'object' &&
firstEmbedding !== null &&
'embeddingType' in firstEmbedding
) {
// Nova response
embeddings = [firstEmbedding.embedding];
} else {
// Cohere v3 response
embeddings = response.embeddings as number[][];
}
} else {
// Cohere v4 response
embeddings = response.embeddings.float;
}
// Extract token count based on response format
const headerTokenCount = Number(
responseHeaders?.['x-amzn-bedrock-input-token-count'],
);
const tokens =
'inputTextTokenCount' in response
? response.inputTextTokenCount // Titan response
: 'inputTokenCount' in response
? (response.inputTokenCount ?? 0) // Nova response
: headerTokenCount;
return {
embeddings,
usage: { tokens },
warnings: [],
};
}
}
function isCohereEmbeddingModel(modelId: string) {
// Use `includes` so cross-region inference profile ids (e.g.
// `us.cohere.embed-v4:0`, `global.cohere.embed-v4:0`) are detected too.
return modelId.includes('cohere.embed-');
}
function isNovaEmbeddingModel(modelId: string) {
return modelId.startsWith('amazon.nova-') && modelId.includes('embed');
}
const AmazonBedrockEmbeddingResponseSchema = z.union([
// Titan-style response
z.object({
embedding: z.array(z.number()),
inputTextTokenCount: z.number(),
}),
// Nova-style response
z.object({
embeddings: z.array(
z.object({
embeddingType: z.string(),
embedding: z.array(z.number()),
}),
),
inputTokenCount: z.number().optional(),
}),
// Cohere v3-style response
z.object({
embeddings: z.array(z.array(z.number())),
}),
// Cohere v4-style response
z.object({
embeddings: z.object({
float: z.array(z.array(z.number())),
}),
}),
]);