@ai-sdk/google-vertex
Version:
The **[Google Vertex provider](https://ai-sdk.dev/providers/ai-sdk-providers/google-vertex)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the [Google Vertex AI](https://cloud.google.com/vertex-ai) APIs.
218 lines (186 loc) • 6.37 kB
text/typescript
import { GoogleGenerativeAILanguageModel } from '@ai-sdk/google/internal';
import { ImageModelV3, LanguageModelV3, ProviderV3 } from '@ai-sdk/provider';
import {
FetchFunction,
generateId,
loadOptionalSetting,
loadSetting,
normalizeHeaders,
resolve,
Resolvable,
withoutTrailingSlash,
withUserAgentSuffix,
} from '@ai-sdk/provider-utils';
import { VERSION } from './version';
import { GoogleVertexConfig } from './google-vertex-config';
import { GoogleVertexEmbeddingModel } from './google-vertex-embedding-model';
import { GoogleVertexEmbeddingModelId } from './google-vertex-embedding-options';
import { GoogleVertexImageModel } from './google-vertex-image-model';
import { GoogleVertexImageModelId } from './google-vertex-image-settings';
import { GoogleVertexModelId } from './google-vertex-options';
import { googleVertexTools } from './google-vertex-tools';
const EXPRESS_MODE_BASE_URL =
'https://aiplatform.googleapis.com/v1/publishers/google';
// set `x-goog-api-key` header to API key for express mode
function createExpressModeFetch(
apiKey: string,
customFetch?: FetchFunction,
): FetchFunction {
return async (url, init) => {
const modifiedInit: RequestInit = {
...init,
headers: {
...(init?.headers ? normalizeHeaders(init.headers) : {}),
'x-goog-api-key': apiKey,
},
};
return (customFetch ?? fetch)(url.toString(), modifiedInit);
};
}
export interface GoogleVertexProvider extends ProviderV3 {
/**
Creates a model for text generation.
*/
(modelId: GoogleVertexModelId): LanguageModelV3;
languageModel: (modelId: GoogleVertexModelId) => LanguageModelV3;
/**
* Creates a model for image generation.
*/
image(modelId: GoogleVertexImageModelId): ImageModelV3;
/**
Creates a model for image generation.
*/
imageModel(modelId: GoogleVertexImageModelId): ImageModelV3;
tools: typeof googleVertexTools;
/**
* @deprecated Use `embeddingModel` instead.
*/
textEmbeddingModel(
modelId: GoogleVertexEmbeddingModelId,
): GoogleVertexEmbeddingModel;
}
export interface GoogleVertexProviderSettings {
/**
* Optional. The API key for the Google Cloud project. If provided, the
* provider will use express mode with API key authentication. Defaults to
* the value of the `GOOGLE_VERTEX_API_KEY` environment variable.
*/
apiKey?: string;
/**
Your Google Vertex location. Defaults to the environment variable `GOOGLE_VERTEX_LOCATION`.
*/
location?: string;
/**
Your Google Vertex project. Defaults to the environment variable `GOOGLE_VERTEX_PROJECT`.
*/
project?: string;
/**
* Headers to use for requests. Can be:
* - A headers object
* - A Promise that resolves to a headers object
* - A function that returns a headers object
* - A function that returns a Promise of a headers object
*/
headers?: Resolvable<Record<string, string | undefined>>;
/**
Custom fetch implementation. You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.
*/
fetch?: FetchFunction;
// for testing
generateId?: () => string;
/**
Base URL for the Google Vertex API calls.
*/
baseURL?: string;
}
/**
Create a Google Vertex AI provider instance.
*/
export function createVertex(
options: GoogleVertexProviderSettings = {},
): GoogleVertexProvider {
const apiKey = loadOptionalSetting({
settingValue: options.apiKey,
environmentVariableName: 'GOOGLE_VERTEX_API_KEY',
});
const loadVertexProject = () =>
loadSetting({
settingValue: options.project,
settingName: 'project',
environmentVariableName: 'GOOGLE_VERTEX_PROJECT',
description: 'Google Vertex project',
});
const loadVertexLocation = () =>
loadSetting({
settingValue: options.location,
settingName: 'location',
environmentVariableName: 'GOOGLE_VERTEX_LOCATION',
description: 'Google Vertex location',
});
const loadBaseURL = () => {
if (apiKey) {
return withoutTrailingSlash(options.baseURL) ?? EXPRESS_MODE_BASE_URL;
}
const region = loadVertexLocation();
const project = loadVertexProject();
// For global region, use aiplatform.googleapis.com directly
// For other regions, use region-aiplatform.googleapis.com
const baseHost = `${region === 'global' ? '' : region + '-'}aiplatform.googleapis.com`;
return (
withoutTrailingSlash(options.baseURL) ??
`https://${baseHost}/v1beta1/projects/${project}/locations/${region}/publishers/google`
);
};
const createConfig = (name: string): GoogleVertexConfig => {
const getHeaders = async () => {
const originalHeaders = await resolve(options.headers ?? {});
return withUserAgentSuffix(
originalHeaders,
`ai-sdk/google-vertex/${VERSION}`,
);
};
return {
provider: `google.vertex.${name}`,
headers: getHeaders,
fetch: apiKey
? createExpressModeFetch(apiKey, options.fetch)
: options.fetch,
baseURL: loadBaseURL(),
};
};
const createChatModel = (modelId: GoogleVertexModelId) => {
return new GoogleGenerativeAILanguageModel(modelId, {
...createConfig('chat'),
generateId: options.generateId ?? generateId,
supportedUrls: () => ({
'*': [
// HTTP URLs:
/^https?:\/\/.*$/,
// Google Cloud Storage URLs:
/^gs:\/\/.*$/,
],
}),
});
};
const createEmbeddingModel = (modelId: GoogleVertexEmbeddingModelId) =>
new GoogleVertexEmbeddingModel(modelId, createConfig('embedding'));
const createImageModel = (modelId: GoogleVertexImageModelId) =>
new GoogleVertexImageModel(modelId, createConfig('image'));
const provider = function (modelId: GoogleVertexModelId) {
if (new.target) {
throw new Error(
'The Google Vertex AI model function cannot be called with the new keyword.',
);
}
return createChatModel(modelId);
};
provider.specificationVersion = 'v3' as const;
provider.languageModel = createChatModel;
provider.embeddingModel = createEmbeddingModel;
provider.textEmbeddingModel = createEmbeddingModel;
provider.image = createImageModel;
provider.imageModel = createImageModel;
provider.tools = googleVertexTools;
return provider;
}