UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

306 lines 15.7 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.ImagenModelDefinition = exports.ImagenMaskMode = exports.ImagenTaskType = void 0; const core_1 = require("@llumiverse/core"); // Import the helper module for converting arbitrary protobuf.Value objects const aiplatform_1 = require("@google-cloud/aiplatform"); var ImagenTaskType; (function (ImagenTaskType) { ImagenTaskType["TEXT_IMAGE"] = "TEXT_IMAGE"; ImagenTaskType["EDIT_MODE_INPAINT_REMOVAL"] = "EDIT_MODE_INPAINT_REMOVAL"; ImagenTaskType["EDIT_MODE_INPAINT_INSERTION"] = "EDIT_MODE_INPAINT_INSERTION"; ImagenTaskType["EDIT_MODE_BGSWAP"] = "EDIT_MODE_BGSWAP"; ImagenTaskType["EDIT_MODE_OUTPAINT"] = "EDIT_MODE_OUTPAINT"; ImagenTaskType["CUSTOMIZATION_SUBJECT"] = "CUSTOMIZATION_SUBJECT"; ImagenTaskType["CUSTOMIZATION_STYLE"] = "CUSTOMIZATION_STYLE"; ImagenTaskType["CUSTOMIZATION_CONTROLLED"] = "CUSTOMIZATION_CONTROLLED"; ImagenTaskType["CUSTOMIZATION_INSTRUCT"] = "CUSTOMIZATION_INSTRUCT"; })(ImagenTaskType || (exports.ImagenTaskType = ImagenTaskType = {})); var ImagenMaskMode; (function (ImagenMaskMode) { ImagenMaskMode["MASK_MODE_USER_PROVIDED"] = "MASK_MODE_USER_PROVIDED"; ImagenMaskMode["MASK_MODE_BACKGROUND"] = "MASK_MODE_BACKGROUND"; ImagenMaskMode["MASK_MODE_FOREGROUND"] = "MASK_MODE_FOREGROUND"; ImagenMaskMode["MASK_MODE_SEMANTIC"] = "MASK_MODE_SEMANTIC"; })(ImagenMaskMode || (exports.ImagenMaskMode = ImagenMaskMode = {})); function getImagenParameters(taskType, options) { const commonParameters = { sampleCount: options?.number_of_images, seed: options?.seed, safetySetting: options?.safety_setting, personGeneration: options?.person_generation, negativePrompt: taskType ? undefined : "", //Filled in later from the prompt //TODO: Add more safety and prompt rejection information //includeSafetyAttributes: true, //includeRaiReason: true, }; switch (taskType) { case ImagenTaskType.EDIT_MODE_INPAINT_REMOVAL: return { ...commonParameters, editMode: "EDIT_MODE_INPAINT_REMOVAL", editConfig: { baseSteps: options?.edit_steps, }, }; case ImagenTaskType.EDIT_MODE_INPAINT_INSERTION: return { ...commonParameters, editMode: "EDIT_MODE_INPAINT_INSERTION", editConfig: { baseSteps: options?.edit_steps, }, }; case ImagenTaskType.EDIT_MODE_BGSWAP: return { ...commonParameters, editMode: "EDIT_MODE_BGSWAP", editConfig: { baseSteps: options?.edit_steps, }, }; case ImagenTaskType.EDIT_MODE_OUTPAINT: return { ...commonParameters, editMode: "EDIT_MODE_OUTPAINT", editConfig: { baseSteps: options?.edit_steps, }, }; case ImagenTaskType.TEXT_IMAGE: return { ...commonParameters, // You can't use a seed value and watermark at the same time. addWatermark: options?.add_watermark, aspectRatio: options?.aspect_ratio, enhancePrompt: options?.enhance_prompt, }; case ImagenTaskType.CUSTOMIZATION_SUBJECT: case ImagenTaskType.CUSTOMIZATION_CONTROLLED: case ImagenTaskType.CUSTOMIZATION_INSTRUCT: case ImagenTaskType.CUSTOMIZATION_STYLE: return { ...commonParameters, }; default: throw new Error("Task type not supported"); } } class ImagenModelDefinition { model; constructor(modelId) { this.model = { id: modelId, name: modelId, provider: 'vertexai', type: core_1.ModelType.Image, can_stream: false, }; } async createPrompt(_driver, segments, options) { const splits = options.model.split("/"); const modelName = splits[splits.length - 1]; options = { ...options, model: modelName }; const prompt = { prompt: "", }; //Collect text prompts, Imagen does not support roles, so everything gets merged together // however we still respect our typical pattern. System First, Safety Last. const system = []; const user = []; const safety = []; const negative = []; const mask_mode = options.model_options?.mask_mode; const imagenOptions = options.model_options; for (const msg of segments) { if (msg.role === core_1.PromptRole.safety) { safety.push(msg.content); } else if (msg.role === core_1.PromptRole.system) { system.push(msg.content); } else if (msg.role === core_1.PromptRole.negative) { negative.push(msg.content); } else { //Everything else is assumed to be user or user adjacent. user.push(msg.content); } if (msg.files) { //Get images from messages if (!prompt.referenceImages) { prompt.referenceImages = []; } //Always required, but only used by customisation. //Each ref ID refers to a single "reference", i.e. object. To provide multiple images of a single ref, //include multiple images in one prompt. const refId = prompt.referenceImages.length + 1; for (const img of msg.files) { if (img.mime_type?.includes("image")) { if (msg.role !== core_1.PromptRole.mask) { //Editing based mode requires a reference image if (imagenOptions?.edit_mode?.includes("EDIT_MODE")) { prompt.referenceImages.push({ referenceType: "REFERENCE_TYPE_RAW", referenceId: refId, referenceImage: { bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()), } }); //If mask is auto-generated, add a mask reference if (mask_mode !== ImagenMaskMode.MASK_MODE_USER_PROVIDED) { prompt.referenceImages.push({ referenceType: "REFERENCE_TYPE_MASK", referenceId: refId, maskImageConfig: { maskMode: mask_mode, dilation: imagenOptions?.mask_dilation, } }); } } else if (options.model_options?.edit_mode === ImagenTaskType.CUSTOMIZATION_SUBJECT) { //First image is always the control image if (refId == 1) { //Customization subject mode requires a control image prompt.referenceImages.push({ referenceType: "REFERENCE_TYPE_CONTROL", referenceId: refId, referenceImage: { bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()), }, controlImageConfig: { controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY", enableControlImageComputation: imagenOptions?.controlImageComputation, } }); } else { // Subject images prompt.referenceImages.push({ referenceType: "REFERENCE_TYPE_SUBJECT", referenceId: refId, referenceImage: { bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()), }, subjectImageConfig: { subjectDescription: prompt.subjectDescription ?? msg.content, subjectType: imagenOptions?.subjectType ?? "SUBJECT_TYPE_DEFAULT", } }); } } else if (options.model_options?.edit_mode === ImagenTaskType.CUSTOMIZATION_STYLE) { // Style images prompt.referenceImages.push({ referenceType: "REFERENCE_TYPE_STYLE", referenceId: refId, referenceImage: { bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()), }, styleImageConfig: { styleDescription: prompt.subjectDescription ?? msg.content, } }); } else if (options.model_options?.edit_mode === ImagenTaskType.CUSTOMIZATION_CONTROLLED) { // Control images prompt.referenceImages.push({ referenceType: "REFERENCE_TYPE_CONTROL", referenceId: refId, referenceImage: { bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()), }, controlImageConfig: { controlType: imagenOptions?.controlType === "CONTROL_TYPE_FACE_MESH" ? "CONTROL_TYPE_FACE_MESH" : "CONTROL_TYPE_CANNY", enableControlImageComputation: imagenOptions?.controlImageComputation, } }); } else if (options.model_options?.edit_mode === ImagenTaskType.CUSTOMIZATION_INSTRUCT) { // Control images prompt.referenceImages.push({ referenceType: "REFERENCE_TYPE_RAW", referenceId: refId, referenceImage: { bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()), }, }); } } //If mask is user-provided, add a mask reference if (msg.role === core_1.PromptRole.mask && mask_mode === ImagenMaskMode.MASK_MODE_USER_PROVIDED) { prompt.referenceImages.push({ referenceType: "REFERENCE_TYPE_MASK", referenceId: refId, referenceImage: { bytesBase64Encoded: await (0, core_1.readStreamAsBase64)(await img.getStream()), }, maskImageConfig: { maskMode: mask_mode, dilation: imagenOptions?.mask_dilation, } }); } } } } } //Extract the text from the segments prompt.prompt += [system.join("\n\n"), user.join("\n\n"), safety.join("\n\n")].join("\n\n"); //Negative prompt if (negative.length > 0) { prompt.negativePrompt = negative.join(", "); } console.log(prompt); return prompt; } async requestImageGeneration(driver, prompt, options) { if (options.model_options?._option_id !== "vertexai-imagen") { driver.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; if (options.output_modality !== core_1.Modalities.image) { throw new Error(`Image generation requires image output_modality`); } const taskType = options.model_options?.edit_mode ?? ImagenTaskType.TEXT_IMAGE; driver.logger.info("Task type: " + taskType); const modelName = options.model.split("/").pop() ?? ''; // Configure the parent resource // TODO: make location configurable, fixed to us-central1 for now const endpoint = `projects/${driver.options.project}/locations/us-central1/publishers/google/models/${modelName}`; const instanceValue = aiplatform_1.helpers.toValue(prompt); if (!instanceValue) { throw new Error('No instance value found'); } const instances = [instanceValue]; let parameter = getImagenParameters(taskType, options.model_options ?? { _option_id: "vertexai-imagen" }); parameter.negativePrompt = prompt.negativePrompt ?? undefined; const numberOfImages = options.model_options?.number_of_images ?? 1; // Remove all undefined values parameter = Object.fromEntries(Object.entries(parameter).filter(([_, v]) => v !== undefined)); const parameters = aiplatform_1.helpers.toValue(parameter); const request = { endpoint, instances, parameters, }; const client = await driver.getImagenClient(); // Predict request const [response] = await client.predict(request, { timeout: 120000 * numberOfImages }); //Extended timeout for image generation const predictions = response.predictions; if (!predictions) { throw new Error('No predictions found'); } // Extract base64 encoded images from predictions const images = predictions.map(prediction => prediction.structValue?.fields?.bytesBase64Encoded?.stringValue ?? ''); return { result: images.map(image => ({ type: "image", value: image })), }; } } exports.ImagenModelDefinition = ImagenModelDefinition; //# sourceMappingURL=imagen.js.map