UNPKG

@aj-archipelago/cortex

Version:

Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.

579 lines (511 loc) 23.6 kB
// replicateApiPlugin.js import ModelPlugin from "./modelPlugin.js"; import CortexResponse from "../../lib/cortexResponse.js"; import logger from "../../lib/logger.js"; import axios from "axios"; import mime from "mime-types"; // Helper function to collect images from various parameter sources const collectImages = (candidate, accumulator) => { if (!candidate) return; if (Array.isArray(candidate)) { candidate.forEach((item) => collectImages(item, accumulator)); return; } accumulator.push(candidate); }; // Helper function to normalize image entries to strings const normalizeImageEntry = (entry) => { if (!entry) return null; if (typeof entry === "string") { return entry; } if (typeof entry === "object") { if (Array.isArray(entry)) { return null; } if (entry.value) { return entry.value; } if (entry.url) { return entry.url; } if (entry.path) { return entry.path; } } return null; }; // Helper function to omit undefined/null values from an object const omitUndefined = (obj) => Object.fromEntries( Object.entries(obj).filter(([, value]) => value !== undefined && value !== null), ); // Helper function to collect and normalize images from combined parameters const collectNormalizedImages = (combinedParameters, additionalFields = []) => { const imageCandidates = []; const defaultFields = [ 'image', 'images', 'input_image', 'input_images', 'input_image_1', 'input_image_2', 'input_image_3', 'image_1', 'image_2' ]; const allFields = [...defaultFields, ...additionalFields]; allFields.forEach(field => { collectImages(combinedParameters[field], imageCandidates); }); return imageCandidates .map((candidate) => normalizeImageEntry(candidate)) .filter((candidate) => candidate && typeof candidate === 'string'); }; class ReplicateApiPlugin extends ModelPlugin { constructor(pathway, model) { super(pathway, model); } // Set up parameters specific to the Replicate API getRequestParameters(text, parameters, prompt) { const combinedParameters = { ...this.promptParameters, ...parameters }; const { modelPromptText } = this.getCompiledPrompt( text, parameters, prompt, ); let requestParameters = {}; switch (combinedParameters.model) { case "replicate-flux-11-pro": requestParameters = { input: { aspect_ratio: combinedParameters.aspectRatio || "1:1", output_format: combinedParameters.outputFormat || "webp", output_quality: combinedParameters.outputQuality || 80, prompt: modelPromptText, prompt_upsampling: combinedParameters.promptUpsampling || false, safety_tolerance: combinedParameters.safety_tolerance || 3, go_fast: true, megapixels: "1", width: combinedParameters.width, height: combinedParameters.height, size: combinedParameters.size || "1024x1024", style: combinedParameters.style || "realistic_image", ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed) ? { seed: combinedParameters.seed } : {}), }, }; break; case "replicate-recraft-v3": { const validStyles = [ 'any', 'realistic_image', 'digital_illustration', 'digital_illustration/pixel_art', 'digital_illustration/hand_drawn', 'digital_illustration/grain', 'digital_illustration/infantile_sketch', 'digital_illustration/2d_art_poster', 'digital_illustration/handmade_3d', 'digital_illustration/hand_drawn_outline', 'digital_illustration/engraving_color', 'digital_illustration/2d_art_poster_2', 'realistic_image/b_and_w', 'realistic_image/hard_flash', 'realistic_image/hdr', 'realistic_image/natural_light', 'realistic_image/studio_portrait', 'realistic_image/enterprise', 'realistic_image/motion_blur' ]; requestParameters = { input: { prompt: modelPromptText, size: combinedParameters.size || "1024x1024", style: validStyles.includes(combinedParameters.style) ? combinedParameters.style : "realistic_image", }, }; break; } case "replicate-flux-1-schnell": { const validRatios = [ '1:1', '16:9', '21:9', '3:2', '2:3', '4:5', '5:4', '3:4', '4:3', '9:16', '9:21' ]; requestParameters = { input: { aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "1:1", output_format: combinedParameters.outputFormat || "webp", output_quality: combinedParameters.outputQuality || 80, prompt: modelPromptText, go_fast: true, megapixels: "1", num_outputs: combinedParameters.numberResults, num_inference_steps: combinedParameters.steps || 4, disable_safety_checker: true, }, }; break; } case "replicate-qwen-image": { const aspectRatio = combinedParameters.aspect_ratio ?? combinedParameters.aspectRatio ?? "16:9"; const imageSize = combinedParameters.image_size ?? combinedParameters.imageSize ?? "optimize_for_quality"; const outputFormat = combinedParameters.output_format ?? combinedParameters.outputFormat ?? "webp"; const outputQuality = combinedParameters.output_quality ?? combinedParameters.outputQuality ?? 80; const loraScale = combinedParameters.lora_scale ?? combinedParameters.loraScale ?? 1; const enhancePrompt = combinedParameters.enhance_prompt ?? combinedParameters.enhancePrompt ?? false; const negativePrompt = combinedParameters.negative_prompt ?? combinedParameters.negativePrompt ?? " "; const numInferenceSteps = combinedParameters.num_inference_steps ?? combinedParameters.steps ?? 50; const goFast = combinedParameters.go_fast ?? combinedParameters.goFast ?? true; const guidance = combinedParameters.guidance ?? 4; const strength = combinedParameters.strength ?? 0.9; const numOutputs = combinedParameters.num_outputs ?? combinedParameters.numberResults; const disableSafetyChecker = combinedParameters.disable_safety_checker ?? combinedParameters.disableSafetyChecker ?? false; requestParameters = { input: { prompt: modelPromptText, go_fast: goFast, guidance, strength, image_size: imageSize, lora_scale: loraScale, aspect_ratio: aspectRatio, output_format: outputFormat, enhance_prompt: enhancePrompt, output_quality: outputQuality, negative_prompt: negativePrompt, num_inference_steps: numInferenceSteps, disable_safety_checker: disableSafetyChecker, ...(numOutputs ? { num_outputs: numOutputs } : {}), ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed) ? { seed: combinedParameters.seed } : {}), ...(combinedParameters.image ? { image: combinedParameters.image } : {}), ...(combinedParameters.input_image ? { input_image: combinedParameters.input_image } : {}), }, }; break; } case "replicate-qwen-image-edit-plus": { const aspectRatio = combinedParameters.aspect_ratio ?? combinedParameters.aspectRatio ?? "match_input_image"; const outputFormat = combinedParameters.output_format ?? combinedParameters.outputFormat ?? "webp"; const outputQuality = combinedParameters.output_quality ?? combinedParameters.outputQuality ?? 95; const goFast = combinedParameters.go_fast ?? combinedParameters.goFast ?? true; const disableSafetyChecker = combinedParameters.disable_safety_checker ?? combinedParameters.disableSafetyChecker ?? false; const normalizedImages = collectNormalizedImages(combinedParameters); const basePayload = omitUndefined({ prompt: modelPromptText, go_fast: goFast, aspect_ratio: aspectRatio, output_format: outputFormat, output_quality: outputQuality, disable_safety_checker: disableSafetyChecker, }); // For qwen-image-edit-plus, always include the image array if we have images const inputPayload = { ...basePayload, ...(normalizedImages.length > 0 ? { image: normalizedImages } : {}) }; requestParameters = { input: inputPayload, }; break; } case "replicate-qwen-image-edit-2511": { const validRatios = ["1:1", "16:9", "9:16", "4:3", "3:4", "match_input_image"]; const validOutputFormats = ["webp", "jpg", "png"]; const aspectRatio = validRatios.includes(combinedParameters.aspect_ratio ?? combinedParameters.aspectRatio) ? (combinedParameters.aspect_ratio ?? combinedParameters.aspectRatio) : "match_input_image"; const outputFormat = validOutputFormats.includes(combinedParameters.output_format ?? combinedParameters.outputFormat) ? (combinedParameters.output_format ?? combinedParameters.outputFormat) : "webp"; const outputQuality = combinedParameters.output_quality ?? combinedParameters.outputQuality ?? 95; const goFast = combinedParameters.go_fast ?? combinedParameters.goFast ?? true; const disableSafetyChecker = combinedParameters.disable_safety_checker ?? combinedParameters.disableSafetyChecker ?? false; const normalizedImages = collectNormalizedImages(combinedParameters); const basePayload = omitUndefined({ prompt: modelPromptText, go_fast: goFast, aspect_ratio: aspectRatio, output_format: outputFormat, output_quality: Math.max(0, Math.min(100, outputQuality)), disable_safety_checker: disableSafetyChecker, ...(Number.isInteger(combinedParameters.seed) && combinedParameters.seed > 0 ? { seed: combinedParameters.seed } : {}), }); // For qwen-image-edit-2511, format images as array of strings (not objects) const inputPayload = { ...basePayload, ...(normalizedImages.length > 0 ? { image: normalizedImages } : {}) }; requestParameters = { input: inputPayload, }; break; } case "replicate-flux-kontext-pro": case "replicate-flux-kontext-max": { const validRatios = [ '1:1', '16:9', '21:9', '3:2', '2:3', '4:5', '5:4', '3:4', '4:3', '9:16', '9:21', 'match_input_image' ]; let safetyTolerance = combinedParameters.safety_tolerance || 3; if(combinedParameters.input_image){ safetyTolerance = Math.min(safetyTolerance, 2); } requestParameters = { input: { prompt: modelPromptText, input_image: combinedParameters.input_image, aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "1:1", safety_tolerance: safetyTolerance, ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed) && combinedParameters.seed > 0 ? { seed: combinedParameters.seed } : {}), }, }; break; } case "replicate-multi-image-kontext-max": { const validRatios = [ '1:1', '16:9', '21:9', '3:2', '2:3', '4:5', '5:4', '3:4', '4:3', '9:16', '9:21', 'match_input_image' ]; let safetyTolerance = combinedParameters.safety_tolerance || 3; if(combinedParameters.input_image_1 || combinedParameters.input_image) { safetyTolerance = Math.min(safetyTolerance, 2); } requestParameters = { input: { prompt: modelPromptText, input_image_1: combinedParameters.input_image_1 || combinedParameters.input_image, input_image_2: combinedParameters.input_image_2, aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "1:1", safety_tolerance: safetyTolerance, ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed) && combinedParameters.seed > 0 ? { seed: combinedParameters.seed } : {}), }, }; break; } case "replicate-seedance-1-pro": { const validResolutions = ["480p", "1080p"]; const validRatios = ["16:9", "4:3", "9:16", "1:1", "3:4", "21:9", "9:21"]; const validFps = [24]; requestParameters = { input: { prompt: modelPromptText, resolution: validResolutions.includes(combinedParameters.resolution) ? combinedParameters.resolution : "1080p", aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "16:9", ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed) && combinedParameters.seed > 0 ? { seed: combinedParameters.seed } : {}), fps: validFps.includes(combinedParameters.fps) ? combinedParameters.fps : 24, camera_fixed: combinedParameters.camera_fixed || false, duration: combinedParameters.duration || 5, ...(combinedParameters.image ? { image: combinedParameters.image } : {}), }, }; break; } case "replicate-seedance-1.5-pro": { const validRatios = ["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "9:21"]; requestParameters = { input: { prompt: modelPromptText, aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "16:9", duration: Math.min(12, Math.max(2, combinedParameters.duration || 5)), fps: 24, camera_fixed: combinedParameters.camera_fixed || false, generate_audio: combinedParameters.generate_audio || false, ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed) && combinedParameters.seed > 0 ? { seed: combinedParameters.seed } : {}), ...(combinedParameters.image ? { image: combinedParameters.image } : {}), ...(combinedParameters.image && combinedParameters.last_frame_image ? { last_frame_image: combinedParameters.last_frame_image } : {}), }, }; break; } case "replicate-seedream-4": { const validSizes = ["1K", "2K", "4K", "custom"]; const validRatios = ["1:1", "4:3", "3:4", "16:9", "9:16", "match_input_image"]; const validSequentialModes = ["disabled", "auto"]; const normalizedImages = collectNormalizedImages(combinedParameters, ['imageInput']); const basePayload = omitUndefined({ prompt: modelPromptText, size: validSizes.includes(combinedParameters.size) ? combinedParameters.size : "2K", width: combinedParameters.width || 2048, height: combinedParameters.height || 2048, max_images: combinedParameters.maxImages || combinedParameters.numberResults || 1, aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "4:3", sequential_image_generation: validSequentialModes.includes(combinedParameters.sequentialImageGeneration) ? combinedParameters.sequentialImageGeneration : "disabled", ...(Number.isInteger(combinedParameters.seed) && combinedParameters.seed > 0 ? { seed: combinedParameters.seed } : {}), }); // For seedream-4, include the image_input array if we have images const inputPayload = { ...basePayload, ...(normalizedImages.length > 0 ? { image_input: normalizedImages } : {}) }; requestParameters = { input: inputPayload, }; break; } case "replicate-flux-2-pro": { const validResolutions = ["match_input_image", "0.5 MP", "1 MP", "2 MP", "4 MP"]; const validRatios = [ "match_input_image", "custom", "1:1", "16:9", "3:2", "2:3", "4:5", "5:4", "9:16", "3:4", "4:3" ]; const validOutputFormats = ["webp", "jpg", "png"]; const normalizedImages = collectNormalizedImages(combinedParameters).slice(0, 8); // Maximum 8 images const aspectRatio = validRatios.includes(combinedParameters.aspect_ratio ?? combinedParameters.aspectRatio) ? (combinedParameters.aspect_ratio ?? combinedParameters.aspectRatio) : "1:1"; const resolution = validResolutions.includes(combinedParameters.resolution) ? combinedParameters.resolution : "1 MP"; const outputFormat = validOutputFormats.includes(combinedParameters.output_format ?? combinedParameters.outputFormat) ? (combinedParameters.output_format ?? combinedParameters.outputFormat) : "webp"; const outputQuality = combinedParameters.output_quality ?? combinedParameters.outputQuality ?? 80; const safetyTolerance = combinedParameters.safety_tolerance ?? combinedParameters.safetyTolerance ?? 2; // Validate and round width/height to multiples of 32 if provided let width = combinedParameters.width; let height = combinedParameters.height; if (width !== undefined && width !== null) { width = Math.max(256, Math.min(2048, Math.round(width / 32) * 32)); } if (height !== undefined && height !== null) { height = Math.max(256, Math.min(2048, Math.round(height / 32) * 32)); } const basePayload = omitUndefined({ prompt: modelPromptText, aspect_ratio: aspectRatio, resolution: resolution, output_format: outputFormat, output_quality: Math.max(0, Math.min(100, outputQuality)), safety_tolerance: Math.max(1, Math.min(5, safetyTolerance)), ...(width !== undefined && width !== null ? { width } : {}), ...(height !== undefined && height !== null ? { height } : {}), ...(Number.isInteger(combinedParameters.seed) && combinedParameters.seed > 0 ? { seed: combinedParameters.seed } : {}), }); // Include input_images array if we have images const inputPayload = { ...basePayload, ...(normalizedImages.length > 0 ? { input_images: normalizedImages } : {}) }; requestParameters = { input: inputPayload, }; break; } } return requestParameters; } // Execute the request to the Replicate API async execute(text, parameters, prompt, cortexRequest) { const requestParameters = this.getRequestParameters( text, parameters, prompt, ); cortexRequest.data = requestParameters; cortexRequest.params = requestParameters.params; // Make initial request to start prediction const response = await this.executeRequest(cortexRequest); // Parse the response to get the actual Replicate data const parsedResponse = JSON.parse(response.output_text); // If we got a completed response, return it as CortexResponse if (parsedResponse?.status === "succeeded") { return this.createCortexResponse(response); } logger.info("Replicate API returned a non-completed response."); if (!parsedResponse?.id) { throw new Error("No prediction ID returned from Replicate API"); } // Get the prediction ID and polling URL const predictionId = parsedResponse.id; const pollUrl = parsedResponse.urls?.get; if (!pollUrl) { throw new Error("No polling URL returned from Replicate API"); } // Poll for results const maxAttempts = 60; // 5 minutes with 5 second intervals const pollInterval = 5000; for (let attempt = 0; attempt < maxAttempts; attempt++) { try { const pollResponse = await axios.get(pollUrl, { headers: cortexRequest.headers }); logger.info("Polling Replicate API - attempt " + attempt); const status = pollResponse.data?.status; if (status === "succeeded") { logger.info("Replicate API returned a completed response after polling"); // Parse the polled response to extract artifacts const parsedResponse = this.parseResponse(pollResponse.data); return this.createCortexResponse(parsedResponse); } else if (status === "failed" || status === "canceled") { throw new Error(`Prediction ${status}: ${pollResponse.data?.error || "Unknown error"}`); } // Wait before next poll await new Promise(resolve => setTimeout(resolve, pollInterval)); } catch (error) { logger.error(`Error polling prediction ${predictionId}: ${error.message}`); throw error; } } throw new Error(`Prediction ${predictionId} timed out after ${maxAttempts * pollInterval / 1000} seconds`); } // Parse the response from the Replicate API and extract image artifacts parseResponse(data) { const responseData = data.data || data; const stringifiedResponse = JSON.stringify(responseData); // Extract image URLs from Replicate response for artifacts const imageArtifacts = []; if (responseData?.output && Array.isArray(responseData.output)) { for (const outputItem of responseData.output) { if (typeof outputItem === 'string' && outputItem.match(/\.(jpg|jpeg|png|gif|webp)$/i)) { // This is an image URL from Replicate imageArtifacts.push({ type: "image", url: outputItem, mimeType: this.getMimeTypeFromUrl(outputItem) }); } } } return { output_text: stringifiedResponse, artifacts: imageArtifacts }; } // Create a CortexResponse from parsed response data createCortexResponse(parsedResponse) { if (typeof parsedResponse === 'string') { // Handle string response (backward compatibility) return new CortexResponse({ output_text: parsedResponse, artifacts: [] }); } else if (parsedResponse && typeof parsedResponse === 'object') { // Handle object response with artifacts return new CortexResponse({ output_text: parsedResponse.output_text, artifacts: parsedResponse.artifacts || [] }); } else { throw new Error('Unexpected response format'); } } // Helper method to determine MIME type from URL extension getMimeTypeFromUrl(url) { // Extract path from URL (remove query params and fragments) const urlPath = url.split('?')[0].split('#')[0]; return mime.lookup(urlPath) || 'image/jpeg'; // Default fallback for images } // Override the logging function to display the request and response logRequestData(data, responseData, prompt) { const modelInput = data?.input?.prompt; logger.verbose(`${modelInput}`); logger.verbose(`${this.parseResponse(responseData)}`); prompt && prompt.debugInfo && (prompt.debugInfo += `\n${JSON.stringify(data)}`); } } export default ReplicateApiPlugin;