ai-utils.js
Version:
Build AI applications, chatbots, and agents with JavaScript and TypeScript.
134 lines (133 loc) • 4.71 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.StabilityImageGenerationModel = void 0;
const zod_1 = require("zod");
const AbstractModel_js_1 = require("../../model-function/AbstractModel.cjs");
const callWithRetryAndThrottle_js_1 = require("../../util/api/callWithRetryAndThrottle.cjs");
const postToApi_js_1 = require("../../util/api/postToApi.cjs");
const StabilityError_js_1 = require("./StabilityError.cjs");
/**
* Create an image generation model that calls the Stability AI image generation API.
*
* @see https://api.stability.ai/docs#tag/v1generation/operation/textToImage
*
* @example
* const { image } = await generateImage(
* new StabilityImageGenerationModel({
* model: "stable-diffusion-512-v2-1",
* cfgScale: 7,
* clipGuidancePreset: "FAST_BLUE",
* height: 512,
* width: 512,
* samples: 1,
* steps: 30,
* })
* [
* { text: "the wicked witch of the west" },
* { text: "style of early 19th century painting", weight: 0.5 },
* ]
* );
*/
class StabilityImageGenerationModel extends AbstractModel_js_1.AbstractModel {
constructor(settings) {
super({ settings });
Object.defineProperty(this, "provider", {
enumerable: true,
configurable: true,
writable: true,
value: "stability"
});
}
get modelName() {
return this.settings.model;
}
get apiKey() {
const apiKey = this.settings.apiKey ?? process.env.STABILITY_API_KEY;
if (apiKey == null) {
throw new Error("No API key provided. Either pass an API key to the constructor or set the STABILITY_API_KEY environment variable.");
}
return apiKey;
}
async callAPI(input, options) {
const run = options?.run;
const settings = options?.settings;
const callSettings = Object.assign({
apiKey: this.apiKey,
}, this.settings, settings, {
abortSignal: run?.abortSignal,
engineId: this.settings.model,
textPrompts: input,
});
return (0, callWithRetryAndThrottle_js_1.callWithRetryAndThrottle)({
retry: this.settings.retry,
throttle: this.settings.throttle,
call: async () => callStabilityImageGenerationAPI(callSettings),
});
}
generateImageResponse(prompt, options) {
return this.callAPI(prompt, options);
}
extractBase64Image(response) {
return response.artifacts[0].base64;
}
withSettings(additionalSettings) {
return new StabilityImageGenerationModel(Object.assign({}, this.settings, additionalSettings));
}
}
exports.StabilityImageGenerationModel = StabilityImageGenerationModel;
const stabilityImageGenerationResponseSchema = zod_1.z.object({
artifacts: zod_1.z.array(zod_1.z.object({
base64: zod_1.z.string(),
seed: zod_1.z.number(),
finishReason: zod_1.z.enum(["SUCCESS", "ERROR", "CONTENT_FILTERED"]),
})),
});
/**
* Call the Stability AI API for image generation.
*
* @see https://api.stability.ai/docs#tag/v1generation/operation/textToImage
*
* @example
* const imageResponse = await callStabilityImageGenerationAPI({
* apiKey: STABILITY_API_KEY,
* engineId: "stable-diffusion-512-v2-1",
* textPrompts: [
* { text: "the wicked witch of the west" },
* { text: "style of early 19th century painting", weight: 0.5 },
* ],
* cfgScale: 7,
* clipGuidancePreset: "FAST_BLUE",
* height: 512,
* width: 512,
* samples: 1,
* steps: 30,
* });
*
* imageResponse.artifacts.forEach((image, index) => {
* fs.writeFileSync(
* `./stability-image-example-${index}.png`,
* Buffer.from(image.base64, "base64")
* );
* });
*/
async function callStabilityImageGenerationAPI({ baseUrl = "https://api.stability.ai/v1", abortSignal, apiKey, engineId, height, width, textPrompts, cfgScale, clipGuidancePreset, sampler, samples, seed, steps, stylePreset, }) {
return (0, postToApi_js_1.postJsonToApi)({
url: `${baseUrl}/generation/${engineId}/text-to-image`,
apiKey,
body: {
height,
width,
text_prompts: textPrompts,
cfg_scale: cfgScale,
clip_guidance_preset: clipGuidancePreset,
sampler,
samples,
seed,
steps,
style_preset: stylePreset,
},
failedResponseHandler: StabilityError_js_1.failedStabilityCallResponseHandler,
successfulResponseHandler: (0, postToApi_js_1.createJsonResponseHandler)(stabilityImageGenerationResponseSchema),
abortSignal,
});
}