UNPKG

automatic1111-provider

Version:

A TypeScript provider for the Vercel AI SDK that enables image generation using AUTOMATIC1111 Stable Diffusion WebUI

151 lines (150 loc) 6.05 kB
import { InvalidResponseDataError, NoSuchModelError, } from '@ai-sdk/provider'; import { combineHeaders, createJsonResponseHandler, createJsonErrorResponseHandler, postJsonToApi, } from '@ai-sdk/provider-utils'; import { z } from 'zod/v4'; export class Automatic1111ImageModel { get provider() { return this.config.provider; } constructor(modelId, config) { this.modelId = modelId; this.config = config; this.specificationVersion = 'v2'; this.maxImagesPerCall = 1; } async doGenerate({ prompt, n, size, aspectRatio, seed, providerOptions, headers, abortSignal, }) { const warnings = []; if (aspectRatio != null) { warnings.push({ type: 'unsupported-setting', setting: 'aspectRatio', details: 'This model does not support the `aspectRatio` option. Use `size` instead.', }); } // Extract the provider options const { negative_prompt, styles, steps, cfg_scale, sampler_name, denoising_strength, check_model_exists, ...providerRequestOptions } = providerOptions.automatic1111 ?? {}; // Get the current date for timestamp const currentDate = this.config._internal?.currentDate?.() ?? new Date(); // Combine the headers const fullHeaders = combineHeaders(this.config.headers(), headers); // Check if the model exists if (check_model_exists) { // Get the available models to check (automatic1111 uses default model if not specified, so we need to check if the model is available) const availableModels = await fetch(this.getAutomatic1111ModelsUrl()); const availableModelsJson = await availableModels.json(); const model = Automatic1111ModelListSchema.parse(availableModelsJson).find((model) => model.model_name === this.modelId); if (!model) { throw new NoSuchModelError({ errorName: 'NoSuchModelError', modelId: this.modelId, modelType: 'imageModel', message: `Model ${this.modelId} not found`, }); } } const modelId = this.modelId; // Send the request to the API const { value: generationResponse, responseHeaders } = await postJsonToApi({ url: this.getAutomatic1111GenerationsUrl(), headers: fullHeaders, body: { prompt, negative_prompt, styles, seed, sampler_name, n_iter: n, steps, cfg_scale, denoising_strength, width: size?.split('x')[0] ?? 512, height: size?.split('x')[1] ?? 512, override_settings: { sd_model_checkpoint: modelId, }, ...providerRequestOptions, }, abortSignal, failedResponseHandler: this.createAutomatic1111ErrorHandler(), successfulResponseHandler: createJsonResponseHandler(Automatic1111GenerationResponseSchema), }); // Check if the response is valid if (generationResponse === null || generationResponse === undefined || generationResponse.images === null || generationResponse.images === undefined) { throw new InvalidResponseDataError({ data: generationResponse, message: 'Invalid response data', }); } // Convert the images to Uint8Array const images = generationResponse.images.map(image => this.base64ToUint8Array(image)); // Return the images return { images, warnings, response: { modelId: modelId, timestamp: currentDate, headers: responseHeaders, }, }; } // Create the error handler for the API createAutomatic1111ErrorHandler() { return createJsonErrorResponseHandler({ errorSchema: Automatic1111ErrorSchema, errorToMessage: (error) => error.detail[0].msg ?? 'Unknown error', }); } // Get the URL for the generations API getAutomatic1111GenerationsUrl() { return `${this.config.baseURL}/sdapi/v1/txt2img/`; } // Get the URL for the models API getAutomatic1111ModelsUrl() { return `${this.config.baseURL}/sdapi/v1/sd-models/`; } // Convert a base64 string to a Uint8Array base64ToUint8Array(base64String) { // Remove data URL prefix if present (e.g., "data:image/png;base64,") const base64Data = base64String.replace(/^data:image\/[a-z]+;base64,/, ''); // Convert base64 to binary string const binaryString = Buffer.from(base64Data, 'base64').toString('binary'); // Convert binary string to Uint8Array const bytes = new Uint8Array(binaryString.length); for (let i = 0; i < binaryString.length; i++) { bytes[i] = binaryString.charCodeAt(i); } return bytes; } } // Schema for the response from the API const Automatic1111GenerationResponseSchema = z.object({ images: z.array(z.string()), }); // Schema for the error response from the API const Automatic1111ErrorSchema = z.object({ detail: z.array(z.object({ loc: z.array(z.object({ where: z.string(), index: z.number(), })), msg: z.string(), type: z.string(), ctx: z .object({ msg: z.string(), doc: z.string(), pos: z.number(), lineno: z.number(), colno: z.number(), }) .nullish(), })), }); export const Automatic1111ModelListSchema = z.array(z.object({ title: z.string(), model_name: z.string(), hash: z.string(), sha256: z.string(), filename: z.string(), config: z.string().nullish(), }));