@spawn-so/spawn-node
Version:
Client for the Spawn API using NodeJS
595 lines (591 loc) • 20.4 kB
JavaScript
'use strict';
const supabaseJs = require('@supabase/supabase-js');
const Pusher = require('pusher-client');
class SpawnClient {
constructor(supabase, app_id, key, secret, worker_filter) {
this.handle_error = (error) => {
if (error.code === "") {
throw new Error(
"The database cannot be reached. Contact the administrator."
);
}
if (error.message === "Invalid API key") {
throw new Error("The API key is invalid. Contact the administrator.");
}
if (error.code === "22P02") {
throw new Error("The credentials are not correct.");
}
if (error.code === "P0001") {
throw new Error(error.message);
}
if (error.code === "23505") {
throw new Error("This object already exists.");
} else
throw new Error(
"An unexpected error occured. Contact the administrator. " + error.message
);
};
this.owner_rpc = async (fn, params) => {
const paramsWithSecret = {
...params,
p_secret: this.secret,
p_app_id: this.app_id,
p_key: this.key
};
const { data, error } = await this.supabase.rpc(fn, paramsWithSecret);
return { data, error };
};
this.user_rpc = async (fn, params) => {
const paramsWithToken = {
...params,
p_app_id: this.app_id,
p_key: this.key,
p_app_user_id: this.app_user_id,
p_app_user_token: this.app_user_token
};
const { data, error } = await this.supabase.rpc(fn, paramsWithToken);
return { data, error };
};
this.setUserID = async (app_user_external_id) => {
const { data, error } = await this.owner_rpc("app_owner_get_user_id", { p_app_user_external_id: app_user_external_id });
if (error) {
throw new Error(error.message);
}
if (data) {
this.app_user_id = String(data);
this.app_user_token = String(await this.getAppUserToken(app_user_external_id));
}
};
this.test_connection = async () => {
const { data, error } = await this.owner_rpc("app_owner_echo", {
message_app_owner: "check"
});
if (error) {
this.handle_error(error);
}
if (data) {
if (String(data) !== "check") {
throw new Error(
"There is a problem with the database. Contact the administrator."
);
}
}
};
this.getServiceList = async () => {
var response;
if (this.app_user_id == "")
response = await this.owner_rpc("app_owner_get_services", {});
else
response = await this.user_rpc("app_user_get_services", {});
const { data, error } = response;
if (error) {
this.handle_error(error);
}
if (data) {
this.services = data;
}
return data;
};
this.updateAddOnList = async () => {
var response;
if (this.app_user_id == "")
response = await this.owner_rpc("app_owner_get_add_ons", {});
else
response = await this.user_rpc("app_user_get_add_ons", {});
const { data, error } = response;
if (error) {
this.handle_error(error);
}
if (data) {
this.add_ons = data;
}
};
this.getAddOnList = async () => {
return this.add_ons;
};
this.echo = async (message) => {
const { data, error } = await this.owner_rpc("app_owner_echo", {
message_app_owner: message
});
if (error) {
this.handle_error(error);
}
return data;
};
this.createAppUser = async (external_id) => {
const { data, error } = await this.owner_rpc("app_owner_create_user", {
p_external_id: external_id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.getUserId = async (external_id) => {
const { data, error } = await this.owner_rpc("app_owner_get_user_id", {
p_app_user_external_id: external_id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.isUser = async (external_id) => {
let app_user_id = await this.getUserId(external_id);
const { data, error } = await this.owner_rpc("app_owner_is_user", {
p_app_user_id: app_user_id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.createToken = async (app_user_external_id) => {
let app_user_id = await this.getUserId(app_user_external_id);
const { data, error } = await this.owner_rpc("app_owner_create_user_token", {
p_app_user_id: app_user_id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.deleteAllTokenOfAppUser = async (app_user_external_id) => {
let app_user_id = await this.getUserId(app_user_external_id);
var token = await this.owner_rpc("app_owner_get_token", {
p_app_user_id: app_user_id
});
if (token.error) {
this.handle_error(token.error);
}
var deleted = await this.owner_rpc("app_owner_revoke_user_token", {
p_app_user_id: app_user_id,
p_token: token.data
});
if (deleted.error) {
this.handle_error(deleted.error);
}
return deleted.data;
};
this.getAppUserToken = async (app_user_external_id) => {
let app_user_id = await this.getUserId(app_user_external_id);
const { data, error } = await this.owner_rpc("app_owner_get_user_token_value", {
p_app_user_id: app_user_id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.setCredit = async (app_user_external_id, amount) => {
let app_user_id = await this.getUserId(app_user_external_id);
const { data, error } = await this.owner_rpc("app_owner_set_user_credits", {
p_amount: amount,
p_app_user_id: app_user_id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.getAppUserCredits = async (app_user_external_id) => {
let app_user_id = await this.getUserId(app_user_external_id);
const { data, error } = await this.owner_rpc("app_owner_get_user_credits", {
p_app_user_id: app_user_id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.getAppUserJobHistory = async (app_user_external_id, p_limit, p_offset) => {
let app_user_id = await this.getUserId(app_user_external_id);
const { data, error } = await this.owner_rpc("app_owner_get_job_history_detail", {
p_app_user_id: app_user_id,
p_limit,
p_offset
});
if (error) {
this.handle_error(error);
}
return data;
};
this.shareAddOn = async (add_on_name, app_user_external_id) => {
await this.updateAddOnList();
const my_add_on = this.add_ons.find(
(add_on) => add_on.name === add_on_name
);
if (!my_add_on) {
throw new Error(`The add-on ${add_on_name} does not exist`);
}
const { data, error } = await this.owner_rpc("app_owner_share_add_on", {
p_add_on_id: my_add_on.id,
p_app_user_external_id: app_user_external_id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.deleteAddOn = async (add_on_name) => {
const my_add_on = this.add_ons.find(
(add_on) => add_on.name === add_on_name
);
if (!my_add_on) {
throw new Error(`The add-on ${add_on_name} does not exist`);
}
const { data, error } = await this.owner_rpc("app_owner_delete_add_on", {
p_add_on_id: my_add_on.id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.renameAddOn = async (add_on_name, new_add_on_name) => {
const my_add_on = this.add_ons.find(
(add_on) => add_on.name === add_on_name
);
if (!my_add_on) {
throw new Error(`The add-on ${add_on_name} does not exist`);
}
let can_rename = await this.owner_rpc("app_owner_can_rename", {
p_add_on_id: my_add_on.id,
p_new_name: new_add_on_name
});
if (can_rename.error) {
this.handle_error(can_rename.error);
}
if (!can_rename.data) {
throw new Error(`The name ${new_add_on_name} is not available`);
}
const { data, error } = await this.owner_rpc("app_owner_rename_add_on", {
p_add_on_id: my_add_on.id,
p_new_name: new_add_on_name
});
await this.updateAddOnList();
if (error) {
this.handle_error(error);
}
return data;
};
this.publishAddOn = async (add_on_name) => {
const my_add_on = this.add_ons.find(
(add_on) => add_on.name === add_on_name
);
if (!my_add_on) {
throw new Error(`The add-on ${add_on_name} does not exist`);
}
const { data, error } = await this.owner_rpc("app_owner_publish_add_on", {
p_add_on_id: my_add_on.id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.unpublishAddOn = async (add_on_name) => {
const my_add_on = this.add_ons.find(
(add_on) => add_on.name === add_on_name
);
if (!my_add_on) {
throw new Error(`The add-on ${add_on_name} does not exist`);
}
const { data, error } = await this.owner_rpc("app_owner_unpublish_add_on", {
p_add_on_id: my_add_on.id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.postJob = async (service_name, job_config) => {
const service = this.services.find(
(service2) => service2.name === service_name
);
if (!service) {
throw new Error("Invalid model name");
}
var response;
if (this.app_user_id == "")
response = await this.owner_rpc("app_owner_post_job_admin", {
p_service_id: service["id"],
p_job_config: JSON.stringify(job_config),
p_worker_filter: this.worker_filter
});
else
response = await this.user_rpc("post_job", {
p_service_id: service["id"],
p_job_config: JSON.stringify(job_config),
p_worker_filter: this.worker_filter
});
const { data, error } = response;
if (error) {
this.handle_error(error);
}
return data;
};
this.getResult = async (job_id) => {
const { data, error } = await this.owner_rpc("app_owner_get_result", {
p_job_id: job_id
});
if (error) {
this.handle_error(error);
}
return data;
};
this.subscribeToJob = async (job_id, callback) => {
const client = new Pusher("ed00ed3037c02a5fd912", {
cluster: "eu"
});
client.connection.connectionCallbacks["close"] = (_) => {
};
const channel = client.subscribe(`job-${job_id}`);
const fn = function(data) {
callback(data);
if ("result" in data) {
client.unsubscribe(`job-${job_id}`);
client.disconnect();
}
};
channel.bind("result", fn);
};
this.costStableDiffusion = async (prompt, args) => {
const service_name = args?.service_name || "stable-diffusion-2-1-base";
if (!this.services.find((service2) => service2.name === service_name)) {
throw new Error(`The service ${service_name} does not exist`);
}
const service_interface = this.services.find(
(service2) => service2.name === service_name
).interface;
if (service_interface !== "stable-diffusion") {
throw new Error(
`The service ${service_name} does not have the stable-diffusion interface`
);
}
for (const patch of args?.patches || []) {
if (!this.add_ons.find((add_on) => add_on.name === patch.name)) {
throw new Error(`The add-on ${patch.name} does not exist`);
}
if (!this.add_ons.find((add_on) => add_on.name === patch.name).service_name.includes(service_name)) {
throw new Error(
`The service ${service_name} does not have the add-on ${patch.name}`
);
}
}
let add_ons = args?.patches?.map(
(patch) => this.patchConfigToAddonConfig(patch)
);
const config = {
steps: args?.steps || 28,
skip_steps: args?.skip_steps || 0,
batch_size: args?.batch_size || 1,
sampler: args?.sampler || "k_euler",
guidance_scale: args?.guidance_scale || 10,
width: args?.width || 512,
height: args?.height || 512,
prompt: prompt || "banana in the kitchen",
negative_prompt: args?.negative_prompt || "ugly",
image_format: args?.image_format || "jpeg",
translate_prompt: args?.translate_prompt || false,
nsfw_filter: args?.nsfw_filter || false,
add_ons
};
const service = this.services.find(
(service2) => service2.name === service_name
);
if (!service) {
throw new Error("Invalid model name");
}
const { data, error } = await this.supabase.rpc(
"get_service_config_cost_client",
{ p_service_id: service["id"], p_config: JSON.stringify(config) }
);
if (error) {
this.handle_error(error);
}
return data;
};
this.runStableDiffusion = async (prompt, args) => {
const service_name = args?.service_name || "stable-diffusion-2-1-base";
if (!this.services.find((service) => service.name === service_name)) {
throw new Error(`The service ${service_name} does not exist`);
}
const service_interface = this.services.find(
(service) => service.name === service_name
).interface;
if (service_interface !== "stable-diffusion") {
throw new Error(
`The service ${service_name} does not have the stable-diffusion interface`
);
}
for (const patch of args?.patches || []) {
if (!this.add_ons.find((add_on) => add_on.name === patch.name)) {
throw new Error(`The add-on ${patch.name} does not exist`);
}
if (!this.add_ons.find((add_on) => add_on.name === patch.name).service_name.includes(service_name)) {
throw new Error(
`The service ${service_name} does not have the add-on ${patch.name}`
);
}
}
let add_ons = args?.patches?.map(
(patch) => this.patchConfigToAddonConfig(patch)
);
const config = {
steps: args?.steps || 28,
skip_steps: args?.skip_steps || 0,
batch_size: args?.batch_size || 1,
sampler: args?.sampler || "k_euler",
guidance_scale: args?.guidance_scale || 10,
width: args?.width || 512,
height: args?.height || 512,
prompt: prompt || "banana in the kitchen",
negative_prompt: args?.negative_prompt || "ugly",
image_format: args?.image_format || "jpeg",
translate_prompt: args?.translate_prompt || false,
nsfw_filter: args?.nsfw_filter || false,
add_ons,
seed: args?.seed
};
var current_callback = args?.callback || function(data) {
console.log(data);
};
const response = await this.postJob(service_name, config);
if (response) {
if ("job_id" in response) {
await this.subscribeToJob(String(response["job_id"]), current_callback);
}
}
return response;
};
this.patchConfigToAddonConfig = (patch_config) => {
return {
id: this.add_ons.find((add_on) => add_on.name === patch_config.name).id,
config: {
alpha_unet: patch_config.alpha_unet,
alpha_text_encoder: patch_config.alpha_text_encoder,
steps: patch_config.steps
}
};
};
this.costPatchTrainer = async (dataset, patch_name, args) => {
const service_name = args?.service_name || "patch_trainer_v1";
if (!this.services.find((service2) => service2.name === service_name)) {
throw new Error(`The service ${service_name} does not exist`);
}
const service_interface = this.services.find(
(service2) => service2.name === service_name
).interface;
if (service_interface !== "train-patch-stable-diffusion") {
throw new Error(
`The service ${service_name} does not have the train-patch-stable-diffusion interface`
);
}
const trainerConfig = {
dataset,
patch_name,
description: args?.description || "",
learning_rate: args?.learning_rate || 1e-4,
steps: args?.steps || 100,
rank: args?.rank || 4
};
const service = this.services.find(
(service2) => service2.name === service_name
);
if (!service) {
throw new Error("Invalid model name");
}
const { data, error } = await this.supabase.rpc(
"get_service_config_cost_client",
{ p_service_id: service["id"], p_config: JSON.stringify(trainerConfig) }
);
if (error) {
this.handle_error(error);
}
return data;
};
this.runPatchTrainer = async (dataset, patch_name, args) => {
const service_name = args?.service_name || "patch_trainer_v1";
if (!this.services.find((service) => service.name === service_name)) {
throw new Error(`The service ${service_name} does not exist`);
}
const service_interface = this.services.find(
(service) => service.name === service_name
).interface;
if (service_interface !== "train-patch-stable-diffusion") {
throw new Error(
`The service ${service_name} does not have the train-patch-stable-diffusion interface`
);
}
await this.updateAddOnList();
if (this.add_ons.find((add_on) => add_on.name === patch_name)) {
throw new Error(`The add-on ${patch_name} already exists`);
}
let is_creating = await this.owner_rpc("app_owner_is_creating_add_on", {
p_add_on_name: patch_name
});
if (is_creating.data) {
throw new Error(`There is already an ${patch_name} add-on being created`);
}
const trainerConfig = {
dataset,
patch_name,
description: args?.description || "",
learning_rate: args?.learning_rate || 1e-4,
steps: args?.steps || 100,
rank: args?.rank || 4
};
const response = await this.postJob(service_name, trainerConfig);
var current_callback = args?.callback || function(data) {
console.log(data);
};
if (response) {
if ("job_id" in response) {
await this.subscribeToJob(String(response["job_id"]), current_callback);
}
}
return response;
};
this.getCountActiveWorker = async () => {
const { data, error } = await this.supabase.rpc("get_active_worker_count", {
p_worker_filter: this.worker_filter
});
if (error) {
this.handle_error(error);
}
return data;
};
this.supabase = supabase;
this.app_id = app_id;
this.key = key;
this.secret = secret;
this.worker_filter = worker_filter || { branch: "prod" };
this.services = [];
this.add_ons = [];
this.app_user_id = "";
this.app_user_token = "";
}
}
const createSpawnClient = async (credentials, worker_filter) => {
const SUPABASE_URL = "https://lgwrsefyncubvpholtmh.supabase.co";
const SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imxnd3JzZWZ5bmN1YnZwaG9sdG1oIiwicm9sZSI6ImFub24iLCJpYXQiOjE2Njk0MDE0MzYsImV4cCI6MTk4NDk3NzQzNn0.o-QO3JKyJ5E-XzWRPC9WdWHY8WjzEFRRnDRSflLzHsc";
const supabase = supabaseJs.createClient(SUPABASE_URL, SUPABASE_KEY, {
auth: { persistSession: true }
});
const spawn = new SpawnClient(
supabase,
credentials.app_id,
credentials.key,
credentials.secret,
worker_filter
);
await spawn.test_connection();
if (credentials.app_user_external_id)
await spawn.setUserID(credentials.app_user_external_id);
await spawn.getServiceList();
await spawn.updateAddOnList();
return spawn;
};
exports.SpawnClient = SpawnClient;
exports.createSpawnClient = createSpawnClient;