@subscribe.dev/replicate-frontend-proxy
Version:
AWS Lambda function that serves as a secure proxy for the Replicate API
1,509 lines (1,483 loc) • 54.4 kB
JavaScript
var __create = Object.create;
var __getProtoOf = Object.getPrototypeOf;
var __defProp = Object.defineProperty;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __toESM = (mod, isNodeMode, target) => {
target = mod != null ? __create(__getProtoOf(mod)) : {};
const to = isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target;
for (let key of __getOwnPropNames(mod))
if (!__hasOwnProp.call(to, key))
__defProp(to, key, {
get: () => mod[key],
enumerable: true
});
return to;
};
var __commonJS = (cb, mod) => () => (mod || cb((mod = { exports: {} }).exports, mod), mod.exports);
// node_modules/replicate/lib/error.js
var require_error = __commonJS((exports, module) => {
class ApiError extends Error {
constructor(message, request, response) {
super(message);
this.name = "ApiError";
this.request = request;
this.response = response;
}
}
module.exports = ApiError;
});
// node_modules/replicate/lib/identifier.js
var require_identifier = __commonJS((exports, module) => {
class ModelVersionIdentifier {
constructor(owner, name, version = null) {
this.owner = owner;
this.name = name;
this.version = version;
}
static parse(ref) {
const match = ref.match(/^(?<owner>[^/]+)\/(?<name>[^/:]+)(:(?<version>.+))?$/);
if (!match) {
throw new Error(`Invalid reference to model version: ${ref}. Expected format: owner/name or owner/name:version`);
}
const { owner, name, version } = match.groups;
return new ModelVersionIdentifier(owner, name, version);
}
}
module.exports = ModelVersionIdentifier;
});
// node_modules/replicate/lib/files.js
var require_files = __commonJS((exports, module) => {
async function createFile(file, metadata = {}, { signal } = {}) {
const form = new FormData;
let filename;
let blob;
if (file instanceof Blob) {
filename = file.name || `blob_${Date.now()}`;
blob = file;
} else if (Buffer.isBuffer(file)) {
filename = `buffer_${Date.now()}`;
const bytes = new Uint8Array(file);
blob = new Blob([bytes], {
type: "application/octet-stream",
name: filename
});
} else {
throw new Error("Invalid file argument, must be a Blob, File or Buffer");
}
form.append("content", blob, filename);
form.append("metadata", new Blob([JSON.stringify(metadata)], { type: "application/json" }));
const response = await this.request("/files", {
method: "POST",
data: form,
headers: {
"Content-Type": "multipart/form-data"
},
signal
});
return response.json();
}
async function listFiles({ signal } = {}) {
const response = await this.request("/files", {
method: "GET",
signal
});
return response.json();
}
async function getFile(file_id, { signal } = {}) {
const response = await this.request(`/files/${file_id}`, {
method: "GET",
signal
});
return response.json();
}
async function deleteFile(file_id, { signal } = {}) {
const response = await this.request(`/files/${file_id}`, {
method: "DELETE",
signal
});
return response.status === 204;
}
module.exports = {
create: createFile,
list: listFiles,
get: getFile,
delete: deleteFile
};
});
// node_modules/replicate/lib/util.js
var require_util = __commonJS((exports, module) => {
var ApiError = require_error();
var { create: createFile } = require_files();
async function validateWebhook(requestData, secretOrCrypto, customCrypto) {
let id;
let body;
let timestamp;
let signature;
let secret;
let crypto = globalThis.crypto;
if (requestData && requestData.headers && requestData.body) {
if (typeof requestData.headers.get === "function") {
id = requestData.headers.get("webhook-id");
timestamp = requestData.headers.get("webhook-timestamp");
signature = requestData.headers.get("webhook-signature");
} else {
id = requestData.headers["webhook-id"];
timestamp = requestData.headers["webhook-timestamp"];
signature = requestData.headers["webhook-signature"];
}
body = requestData.body;
if (typeof secretOrCrypto !== "string") {
throw new Error("Unexpected value for secret passed to validateWebhook, expected a string");
}
secret = secretOrCrypto;
if (customCrypto) {
crypto = customCrypto;
}
} else {
id = requestData.id;
body = requestData.body;
timestamp = requestData.timestamp;
signature = requestData.signature;
secret = requestData.secret;
if (secretOrCrypto) {
crypto = secretOrCrypto;
}
}
if (body instanceof ReadableStream || body.readable) {
try {
body = await new Response(body).text();
} catch (err) {
throw new Error(`Error reading body: ${err.message}`);
}
} else if (isTypedArray(body)) {
body = await new Blob([body]).text();
} else if (typeof body === "object") {
body = JSON.stringify(body);
} else if (typeof body !== "string") {
throw new Error("Invalid body type");
}
if (!id || !timestamp || !signature) {
throw new Error("Missing required webhook headers");
}
if (!body) {
throw new Error("Missing required body");
}
if (!secret) {
throw new Error("Missing required secret");
}
if (!crypto) {
throw new Error('Missing `crypto` implementation. If using Node 18 pass in require("node:crypto").webcrypto');
}
const signedContent = `${id}.${timestamp}.${body}`;
const computedSignature = await createHMACSHA256(secret.split("_").pop(), signedContent, crypto);
const expectedSignatures = signature.split(" ").map((sig) => sig.split(",")[1]);
return expectedSignatures.some((expectedSignature) => expectedSignature === computedSignature);
}
async function createHMACSHA256(secret, data, crypto) {
const encoder = new TextEncoder;
const key = await crypto.subtle.importKey("raw", base64ToBytes(secret), { name: "HMAC", hash: "SHA-256" }, false, ["sign"]);
const signature = await crypto.subtle.sign("HMAC", key, encoder.encode(data));
return bytesToBase64(signature);
}
function base64ToBytes(base64) {
return Uint8Array.from(atob(base64), (m) => m.codePointAt(0));
}
function bytesToBase64(bytes) {
return btoa(String.fromCharCode.apply(null, new Uint8Array(bytes)));
}
async function withAutomaticRetries(request, options = {}) {
const shouldRetry = options.shouldRetry || (() => false);
const maxRetries = options.maxRetries || 5;
const interval = options.interval || 500;
const jitter = options.jitter || 100;
const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms));
let attempts = 0;
do {
let delay = interval * 2 ** attempts + Math.random() * jitter;
try {
const response = await request();
if (response.ok || !shouldRetry(response)) {
return response;
}
} catch (error) {
if (error instanceof ApiError) {
const retryAfter = error.response.headers.get("Retry-After");
if (retryAfter) {
if (!Number.isInteger(retryAfter)) {
const date = new Date(retryAfter);
if (!Number.isNaN(date.getTime())) {
delay = date.getTime() - new Date().getTime();
}
} else {
delay = retryAfter * 1000;
}
}
}
}
if (Number.isInteger(maxRetries) && maxRetries > 0) {
if (Number.isInteger(delay) && delay > 0) {
await sleep(interval * 2 ** (options.maxRetries - maxRetries));
}
attempts += 1;
}
} while (attempts < maxRetries);
return request();
}
async function transformFileInputs(client, inputs, strategy) {
switch (strategy) {
case "data-uri":
return await transformFileInputsToBase64EncodedDataURIs(client, inputs);
case "upload":
return await transformFileInputsToReplicateFileURLs(client, inputs);
case "default":
try {
return await transformFileInputsToReplicateFileURLs(client, inputs);
} catch (error) {
if (error instanceof ApiError && error.response.status >= 400 && error.response.status < 500) {
throw error;
}
return await transformFileInputsToBase64EncodedDataURIs(inputs);
}
default:
throw new Error(`Unexpected file upload strategy: ${strategy}`);
}
}
async function transformFileInputsToReplicateFileURLs(client, inputs) {
return await transform(inputs, async (value) => {
if (value instanceof Blob || value instanceof Buffer) {
const file = await createFile.call(client, value);
return file.urls.get;
}
return value;
});
}
var MAX_DATA_URI_SIZE = 1e7;
async function transformFileInputsToBase64EncodedDataURIs(inputs) {
let totalBytes = 0;
return await transform(inputs, async (value) => {
let buffer;
let mime;
if (value instanceof Blob) {
buffer = await value.arrayBuffer();
mime = value.type;
} else if (isTypedArray(value)) {
buffer = value;
} else {
return value;
}
totalBytes += buffer.byteLength;
if (totalBytes > MAX_DATA_URI_SIZE) {
throw new Error(`Combined filesize of prediction ${totalBytes} bytes exceeds 10mb limit for inline encoding, please provide URLs instead`);
}
const data = bytesToBase64(buffer);
mime = mime || "application/octet-stream";
return `data:${mime};base64,${data}`;
});
}
async function transform(value, mapper) {
if (Array.isArray(value)) {
const copy = [];
for (const val of value) {
const transformed = await transform(val, mapper);
copy.push(transformed);
}
return copy;
}
if (isPlainObject(value)) {
const copy = {};
for (const key of Object.keys(value)) {
copy[key] = await transform(value[key], mapper);
}
return copy;
}
return await mapper(value);
}
function isTypedArray(arr) {
return arr instanceof Int8Array || arr instanceof Int16Array || arr instanceof Int32Array || arr instanceof Uint8Array || arr instanceof Uint8ClampedArray || arr instanceof Uint16Array || arr instanceof Uint32Array || arr instanceof Float32Array || arr instanceof Float64Array;
}
function isPlainObject(value) {
const isObjectLike = typeof value === "object" && value !== null;
if (!isObjectLike || String(value) !== "[object Object]") {
return false;
}
const proto = Object.getPrototypeOf(value);
if (proto === null) {
return true;
}
const Ctor = Object.prototype.hasOwnProperty.call(proto, "constructor") && proto.constructor;
return typeof Ctor === "function" && Ctor instanceof Ctor && Function.prototype.toString.call(Ctor) === Function.prototype.toString.call(Object);
}
function parseProgressFromLogs(input) {
const logs = typeof input === "object" && input.logs ? input.logs : input;
if (!logs || typeof logs !== "string") {
return null;
}
const pattern = /^\s*(\d+)%\s*\|.+?\|\s*(\d+)\/(\d+)/;
const lines = logs.split(`
`).reverse();
for (const line of lines) {
const matches = line.match(pattern);
if (matches && matches.length === 4) {
return {
percentage: parseInt(matches[1], 10) / 100,
current: parseInt(matches[2], 10),
total: parseInt(matches[3], 10)
};
}
}
return null;
}
async function* streamAsyncIterator(stream) {
const reader = stream.getReader();
try {
while (true) {
const { done, value } = await reader.read();
if (done)
return;
yield value;
}
} finally {
reader.releaseLock();
}
}
module.exports = {
transform,
transformFileInputs,
validateWebhook,
withAutomaticRetries,
parseProgressFromLogs,
streamAsyncIterator
};
});
// node_modules/replicate/vendor/eventsource-parser/stream.js
var require_stream = __commonJS((exports, module) => {
var __defProp2 = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames2 = Object.getOwnPropertyNames;
var __hasOwnProp2 = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name in all)
__defProp2(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames2(from))
if (!__hasOwnProp2.call(to, key) && key !== except)
__defProp2(to, key, {
get: () => from[key],
enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable
});
}
return to;
};
var __toCommonJS = (mod) => __copyProps(__defProp2({}, "__esModule", { value: true }), mod);
var input_exports = {};
__export(input_exports, {
EventSourceParserStream: () => EventSourceParserStream
});
module.exports = __toCommonJS(input_exports);
function createParser(onParse) {
let isFirstChunk;
let buffer;
let startingPosition;
let startingFieldLength;
let eventId;
let eventName;
let data;
reset();
return {
feed,
reset
};
function reset() {
isFirstChunk = true;
buffer = "";
startingPosition = 0;
startingFieldLength = -1;
eventId = undefined;
eventName = undefined;
data = "";
}
function feed(chunk) {
buffer = buffer ? buffer + chunk : chunk;
if (isFirstChunk && hasBom(buffer)) {
buffer = buffer.slice(BOM.length);
}
isFirstChunk = false;
const length = buffer.length;
let position = 0;
let discardTrailingNewline = false;
while (position < length) {
if (discardTrailingNewline) {
if (buffer[position] === `
`) {
++position;
}
discardTrailingNewline = false;
}
let lineLength = -1;
let fieldLength = startingFieldLength;
let character;
for (let index = startingPosition;lineLength < 0 && index < length; ++index) {
character = buffer[index];
if (character === ":" && fieldLength < 0) {
fieldLength = index - position;
} else if (character === "\r") {
discardTrailingNewline = true;
lineLength = index - position;
} else if (character === `
`) {
lineLength = index - position;
}
}
if (lineLength < 0) {
startingPosition = length - position;
startingFieldLength = fieldLength;
break;
} else {
startingPosition = 0;
startingFieldLength = -1;
}
parseEventStreamLine(buffer, position, fieldLength, lineLength);
position += lineLength + 1;
}
if (position === length) {
buffer = "";
} else if (position > 0) {
buffer = buffer.slice(position);
}
}
function parseEventStreamLine(lineBuffer, index, fieldLength, lineLength) {
if (lineLength === 0) {
if (data.length > 0) {
onParse({
type: "event",
id: eventId,
event: eventName || undefined,
data: data.slice(0, -1)
});
data = "";
eventId = undefined;
}
eventName = undefined;
return;
}
const noValue = fieldLength < 0;
const field = lineBuffer.slice(index, index + (noValue ? lineLength : fieldLength));
let step = 0;
if (noValue) {
step = lineLength;
} else if (lineBuffer[index + fieldLength + 1] === " ") {
step = fieldLength + 2;
} else {
step = fieldLength + 1;
}
const position = index + step;
const valueLength = lineLength - step;
const value = lineBuffer.slice(position, position + valueLength).toString();
if (field === "data") {
data += value ? "".concat(value, `
`) : `
`;
} else if (field === "event") {
eventName = value;
} else if (field === "id" && !value.includes("\x00")) {
eventId = value;
} else if (field === "retry") {
const retry = parseInt(value, 10);
if (!Number.isNaN(retry)) {
onParse({
type: "reconnect-interval",
value: retry
});
}
}
}
}
var BOM = [239, 187, 191];
function hasBom(buffer) {
return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);
}
var EventSourceParserStream = class extends TransformStream {
constructor() {
let parser;
super({
start(controller) {
parser = createParser((event) => {
if (event.type === "event") {
controller.enqueue(event);
}
});
},
transform(chunk) {
parser.feed(chunk);
}
});
}
};
});
// node_modules/replicate/vendor/streams-text-encoding/text-decoder-stream.js
var require_text_decoder_stream = __commonJS((exports, module) => {
var __defProp2 = Object.defineProperty;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropNames2 = Object.getOwnPropertyNames;
var __hasOwnProp2 = Object.prototype.hasOwnProperty;
var __export = (target, all) => {
for (var name in all)
__defProp2(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames2(from))
if (!__hasOwnProp2.call(to, key) && key !== except)
__defProp2(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toCommonJS = (mod) => __copyProps(__defProp2({}, "__esModule", { value: true }), mod);
var input_exports = {};
__export(input_exports, {
TextDecoderStream: () => TextDecoderStream
});
module.exports = __toCommonJS(input_exports);
var decDecoder = Symbol("decDecoder");
var decTransform = Symbol("decTransform");
var TextDecodeTransformer = class {
constructor(decoder) {
this.decoder_ = decoder;
}
transform(chunk, controller) {
if (!(chunk instanceof ArrayBuffer || ArrayBuffer.isView(chunk))) {
throw new TypeError("Input data must be a BufferSource");
}
const text = this.decoder_.decode(chunk, { stream: true });
if (text.length !== 0) {
controller.enqueue(text);
}
}
flush(controller) {
const text = this.decoder_.decode();
if (text.length !== 0) {
controller.enqueue(text);
}
}
};
var TextDecoderStream = class {
constructor(label, options) {
const decoder = new TextDecoder(label || "utf-8", options || {});
this[decDecoder] = decoder;
this[decTransform] = new TransformStream(new TextDecodeTransformer(decoder));
}
get encoding() {
return this[decDecoder].encoding;
}
get fatal() {
return this[decDecoder].fatal;
}
get ignoreBOM() {
return this[decDecoder].ignoreBOM;
}
get readable() {
return this[decTransform].readable;
}
get writable() {
return this[decTransform].writable;
}
};
var encEncoder = Symbol("encEncoder");
var encTransform = Symbol("encTransform");
});
// node_modules/replicate/lib/stream.js
var require_stream2 = __commonJS((exports, module) => {
var ApiError = require_error();
var { streamAsyncIterator } = require_util();
var {
EventSourceParserStream
} = require_stream();
var { TextDecoderStream } = typeof globalThis.TextDecoderStream === "undefined" ? require_text_decoder_stream() : globalThis;
class ServerSentEvent {
constructor(event, data, id, retry) {
this.event = event;
this.data = data;
this.id = id;
this.retry = retry;
}
toString() {
if (this.event === "output") {
return this.data;
}
return "";
}
}
function createReadableStream({ url, fetch, options = {} }) {
const { useFileOutput = true, headers = {}, ...initOptions } = options;
return new ReadableStream({
async start(controller) {
const init2 = {
...initOptions,
headers: {
...headers,
Accept: "text/event-stream"
}
};
const response = await fetch(url, init2);
if (!response.ok) {
const text = await response.text();
const request = new Request(url, init2);
controller.error(new ApiError(`Request to ${url} failed with status ${response.status}: ${text}`, request, response));
}
const stream = response.body.pipeThrough(new TextDecoderStream).pipeThrough(new EventSourceParserStream);
for await (const event of streamAsyncIterator(stream)) {
if (event.event === "error") {
controller.error(new Error(event.data));
break;
}
let data = event.data;
if (useFileOutput && typeof data === "string" && (data.startsWith("https:") || data.startsWith("data:"))) {
data = createFileOutput({ data, fetch });
}
controller.enqueue(new ServerSentEvent(event.event, data, event.id));
if (event.event === "done") {
break;
}
}
controller.close();
}
});
}
function createFileOutput({ url, fetch }) {
let type = "application/octet-stream";
class FileOutput extends ReadableStream {
async blob() {
const chunks = [];
for await (const chunk of this) {
chunks.push(chunk);
}
return new Blob(chunks, { type });
}
url() {
return new URL(url);
}
toString() {
return url;
}
}
return new FileOutput({
async start(controller) {
const response = await fetch(url);
if (!response.ok) {
const text = await response.text();
const request = new Request(url, init);
controller.error(new ApiError(`Request to ${url} failed with status ${response.status}: ${text}`, request, response));
}
if (response.headers.get("Content-Type")) {
type = response.headers.get("Content-Type");
}
try {
for await (const chunk of streamAsyncIterator(response.body)) {
controller.enqueue(chunk);
}
controller.close();
} catch (err) {
controller.error(err);
}
}
});
}
module.exports = {
createFileOutput,
createReadableStream,
ServerSentEvent
};
});
// node_modules/replicate/lib/accounts.js
var require_accounts = __commonJS((exports, module) => {
async function getCurrentAccount({ signal } = {}) {
const response = await this.request("/account", {
method: "GET",
signal
});
return response.json();
}
module.exports = {
current: getCurrentAccount
};
});
// node_modules/replicate/lib/collections.js
var require_collections = __commonJS((exports, module) => {
async function getCollection(collection_slug, { signal } = {}) {
const response = await this.request(`/collections/${collection_slug}`, {
method: "GET",
signal
});
return response.json();
}
async function listCollections({ signal } = {}) {
const response = await this.request("/collections", {
method: "GET",
signal
});
return response.json();
}
module.exports = { get: getCollection, list: listCollections };
});
// node_modules/replicate/lib/deployments.js
var require_deployments = __commonJS((exports, module) => {
var { transformFileInputs } = require_util();
async function createPrediction(deployment_owner, deployment_name, options) {
const { input, wait, signal, ...data } = options;
if (data.webhook) {
try {
new URL(data.webhook);
} catch (err) {
throw new Error("Invalid webhook URL");
}
}
const headers = {};
if (wait) {
if (typeof wait === "number") {
const n = Math.max(1, Math.ceil(Number(wait)) || 1);
headers["Prefer"] = `wait=${n}`;
} else {
headers["Prefer"] = "wait";
}
}
const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}/predictions`, {
method: "POST",
headers,
data: {
...data,
input: await transformFileInputs(this, input, this.fileEncodingStrategy)
},
signal
});
return response.json();
}
async function getDeployment(deployment_owner, deployment_name, { signal } = {}) {
const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}`, {
method: "GET",
signal
});
return response.json();
}
async function createDeployment(deployment_config, { signal } = {}) {
const response = await this.request("/deployments", {
method: "POST",
data: deployment_config,
signal
});
return response.json();
}
async function updateDeployment(deployment_owner, deployment_name, deployment_config, { signal } = {}) {
const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}`, {
method: "PATCH",
data: deployment_config,
signal
});
return response.json();
}
async function deleteDeployment(deployment_owner, deployment_name, { signal } = {}) {
const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}`, {
method: "DELETE",
signal
});
return response.status === 204;
}
async function listDeployments({ signal } = {}) {
const response = await this.request("/deployments", {
method: "GET",
signal
});
return response.json();
}
module.exports = {
predictions: {
create: createPrediction
},
get: getDeployment,
create: createDeployment,
update: updateDeployment,
list: listDeployments,
delete: deleteDeployment
};
});
// node_modules/replicate/lib/hardware.js
var require_hardware = __commonJS((exports, module) => {
async function listHardware({ signal } = {}) {
const response = await this.request("/hardware", {
method: "GET",
signal
});
return response.json();
}
module.exports = {
list: listHardware
};
});
// node_modules/replicate/lib/models.js
var require_models = __commonJS((exports, module) => {
async function getModel(model_owner, model_name, { signal } = {}) {
const response = await this.request(`/models/${model_owner}/${model_name}`, {
method: "GET",
signal
});
return response.json();
}
async function listModelVersions(model_owner, model_name, { signal } = {}) {
const response = await this.request(`/models/${model_owner}/${model_name}/versions`, {
method: "GET",
signal
});
return response.json();
}
async function getModelVersion(model_owner, model_name, version_id, { signal } = {}) {
const response = await this.request(`/models/${model_owner}/${model_name}/versions/${version_id}`, {
method: "GET",
signal
});
return response.json();
}
async function listModels({ signal } = {}) {
const response = await this.request("/models", {
method: "GET",
signal
});
return response.json();
}
async function createModel(model_owner, model_name, options) {
const { signal, ...rest } = options;
const data = { owner: model_owner, name: model_name, ...rest };
const response = await this.request("/models", {
method: "POST",
data,
signal
});
return response.json();
}
async function search(query, { signal } = {}) {
const response = await this.request("/models", {
method: "QUERY",
headers: {
"Content-Type": "text/plain"
},
data: query,
signal
});
return response.json();
}
module.exports = {
get: getModel,
list: listModels,
create: createModel,
versions: { list: listModelVersions, get: getModelVersion },
search
};
});
// node_modules/replicate/lib/predictions.js
var require_predictions = __commonJS((exports, module) => {
var { transformFileInputs } = require_util();
async function createPrediction(options) {
const { model, version, input, wait, signal, ...data } = options;
if (data.webhook) {
try {
new URL(data.webhook);
} catch (err) {
throw new Error("Invalid webhook URL");
}
}
const headers = {};
if (wait) {
if (typeof wait === "number") {
const n = Math.max(1, Math.ceil(Number(wait)) || 1);
headers["Prefer"] = `wait=${n}`;
} else {
headers["Prefer"] = "wait";
}
}
let response;
if (version) {
response = await this.request("/predictions", {
method: "POST",
headers,
data: {
...data,
input: await transformFileInputs(this, input, this.fileEncodingStrategy),
version
},
signal
});
} else if (model) {
response = await this.request(`/models/${model}/predictions`, {
method: "POST",
headers,
data: {
...data,
input: await transformFileInputs(this, input, this.fileEncodingStrategy)
},
signal
});
} else {
throw new Error("Either model or version must be specified");
}
return response.json();
}
async function getPrediction(prediction_id, { signal } = {}) {
const response = await this.request(`/predictions/${prediction_id}`, {
method: "GET",
signal
});
return response.json();
}
async function cancelPrediction(prediction_id, { signal } = {}) {
const response = await this.request(`/predictions/${prediction_id}/cancel`, {
method: "POST",
signal
});
return response.json();
}
async function listPredictions({ signal } = {}) {
const response = await this.request("/predictions", {
method: "GET",
signal
});
return response.json();
}
module.exports = {
create: createPrediction,
get: getPrediction,
cancel: cancelPrediction,
list: listPredictions
};
});
// node_modules/replicate/lib/trainings.js
var require_trainings = __commonJS((exports, module) => {
async function createTraining(model_owner, model_name, version_id, options) {
const { signal, ...data } = options;
if (data.webhook) {
try {
new URL(data.webhook);
} catch (err) {
throw new Error("Invalid webhook URL");
}
}
const response = await this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, {
method: "POST",
data,
signal
});
return response.json();
}
async function getTraining(training_id, { signal } = {}) {
const response = await this.request(`/trainings/${training_id}`, {
method: "GET",
signal
});
return response.json();
}
async function cancelTraining(training_id, { signal } = {}) {
const response = await this.request(`/trainings/${training_id}/cancel`, {
method: "POST",
signal
});
return response.json();
}
async function listTrainings({ signal } = {}) {
const response = await this.request("/trainings", {
method: "GET",
signal
});
return response.json();
}
module.exports = {
create: createTraining,
get: getTraining,
cancel: cancelTraining,
list: listTrainings
};
});
// node_modules/replicate/lib/webhooks.js
var require_webhooks = __commonJS((exports, module) => {
async function getDefaultWebhookSecret({ signal } = {}) {
const response = await this.request("/webhooks/default/secret", {
method: "GET",
signal
});
return response.json();
}
module.exports = {
default: {
secret: {
get: getDefaultWebhookSecret
}
}
};
});
// node_modules/replicate/package.json
var require_package = __commonJS((exports, module) => {
module.exports = {
name: "replicate",
version: "1.1.0",
description: "JavaScript client for Replicate",
repository: "github:replicate/replicate-javascript",
homepage: "https://github.com/replicate/replicate-javascript#readme",
bugs: "https://github.com/replicate/replicate-javascript/issues",
license: "Apache-2.0",
main: "index.js",
type: "commonjs",
types: "index.d.ts",
files: [
"CONTRIBUTING.md",
"LICENSE",
"README.md",
"index.d.ts",
"index.js",
"lib/**/*.js",
"vendor/**/*",
"package.json"
],
engines: {
node: ">=18.0.0",
npm: ">=7.19.0",
git: ">=2.11.0",
yarn: ">=1.7.0"
},
scripts: {
check: "tsc",
format: "biome format . --write",
"lint-biome": "biome lint .",
"lint-publint": "publint",
lint: "npm run lint-biome && npm run lint-publint",
test: "jest"
},
optionalDependencies: {
"readable-stream": ">=4.0.0"
},
devDependencies: {
"@biomejs/biome": "^1.4.1",
"@types/jest": "^29.5.3",
"@typescript-eslint/eslint-plugin": "^5.56.0",
"cross-fetch": "^3.1.5",
jest: "^29.7.0",
nock: "^14.0.0-beta.6",
publint: "^0.2.7",
"ts-jest": "^29.1.0",
typescript: "^5.0.2"
}
};
});
// node_modules/replicate/index.js
var require_replicate = __commonJS((exports, module) => {
var ApiError = require_error();
var ModelVersionIdentifier = require_identifier();
var { createReadableStream, createFileOutput } = require_stream2();
var {
transform,
withAutomaticRetries,
validateWebhook,
parseProgressFromLogs,
streamAsyncIterator
} = require_util();
var accounts = require_accounts();
var collections = require_collections();
var deployments = require_deployments();
var files = require_files();
var hardware = require_hardware();
var models = require_models();
var predictions = require_predictions();
var trainings = require_trainings();
var webhooks = require_webhooks();
var packageJSON = require_package();
class Replicate {
constructor(options = {}) {
this.auth = options.auth || (typeof process !== "undefined" ? process.env.REPLICATE_API_TOKEN : null);
this.userAgent = options.userAgent || `replicate-javascript/${packageJSON.version}`;
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
this.fetch = options.fetch || globalThis.fetch;
this.fileEncodingStrategy = options.fileEncodingStrategy || "default";
this.useFileOutput = options.useFileOutput === false ? false : true;
this.accounts = {
current: accounts.current.bind(this)
};
this.collections = {
list: collections.list.bind(this),
get: collections.get.bind(this)
};
this.deployments = {
get: deployments.get.bind(this),
create: deployments.create.bind(this),
update: deployments.update.bind(this),
delete: deployments.delete.bind(this),
list: deployments.list.bind(this),
predictions: {
create: deployments.predictions.create.bind(this)
}
};
this.files = {
create: files.create.bind(this),
get: files.get.bind(this),
list: files.list.bind(this),
delete: files.delete.bind(this)
};
this.hardware = {
list: hardware.list.bind(this)
};
this.models = {
get: models.get.bind(this),
list: models.list.bind(this),
create: models.create.bind(this),
versions: {
list: models.versions.list.bind(this),
get: models.versions.get.bind(this)
},
search: models.search.bind(this)
};
this.predictions = {
create: predictions.create.bind(this),
get: predictions.get.bind(this),
cancel: predictions.cancel.bind(this),
list: predictions.list.bind(this)
};
this.trainings = {
create: trainings.create.bind(this),
get: trainings.get.bind(this),
cancel: trainings.cancel.bind(this),
list: trainings.list.bind(this)
};
this.webhooks = {
default: {
secret: {
get: webhooks.default.secret.get.bind(this)
}
}
};
}
async run(ref, options, progress) {
const { wait = { mode: "block" }, signal, ...data } = options;
const identifier = ModelVersionIdentifier.parse(ref);
let prediction;
if (identifier.version) {
prediction = await this.predictions.create({
...data,
version: identifier.version,
wait: wait.mode === "block" ? wait.timeout ?? true : false
});
} else if (identifier.owner && identifier.name) {
prediction = await this.predictions.create({
...data,
model: `${identifier.owner}/${identifier.name}`,
wait: wait.mode === "block" ? wait.timeout ?? true : false
});
} else {
throw new Error("Invalid model version identifier");
}
if (progress) {
progress(prediction);
}
const isDone = wait.mode === "block" && prediction.status !== "starting";
if (!isDone) {
prediction = await this.wait(prediction, { interval: wait.mode === "poll" ? wait.interval : undefined }, async (updatedPrediction) => {
if (progress) {
progress(updatedPrediction);
}
if (signal && signal.aborted) {
return true;
}
return false;
});
}
if (signal && signal.aborted) {
prediction = await this.predictions.cancel(prediction.id);
}
if (progress) {
progress(prediction);
}
if (prediction.status === "failed") {
throw new Error(`Prediction failed: ${prediction.error}`);
}
return transform(prediction.output, (value) => {
if (typeof value === "string" && (value.startsWith("https:") || value.startsWith("data:"))) {
return this.useFileOutput ? createFileOutput({ url: value, fetch: this.fetch }) : value;
}
return value;
});
}
async request(route, options) {
const { auth, baseUrl, userAgent } = this;
let url;
if (route instanceof URL) {
url = route;
} else {
url = new URL(route.startsWith("/") ? route.slice(1) : route, baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`);
}
const { method = "GET", params = {}, data, signal } = options;
for (const [key, value] of Object.entries(params)) {
url.searchParams.append(key, value);
}
const headers = {
"Content-Type": "application/json",
"User-Agent": userAgent
};
if (auth) {
headers["Authorization"] = `Bearer ${auth}`;
}
if (options.headers) {
for (const [key, value] of Object.entries(options.headers)) {
headers[key] = value;
}
}
let body = undefined;
if (data instanceof FormData) {
body = data;
delete headers["Content-Type"];
} else if (data) {
body = JSON.stringify(data);
}
const init2 = {
method,
headers,
body,
signal
};
const shouldRetry = method === "GET" ? (response2) => response2.status === 429 || response2.status >= 500 : (response2) => response2.status === 429;
const _fetch = this.fetch;
const response = await withAutomaticRetries(async () => _fetch(url, init2), {
shouldRetry
});
if (!response.ok) {
const request = new Request(url, init2);
const responseText = await response.text();
throw new ApiError(`Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`, request, response);
}
return response;
}
async* stream(ref, options) {
const { wait, signal, ...data } = options;
const identifier = ModelVersionIdentifier.parse(ref);
let prediction;
if (identifier.version) {
prediction = await this.predictions.create({
...data,
version: identifier.version
});
} else if (identifier.owner && identifier.name) {
prediction = await this.predictions.create({
...data,
model: `${identifier.owner}/${identifier.name}`
});
} else {
throw new Error("Invalid model version identifier");
}
if (prediction.urls && prediction.urls.stream) {
const stream = createReadableStream({
url: prediction.urls.stream,
fetch: this.fetch,
...signal ? { options: { signal } } : {}
});
yield* streamAsyncIterator(stream);
} else {
throw new Error("Prediction does not support streaming");
}
}
async* paginate(endpoint, options = {}) {
const response = await endpoint();
yield response.results;
if (response.next && !(options.signal && options.signal.aborted)) {
const nextPage = () => this.request(response.next, {
method: "GET",
signal: options.signal
}).then((r) => r.json());
yield* this.paginate(nextPage, options);
}
}
async wait(prediction, options, stop) {
const { id } = prediction;
if (!id) {
throw new Error("Invalid prediction");
}
if (prediction.status === "succeeded" || prediction.status === "failed" || prediction.status === "canceled") {
return prediction;
}
const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms));
const interval = options && options.interval || 500;
let updatedPrediction = await this.predictions.get(id);
while (updatedPrediction.status !== "succeeded" && updatedPrediction.status !== "failed" && updatedPrediction.status !== "canceled") {
if (stop && await stop(updatedPrediction) === true) {
break;
}
await sleep(interval);
updatedPrediction = await this.predictions.get(prediction.id);
}
if (updatedPrediction.status === "failed") {
throw new Error(`Prediction failed: ${updatedPrediction.error}`);
}
return updatedPrediction;
}
}
module.exports = Replicate;
module.exports.validateWebhook = validateWebhook;
module.exports.parseProgressFromLogs = parseProgressFromLogs;
});
// src/proxy.ts
var import_replicate = __toESM(require_replicate(), 1);
// src/config.ts
var getConfig = () => {
const isProduction = false;
return {
maxRequestSize: parseInt(process.env.MAX_REQUEST_SIZE || "1048576"),
replicateTimeout: parseInt(process.env.REPLICATE_TIMEOUT || "300000"),
corsAllowedOrigins: process.env.CORS_ALLOWED_ORIGINS?.split(",") || ["*"],
logLevel: process.env.LOG_LEVEL || (isProduction ? "warn" : "debug"),
enableStackTraces: process.env.ENABLE_STACK_TRACES === "true" || process.env.ENABLE_STACK_TRACES !== "false" && !isProduction
};
};
// src/responses.ts
var getCorsHeaders = (origin) => {
const config = getConfig();
const allowedOrigins = config.corsAllowedOrigins;
let corsOrigin = "*";
if (origin && allowedOrigins.length > 0 && !allowedOrigins.includes("*")) {
corsOrigin = allowedOrigins.includes(origin) ? origin : allowedOrigins[0];
}
return {
"Access-Control-Allow-Origin": corsOrigin,
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
"Content-Type": "application/json"
};
};
var createResponse = (statusCode, body, headers = {}, origin) => {
const corsHeaders = getCorsHeaders(origin);
const hasExistingCorsOrigin = Object.keys(headers).some((key) => key.toLowerCase() === "access-control-allow-origin");
const finalHeaders = hasExistingCorsOrigin ? headers : { ...corsHeaders, ...headers };
return {
statusCode,
headers: finalHeaders,
body: typeof body === "string" ? body : JSON.stringify(body)
};
};
var ok = (body, origin) => createResponse(200, body, {}, origin);
var badRequest = (error, details, origin) => createResponse(400, { error, ...details && { details } }, {}, origin);
var unauthorized = (error = "API key is required", origin) => createResponse(401, { error }, {}, origin);
var notFound = (message, origin) => createResponse(404, { message, error: "Not Found", statusCode: 404 }, {}, origin);
var internalServerError = (error, details, stack, origin) => createResponse(500, {
error,
...details && { details },
...stack && { stack }
}, {}, origin);
var corsPreflightResponse = (origin) => createResponse(200, "", {}, origin);
var customError = (statusCode, error, details, origin) => createResponse(statusCode, { error, ...details && { details } }, {}, origin);
// src/types.ts
var isValidApiKey = (apiKey) => {
return typeof apiKey === "string" && apiKey.length >= 8 && apiKey.length <= 200 && apiKey.trim() === apiKey;
};
var isValidModelName = (model) => {
return typeof model === "string" && model.length > 0 && /^[a-zA-Z0-9][a-zA-Z0-9-_]*\/[a-zA-Z0-9][a-zA-Z0-9-_]*$/.test(model) && model.length <= 100;
};
var validateReplicateRequest = (body) => {
if (!body || typeof body !== "object") {
return { isValid: false, error: "Request body must be a valid JSON object" };
}
if (!body.model || !isValidModelName(body.model)) {
return { isValid: false, error: 'Model name is required and must be in format "owner/model"' };
}
if (!body.apiKey || !isValidApiKey(body.apiKey)) {
return { isValid: false, error: "Valid API key is required (8-200 characters)" };
}
if (body.input !== undefined && (typeof body.input !== "object" || body.input === null)) {
return { isValid: false, error: "Input must be an object if provided" };
}
return { isValid: true };
};
// src/utils.ts
var withTimeout = (promise, timeoutMs, timeoutMessage = "Operation timed out") => {
return Promise.race([
promise,
new Promise((_, reject) => {
setTimeout(() => {
reject(new Error(timeoutMessage));
}, timeoutMs);
})
]);
};
var sanitizeForLogs = (data, maxLength = 100) => {
if (typeof data === "string") {
return data.length > maxLength ? data.substring(0, maxLength) + "..." : data;
}
if (typeof data === "object" && data !== null) {
try {
const jsonString = JSON.stringify(data);
return jsonString.length > maxLength ? jsonString.substring(0, maxLength) + "..." : jsonString;
} catch {
return "[Object - could not serialize]";
}
}
return String(data);
};
var isValidJsonSize = (jsonString, maxSizeBytes) => {
return Buffer.byteLength(jsonString, "utf8") <= maxSizeBytes;
};
// src/proxy.ts
function hasUrlMethod(obj) {
return obj !== null && typeof obj === "object" && "url" in obj && typeof obj.url === "function";
}
var handler = async (event, context) => {
const config = getConfig();
const requestId = context.awsRequestId;
const timestamp = new Date().toISOString();
let requestOrigin;
try {
let path;
let method;
let eventType;
let headers;
let body;
if (event.Records && event.Records[0]?.cf?.request) {
const cfRequest = event.Records[0].cf.request;
path = cfRequest.uri || "/";
method = cfRequest.method || "GET";
eventType = "CloudFront OAC";
headers = Object.fromEntries(Object.entries(cfRequest.headers || {}).map(([key, values]) => [
key,
Array.isArray(values) ? values[0].value : values
]));
body = cfRequest.body?.data ? Buffer.from(cfRequest.body.data, "base64").toString() : null;
} else if (event.requestContext?.http) {
path = event.rawPath || event.requestContext.http.path || "/";
method = event.requestContext.http.method || "GET";
eventType = "API Gateway v2";
headers = event.headers || {};
body = event.body;
} else if (event.requestContext && event.httpMethod) {
path = event.path || event.requestContext.path || "/";
method = event.httpMethod || "GET";
eventType = "API Gateway v1";
headers = event.headers || {};
body = event.body;
} else if (event.rawPath) {
path = event.rawPath || "/";
method = event.requestContext?.http?.method || "GET";
eventType = "Function URL";
headers = event.headers || {};
body = event.body;
} else {
path = "/";
method = "GET";
eventType = "Unknown";
headers = event.headers || {};
body = event.body || null;
}
requestOrigin = Object.keys(headers || {}).find((key) => key.toLowerCase() === "origin") ? headers[Object.keys(headers).find((key) => key.toLowerCase() === "origin")] : undefined;
console.log(`[${requestId}] Received ${method} request at ${timestamp}`);
console.log(`[${requestId}] Path: ${path}`);
console.log(`[${requestId}] Event Type: ${eventType}`);
console.log(`[${requestId}] Full Event Object:`, JSON.stringify(event, null, 2));
if (path === "/health" && method === "GET") {
const healthResponse = {
status: "ok",
message: "Replicate proxy server is running",
timestamp,
requestId
};
return ok(healthResponse, requestOrigin);
}