@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
306 lines • 15.7 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.ImagenModelDefinition = exports.ImagenMaskMode = exports.ImagenTaskType = void 0;
const core_1 = require("@llumiverse/core");
// Import the helper module for converting arbitrary protobuf.Value objects
const aiplatform_1 = require("@google-cloud/aiplatform");
var ImagenTaskType;
(function (ImagenTaskType) {
ImagenTaskType["TEXT_IMAGE"] = "TEXT_IMAGE";
ImagenTaskType["EDIT_MODE_INPAINT_REMOVAL"] = "EDIT_MODE_INPAINT_REMOVAL";
ImagenTaskType["EDIT_MODE_INPAINT_INSERTION"] = "EDIT_MODE_INPAINT_INSERTION";
ImagenTaskType["EDIT_MODE_BGSWAP"] = "EDIT_MODE_BGSWAP";
ImagenTaskType["EDIT_MODE_OUTPAINT"] = "EDIT_MODE_OUTPAINT";
ImagenTaskType["CUSTOMIZATION_SUBJECT"] = "CUSTOMIZATION_SUBJECT";
ImagenTaskType["CUSTOMIZATION_STYLE"] = "CUSTOMIZATION_STYLE";
ImagenTaskType["CUSTOMIZATION_CONTROLLED"] = "CUSTOMIZATION_CONTROLLED";
ImagenTaskType["CUSTOMIZATION_INSTRUCT"] = "CUSTOMIZATION_INSTRUCT";
})(ImagenTaskType || (exports.ImagenTaskType = ImagenTaskType = {}));
var ImagenMaskMode;
(function (ImagenMaskMode) {
ImagenMaskMode["MASK_MODE_USER_PROVIDED"] = "MASK_MODE_USER_PROVIDED";
ImagenMaskMode["MASK_MODE_BACKGROUND"] = "MASK_MODE_BACKGROUND";
ImagenMaskMode["MASK_MODE_FOREGROUND"] = "MASK_MODE_FOREGROUND";
ImagenMaskMode["MASK_MODE_SEMANTIC"] = "MASK_MODE_SEMANTIC";
})(ImagenMaskMode || (exports.ImagenMaskMode = ImagenMaskMode = {}));
function getImagenParameters(taskType, options) {
const commonParameters = {
sampleCount: options?.number_of_images,
seed: options?.seed,
safetySetting: options?.safety_setting,
personGeneration: options?.person_generation,
negativePrompt: taskType ? undefined : "", //Filled in later from the prompt
//TODO: Add more safety and prompt rejection information
//includeSafetyAttributes: true,
//includeRaiReason: true,
};
switch (taskType) {
case ImagenTaskType.EDIT_MODE_INPAINT_REMOVAL:
return {
...commonParameters,
editMode: "EDIT_MODE_INPAINT_REMOVAL",
editConfig: {
baseSteps: options?.edit_steps,
},
};
case ImagenTaskType.EDIT_MODE_INPAINT_INSERTION:
return {
...commonParameters,
editMode: "EDIT_MODE_INPAINT_INSERTION",
editConfig: {
baseSteps: options?.edit_steps,
},
};
case ImagenTaskType.EDIT_MODE_BGSWAP:
return {
...commonParameters,
editMode: "EDIT_MODE_BGSWAP",
editConfig: {
baseSteps: options?.edit_steps,
},
};
case ImagenTaskType.EDIT_MODE_OUTPAINT:
return {
...commonParameters,
editMode: "EDIT_MODE_OUTPAINT",
editConfig: {
baseSteps: options?.edit_steps,
},
};
case ImagenTaskType.TEXT_IMAGE:
return {
...commonParameters,
// You can't use a seed value and watermark at the same time.
addWatermark: options?.add_watermark,
aspectRatio: options?.aspect_ratio,
enhancePrompt: options?.enhance_prompt,
};
case ImagenTaskType.CUSTOMIZATION_SUBJECT:
case ImagenTaskType.CUSTOMIZATION_CONTROLLED:
case ImagenTaskType.CUSTOMIZATION_INSTRUCT:
case ImagenTaskType.CUSTOMIZATION_STYLE:
return {
...commonParameters,
};
default:
throw new Error("Task type not supported");
}
}
class ImagenModelDefinition {
model;
constructor(modelId) {
this.model = {
id: modelId,
name: modelId,
provider: 'vertexai',
type: core_1.ModelType.Image,
can_stream: false,
};
}
async createPrompt(_driver, segments, options) {
const splits = options.model.split("/");
const modelName = splits[splits.length - 1];
options = { ...options, model: modelName };
const prompt = {
prompt: "",
};
//Collect text prompts, Imagen does not support roles, so everything gets merged together
// however we still respect our typical pattern. System First, Safety Last.
const system = [];
const user = [];
const safety = [];
const negative = [];
const mask_mode = options.model_options?.mask_mode;
const imagenOptions = options.model_options;
for (const msg of segments) {
if (msg.role === core_1.PromptRole.safety) {
safety.push(msg.content);
}
else if (msg.role === core_1.PromptRole.system) {
system.push(msg.content);
}
else if (msg.role === core_1.PromptRole.negative) {
negative.push(msg.content);
}
else {
//Everything else is assumed to be user or user adjacent.
user.push(msg.content);
}
if (msg.files) {
//Get images from messages
if (!prompt.referenceImages) {
prompt.referenceImages = [];
}
//Always required, but only used by customisation.
//Each ref ID refers to a single "reference", i.e. object. To provide multiple images of a single ref,
//include multiple images in one prompt.
const refId = prompt.referenceImages.length + 1;
for (const img of msg.files) {
if (img.mime_type?.includes("image")) {
if (msg.role !== core_1.PromptRole.mask) {
//Editing based mode requires a reference image
if (imagenOptions?.edit_mode?.includes("EDIT_MODE")) {
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_RAW",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()),
}
});
//If mask is auto-generated, add a mask reference
if (mask_mode !== ImagenMaskMode.MASK_MODE_USER_PROVIDED) {
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_MASK",
referenceId: refId,
maskImageConfig: {
maskMode: mask_mode,
dilation: imagenOptions?.mask_dilation,
}
});
}
}
else if (options.model_options?.edit_mode === ImagenTaskType.CUSTOMIZATION_SUBJECT) {
//First image is always the control image
if (refId == 1) {
//Customization subject mode requires a control image
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_CONTROL",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()),
},
controlImageConfig: {
controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY",
enableControlImageComputation: imagenOptions?.controlImageComputation,
}
});
}
else {
// Subject images
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_SUBJECT",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()),
},
subjectImageConfig: {
subjectDescription: prompt.subjectDescription ?? msg.content,
subjectType: imagenOptions?.subjectType ?? "SUBJECT_TYPE_DEFAULT",
}
});
}
}
else if (options.model_options?.edit_mode === ImagenTaskType.CUSTOMIZATION_STYLE) {
// Style images
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_STYLE",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()),
},
styleImageConfig: {
styleDescription: prompt.subjectDescription ?? msg.content,
}
});
}
else if (options.model_options?.edit_mode === ImagenTaskType.CUSTOMIZATION_CONTROLLED) {
// Control images
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_CONTROL",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()),
},
controlImageConfig: {
controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY",
enableControlImageComputation: imagenOptions?.controlImageComputation,
}
});
}
else if (options.model_options?.edit_mode === ImagenTaskType.CUSTOMIZATION_INSTRUCT) {
// Control images
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_RAW",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()),
},
});
}
}
//If mask is user-provided, add a mask reference
if (msg.role === core_1.PromptRole.mask && mask_mode === ImagenMaskMode.MASK_MODE_USER_PROVIDED) {
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_MASK",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()),
},
maskImageConfig: {
maskMode: mask_mode,
dilation: imagenOptions?.mask_dilation,
}
});
}
}
}
}
}
//Extract the text from the segments
prompt.prompt += [system.join("\n\n"), user.join("\n\n"), safety.join("\n\n")].join("\n\n");
//Negative prompt
if (negative.length > 0) {
prompt.negativePrompt = negative.join(", ");
}
console.log(prompt);
return prompt;
}
async requestImageGeneration(driver, prompt, options) {
if (options.model_options?._option_id !== "vertexai-imagen") {
driver.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
if (options.output_modality !== core_1.Modalities.image) {
throw new Error(`Image generation requires image output_modality`);
}
const taskType = options.model_options?.edit_mode ?? ImagenTaskType.TEXT_IMAGE;
driver.logger.info("Task type: " + taskType);
const modelName = options.model.split("/").pop() ?? '';
// Configure the parent resource
// TODO: make location configurable, fixed to us-central1 for now
const endpoint = `projects/${driver.options.project}/locations/us-central1/publishers/google/models/${modelName}`;
const instanceValue = aiplatform_1.helpers.toValue(prompt);
if (!instanceValue) {
throw new Error('No instance value found');
}
const instances = [instanceValue];
let parameter = getImagenParameters(taskType, options.model_options ?? { _option_id: "vertexai-imagen" });
parameter.negativePrompt = prompt.negativePrompt ?? undefined;
const numberOfImages = options.model_options?.number_of_images ?? 1;
// Remove all undefined values
parameter = Object.fromEntries(Object.entries(parameter).filter(([_, v]) => v !== undefined));
const parameters = aiplatform_1.helpers.toValue(parameter);
const request = {
endpoint,
instances,
parameters,
};
const client = await driver.getImagenClient();
// Predict request
const [response] = await client.predict(request, { timeout: 120000 * numberOfImages }); //Extended timeout for image generation
const predictions = response.predictions;
if (!predictions) {
throw new Error('No predictions found');
}
// Extract base64 encoded images from predictions
const images = predictions.map(prediction => prediction.structValue?.fields?.bytesBase64Encoded?.stringValue ?? '');
return {
result: images.map(image => ({
type: "image",
value: image
})),
};
}
}
exports.ImagenModelDefinition = ImagenModelDefinition;
//# sourceMappingURL=imagen.js.map