ai-utils.js
Version:
Build AI applications, chatbots, and agents with JavaScript and TypeScript.
130 lines (129 loc) • 4.39 kB
JavaScript
import { z } from "zod";
import { AbstractModel } from "../../model-function/AbstractModel.js";
import { callWithRetryAndThrottle } from "../../util/api/callWithRetryAndThrottle.js";
import { createJsonResponseHandler, postJsonToApi, } from "../../util/api/postToApi.js";
import { failedStabilityCallResponseHandler } from "./StabilityError.js";
/**
* Create an image generation model that calls the Stability AI image generation API.
*
* @see https://api.stability.ai/docs#tag/v1generation/operation/textToImage
*
* @example
* const { image } = await generateImage(
* new StabilityImageGenerationModel({
* model: "stable-diffusion-512-v2-1",
* cfgScale: 7,
* clipGuidancePreset: "FAST_BLUE",
* height: 512,
* width: 512,
* samples: 1,
* steps: 30,
* })
* [
* { text: "the wicked witch of the west" },
* { text: "style of early 19th century painting", weight: 0.5 },
* ]
* );
*/
export class StabilityImageGenerationModel extends AbstractModel {
constructor(settings) {
super({ settings });
Object.defineProperty(this, "provider", {
enumerable: true,
configurable: true,
writable: true,
value: "stability"
});
}
get modelName() {
return this.settings.model;
}
get apiKey() {
const apiKey = this.settings.apiKey ?? process.env.STABILITY_API_KEY;
if (apiKey == null) {
throw new Error("No API key provided. Either pass an API key to the constructor or set the STABILITY_API_KEY environment variable.");
}
return apiKey;
}
async callAPI(input, options) {
const run = options?.run;
const settings = options?.settings;
const callSettings = Object.assign({
apiKey: this.apiKey,
}, this.settings, settings, {
abortSignal: run?.abortSignal,
engineId: this.settings.model,
textPrompts: input,
});
return callWithRetryAndThrottle({
retry: this.settings.retry,
throttle: this.settings.throttle,
call: async () => callStabilityImageGenerationAPI(callSettings),
});
}
generateImageResponse(prompt, options) {
return this.callAPI(prompt, options);
}
extractBase64Image(response) {
return response.artifacts[0].base64;
}
withSettings(additionalSettings) {
return new StabilityImageGenerationModel(Object.assign({}, this.settings, additionalSettings));
}
}
const stabilityImageGenerationResponseSchema = z.object({
artifacts: z.array(z.object({
base64: z.string(),
seed: z.number(),
finishReason: z.enum(["SUCCESS", "ERROR", "CONTENT_FILTERED"]),
})),
});
/**
* Call the Stability AI API for image generation.
*
* @see https://api.stability.ai/docs#tag/v1generation/operation/textToImage
*
* @example
* const imageResponse = await callStabilityImageGenerationAPI({
* apiKey: STABILITY_API_KEY,
* engineId: "stable-diffusion-512-v2-1",
* textPrompts: [
* { text: "the wicked witch of the west" },
* { text: "style of early 19th century painting", weight: 0.5 },
* ],
* cfgScale: 7,
* clipGuidancePreset: "FAST_BLUE",
* height: 512,
* width: 512,
* samples: 1,
* steps: 30,
* });
*
* imageResponse.artifacts.forEach((image, index) => {
* fs.writeFileSync(
* `./stability-image-example-${index}.png`,
* Buffer.from(image.base64, "base64")
* );
* });
*/
async function callStabilityImageGenerationAPI({ baseUrl = "https://api.stability.ai/v1", abortSignal, apiKey, engineId, height, width, textPrompts, cfgScale, clipGuidancePreset, sampler, samples, seed, steps, stylePreset, }) {
return postJsonToApi({
url: `${baseUrl}/generation/${engineId}/text-to-image`,
apiKey,
body: {
height,
width,
text_prompts: textPrompts,
cfg_scale: cfgScale,
clip_guidance_preset: clipGuidancePreset,
sampler,
samples,
seed,
steps,
style_preset: stylePreset,
},
failedResponseHandler: failedStabilityCallResponseHandler,
successfulResponseHandler: createJsonResponseHandler(stabilityImageGenerationResponseSchema),
abortSignal,
});
}