@genkit-ai/googleai
Version:
Genkit AI framework plugin for Google AI APIs, including Gemini APIs.
111 lines • 3.33 kB
JavaScript
import { GenkitError, z } from "genkit";
import {
getBasicUsageStats,
modelRef
} from "genkit/model";
import { getApiKeyFromEnvVar } from "./common.js";
import { predictModel } from "./predict.js";
const ImagenConfigSchema = z.object({
numberOfImages: z.number().describe(
"The number of images to generate, from 1 to 4 (inclusive). The default is 1."
).optional(),
aspectRatio: z.enum(["1:1", "9:16", "16:9", "3:4", "4:3"]).describe("Desired aspect ratio of the output image.").optional(),
personGeneration: z.enum(["dont_allow", "allow_adult", "allow_all"]).describe(
"Control if/how images of people will be generated by the model."
).optional()
}).passthrough();
function toParameters(request) {
const out = {
sampleCount: request.config?.numberOfImages ?? 1,
...request?.config
};
for (const k in out) {
if (!out[k]) delete out[k];
}
return out;
}
function extractText(request) {
return request.messages.at(-1).content.map((c) => c.text || "").join("");
}
function extractBaseImage(request) {
return request.messages.at(-1)?.content.find((p) => !!p.media)?.media?.url.split(",")[1];
}
const GENERIC_IMAGEN_INFO = {
label: `Google AI - Generic Imagen`,
supports: {
media: true,
multiturn: false,
tools: false,
systemRole: false,
output: ["media"]
}
};
function defineImagenModel(ai, name, apiKey) {
if (apiKey !== false) {
apiKey = apiKey || getApiKeyFromEnvVar();
if (!apiKey) {
throw new GenkitError({
status: "FAILED_PRECONDITION",
message: "Please pass in the API key or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable.\nFor more details see https://genkit.dev/docs/plugins/google-genai"
});
}
}
const modelName = `googleai/${name}`;
const model = modelRef({
name: modelName,
info: {
...GENERIC_IMAGEN_INFO,
label: `Google AI - ${name}`
},
configSchema: ImagenConfigSchema
});
return ai.defineModel(
{
name: modelName,
...model.info,
configSchema: ImagenConfigSchema
},
async (request) => {
const instance = {
prompt: extractText(request)
};
const baseImage = extractBaseImage(request);
if (baseImage) {
instance.image = { bytesBase64Encoded: baseImage };
}
const predictClient = predictModel(model.version || name, apiKey, "predict");
const response = await predictClient([instance], toParameters(request));
if (!response.predictions || response.predictions.length == 0) {
throw new Error(
"Model returned no predictions. Possibly due to content filters."
);
}
const message = {
role: "model",
content: []
};
response.predictions.forEach((p, i) => {
const b64data = p.bytesBase64Encoded;
const mimeType = p.mimeType;
message.content.push({
media: {
url: `data:${mimeType};base64,${b64data}`,
contentType: mimeType
}
});
});
return {
finishReason: "stop",
message,
usage: getBasicUsageStats(request.messages, message),
custom: response
};
}
);
}
export {
GENERIC_IMAGEN_INFO,
ImagenConfigSchema,
defineImagenModel
};
//# sourceMappingURL=imagen.mjs.map