UNPKG

@spawn-so/spawn-node

Version:

Client for the Spawn API using NodeJS

595 lines (591 loc) 20.4 kB
'use strict'; const supabaseJs = require('@supabase/supabase-js'); const Pusher = require('pusher-client'); class SpawnClient { constructor(supabase, app_id, key, secret, worker_filter) { this.handle_error = (error) => { if (error.code === "") { throw new Error( "The database cannot be reached. Contact the administrator." ); } if (error.message === "Invalid API key") { throw new Error("The API key is invalid. Contact the administrator."); } if (error.code === "22P02") { throw new Error("The credentials are not correct."); } if (error.code === "P0001") { throw new Error(error.message); } if (error.code === "23505") { throw new Error("This object already exists."); } else throw new Error( "An unexpected error occured. Contact the administrator. " + error.message ); }; this.owner_rpc = async (fn, params) => { const paramsWithSecret = { ...params, p_secret: this.secret, p_app_id: this.app_id, p_key: this.key }; const { data, error } = await this.supabase.rpc(fn, paramsWithSecret); return { data, error }; }; this.user_rpc = async (fn, params) => { const paramsWithToken = { ...params, p_app_id: this.app_id, p_key: this.key, p_app_user_id: this.app_user_id, p_app_user_token: this.app_user_token }; const { data, error } = await this.supabase.rpc(fn, paramsWithToken); return { data, error }; }; this.setUserID = async (app_user_external_id) => { const { data, error } = await this.owner_rpc("app_owner_get_user_id", { p_app_user_external_id: app_user_external_id }); if (error) { throw new Error(error.message); } if (data) { this.app_user_id = String(data); this.app_user_token = String(await this.getAppUserToken(app_user_external_id)); } }; this.test_connection = async () => { const { data, error } = await this.owner_rpc("app_owner_echo", { message_app_owner: "check" }); if (error) { this.handle_error(error); } if (data) { if (String(data) !== "check") { throw new Error( "There is a problem with the database. Contact the administrator." ); } } }; this.getServiceList = async () => { var response; if (this.app_user_id == "") response = await this.owner_rpc("app_owner_get_services", {}); else response = await this.user_rpc("app_user_get_services", {}); const { data, error } = response; if (error) { this.handle_error(error); } if (data) { this.services = data; } return data; }; this.updateAddOnList = async () => { var response; if (this.app_user_id == "") response = await this.owner_rpc("app_owner_get_add_ons", {}); else response = await this.user_rpc("app_user_get_add_ons", {}); const { data, error } = response; if (error) { this.handle_error(error); } if (data) { this.add_ons = data; } }; this.getAddOnList = async () => { return this.add_ons; }; this.echo = async (message) => { const { data, error } = await this.owner_rpc("app_owner_echo", { message_app_owner: message }); if (error) { this.handle_error(error); } return data; }; this.createAppUser = async (external_id) => { const { data, error } = await this.owner_rpc("app_owner_create_user", { p_external_id: external_id }); if (error) { this.handle_error(error); } return data; }; this.getUserId = async (external_id) => { const { data, error } = await this.owner_rpc("app_owner_get_user_id", { p_app_user_external_id: external_id }); if (error) { this.handle_error(error); } return data; }; this.isUser = async (external_id) => { let app_user_id = await this.getUserId(external_id); const { data, error } = await this.owner_rpc("app_owner_is_user", { p_app_user_id: app_user_id }); if (error) { this.handle_error(error); } return data; }; this.createToken = async (app_user_external_id) => { let app_user_id = await this.getUserId(app_user_external_id); const { data, error } = await this.owner_rpc("app_owner_create_user_token", { p_app_user_id: app_user_id }); if (error) { this.handle_error(error); } return data; }; this.deleteAllTokenOfAppUser = async (app_user_external_id) => { let app_user_id = await this.getUserId(app_user_external_id); var token = await this.owner_rpc("app_owner_get_token", { p_app_user_id: app_user_id }); if (token.error) { this.handle_error(token.error); } var deleted = await this.owner_rpc("app_owner_revoke_user_token", { p_app_user_id: app_user_id, p_token: token.data }); if (deleted.error) { this.handle_error(deleted.error); } return deleted.data; }; this.getAppUserToken = async (app_user_external_id) => { let app_user_id = await this.getUserId(app_user_external_id); const { data, error } = await this.owner_rpc("app_owner_get_user_token_value", { p_app_user_id: app_user_id }); if (error) { this.handle_error(error); } return data; }; this.setCredit = async (app_user_external_id, amount) => { let app_user_id = await this.getUserId(app_user_external_id); const { data, error } = await this.owner_rpc("app_owner_set_user_credits", { p_amount: amount, p_app_user_id: app_user_id }); if (error) { this.handle_error(error); } return data; }; this.getAppUserCredits = async (app_user_external_id) => { let app_user_id = await this.getUserId(app_user_external_id); const { data, error } = await this.owner_rpc("app_owner_get_user_credits", { p_app_user_id: app_user_id }); if (error) { this.handle_error(error); } return data; }; this.getAppUserJobHistory = async (app_user_external_id, p_limit, p_offset) => { let app_user_id = await this.getUserId(app_user_external_id); const { data, error } = await this.owner_rpc("app_owner_get_job_history_detail", { p_app_user_id: app_user_id, p_limit, p_offset }); if (error) { this.handle_error(error); } return data; }; this.shareAddOn = async (add_on_name, app_user_external_id) => { await this.updateAddOnList(); const my_add_on = this.add_ons.find( (add_on) => add_on.name === add_on_name ); if (!my_add_on) { throw new Error(`The add-on ${add_on_name} does not exist`); } const { data, error } = await this.owner_rpc("app_owner_share_add_on", { p_add_on_id: my_add_on.id, p_app_user_external_id: app_user_external_id }); if (error) { this.handle_error(error); } return data; }; this.deleteAddOn = async (add_on_name) => { const my_add_on = this.add_ons.find( (add_on) => add_on.name === add_on_name ); if (!my_add_on) { throw new Error(`The add-on ${add_on_name} does not exist`); } const { data, error } = await this.owner_rpc("app_owner_delete_add_on", { p_add_on_id: my_add_on.id }); if (error) { this.handle_error(error); } return data; }; this.renameAddOn = async (add_on_name, new_add_on_name) => { const my_add_on = this.add_ons.find( (add_on) => add_on.name === add_on_name ); if (!my_add_on) { throw new Error(`The add-on ${add_on_name} does not exist`); } let can_rename = await this.owner_rpc("app_owner_can_rename", { p_add_on_id: my_add_on.id, p_new_name: new_add_on_name }); if (can_rename.error) { this.handle_error(can_rename.error); } if (!can_rename.data) { throw new Error(`The name ${new_add_on_name} is not available`); } const { data, error } = await this.owner_rpc("app_owner_rename_add_on", { p_add_on_id: my_add_on.id, p_new_name: new_add_on_name }); await this.updateAddOnList(); if (error) { this.handle_error(error); } return data; }; this.publishAddOn = async (add_on_name) => { const my_add_on = this.add_ons.find( (add_on) => add_on.name === add_on_name ); if (!my_add_on) { throw new Error(`The add-on ${add_on_name} does not exist`); } const { data, error } = await this.owner_rpc("app_owner_publish_add_on", { p_add_on_id: my_add_on.id }); if (error) { this.handle_error(error); } return data; }; this.unpublishAddOn = async (add_on_name) => { const my_add_on = this.add_ons.find( (add_on) => add_on.name === add_on_name ); if (!my_add_on) { throw new Error(`The add-on ${add_on_name} does not exist`); } const { data, error } = await this.owner_rpc("app_owner_unpublish_add_on", { p_add_on_id: my_add_on.id }); if (error) { this.handle_error(error); } return data; }; this.postJob = async (service_name, job_config) => { const service = this.services.find( (service2) => service2.name === service_name ); if (!service) { throw new Error("Invalid model name"); } var response; if (this.app_user_id == "") response = await this.owner_rpc("app_owner_post_job_admin", { p_service_id: service["id"], p_job_config: JSON.stringify(job_config), p_worker_filter: this.worker_filter }); else response = await this.user_rpc("post_job", { p_service_id: service["id"], p_job_config: JSON.stringify(job_config), p_worker_filter: this.worker_filter }); const { data, error } = response; if (error) { this.handle_error(error); } return data; }; this.getResult = async (job_id) => { const { data, error } = await this.owner_rpc("app_owner_get_result", { p_job_id: job_id }); if (error) { this.handle_error(error); } return data; }; this.subscribeToJob = async (job_id, callback) => { const client = new Pusher("ed00ed3037c02a5fd912", { cluster: "eu" }); client.connection.connectionCallbacks["close"] = (_) => { }; const channel = client.subscribe(`job-${job_id}`); const fn = function(data) { callback(data); if ("result" in data) { client.unsubscribe(`job-${job_id}`); client.disconnect(); } }; channel.bind("result", fn); }; this.costStableDiffusion = async (prompt, args) => { const service_name = args?.service_name || "stable-diffusion-2-1-base"; if (!this.services.find((service2) => service2.name === service_name)) { throw new Error(`The service ${service_name} does not exist`); } const service_interface = this.services.find( (service2) => service2.name === service_name ).interface; if (service_interface !== "stable-diffusion") { throw new Error( `The service ${service_name} does not have the stable-diffusion interface` ); } for (const patch of args?.patches || []) { if (!this.add_ons.find((add_on) => add_on.name === patch.name)) { throw new Error(`The add-on ${patch.name} does not exist`); } if (!this.add_ons.find((add_on) => add_on.name === patch.name).service_name.includes(service_name)) { throw new Error( `The service ${service_name} does not have the add-on ${patch.name}` ); } } let add_ons = args?.patches?.map( (patch) => this.patchConfigToAddonConfig(patch) ); const config = { steps: args?.steps || 28, skip_steps: args?.skip_steps || 0, batch_size: args?.batch_size || 1, sampler: args?.sampler || "k_euler", guidance_scale: args?.guidance_scale || 10, width: args?.width || 512, height: args?.height || 512, prompt: prompt || "banana in the kitchen", negative_prompt: args?.negative_prompt || "ugly", image_format: args?.image_format || "jpeg", translate_prompt: args?.translate_prompt || false, nsfw_filter: args?.nsfw_filter || false, add_ons }; const service = this.services.find( (service2) => service2.name === service_name ); if (!service) { throw new Error("Invalid model name"); } const { data, error } = await this.supabase.rpc( "get_service_config_cost_client", { p_service_id: service["id"], p_config: JSON.stringify(config) } ); if (error) { this.handle_error(error); } return data; }; this.runStableDiffusion = async (prompt, args) => { const service_name = args?.service_name || "stable-diffusion-2-1-base"; if (!this.services.find((service) => service.name === service_name)) { throw new Error(`The service ${service_name} does not exist`); } const service_interface = this.services.find( (service) => service.name === service_name ).interface; if (service_interface !== "stable-diffusion") { throw new Error( `The service ${service_name} does not have the stable-diffusion interface` ); } for (const patch of args?.patches || []) { if (!this.add_ons.find((add_on) => add_on.name === patch.name)) { throw new Error(`The add-on ${patch.name} does not exist`); } if (!this.add_ons.find((add_on) => add_on.name === patch.name).service_name.includes(service_name)) { throw new Error( `The service ${service_name} does not have the add-on ${patch.name}` ); } } let add_ons = args?.patches?.map( (patch) => this.patchConfigToAddonConfig(patch) ); const config = { steps: args?.steps || 28, skip_steps: args?.skip_steps || 0, batch_size: args?.batch_size || 1, sampler: args?.sampler || "k_euler", guidance_scale: args?.guidance_scale || 10, width: args?.width || 512, height: args?.height || 512, prompt: prompt || "banana in the kitchen", negative_prompt: args?.negative_prompt || "ugly", image_format: args?.image_format || "jpeg", translate_prompt: args?.translate_prompt || false, nsfw_filter: args?.nsfw_filter || false, add_ons, seed: args?.seed }; var current_callback = args?.callback || function(data) { console.log(data); }; const response = await this.postJob(service_name, config); if (response) { if ("job_id" in response) { await this.subscribeToJob(String(response["job_id"]), current_callback); } } return response; }; this.patchConfigToAddonConfig = (patch_config) => { return { id: this.add_ons.find((add_on) => add_on.name === patch_config.name).id, config: { alpha_unet: patch_config.alpha_unet, alpha_text_encoder: patch_config.alpha_text_encoder, steps: patch_config.steps } }; }; this.costPatchTrainer = async (dataset, patch_name, args) => { const service_name = args?.service_name || "patch_trainer_v1"; if (!this.services.find((service2) => service2.name === service_name)) { throw new Error(`The service ${service_name} does not exist`); } const service_interface = this.services.find( (service2) => service2.name === service_name ).interface; if (service_interface !== "train-patch-stable-diffusion") { throw new Error( `The service ${service_name} does not have the train-patch-stable-diffusion interface` ); } const trainerConfig = { dataset, patch_name, description: args?.description || "", learning_rate: args?.learning_rate || 1e-4, steps: args?.steps || 100, rank: args?.rank || 4 }; const service = this.services.find( (service2) => service2.name === service_name ); if (!service) { throw new Error("Invalid model name"); } const { data, error } = await this.supabase.rpc( "get_service_config_cost_client", { p_service_id: service["id"], p_config: JSON.stringify(trainerConfig) } ); if (error) { this.handle_error(error); } return data; }; this.runPatchTrainer = async (dataset, patch_name, args) => { const service_name = args?.service_name || "patch_trainer_v1"; if (!this.services.find((service) => service.name === service_name)) { throw new Error(`The service ${service_name} does not exist`); } const service_interface = this.services.find( (service) => service.name === service_name ).interface; if (service_interface !== "train-patch-stable-diffusion") { throw new Error( `The service ${service_name} does not have the train-patch-stable-diffusion interface` ); } await this.updateAddOnList(); if (this.add_ons.find((add_on) => add_on.name === patch_name)) { throw new Error(`The add-on ${patch_name} already exists`); } let is_creating = await this.owner_rpc("app_owner_is_creating_add_on", { p_add_on_name: patch_name }); if (is_creating.data) { throw new Error(`There is already an ${patch_name} add-on being created`); } const trainerConfig = { dataset, patch_name, description: args?.description || "", learning_rate: args?.learning_rate || 1e-4, steps: args?.steps || 100, rank: args?.rank || 4 }; const response = await this.postJob(service_name, trainerConfig); var current_callback = args?.callback || function(data) { console.log(data); }; if (response) { if ("job_id" in response) { await this.subscribeToJob(String(response["job_id"]), current_callback); } } return response; }; this.getCountActiveWorker = async () => { const { data, error } = await this.supabase.rpc("get_active_worker_count", { p_worker_filter: this.worker_filter }); if (error) { this.handle_error(error); } return data; }; this.supabase = supabase; this.app_id = app_id; this.key = key; this.secret = secret; this.worker_filter = worker_filter || { branch: "prod" }; this.services = []; this.add_ons = []; this.app_user_id = ""; this.app_user_token = ""; } } const createSpawnClient = async (credentials, worker_filter) => { const SUPABASE_URL = "https://lgwrsefyncubvpholtmh.supabase.co"; const SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Imxnd3JzZWZ5bmN1YnZwaG9sdG1oIiwicm9sZSI6ImFub24iLCJpYXQiOjE2Njk0MDE0MzYsImV4cCI6MTk4NDk3NzQzNn0.o-QO3JKyJ5E-XzWRPC9WdWHY8WjzEFRRnDRSflLzHsc"; const supabase = supabaseJs.createClient(SUPABASE_URL, SUPABASE_KEY, { auth: { persistSession: true } }); const spawn = new SpawnClient( supabase, credentials.app_id, credentials.key, credentials.secret, worker_filter ); await spawn.test_connection(); if (credentials.app_user_external_id) await spawn.setUserID(credentials.app_user_external_id); await spawn.getServiceList(); await spawn.updateAddOnList(); return spawn; }; exports.SpawnClient = SpawnClient; exports.createSpawnClient = createSpawnClient;