UNPKG

@ai-sdk/fireworks

Version:

The **[Fireworks provider](https://ai-sdk.dev/providers/ai-sdk-providers/fireworks)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model and image model support for the [Fireworks](https://fireworks.ai) platform.

229 lines (226 loc) 7.22 kB
// src/fireworks-image-model.ts import { combineHeaders, convertImageModelFileToDataUri, createBinaryResponseHandler, createStatusCodeErrorResponseHandler, postJsonToApi } from "@ai-sdk/provider-utils"; var modelToBackendConfig = { "accounts/fireworks/models/flux-1-dev-fp8": { urlFormat: "workflows" }, "accounts/fireworks/models/flux-1-schnell-fp8": { urlFormat: "workflows" }, "accounts/fireworks/models/flux-kontext-pro": { urlFormat: "workflows_edit", supportsEditing: true }, "accounts/fireworks/models/flux-kontext-max": { urlFormat: "workflows_edit", supportsEditing: true }, "accounts/fireworks/models/playground-v2-5-1024px-aesthetic": { urlFormat: "image_generation", supportsSize: true }, "accounts/fireworks/models/japanese-stable-diffusion-xl": { urlFormat: "image_generation", supportsSize: true }, "accounts/fireworks/models/playground-v2-1024px-aesthetic": { urlFormat: "image_generation", supportsSize: true }, "accounts/fireworks/models/stable-diffusion-xl-1024-v1-0": { urlFormat: "image_generation", supportsSize: true }, "accounts/fireworks/models/SSD-1B": { urlFormat: "image_generation", supportsSize: true } }; function getUrlForModel(baseUrl, modelId, hasInputImage) { const config = modelToBackendConfig[modelId]; switch (config == null ? void 0 : config.urlFormat) { case "image_generation": return `${baseUrl}/image_generation/${modelId}`; case "workflows_edit": return `${baseUrl}/workflows/${modelId}`; case "workflows": default: if (hasInputImage && (config == null ? void 0 : config.supportsEditing)) { return `${baseUrl}/workflows/${modelId}`; } return `${baseUrl}/workflows/${modelId}/text_to_image`; } } var FireworksImageModel = class { constructor(modelId, config) { this.modelId = modelId; this.config = config; this.specificationVersion = "v3"; this.maxImagesPerCall = 1; } get provider() { return this.config.provider; } async doGenerate({ prompt, n, size, aspectRatio, seed, providerOptions, headers, abortSignal, files, mask }) { var _a, _b, _c, _d; const warnings = []; const backendConfig = modelToBackendConfig[this.modelId]; if (!(backendConfig == null ? void 0 : backendConfig.supportsSize) && size != null) { warnings.push({ type: "unsupported", feature: "size", details: "This model does not support the `size` option. Use `aspectRatio` instead." }); } if ((backendConfig == null ? void 0 : backendConfig.supportsSize) && aspectRatio != null) { warnings.push({ type: "unsupported", feature: "aspectRatio", details: "This model does not support the `aspectRatio` option." }); } const hasInputImage = files != null && files.length > 0; let inputImage; if (hasInputImage) { inputImage = convertImageModelFileToDataUri(files[0]); if (files.length > 1) { warnings.push({ type: "other", message: "Fireworks only supports a single input image. Additional images are ignored." }); } } if (mask != null) { warnings.push({ type: "unsupported", feature: "mask", details: "Fireworks Kontext models do not support explicit masks. Use the prompt to describe the areas to edit." }); } const splitSize = size == null ? void 0 : size.split("x"); const currentDate = (_c = (_b = (_a = this.config._internal) == null ? void 0 : _a.currentDate) == null ? void 0 : _b.call(_a)) != null ? _c : /* @__PURE__ */ new Date(); const { value: response, responseHeaders } = await postJsonToApi({ url: getUrlForModel(this.config.baseURL, this.modelId, hasInputImage), headers: combineHeaders(this.config.headers(), headers), body: { prompt, aspect_ratio: aspectRatio, seed, samples: n, ...inputImage && { input_image: inputImage }, ...splitSize && { width: splitSize[0], height: splitSize[1] }, ...(_d = providerOptions.fireworks) != null ? _d : {} }, failedResponseHandler: createStatusCodeErrorResponseHandler(), successfulResponseHandler: createBinaryResponseHandler(), abortSignal, fetch: this.config.fetch }); return { images: [response], warnings, response: { timestamp: currentDate, modelId: this.modelId, headers: responseHeaders } }; } }; // src/fireworks-provider.ts import { OpenAICompatibleChatLanguageModel, OpenAICompatibleCompletionLanguageModel, OpenAICompatibleEmbeddingModel } from "@ai-sdk/openai-compatible"; import { loadApiKey, withoutTrailingSlash, withUserAgentSuffix } from "@ai-sdk/provider-utils"; import { z } from "zod/v4"; // src/version.ts var VERSION = true ? "2.0.13" : "0.0.0-test"; // src/fireworks-provider.ts var fireworksErrorSchema = z.object({ error: z.string() }); var fireworksErrorStructure = { errorSchema: fireworksErrorSchema, errorToMessage: (data) => data.error }; var defaultBaseURL = "https://api.fireworks.ai/inference/v1"; function createFireworks(options = {}) { var _a; const baseURL = withoutTrailingSlash((_a = options.baseURL) != null ? _a : defaultBaseURL); const getHeaders = () => withUserAgentSuffix( { Authorization: `Bearer ${loadApiKey({ apiKey: options.apiKey, environmentVariableName: "FIREWORKS_API_KEY", description: "Fireworks API key" })}`, ...options.headers }, `ai-sdk/fireworks/${VERSION}` ); const getCommonModelConfig = (modelType) => ({ provider: `fireworks.${modelType}`, url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch }); const createChatModel = (modelId) => { return new OpenAICompatibleChatLanguageModel(modelId, { ...getCommonModelConfig("chat"), errorStructure: fireworksErrorStructure }); }; const createCompletionModel = (modelId) => new OpenAICompatibleCompletionLanguageModel(modelId, { ...getCommonModelConfig("completion"), errorStructure: fireworksErrorStructure }); const createEmbeddingModel = (modelId) => new OpenAICompatibleEmbeddingModel(modelId, { ...getCommonModelConfig("embedding"), errorStructure: fireworksErrorStructure }); const createImageModel = (modelId) => new FireworksImageModel(modelId, { ...getCommonModelConfig("image"), baseURL: baseURL != null ? baseURL : defaultBaseURL }); const provider = (modelId) => createChatModel(modelId); provider.specificationVersion = "v3"; provider.completionModel = createCompletionModel; provider.chatModel = createChatModel; provider.languageModel = createChatModel; provider.embeddingModel = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; provider.image = createImageModel; provider.imageModel = createImageModel; return provider; } var fireworks = createFireworks(); export { FireworksImageModel, VERSION, createFireworks, fireworks }; //# sourceMappingURL=index.mjs.map