@genkit-ai/googleai
Version:
Genkit AI framework plugin for Google AI APIs, including Gemini APIs.
324 lines (305 loc) • 8.63 kB
text/typescript
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {
embedderActionMetadata,
embedderRef,
modelActionMetadata,
type ActionMetadata,
type EmbedderReference,
type Genkit,
type ModelReference,
type z,
} from 'genkit';
import { logger } from 'genkit/logging';
import { modelRef } from 'genkit/model';
import { genkitPlugin, type GenkitPlugin } from 'genkit/plugin';
import type { ActionType } from 'genkit/registry';
import { getApiKeyFromEnvVar } from './common.js';
import {
SUPPORTED_MODELS as EMBEDDER_MODELS,
GeminiEmbeddingConfigSchema,
defineGoogleAIEmbedder,
textEmbedding004,
textEmbeddingGecko001,
type GeminiEmbeddingConfig,
} from './embedder.js';
import {
GeminiConfigSchema,
SUPPORTED_V15_MODELS,
defineGoogleAIModel,
gemini,
gemini10Pro,
gemini15Flash,
gemini15Flash8b,
gemini15Pro,
gemini20Flash,
gemini20FlashExp,
gemini20FlashLite,
gemini20ProExp0205,
gemini25FlashPreview0417,
gemini25ProExp0325,
gemini25ProPreview0325,
type GeminiConfig,
type GeminiVersionString,
} from './gemini.js';
import { listModels } from './list-models.js';
export {
gemini,
gemini10Pro,
gemini15Flash,
gemini15Flash8b,
gemini15Pro,
gemini20Flash,
gemini20FlashExp,
gemini20FlashLite,
gemini20ProExp0205,
gemini25FlashPreview0417,
gemini25ProExp0325,
gemini25ProPreview0325,
textEmbedding004,
textEmbeddingGecko001,
type GeminiConfig,
type GeminiVersionString,
};
export interface PluginOptions {
/**
* Provide the API key to use to authenticate with the Gemini API. By
* default, an API key must be provided explicitly here or through the
* `GEMINI_API_KEY` or `GOOGLE_API_KEY` environment variables.
*
* If `false` is explicitly passed, the plugin will be configured to
* expect an `apiKey` option to be provided to the model config at
* call time.
**/
apiKey?: string | false;
apiVersion?: string | string[];
baseUrl?: string;
models?: (
| ModelReference</** @ignore */ typeof GeminiConfigSchema>
| string
)[];
experimental_debugTraces?: boolean;
}
async function initializer(ai: Genkit, options?: PluginOptions) {
let apiVersions = ['v1'];
if (options?.apiVersion) {
if (Array.isArray(options?.apiVersion)) {
apiVersions = options?.apiVersion;
} else {
apiVersions = [options?.apiVersion];
}
}
if (apiVersions.includes('v1beta')) {
Object.keys(SUPPORTED_V15_MODELS).forEach((name) =>
defineGoogleAIModel({
ai,
name,
apiKey: options?.apiKey,
apiVersion: 'v1beta',
baseUrl: options?.baseUrl,
debugTraces: options?.experimental_debugTraces,
})
);
}
if (apiVersions.includes('v1')) {
Object.keys(SUPPORTED_V15_MODELS).forEach((name) =>
defineGoogleAIModel({
ai,
name,
apiKey: options?.apiKey,
apiVersion: undefined,
baseUrl: options?.baseUrl,
debugTraces: options?.experimental_debugTraces,
})
);
Object.keys(EMBEDDER_MODELS).forEach((name) =>
defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey })
);
}
if (options?.models) {
for (const modelOrRef of options?.models) {
const modelName =
typeof modelOrRef === 'string'
? modelOrRef
: // strip out the `googleai/` prefix
modelOrRef.name.split('/')[1];
const modelRef =
typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef;
defineGoogleAIModel({
ai,
name: modelName,
apiKey: options?.apiKey,
baseUrl: options?.baseUrl,
info: {
...modelRef.info,
label: `Google AI - ${modelName}`,
},
debugTraces: options?.experimental_debugTraces,
});
}
}
}
async function resolver(
ai: Genkit,
actionType: ActionType,
actionName: string,
options?: PluginOptions
) {
switch (actionType) {
case 'model':
resolveModel(ai, actionName, options);
break;
case 'embedder':
resolveEmbedder(ai, actionName, options);
break;
default:
// no-op
}
}
function resolveModel(ai: Genkit, actionName: string, options?: PluginOptions) {
const modelRef = gemini(actionName);
defineGoogleAIModel({
ai,
name: modelRef.name,
apiKey: options?.apiKey,
baseUrl: options?.baseUrl,
info: {
...modelRef.info,
label: `Google AI - ${actionName}`,
},
debugTraces: options?.experimental_debugTraces,
});
}
function resolveEmbedder(
ai: Genkit,
actionName: string,
options?: PluginOptions
) {
defineGoogleAIEmbedder(ai, `googleai/${actionName}`, {
apiKey: options?.apiKey,
});
}
async function listActions(options?: PluginOptions): Promise<ActionMetadata[]> {
const apiKey = options?.apiKey || getApiKeyFromEnvVar();
if (!apiKey) {
// If API key is not configured we don't want to error, just return empty.
// In practice it will never actually reach this point without the API key,
// plugin initializer will fail before that.
logger.error(
'Pass in the API key or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable.'
);
return [];
}
const models = await listModels(
options?.baseUrl || 'https://generativelanguage.googleapis.com',
apiKey
);
return [
// Models
...models
.filter((m) => m.supportedGenerationMethods.includes('generateContent'))
// Filter out deprecated
.filter((m) => !m.description || !m.description.includes('deprecated'))
.map((m) => {
const ref = gemini(
m.name.startsWith('models/')
? m.name.substring('models/'.length)
: m.name
);
return modelActionMetadata({
name: ref.name,
info: ref.info,
configSchema: GeminiConfigSchema,
});
}),
// Embedders
...models
.filter((m) => m.supportedGenerationMethods.includes('embedContent'))
// Filter out deprecated
.filter((m) => !m.description || !m.description.includes('deprecated'))
.map((m) => {
const name =
'googleai/' +
(m.name.startsWith('models/')
? m.name.substring('models/'.length)
: m.name);
return embedderActionMetadata({
name,
configSchema: GeminiEmbeddingConfigSchema,
info: {
dimensions: 768,
label: `Google Gen AI - ${name}`,
supports: {
input: ['text'],
},
},
});
}),
];
}
/**
* Google Gemini Developer API plugin.
*/
export function googleAIPlugin(options?: PluginOptions): GenkitPlugin {
let listActionsCache;
return genkitPlugin(
'googleai',
async (ai: Genkit) => await initializer(ai, options),
async (ai: Genkit, actionType: ActionType, actionName: string) =>
await resolver(ai, actionType, actionName, options),
async () => {
if (listActionsCache) return listActionsCache;
listActionsCache = await listActions(options);
return listActionsCache;
}
);
}
export type GoogleAIPlugin = {
(params?: PluginOptions): GenkitPlugin;
model(
name: GeminiVersionString,
config?: z.infer<typeof GeminiConfigSchema>
): ModelReference<typeof GeminiConfigSchema>;
embedder(
name: string,
config?: GeminiEmbeddingConfig
): EmbedderReference<typeof GeminiEmbeddingConfigSchema>;
};
/**
* Google Gemini Developer API plugin.
*/
export const googleAI = googleAIPlugin as GoogleAIPlugin;
googleAI.model = (
name: GeminiVersionString,
config?: GeminiConfig
): ModelReference<typeof GeminiConfigSchema> => {
return modelRef({
name: `googleai/${name}`,
config,
configSchema: GeminiConfigSchema,
});
};
googleAI.embedder = (
name: string,
config?: GeminiEmbeddingConfig
): EmbedderReference<typeof GeminiEmbeddingConfigSchema> => {
return embedderRef({
name: `googleai/${name}`,
config,
configSchema: GeminiEmbeddingConfigSchema,
});
};
export default googleAI;