@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
391 lines (349 loc) • 17.6 kB
text/typescript
import {
AIModel, Completion, ExecutionOptions, Modalities,
ModelType, PromptRole, PromptSegment, readStreamAsBase64, ImagenOptions
} from "@llumiverse/core";
import { VertexAIDriver } from "../index.js";
// Import the helper module for converting arbitrary protobuf.Value objects
import { protos, helpers } from '@google-cloud/aiplatform';
interface ImagenBaseReference {
referenceType: "REFERENCE_TYPE_RAW" | "REFERENCE_TYPE_MASK" | "REFERENCE_TYPE_SUBJECT" |
"REFERENCE_TYPE_CONTROL" | "REFERENCE_TYPE_STYLE";
referenceId: number;
referenceImage: {
bytesBase64Encoded: string; //10MB max
}
}
export enum ImagenTaskType {
TEXT_IMAGE = "TEXT_IMAGE",
EDIT_MODE_INPAINT_REMOVAL = "EDIT_MODE_INPAINT_REMOVAL",
EDIT_MODE_INPAINT_INSERTION = "EDIT_MODE_INPAINT_INSERTION",
EDIT_MODE_BGSWAP = "EDIT_MODE_BGSWAP",
EDIT_MODE_OUTPAINT = "EDIT_MODE_OUTPAINT",
CUSTOMIZATION_SUBJECT = "CUSTOMIZATION_SUBJECT",
CUSTOMIZATION_STYLE = "CUSTOMIZATION_STYLE",
CUSTOMIZATION_CONTROLLED = "CUSTOMIZATION_CONTROLLED",
CUSTOMIZATION_INSTRUCT = "CUSTOMIZATION_INSTRUCT",
}
export enum ImagenMaskMode {
MASK_MODE_USER_PROVIDED = "MASK_MODE_USER_PROVIDED",
MASK_MODE_BACKGROUND = "MASK_MODE_BACKGROUND",
MASK_MODE_FOREGROUND = "MASK_MODE_FOREGROUND",
MASK_MODE_SEMANTIC = "MASK_MODE_SEMANTIC",
}
interface ImagenReferenceRaw extends ImagenBaseReference {
referenceType: "REFERENCE_TYPE_RAW";
}
interface ImagenReferenceMask extends Omit<ImagenBaseReference, "referenceImage"> {
referenceType: "REFERENCE_TYPE_MASK";
maskImageConfig: {
maskMode?: ImagenMaskMode;
maskClasses?: number[]; //Used for MASK_MODE_SEMANTIC, based on https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api-customization#segment-ids
dilation?: number; //Recommendation depends on mode: Inpaint: 0.01, BGSwap: 0.0, Outpaint: 0.01-0.03
}
referenceImage?: { //Only used for MASK_MODE_USER_PROVIDED
bytesBase64Encoded: string; //10MB max
}
}
interface ImagenReferenceSubject extends ImagenBaseReference {
referenceType: "REFERENCE_TYPE_SUBJECT";
subjectImageConfig: {
subjectDescription: string;
subjectType: "SUBJECT_TYPE_PERSON" | "SUBJECT_TYPE_ANIMAL" | "SUBJECT_TYPE_PRODUCT" | "SUBJECT_TYPE_DEFAULT";
}
}
interface ImagenReferenceControl extends ImagenBaseReference {
referenceType: "REFERENCE_TYPE_CONTROL";
controlImageConfig: {
controlType: "CONTROL_TYPE_FACE_MESH" | "CONTROL_TYPE_CANNY" | "CONTROL_TYPE_SCRIBBLE";
enableControlImageComputation?: boolean; //If true, the model will compute the control image
}
}
interface ImagenReferenceStyle extends ImagenBaseReference {
referenceType: "REFERENCE_TYPE_STYLE";
styleImageConfig: {
styleDescription?: string;
}
}
type ImagenMessage = ImagenReferenceRaw | ImagenReferenceMask | ImagenReferenceSubject | ImagenReferenceControl | ImagenReferenceStyle;
export interface ImagenPrompt {
prompt: string;
referenceImages?: ImagenMessage[];
subjectDescription?: string; //Used for image customization to describe in the reference image
negativePrompt?: string; //Used for negative prompts
}
function getImagenParameters(taskType: string, options: ImagenOptions) {
const commonParameters = {
sampleCount: options?.number_of_images,
seed: options?.seed,
safetySetting: options?.safety_setting,
personGeneration: options?.person_generation,
negativePrompt: taskType ? undefined : "", //Filled in later from the prompt
//TODO: Add more safety and prompt rejection information
//includeSafetyAttributes: true,
//includeRaiReason: true,
};
switch (taskType) {
case ImagenTaskType.EDIT_MODE_INPAINT_REMOVAL:
return {
...commonParameters,
editMode: "EDIT_MODE_INPAINT_REMOVAL",
editConfig: {
baseSteps: options?.edit_steps,
},
}
case ImagenTaskType.EDIT_MODE_INPAINT_INSERTION:
return {
...commonParameters,
editMode: "EDIT_MODE_INPAINT_INSERTION",
editConfig: {
baseSteps: options?.edit_steps,
},
}
case ImagenTaskType.EDIT_MODE_BGSWAP:
return {
...commonParameters,
editMode: "EDIT_MODE_BGSWAP",
editConfig: {
baseSteps: options?.edit_steps,
},
}
case ImagenTaskType.EDIT_MODE_OUTPAINT:
return {
...commonParameters,
editMode: "EDIT_MODE_OUTPAINT",
editConfig: {
baseSteps: options?.edit_steps,
},
}
case ImagenTaskType.TEXT_IMAGE:
return {
...commonParameters,
// You can't use a seed value and watermark at the same time.
addWatermark: options?.add_watermark,
aspectRatio: options?.aspect_ratio,
enhancePrompt: options?.enhance_prompt,
};
case ImagenTaskType.CUSTOMIZATION_SUBJECT:
case ImagenTaskType.CUSTOMIZATION_CONTROLLED:
case ImagenTaskType.CUSTOMIZATION_INSTRUCT:
case ImagenTaskType.CUSTOMIZATION_STYLE:
return {
...commonParameters,
}
default:
throw new Error("Task type not supported");
}
}
export class ImagenModelDefinition {
model: AIModel
constructor(modelId: string) {
this.model = {
id: modelId,
name: modelId,
provider: 'vertexai',
type: ModelType.Image,
can_stream: false,
};
}
async createPrompt(_driver: VertexAIDriver, segments: PromptSegment[], options: ExecutionOptions): Promise<ImagenPrompt> {
const splits = options.model.split("/");
const modelName = splits[splits.length - 1];
options = { ...options, model: modelName };
const prompt: ImagenPrompt = {
prompt: "",
}
//Collect text prompts, Imagen does not support roles, so everything gets merged together
// however we still respect our typical pattern. System First, Safety Last.
const system: string[] = [];
const user: string[] = [];
const safety: string[] = [];
const negative: string[] = [];
const mask_mode = (options.model_options as ImagenOptions)?.mask_mode;
const imagenOptions = options.model_options as ImagenOptions;
for (const msg of segments) {
if (msg.role === PromptRole.safety) {
safety.push(msg.content);
} else if (msg.role === PromptRole.system) {
system.push(msg.content);
} else if (msg.role === PromptRole.negative) {
negative.push(msg.content);
} else {
//Everything else is assumed to be user or user adjacent.
user.push(msg.content);
}
if (msg.files) {
//Get images from messages
if (!prompt.referenceImages) {
prompt.referenceImages = [];
}
//Always required, but only used by customisation.
//Each ref ID refers to a single "reference", i.e. object. To provide multiple images of a single ref,
//include multiple images in one prompt.
const refId = prompt.referenceImages.length + 1;
for (const img of msg.files) {
if (img.mime_type?.includes("image")) {
if (msg.role !== PromptRole.mask) {
//Editing based mode requires a reference image
if (imagenOptions?.edit_mode?.includes("EDIT_MODE")) {
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_RAW",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
}
});
//If mask is auto-generated, add a mask reference
if (mask_mode !== ImagenMaskMode.MASK_MODE_USER_PROVIDED) {
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_MASK",
referenceId: refId,
maskImageConfig: {
maskMode: mask_mode,
dilation: imagenOptions?.mask_dilation,
}
});
}
}
else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_SUBJECT) {
//First image is always the control image
if (refId == 1) {
//Customization subject mode requires a control image
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_CONTROL",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
},
controlImageConfig: {
controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY",
enableControlImageComputation: imagenOptions?.controlImageComputation,
}
});
} else {
// Subject images
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_SUBJECT",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
},
subjectImageConfig: {
subjectDescription: prompt.subjectDescription ?? msg.content,
subjectType: imagenOptions?.subjectType ?? "SUBJECT_TYPE_DEFAULT",
}
});
}
} else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_STYLE) {
// Style images
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_STYLE",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
},
styleImageConfig: {
styleDescription: prompt.subjectDescription ?? msg.content,
}
});
} else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_CONTROLLED) {
// Control images
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_CONTROL",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
},
controlImageConfig: {
controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY",
enableControlImageComputation: imagenOptions?.controlImageComputation,
}
});
} else if ((options.model_options as ImagenOptions)?.edit_mode === ImagenTaskType.CUSTOMIZATION_INSTRUCT) {
// Control images
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_RAW",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
},
});
}
}
//If mask is user-provided, add a mask reference
if (msg.role === PromptRole.mask && mask_mode === ImagenMaskMode.MASK_MODE_USER_PROVIDED) {
prompt.referenceImages.push({
referenceType: "REFERENCE_TYPE_MASK",
referenceId: refId,
referenceImage: {
bytesBase64Encoded: await readStreamAsBase64(await img.getStream()),
},
maskImageConfig: {
maskMode: mask_mode,
dilation: imagenOptions?.mask_dilation,
}
});
}
}
}
}
}
//Extract the text from the segments
prompt.prompt += [system.join("\n\n"), user.join("\n\n"), safety.join("\n\n")].join("\n\n");
//Negative prompt
if (negative.length > 0) {
prompt.negativePrompt = negative.join(", ");
}
console.log(prompt);
return prompt
}
async requestImageGeneration(driver: VertexAIDriver, prompt: ImagenPrompt, options: ExecutionOptions): Promise<Completion> {
if (options.model_options?._option_id !== "vertexai-imagen") {
driver.logger.warn("Invalid model options", { options: options.model_options });
}
options.model_options = options.model_options as ImagenOptions | undefined;
if (options.output_modality !== Modalities.image) {
throw new Error(`Image generation requires image output_modality`);
}
const taskType: string = options.model_options?.edit_mode ?? ImagenTaskType.TEXT_IMAGE;
driver.logger.info("Task type: " + taskType);
const modelName = options.model.split("/").pop() ?? '';
// Configure the parent resource
// TODO: make location configurable, fixed to us-central1 for now
const endpoint = `projects/${driver.options.project}/locations/us-central1/publishers/google/models/${modelName}`;
const instanceValue = helpers.toValue(prompt);
if (!instanceValue) {
throw new Error('No instance value found');
}
const instances = [instanceValue];
let parameter: any = getImagenParameters(taskType, options.model_options ?? {_option_id: "vertexai-imagen"});
parameter.negativePrompt = prompt.negativePrompt ?? undefined;
const numberOfImages = options.model_options?.number_of_images ?? 1;
// Remove all undefined values
parameter = Object.fromEntries(
Object.entries(parameter).filter(([_, v]) => v !== undefined)
) as any;
const parameters = helpers.toValue(parameter);
const request: protos.google.cloud.aiplatform.v1.IPredictRequest = {
endpoint,
instances,
parameters,
};
const client = driver.getImagenClient();
// Predict request
const [response] = await client.predict(request, { timeout: 120000 * numberOfImages }); //Extended timeout for image generation
const predictions = response.predictions;
if (!predictions) {
throw new Error('No predictions found');
}
// Extract base64 encoded images from predictions
const images: string[] = predictions.map(prediction =>
prediction.structValue?.fields?.bytesBase64Encoded?.stringValue ?? ''
);
return {
result: images.map(image => ({
type: "image" as const,
value: image
})),
};
}
}