UNPKG

@subscribe.dev/replicate-frontend-proxy

Version:

AWS Lambda function that serves as a secure proxy for the Replicate API

1,509 lines (1,483 loc) 54.4 kB
var __create = Object.create; var __getProtoOf = Object.getPrototypeOf; var __defProp = Object.defineProperty; var __getOwnPropNames = Object.getOwnPropertyNames; var __hasOwnProp = Object.prototype.hasOwnProperty; var __toESM = (mod, isNodeMode, target) => { target = mod != null ? __create(__getProtoOf(mod)) : {}; const to = isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target; for (let key of __getOwnPropNames(mod)) if (!__hasOwnProp.call(to, key)) __defProp(to, key, { get: () => mod[key], enumerable: true }); return to; }; var __commonJS = (cb, mod) => () => (mod || cb((mod = { exports: {} }).exports, mod), mod.exports); // node_modules/replicate/lib/error.js var require_error = __commonJS((exports, module) => { class ApiError extends Error { constructor(message, request, response) { super(message); this.name = "ApiError"; this.request = request; this.response = response; } } module.exports = ApiError; }); // node_modules/replicate/lib/identifier.js var require_identifier = __commonJS((exports, module) => { class ModelVersionIdentifier { constructor(owner, name, version = null) { this.owner = owner; this.name = name; this.version = version; } static parse(ref) { const match = ref.match(/^(?<owner>[^/]+)\/(?<name>[^/:]+)(:(?<version>.+))?$/); if (!match) { throw new Error(`Invalid reference to model version: ${ref}. Expected format: owner/name or owner/name:version`); } const { owner, name, version } = match.groups; return new ModelVersionIdentifier(owner, name, version); } } module.exports = ModelVersionIdentifier; }); // node_modules/replicate/lib/files.js var require_files = __commonJS((exports, module) => { async function createFile(file, metadata = {}, { signal } = {}) { const form = new FormData; let filename; let blob; if (file instanceof Blob) { filename = file.name || `blob_${Date.now()}`; blob = file; } else if (Buffer.isBuffer(file)) { filename = `buffer_${Date.now()}`; const bytes = new Uint8Array(file); blob = new Blob([bytes], { type: "application/octet-stream", name: filename }); } else { throw new Error("Invalid file argument, must be a Blob, File or Buffer"); } form.append("content", blob, filename); form.append("metadata", new Blob([JSON.stringify(metadata)], { type: "application/json" })); const response = await this.request("/files", { method: "POST", data: form, headers: { "Content-Type": "multipart/form-data" }, signal }); return response.json(); } async function listFiles({ signal } = {}) { const response = await this.request("/files", { method: "GET", signal }); return response.json(); } async function getFile(file_id, { signal } = {}) { const response = await this.request(`/files/${file_id}`, { method: "GET", signal }); return response.json(); } async function deleteFile(file_id, { signal } = {}) { const response = await this.request(`/files/${file_id}`, { method: "DELETE", signal }); return response.status === 204; } module.exports = { create: createFile, list: listFiles, get: getFile, delete: deleteFile }; }); // node_modules/replicate/lib/util.js var require_util = __commonJS((exports, module) => { var ApiError = require_error(); var { create: createFile } = require_files(); async function validateWebhook(requestData, secretOrCrypto, customCrypto) { let id; let body; let timestamp; let signature; let secret; let crypto = globalThis.crypto; if (requestData && requestData.headers && requestData.body) { if (typeof requestData.headers.get === "function") { id = requestData.headers.get("webhook-id"); timestamp = requestData.headers.get("webhook-timestamp"); signature = requestData.headers.get("webhook-signature"); } else { id = requestData.headers["webhook-id"]; timestamp = requestData.headers["webhook-timestamp"]; signature = requestData.headers["webhook-signature"]; } body = requestData.body; if (typeof secretOrCrypto !== "string") { throw new Error("Unexpected value for secret passed to validateWebhook, expected a string"); } secret = secretOrCrypto; if (customCrypto) { crypto = customCrypto; } } else { id = requestData.id; body = requestData.body; timestamp = requestData.timestamp; signature = requestData.signature; secret = requestData.secret; if (secretOrCrypto) { crypto = secretOrCrypto; } } if (body instanceof ReadableStream || body.readable) { try { body = await new Response(body).text(); } catch (err) { throw new Error(`Error reading body: ${err.message}`); } } else if (isTypedArray(body)) { body = await new Blob([body]).text(); } else if (typeof body === "object") { body = JSON.stringify(body); } else if (typeof body !== "string") { throw new Error("Invalid body type"); } if (!id || !timestamp || !signature) { throw new Error("Missing required webhook headers"); } if (!body) { throw new Error("Missing required body"); } if (!secret) { throw new Error("Missing required secret"); } if (!crypto) { throw new Error('Missing `crypto` implementation. If using Node 18 pass in require("node:crypto").webcrypto'); } const signedContent = `${id}.${timestamp}.${body}`; const computedSignature = await createHMACSHA256(secret.split("_").pop(), signedContent, crypto); const expectedSignatures = signature.split(" ").map((sig) => sig.split(",")[1]); return expectedSignatures.some((expectedSignature) => expectedSignature === computedSignature); } async function createHMACSHA256(secret, data, crypto) { const encoder = new TextEncoder; const key = await crypto.subtle.importKey("raw", base64ToBytes(secret), { name: "HMAC", hash: "SHA-256" }, false, ["sign"]); const signature = await crypto.subtle.sign("HMAC", key, encoder.encode(data)); return bytesToBase64(signature); } function base64ToBytes(base64) { return Uint8Array.from(atob(base64), (m) => m.codePointAt(0)); } function bytesToBase64(bytes) { return btoa(String.fromCharCode.apply(null, new Uint8Array(bytes))); } async function withAutomaticRetries(request, options = {}) { const shouldRetry = options.shouldRetry || (() => false); const maxRetries = options.maxRetries || 5; const interval = options.interval || 500; const jitter = options.jitter || 100; const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); let attempts = 0; do { let delay = interval * 2 ** attempts + Math.random() * jitter; try { const response = await request(); if (response.ok || !shouldRetry(response)) { return response; } } catch (error) { if (error instanceof ApiError) { const retryAfter = error.response.headers.get("Retry-After"); if (retryAfter) { if (!Number.isInteger(retryAfter)) { const date = new Date(retryAfter); if (!Number.isNaN(date.getTime())) { delay = date.getTime() - new Date().getTime(); } } else { delay = retryAfter * 1000; } } } } if (Number.isInteger(maxRetries) && maxRetries > 0) { if (Number.isInteger(delay) && delay > 0) { await sleep(interval * 2 ** (options.maxRetries - maxRetries)); } attempts += 1; } } while (attempts < maxRetries); return request(); } async function transformFileInputs(client, inputs, strategy) { switch (strategy) { case "data-uri": return await transformFileInputsToBase64EncodedDataURIs(client, inputs); case "upload": return await transformFileInputsToReplicateFileURLs(client, inputs); case "default": try { return await transformFileInputsToReplicateFileURLs(client, inputs); } catch (error) { if (error instanceof ApiError && error.response.status >= 400 && error.response.status < 500) { throw error; } return await transformFileInputsToBase64EncodedDataURIs(inputs); } default: throw new Error(`Unexpected file upload strategy: ${strategy}`); } } async function transformFileInputsToReplicateFileURLs(client, inputs) { return await transform(inputs, async (value) => { if (value instanceof Blob || value instanceof Buffer) { const file = await createFile.call(client, value); return file.urls.get; } return value; }); } var MAX_DATA_URI_SIZE = 1e7; async function transformFileInputsToBase64EncodedDataURIs(inputs) { let totalBytes = 0; return await transform(inputs, async (value) => { let buffer; let mime; if (value instanceof Blob) { buffer = await value.arrayBuffer(); mime = value.type; } else if (isTypedArray(value)) { buffer = value; } else { return value; } totalBytes += buffer.byteLength; if (totalBytes > MAX_DATA_URI_SIZE) { throw new Error(`Combined filesize of prediction ${totalBytes} bytes exceeds 10mb limit for inline encoding, please provide URLs instead`); } const data = bytesToBase64(buffer); mime = mime || "application/octet-stream"; return `data:${mime};base64,${data}`; }); } async function transform(value, mapper) { if (Array.isArray(value)) { const copy = []; for (const val of value) { const transformed = await transform(val, mapper); copy.push(transformed); } return copy; } if (isPlainObject(value)) { const copy = {}; for (const key of Object.keys(value)) { copy[key] = await transform(value[key], mapper); } return copy; } return await mapper(value); } function isTypedArray(arr) { return arr instanceof Int8Array || arr instanceof Int16Array || arr instanceof Int32Array || arr instanceof Uint8Array || arr instanceof Uint8ClampedArray || arr instanceof Uint16Array || arr instanceof Uint32Array || arr instanceof Float32Array || arr instanceof Float64Array; } function isPlainObject(value) { const isObjectLike = typeof value === "object" && value !== null; if (!isObjectLike || String(value) !== "[object Object]") { return false; } const proto = Object.getPrototypeOf(value); if (proto === null) { return true; } const Ctor = Object.prototype.hasOwnProperty.call(proto, "constructor") && proto.constructor; return typeof Ctor === "function" && Ctor instanceof Ctor && Function.prototype.toString.call(Ctor) === Function.prototype.toString.call(Object); } function parseProgressFromLogs(input) { const logs = typeof input === "object" && input.logs ? input.logs : input; if (!logs || typeof logs !== "string") { return null; } const pattern = /^\s*(\d+)%\s*\|.+?\|\s*(\d+)\/(\d+)/; const lines = logs.split(` `).reverse(); for (const line of lines) { const matches = line.match(pattern); if (matches && matches.length === 4) { return { percentage: parseInt(matches[1], 10) / 100, current: parseInt(matches[2], 10), total: parseInt(matches[3], 10) }; } } return null; } async function* streamAsyncIterator(stream) { const reader = stream.getReader(); try { while (true) { const { done, value } = await reader.read(); if (done) return; yield value; } } finally { reader.releaseLock(); } } module.exports = { transform, transformFileInputs, validateWebhook, withAutomaticRetries, parseProgressFromLogs, streamAsyncIterator }; }); // node_modules/replicate/vendor/eventsource-parser/stream.js var require_stream = __commonJS((exports, module) => { var __defProp2 = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames2 = Object.getOwnPropertyNames; var __hasOwnProp2 = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp2(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames2(from)) if (!__hasOwnProp2.call(to, key) && key !== except) __defProp2(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod) => __copyProps(__defProp2({}, "__esModule", { value: true }), mod); var input_exports = {}; __export(input_exports, { EventSourceParserStream: () => EventSourceParserStream }); module.exports = __toCommonJS(input_exports); function createParser(onParse) { let isFirstChunk; let buffer; let startingPosition; let startingFieldLength; let eventId; let eventName; let data; reset(); return { feed, reset }; function reset() { isFirstChunk = true; buffer = ""; startingPosition = 0; startingFieldLength = -1; eventId = undefined; eventName = undefined; data = ""; } function feed(chunk) { buffer = buffer ? buffer + chunk : chunk; if (isFirstChunk && hasBom(buffer)) { buffer = buffer.slice(BOM.length); } isFirstChunk = false; const length = buffer.length; let position = 0; let discardTrailingNewline = false; while (position < length) { if (discardTrailingNewline) { if (buffer[position] === ` `) { ++position; } discardTrailingNewline = false; } let lineLength = -1; let fieldLength = startingFieldLength; let character; for (let index = startingPosition;lineLength < 0 && index < length; ++index) { character = buffer[index]; if (character === ":" && fieldLength < 0) { fieldLength = index - position; } else if (character === "\r") { discardTrailingNewline = true; lineLength = index - position; } else if (character === ` `) { lineLength = index - position; } } if (lineLength < 0) { startingPosition = length - position; startingFieldLength = fieldLength; break; } else { startingPosition = 0; startingFieldLength = -1; } parseEventStreamLine(buffer, position, fieldLength, lineLength); position += lineLength + 1; } if (position === length) { buffer = ""; } else if (position > 0) { buffer = buffer.slice(position); } } function parseEventStreamLine(lineBuffer, index, fieldLength, lineLength) { if (lineLength === 0) { if (data.length > 0) { onParse({ type: "event", id: eventId, event: eventName || undefined, data: data.slice(0, -1) }); data = ""; eventId = undefined; } eventName = undefined; return; } const noValue = fieldLength < 0; const field = lineBuffer.slice(index, index + (noValue ? lineLength : fieldLength)); let step = 0; if (noValue) { step = lineLength; } else if (lineBuffer[index + fieldLength + 1] === " ") { step = fieldLength + 2; } else { step = fieldLength + 1; } const position = index + step; const valueLength = lineLength - step; const value = lineBuffer.slice(position, position + valueLength).toString(); if (field === "data") { data += value ? "".concat(value, ` `) : ` `; } else if (field === "event") { eventName = value; } else if (field === "id" && !value.includes("\x00")) { eventId = value; } else if (field === "retry") { const retry = parseInt(value, 10); if (!Number.isNaN(retry)) { onParse({ type: "reconnect-interval", value: retry }); } } } } var BOM = [239, 187, 191]; function hasBom(buffer) { return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode); } var EventSourceParserStream = class extends TransformStream { constructor() { let parser; super({ start(controller) { parser = createParser((event) => { if (event.type === "event") { controller.enqueue(event); } }); }, transform(chunk) { parser.feed(chunk); } }); } }; }); // node_modules/replicate/vendor/streams-text-encoding/text-decoder-stream.js var require_text_decoder_stream = __commonJS((exports, module) => { var __defProp2 = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames2 = Object.getOwnPropertyNames; var __hasOwnProp2 = Object.prototype.hasOwnProperty; var __export = (target, all) => { for (var name in all) __defProp2(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames2(from)) if (!__hasOwnProp2.call(to, key) && key !== except) __defProp2(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod) => __copyProps(__defProp2({}, "__esModule", { value: true }), mod); var input_exports = {}; __export(input_exports, { TextDecoderStream: () => TextDecoderStream }); module.exports = __toCommonJS(input_exports); var decDecoder = Symbol("decDecoder"); var decTransform = Symbol("decTransform"); var TextDecodeTransformer = class { constructor(decoder) { this.decoder_ = decoder; } transform(chunk, controller) { if (!(chunk instanceof ArrayBuffer || ArrayBuffer.isView(chunk))) { throw new TypeError("Input data must be a BufferSource"); } const text = this.decoder_.decode(chunk, { stream: true }); if (text.length !== 0) { controller.enqueue(text); } } flush(controller) { const text = this.decoder_.decode(); if (text.length !== 0) { controller.enqueue(text); } } }; var TextDecoderStream = class { constructor(label, options) { const decoder = new TextDecoder(label || "utf-8", options || {}); this[decDecoder] = decoder; this[decTransform] = new TransformStream(new TextDecodeTransformer(decoder)); } get encoding() { return this[decDecoder].encoding; } get fatal() { return this[decDecoder].fatal; } get ignoreBOM() { return this[decDecoder].ignoreBOM; } get readable() { return this[decTransform].readable; } get writable() { return this[decTransform].writable; } }; var encEncoder = Symbol("encEncoder"); var encTransform = Symbol("encTransform"); }); // node_modules/replicate/lib/stream.js var require_stream2 = __commonJS((exports, module) => { var ApiError = require_error(); var { streamAsyncIterator } = require_util(); var { EventSourceParserStream } = require_stream(); var { TextDecoderStream } = typeof globalThis.TextDecoderStream === "undefined" ? require_text_decoder_stream() : globalThis; class ServerSentEvent { constructor(event, data, id, retry) { this.event = event; this.data = data; this.id = id; this.retry = retry; } toString() { if (this.event === "output") { return this.data; } return ""; } } function createReadableStream({ url, fetch, options = {} }) { const { useFileOutput = true, headers = {}, ...initOptions } = options; return new ReadableStream({ async start(controller) { const init2 = { ...initOptions, headers: { ...headers, Accept: "text/event-stream" } }; const response = await fetch(url, init2); if (!response.ok) { const text = await response.text(); const request = new Request(url, init2); controller.error(new ApiError(`Request to ${url} failed with status ${response.status}: ${text}`, request, response)); } const stream = response.body.pipeThrough(new TextDecoderStream).pipeThrough(new EventSourceParserStream); for await (const event of streamAsyncIterator(stream)) { if (event.event === "error") { controller.error(new Error(event.data)); break; } let data = event.data; if (useFileOutput && typeof data === "string" && (data.startsWith("https:") || data.startsWith("data:"))) { data = createFileOutput({ data, fetch }); } controller.enqueue(new ServerSentEvent(event.event, data, event.id)); if (event.event === "done") { break; } } controller.close(); } }); } function createFileOutput({ url, fetch }) { let type = "application/octet-stream"; class FileOutput extends ReadableStream { async blob() { const chunks = []; for await (const chunk of this) { chunks.push(chunk); } return new Blob(chunks, { type }); } url() { return new URL(url); } toString() { return url; } } return new FileOutput({ async start(controller) { const response = await fetch(url); if (!response.ok) { const text = await response.text(); const request = new Request(url, init); controller.error(new ApiError(`Request to ${url} failed with status ${response.status}: ${text}`, request, response)); } if (response.headers.get("Content-Type")) { type = response.headers.get("Content-Type"); } try { for await (const chunk of streamAsyncIterator(response.body)) { controller.enqueue(chunk); } controller.close(); } catch (err) { controller.error(err); } } }); } module.exports = { createFileOutput, createReadableStream, ServerSentEvent }; }); // node_modules/replicate/lib/accounts.js var require_accounts = __commonJS((exports, module) => { async function getCurrentAccount({ signal } = {}) { const response = await this.request("/account", { method: "GET", signal }); return response.json(); } module.exports = { current: getCurrentAccount }; }); // node_modules/replicate/lib/collections.js var require_collections = __commonJS((exports, module) => { async function getCollection(collection_slug, { signal } = {}) { const response = await this.request(`/collections/${collection_slug}`, { method: "GET", signal }); return response.json(); } async function listCollections({ signal } = {}) { const response = await this.request("/collections", { method: "GET", signal }); return response.json(); } module.exports = { get: getCollection, list: listCollections }; }); // node_modules/replicate/lib/deployments.js var require_deployments = __commonJS((exports, module) => { var { transformFileInputs } = require_util(); async function createPrediction(deployment_owner, deployment_name, options) { const { input, wait, signal, ...data } = options; if (data.webhook) { try { new URL(data.webhook); } catch (err) { throw new Error("Invalid webhook URL"); } } const headers = {}; if (wait) { if (typeof wait === "number") { const n = Math.max(1, Math.ceil(Number(wait)) || 1); headers["Prefer"] = `wait=${n}`; } else { headers["Prefer"] = "wait"; } } const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}/predictions`, { method: "POST", headers, data: { ...data, input: await transformFileInputs(this, input, this.fileEncodingStrategy) }, signal }); return response.json(); } async function getDeployment(deployment_owner, deployment_name, { signal } = {}) { const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}`, { method: "GET", signal }); return response.json(); } async function createDeployment(deployment_config, { signal } = {}) { const response = await this.request("/deployments", { method: "POST", data: deployment_config, signal }); return response.json(); } async function updateDeployment(deployment_owner, deployment_name, deployment_config, { signal } = {}) { const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}`, { method: "PATCH", data: deployment_config, signal }); return response.json(); } async function deleteDeployment(deployment_owner, deployment_name, { signal } = {}) { const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}`, { method: "DELETE", signal }); return response.status === 204; } async function listDeployments({ signal } = {}) { const response = await this.request("/deployments", { method: "GET", signal }); return response.json(); } module.exports = { predictions: { create: createPrediction }, get: getDeployment, create: createDeployment, update: updateDeployment, list: listDeployments, delete: deleteDeployment }; }); // node_modules/replicate/lib/hardware.js var require_hardware = __commonJS((exports, module) => { async function listHardware({ signal } = {}) { const response = await this.request("/hardware", { method: "GET", signal }); return response.json(); } module.exports = { list: listHardware }; }); // node_modules/replicate/lib/models.js var require_models = __commonJS((exports, module) => { async function getModel(model_owner, model_name, { signal } = {}) { const response = await this.request(`/models/${model_owner}/${model_name}`, { method: "GET", signal }); return response.json(); } async function listModelVersions(model_owner, model_name, { signal } = {}) { const response = await this.request(`/models/${model_owner}/${model_name}/versions`, { method: "GET", signal }); return response.json(); } async function getModelVersion(model_owner, model_name, version_id, { signal } = {}) { const response = await this.request(`/models/${model_owner}/${model_name}/versions/${version_id}`, { method: "GET", signal }); return response.json(); } async function listModels({ signal } = {}) { const response = await this.request("/models", { method: "GET", signal }); return response.json(); } async function createModel(model_owner, model_name, options) { const { signal, ...rest } = options; const data = { owner: model_owner, name: model_name, ...rest }; const response = await this.request("/models", { method: "POST", data, signal }); return response.json(); } async function search(query, { signal } = {}) { const response = await this.request("/models", { method: "QUERY", headers: { "Content-Type": "text/plain" }, data: query, signal }); return response.json(); } module.exports = { get: getModel, list: listModels, create: createModel, versions: { list: listModelVersions, get: getModelVersion }, search }; }); // node_modules/replicate/lib/predictions.js var require_predictions = __commonJS((exports, module) => { var { transformFileInputs } = require_util(); async function createPrediction(options) { const { model, version, input, wait, signal, ...data } = options; if (data.webhook) { try { new URL(data.webhook); } catch (err) { throw new Error("Invalid webhook URL"); } } const headers = {}; if (wait) { if (typeof wait === "number") { const n = Math.max(1, Math.ceil(Number(wait)) || 1); headers["Prefer"] = `wait=${n}`; } else { headers["Prefer"] = "wait"; } } let response; if (version) { response = await this.request("/predictions", { method: "POST", headers, data: { ...data, input: await transformFileInputs(this, input, this.fileEncodingStrategy), version }, signal }); } else if (model) { response = await this.request(`/models/${model}/predictions`, { method: "POST", headers, data: { ...data, input: await transformFileInputs(this, input, this.fileEncodingStrategy) }, signal }); } else { throw new Error("Either model or version must be specified"); } return response.json(); } async function getPrediction(prediction_id, { signal } = {}) { const response = await this.request(`/predictions/${prediction_id}`, { method: "GET", signal }); return response.json(); } async function cancelPrediction(prediction_id, { signal } = {}) { const response = await this.request(`/predictions/${prediction_id}/cancel`, { method: "POST", signal }); return response.json(); } async function listPredictions({ signal } = {}) { const response = await this.request("/predictions", { method: "GET", signal }); return response.json(); } module.exports = { create: createPrediction, get: getPrediction, cancel: cancelPrediction, list: listPredictions }; }); // node_modules/replicate/lib/trainings.js var require_trainings = __commonJS((exports, module) => { async function createTraining(model_owner, model_name, version_id, options) { const { signal, ...data } = options; if (data.webhook) { try { new URL(data.webhook); } catch (err) { throw new Error("Invalid webhook URL"); } } const response = await this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, { method: "POST", data, signal }); return response.json(); } async function getTraining(training_id, { signal } = {}) { const response = await this.request(`/trainings/${training_id}`, { method: "GET", signal }); return response.json(); } async function cancelTraining(training_id, { signal } = {}) { const response = await this.request(`/trainings/${training_id}/cancel`, { method: "POST", signal }); return response.json(); } async function listTrainings({ signal } = {}) { const response = await this.request("/trainings", { method: "GET", signal }); return response.json(); } module.exports = { create: createTraining, get: getTraining, cancel: cancelTraining, list: listTrainings }; }); // node_modules/replicate/lib/webhooks.js var require_webhooks = __commonJS((exports, module) => { async function getDefaultWebhookSecret({ signal } = {}) { const response = await this.request("/webhooks/default/secret", { method: "GET", signal }); return response.json(); } module.exports = { default: { secret: { get: getDefaultWebhookSecret } } }; }); // node_modules/replicate/package.json var require_package = __commonJS((exports, module) => { module.exports = { name: "replicate", version: "1.1.0", description: "JavaScript client for Replicate", repository: "github:replicate/replicate-javascript", homepage: "https://github.com/replicate/replicate-javascript#readme", bugs: "https://github.com/replicate/replicate-javascript/issues", license: "Apache-2.0", main: "index.js", type: "commonjs", types: "index.d.ts", files: [ "CONTRIBUTING.md", "LICENSE", "README.md", "index.d.ts", "index.js", "lib/**/*.js", "vendor/**/*", "package.json" ], engines: { node: ">=18.0.0", npm: ">=7.19.0", git: ">=2.11.0", yarn: ">=1.7.0" }, scripts: { check: "tsc", format: "biome format . --write", "lint-biome": "biome lint .", "lint-publint": "publint", lint: "npm run lint-biome && npm run lint-publint", test: "jest" }, optionalDependencies: { "readable-stream": ">=4.0.0" }, devDependencies: { "@biomejs/biome": "^1.4.1", "@types/jest": "^29.5.3", "@typescript-eslint/eslint-plugin": "^5.56.0", "cross-fetch": "^3.1.5", jest: "^29.7.0", nock: "^14.0.0-beta.6", publint: "^0.2.7", "ts-jest": "^29.1.0", typescript: "^5.0.2" } }; }); // node_modules/replicate/index.js var require_replicate = __commonJS((exports, module) => { var ApiError = require_error(); var ModelVersionIdentifier = require_identifier(); var { createReadableStream, createFileOutput } = require_stream2(); var { transform, withAutomaticRetries, validateWebhook, parseProgressFromLogs, streamAsyncIterator } = require_util(); var accounts = require_accounts(); var collections = require_collections(); var deployments = require_deployments(); var files = require_files(); var hardware = require_hardware(); var models = require_models(); var predictions = require_predictions(); var trainings = require_trainings(); var webhooks = require_webhooks(); var packageJSON = require_package(); class Replicate { constructor(options = {}) { this.auth = options.auth || (typeof process !== "undefined" ? process.env.REPLICATE_API_TOKEN : null); this.userAgent = options.userAgent || `replicate-javascript/${packageJSON.version}`; this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; this.fetch = options.fetch || globalThis.fetch; this.fileEncodingStrategy = options.fileEncodingStrategy || "default"; this.useFileOutput = options.useFileOutput === false ? false : true; this.accounts = { current: accounts.current.bind(this) }; this.collections = { list: collections.list.bind(this), get: collections.get.bind(this) }; this.deployments = { get: deployments.get.bind(this), create: deployments.create.bind(this), update: deployments.update.bind(this), delete: deployments.delete.bind(this), list: deployments.list.bind(this), predictions: { create: deployments.predictions.create.bind(this) } }; this.files = { create: files.create.bind(this), get: files.get.bind(this), list: files.list.bind(this), delete: files.delete.bind(this) }; this.hardware = { list: hardware.list.bind(this) }; this.models = { get: models.get.bind(this), list: models.list.bind(this), create: models.create.bind(this), versions: { list: models.versions.list.bind(this), get: models.versions.get.bind(this) }, search: models.search.bind(this) }; this.predictions = { create: predictions.create.bind(this), get: predictions.get.bind(this), cancel: predictions.cancel.bind(this), list: predictions.list.bind(this) }; this.trainings = { create: trainings.create.bind(this), get: trainings.get.bind(this), cancel: trainings.cancel.bind(this), list: trainings.list.bind(this) }; this.webhooks = { default: { secret: { get: webhooks.default.secret.get.bind(this) } } }; } async run(ref, options, progress) { const { wait = { mode: "block" }, signal, ...data } = options; const identifier = ModelVersionIdentifier.parse(ref); let prediction; if (identifier.version) { prediction = await this.predictions.create({ ...data, version: identifier.version, wait: wait.mode === "block" ? wait.timeout ?? true : false }); } else if (identifier.owner && identifier.name) { prediction = await this.predictions.create({ ...data, model: `${identifier.owner}/${identifier.name}`, wait: wait.mode === "block" ? wait.timeout ?? true : false }); } else { throw new Error("Invalid model version identifier"); } if (progress) { progress(prediction); } const isDone = wait.mode === "block" && prediction.status !== "starting"; if (!isDone) { prediction = await this.wait(prediction, { interval: wait.mode === "poll" ? wait.interval : undefined }, async (updatedPrediction) => { if (progress) { progress(updatedPrediction); } if (signal && signal.aborted) { return true; } return false; }); } if (signal && signal.aborted) { prediction = await this.predictions.cancel(prediction.id); } if (progress) { progress(prediction); } if (prediction.status === "failed") { throw new Error(`Prediction failed: ${prediction.error}`); } return transform(prediction.output, (value) => { if (typeof value === "string" && (value.startsWith("https:") || value.startsWith("data:"))) { return this.useFileOutput ? createFileOutput({ url: value, fetch: this.fetch }) : value; } return value; }); } async request(route, options) { const { auth, baseUrl, userAgent } = this; let url; if (route instanceof URL) { url = route; } else { url = new URL(route.startsWith("/") ? route.slice(1) : route, baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`); } const { method = "GET", params = {}, data, signal } = options; for (const [key, value] of Object.entries(params)) { url.searchParams.append(key, value); } const headers = { "Content-Type": "application/json", "User-Agent": userAgent }; if (auth) { headers["Authorization"] = `Bearer ${auth}`; } if (options.headers) { for (const [key, value] of Object.entries(options.headers)) { headers[key] = value; } } let body = undefined; if (data instanceof FormData) { body = data; delete headers["Content-Type"]; } else if (data) { body = JSON.stringify(data); } const init2 = { method, headers, body, signal }; const shouldRetry = method === "GET" ? (response2) => response2.status === 429 || response2.status >= 500 : (response2) => response2.status === 429; const _fetch = this.fetch; const response = await withAutomaticRetries(async () => _fetch(url, init2), { shouldRetry }); if (!response.ok) { const request = new Request(url, init2); const responseText = await response.text(); throw new ApiError(`Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`, request, response); } return response; } async* stream(ref, options) { const { wait, signal, ...data } = options; const identifier = ModelVersionIdentifier.parse(ref); let prediction; if (identifier.version) { prediction = await this.predictions.create({ ...data, version: identifier.version }); } else if (identifier.owner && identifier.name) { prediction = await this.predictions.create({ ...data, model: `${identifier.owner}/${identifier.name}` }); } else { throw new Error("Invalid model version identifier"); } if (prediction.urls && prediction.urls.stream) { const stream = createReadableStream({ url: prediction.urls.stream, fetch: this.fetch, ...signal ? { options: { signal } } : {} }); yield* streamAsyncIterator(stream); } else { throw new Error("Prediction does not support streaming"); } } async* paginate(endpoint, options = {}) { const response = await endpoint(); yield response.results; if (response.next && !(options.signal && options.signal.aborted)) { const nextPage = () => this.request(response.next, { method: "GET", signal: options.signal }).then((r) => r.json()); yield* this.paginate(nextPage, options); } } async wait(prediction, options, stop) { const { id } = prediction; if (!id) { throw new Error("Invalid prediction"); } if (prediction.status === "succeeded" || prediction.status === "failed" || prediction.status === "canceled") { return prediction; } const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); const interval = options && options.interval || 500; let updatedPrediction = await this.predictions.get(id); while (updatedPrediction.status !== "succeeded" && updatedPrediction.status !== "failed" && updatedPrediction.status !== "canceled") { if (stop && await stop(updatedPrediction) === true) { break; } await sleep(interval); updatedPrediction = await this.predictions.get(prediction.id); } if (updatedPrediction.status === "failed") { throw new Error(`Prediction failed: ${updatedPrediction.error}`); } return updatedPrediction; } } module.exports = Replicate; module.exports.validateWebhook = validateWebhook; module.exports.parseProgressFromLogs = parseProgressFromLogs; }); // src/proxy.ts var import_replicate = __toESM(require_replicate(), 1); // src/config.ts var getConfig = () => { const isProduction = false; return { maxRequestSize: parseInt(process.env.MAX_REQUEST_SIZE || "1048576"), replicateTimeout: parseInt(process.env.REPLICATE_TIMEOUT || "300000"), corsAllowedOrigins: process.env.CORS_ALLOWED_ORIGINS?.split(",") || ["*"], logLevel: process.env.LOG_LEVEL || (isProduction ? "warn" : "debug"), enableStackTraces: process.env.ENABLE_STACK_TRACES === "true" || process.env.ENABLE_STACK_TRACES !== "false" && !isProduction }; }; // src/responses.ts var getCorsHeaders = (origin) => { const config = getConfig(); const allowedOrigins = config.corsAllowedOrigins; let corsOrigin = "*"; if (origin && allowedOrigins.length > 0 && !allowedOrigins.includes("*")) { corsOrigin = allowedOrigins.includes(origin) ? origin : allowedOrigins[0]; } return { "Access-Control-Allow-Origin": corsOrigin, "Access-Control-Allow-Methods": "GET, POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type, Authorization", "Content-Type": "application/json" }; }; var createResponse = (statusCode, body, headers = {}, origin) => { const corsHeaders = getCorsHeaders(origin); const hasExistingCorsOrigin = Object.keys(headers).some((key) => key.toLowerCase() === "access-control-allow-origin"); const finalHeaders = hasExistingCorsOrigin ? headers : { ...corsHeaders, ...headers }; return { statusCode, headers: finalHeaders, body: typeof body === "string" ? body : JSON.stringify(body) }; }; var ok = (body, origin) => createResponse(200, body, {}, origin); var badRequest = (error, details, origin) => createResponse(400, { error, ...details && { details } }, {}, origin); var unauthorized = (error = "API key is required", origin) => createResponse(401, { error }, {}, origin); var notFound = (message, origin) => createResponse(404, { message, error: "Not Found", statusCode: 404 }, {}, origin); var internalServerError = (error, details, stack, origin) => createResponse(500, { error, ...details && { details }, ...stack && { stack } }, {}, origin); var corsPreflightResponse = (origin) => createResponse(200, "", {}, origin); var customError = (statusCode, error, details, origin) => createResponse(statusCode, { error, ...details && { details } }, {}, origin); // src/types.ts var isValidApiKey = (apiKey) => { return typeof apiKey === "string" && apiKey.length >= 8 && apiKey.length <= 200 && apiKey.trim() === apiKey; }; var isValidModelName = (model) => { return typeof model === "string" && model.length > 0 && /^[a-zA-Z0-9][a-zA-Z0-9-_]*\/[a-zA-Z0-9][a-zA-Z0-9-_]*$/.test(model) && model.length <= 100; }; var validateReplicateRequest = (body) => { if (!body || typeof body !== "object") { return { isValid: false, error: "Request body must be a valid JSON object" }; } if (!body.model || !isValidModelName(body.model)) { return { isValid: false, error: 'Model name is required and must be in format "owner/model"' }; } if (!body.apiKey || !isValidApiKey(body.apiKey)) { return { isValid: false, error: "Valid API key is required (8-200 characters)" }; } if (body.input !== undefined && (typeof body.input !== "object" || body.input === null)) { return { isValid: false, error: "Input must be an object if provided" }; } return { isValid: true }; }; // src/utils.ts var withTimeout = (promise, timeoutMs, timeoutMessage = "Operation timed out") => { return Promise.race([ promise, new Promise((_, reject) => { setTimeout(() => { reject(new Error(timeoutMessage)); }, timeoutMs); }) ]); }; var sanitizeForLogs = (data, maxLength = 100) => { if (typeof data === "string") { return data.length > maxLength ? data.substring(0, maxLength) + "..." : data; } if (typeof data === "object" && data !== null) { try { const jsonString = JSON.stringify(data); return jsonString.length > maxLength ? jsonString.substring(0, maxLength) + "..." : jsonString; } catch { return "[Object - could not serialize]"; } } return String(data); }; var isValidJsonSize = (jsonString, maxSizeBytes) => { return Buffer.byteLength(jsonString, "utf8") <= maxSizeBytes; }; // src/proxy.ts function hasUrlMethod(obj) { return obj !== null && typeof obj === "object" && "url" in obj && typeof obj.url === "function"; } var handler = async (event, context) => { const config = getConfig(); const requestId = context.awsRequestId; const timestamp = new Date().toISOString(); let requestOrigin; try { let path; let method; let eventType; let headers; let body; if (event.Records && event.Records[0]?.cf?.request) { const cfRequest = event.Records[0].cf.request; path = cfRequest.uri || "/"; method = cfRequest.method || "GET"; eventType = "CloudFront OAC"; headers = Object.fromEntries(Object.entries(cfRequest.headers || {}).map(([key, values]) => [ key, Array.isArray(values) ? values[0].value : values ])); body = cfRequest.body?.data ? Buffer.from(cfRequest.body.data, "base64").toString() : null; } else if (event.requestContext?.http) { path = event.rawPath || event.requestContext.http.path || "/"; method = event.requestContext.http.method || "GET"; eventType = "API Gateway v2"; headers = event.headers || {}; body = event.body; } else if (event.requestContext && event.httpMethod) { path = event.path || event.requestContext.path || "/"; method = event.httpMethod || "GET"; eventType = "API Gateway v1"; headers = event.headers || {}; body = event.body; } else if (event.rawPath) { path = event.rawPath || "/"; method = event.requestContext?.http?.method || "GET"; eventType = "Function URL"; headers = event.headers || {}; body = event.body; } else { path = "/"; method = "GET"; eventType = "Unknown"; headers = event.headers || {}; body = event.body || null; } requestOrigin = Object.keys(headers || {}).find((key) => key.toLowerCase() === "origin") ? headers[Object.keys(headers).find((key) => key.toLowerCase() === "origin")] : undefined; console.log(`[${requestId}] Received ${method} request at ${timestamp}`); console.log(`[${requestId}] Path: ${path}`); console.log(`[${requestId}] Event Type: ${eventType}`); console.log(`[${requestId}] Full Event Object:`, JSON.stringify(event, null, 2)); if (path === "/health" && method === "GET") { const healthResponse = { status: "ok", message: "Replicate proxy server is running", timestamp, requestId }; return ok(healthResponse, requestOrigin); }