UNPKG

@aj-archipelago/cortex

Version:

Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.

252 lines (220 loc) 9.36 kB
// replicateApiPlugin.js import ModelPlugin from "./modelPlugin.js"; import logger from "../../lib/logger.js"; import axios from "axios"; class ReplicateApiPlugin extends ModelPlugin { constructor(pathway, model) { super(pathway, model); } // Set up parameters specific to the Replicate API getRequestParameters(text, parameters, prompt) { const combinedParameters = { ...this.promptParameters, ...parameters }; const { modelPromptText } = this.getCompiledPrompt( text, parameters, prompt, ); let requestParameters = {}; switch (combinedParameters.model) { case "replicate-flux-11-pro": requestParameters = { input: { aspect_ratio: combinedParameters.aspectRatio || "1:1", output_format: combinedParameters.outputFormat || "webp", output_quality: combinedParameters.outputQuality || 80, prompt: modelPromptText, prompt_upsampling: combinedParameters.promptUpsampling || false, safety_tolerance: combinedParameters.safety_tolerance || 3, go_fast: true, megapixels: "1", width: combinedParameters.width, height: combinedParameters.height, size: combinedParameters.size || "1024x1024", style: combinedParameters.style || "realistic_image", ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed) ? { seed: combinedParameters.seed } : {}), }, }; break; case "replicate-recraft-v3": { const validStyles = [ 'any', 'realistic_image', 'digital_illustration', 'digital_illustration/pixel_art', 'digital_illustration/hand_drawn', 'digital_illustration/grain', 'digital_illustration/infantile_sketch', 'digital_illustration/2d_art_poster', 'digital_illustration/handmade_3d', 'digital_illustration/hand_drawn_outline', 'digital_illustration/engraving_color', 'digital_illustration/2d_art_poster_2', 'realistic_image/b_and_w', 'realistic_image/hard_flash', 'realistic_image/hdr', 'realistic_image/natural_light', 'realistic_image/studio_portrait', 'realistic_image/enterprise', 'realistic_image/motion_blur' ]; requestParameters = { input: { prompt: modelPromptText, size: combinedParameters.size || "1024x1024", style: validStyles.includes(combinedParameters.style) ? combinedParameters.style : "realistic_image", }, }; break; } case "replicate-flux-1-schnell": { const validRatios = [ '1:1', '16:9', '21:9', '3:2', '2:3', '4:5', '5:4', '3:4', '4:3', '9:16', '9:21' ]; requestParameters = { input: { aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "1:1", output_format: combinedParameters.outputFormat || "webp", output_quality: combinedParameters.outputQuality || 80, prompt: modelPromptText, go_fast: true, megapixels: "1", num_outputs: combinedParameters.numberResults, num_inference_steps: combinedParameters.steps || 4, disable_safety_checker: true, }, }; break; } case "replicate-flux-kontext-pro": case "replicate-flux-kontext-max": { const validRatios = [ '1:1', '16:9', '21:9', '3:2', '2:3', '4:5', '5:4', '3:4', '4:3', '9:16', '9:21', 'match_input_image' ]; let safetyTolerance = combinedParameters.safety_tolerance || 3; if(combinedParameters.input_image){ safetyTolerance = Math.min(safetyTolerance, 2); } requestParameters = { input: { prompt: modelPromptText, input_image: combinedParameters.input_image, aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "1:1", safety_tolerance: safetyTolerance, ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed && combinedParameters.seed > 0) ? { seed: combinedParameters.seed } : {}), }, }; break; } case "replicate-multi-image-kontext-max": { const validRatios = [ '1:1', '16:9', '21:9', '3:2', '2:3', '4:5', '5:4', '3:4', '4:3', '9:16', '9:21', 'match_input_image' ]; let safetyTolerance = combinedParameters.safety_tolerance || 3; if(combinedParameters.input_image_1 || combinedParameters.input_image) { safetyTolerance = Math.min(safetyTolerance, 2); } requestParameters = { input: { prompt: modelPromptText, input_image_1: combinedParameters.input_image_1 || combinedParameters.input_image, input_image_2: combinedParameters.input_image_2, aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "1:1", safety_tolerance: safetyTolerance, ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed && combinedParameters.seed > 0) ? { seed: combinedParameters.seed } : {}), }, }; break; } case "replicate-seedance-1-pro": { const validResolutions = ["480p", "1080p"]; const validRatios = ["16:9", "4:3", "9:16", "1:1", "3:4", "21:9", "9:21"]; const validFps = [24]; requestParameters = { input: { prompt: modelPromptText, resolution: validResolutions.includes(combinedParameters.resolution) ? combinedParameters.resolution : "1080p", aspect_ratio: validRatios.includes(combinedParameters.aspectRatio) ? combinedParameters.aspectRatio : "16:9", ...(combinedParameters.seed && Number.isInteger(combinedParameters.seed && combinedParameters.seed > 0) ? { seed: combinedParameters.seed } : {}), fps: validFps.includes(combinedParameters.fps) ? combinedParameters.fps : 24, camera_fixed: combinedParameters.camera_fixed || false, duration: combinedParameters.duration || 5, ...(combinedParameters.image ? { image: combinedParameters.image } : {}), }, }; break; } } return requestParameters; } // Execute the request to the Replicate API async execute(text, parameters, prompt, cortexRequest) { const requestParameters = this.getRequestParameters( text, parameters, prompt, ); cortexRequest.data = requestParameters; cortexRequest.params = requestParameters.params; // Make initial request to start prediction const stringifiedResponse = await this.executeRequest(cortexRequest); const parsedResponse = JSON.parse(stringifiedResponse); // If we got a completed response, return it if (parsedResponse?.status === "succeeded") { return stringifiedResponse; } logger.info("Replicate API returned a non-completed response."); if (!parsedResponse?.id) { throw new Error("No prediction ID returned from Replicate API"); } // Get the prediction ID and polling URL const predictionId = parsedResponse.id; const pollUrl = parsedResponse.urls?.get; if (!pollUrl) { throw new Error("No polling URL returned from Replicate API"); } // Poll for results const maxAttempts = 60; // 5 minutes with 5 second intervals const pollInterval = 5000; for (let attempt = 0; attempt < maxAttempts; attempt++) { try { const pollResponse = await axios.get(pollUrl, { headers: cortexRequest.headers }); logger.info("Polling Replicate API - attempt " + attempt); const status = pollResponse.data?.status; if (status === "succeeded") { logger.info("Replicate API returned a completed response after polling"); return JSON.stringify(pollResponse.data); } else if (status === "failed" || status === "canceled") { throw new Error(`Prediction ${status}: ${pollResponse.data?.error || "Unknown error"}`); } // Wait before next poll await new Promise(resolve => setTimeout(resolve, pollInterval)); } catch (error) { logger.error(`Error polling prediction ${predictionId}: ${error.message}`); throw error; } } throw new Error(`Prediction ${predictionId} timed out after ${maxAttempts * pollInterval / 1000} seconds`); } // Stringify the response from the Replicate API parseResponse(data) { if (data.data) { return JSON.stringify(data.data); } return JSON.stringify(data); } // Override the logging function to display the request and response logRequestData(data, responseData, prompt) { const modelInput = data?.input?.prompt; logger.verbose(`${modelInput}`); logger.verbose(`${this.parseResponse(responseData)}`); prompt && prompt.debugInfo && (prompt.debugInfo += `\n${JSON.stringify(data)}`); } } export default ReplicateApiPlugin;