UNPKG

@genkit-ai/googleai

Version:

Genkit AI framework plugin for Google AI APIs, including Gemini APIs.

111 lines 3.33 kB
import { GenkitError, z } from "genkit"; import { getBasicUsageStats, modelRef } from "genkit/model"; import { getApiKeyFromEnvVar } from "./common.js"; import { predictModel } from "./predict.js"; const ImagenConfigSchema = z.object({ numberOfImages: z.number().describe( "The number of images to generate, from 1 to 4 (inclusive). The default is 1." ).optional(), aspectRatio: z.enum(["1:1", "9:16", "16:9", "3:4", "4:3"]).describe("Desired aspect ratio of the output image.").optional(), personGeneration: z.enum(["dont_allow", "allow_adult", "allow_all"]).describe( "Control if/how images of people will be generated by the model." ).optional() }).passthrough(); function toParameters(request) { const out = { sampleCount: request.config?.numberOfImages ?? 1, ...request?.config }; for (const k in out) { if (!out[k]) delete out[k]; } return out; } function extractText(request) { return request.messages.at(-1).content.map((c) => c.text || "").join(""); } function extractBaseImage(request) { return request.messages.at(-1)?.content.find((p) => !!p.media)?.media?.url.split(",")[1]; } const GENERIC_IMAGEN_INFO = { label: `Google AI - Generic Imagen`, supports: { media: true, multiturn: false, tools: false, systemRole: false, output: ["media"] } }; function defineImagenModel(ai, name, apiKey) { if (apiKey !== false) { apiKey = apiKey || getApiKeyFromEnvVar(); if (!apiKey) { throw new GenkitError({ status: "FAILED_PRECONDITION", message: "Please pass in the API key or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable.\nFor more details see https://genkit.dev/docs/plugins/google-genai" }); } } const modelName = `googleai/${name}`; const model = modelRef({ name: modelName, info: { ...GENERIC_IMAGEN_INFO, label: `Google AI - ${name}` }, configSchema: ImagenConfigSchema }); return ai.defineModel( { name: modelName, ...model.info, configSchema: ImagenConfigSchema }, async (request) => { const instance = { prompt: extractText(request) }; const baseImage = extractBaseImage(request); if (baseImage) { instance.image = { bytesBase64Encoded: baseImage }; } const predictClient = predictModel(model.version || name, apiKey, "predict"); const response = await predictClient([instance], toParameters(request)); if (!response.predictions || response.predictions.length == 0) { throw new Error( "Model returned no predictions. Possibly due to content filters." ); } const message = { role: "model", content: [] }; response.predictions.forEach((p, i) => { const b64data = p.bytesBase64Encoded; const mimeType = p.mimeType; message.content.push({ media: { url: `data:${mimeType};base64,${b64data}`, contentType: mimeType } }); }); return { finishReason: "stop", message, usage: getBasicUsageStats(request.messages, message), custom: response }; } ); } export { GENERIC_IMAGEN_INFO, ImagenConfigSchema, defineImagenModel }; //# sourceMappingURL=imagen.mjs.map