@lobehub/chat
Version:
Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.
89 lines (74 loc) • 2.7 kB
text/typescript
import { fal } from '@fal-ai/client';
import debug from 'debug';
import { pick } from 'lodash-es';
import { ClientOptions } from 'openai';
import { RuntimeImageGenParamsValue } from '@/libs/standard-parameters/meta-schema';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { CreateImagePayload, CreateImageResponse } from '../types/image';
import { AgentRuntimeError } from '../utils/createError';
// Create debug logger
const log = debug('lobe-image:fal');
type FluxDevOutput = Awaited<ReturnType<typeof fal.subscribe<'fal-ai/flux/dev'>>>['data'];
export class LobeFalAI implements LobeRuntimeAI {
constructor({ apiKey }: ClientOptions = {}) {
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
fal.config({
credentials: apiKey,
});
log('FalAI initialized with apiKey: %s', apiKey ? '*****' : 'Not set');
}
async createImage(payload: CreateImagePayload): Promise<CreateImageResponse> {
const { model, params } = payload;
log('Creating image with model: %s and params: %O', model, params);
const paramsMap = new Map<RuntimeImageGenParamsValue, string>([
['steps', 'num_inference_steps'],
['cfg', 'guidance_scale'],
['imageUrl', 'image_url'],
]);
const defaultInput = {
enable_safety_checker: false,
num_images: 1,
};
const userInput = Object.fromEntries(
(Object.entries(params) as [keyof typeof params, any][]).map(([key, value]) => [
paramsMap.get(key) ?? key,
value,
]),
);
if ('width' in userInput && 'height' in userInput) {
userInput.image_size = {
height: userInput.height,
width: userInput.width,
};
delete userInput.width;
delete userInput.height;
}
const endpoint = `fal-ai/${model}`;
log('Calling fal.subscribe with endpoint: %s and input: %O', endpoint, {
...defaultInput,
...userInput,
});
try {
const { data } = await fal.subscribe(endpoint, {
input: {
...defaultInput,
...userInput,
},
});
const image = (data as FluxDevOutput).images[0];
log('Received image data: %O', image);
return {
imageUrl: image.url,
...pick(image, ['width', 'height']),
};
} catch (error) {
if (error instanceof Error && 'status' in error && (error as any).status === 401) {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey, {
error,
});
}
throw AgentRuntimeError.createError(AgentRuntimeErrorType.ProviderBizError, { error });
}
}
}