@ai-sdk/fireworks
Version:
The **[Fireworks provider](https://ai-sdk.dev/providers/ai-sdk-providers/fireworks)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model and image model support for the [Fireworks](https://fireworks.ai) platform.
229 lines (226 loc) • 7.22 kB
JavaScript
// src/fireworks-image-model.ts
import {
combineHeaders,
convertImageModelFileToDataUri,
createBinaryResponseHandler,
createStatusCodeErrorResponseHandler,
postJsonToApi
} from "@ai-sdk/provider-utils";
var modelToBackendConfig = {
"accounts/fireworks/models/flux-1-dev-fp8": {
urlFormat: "workflows"
},
"accounts/fireworks/models/flux-1-schnell-fp8": {
urlFormat: "workflows"
},
"accounts/fireworks/models/flux-kontext-pro": {
urlFormat: "workflows_edit",
supportsEditing: true
},
"accounts/fireworks/models/flux-kontext-max": {
urlFormat: "workflows_edit",
supportsEditing: true
},
"accounts/fireworks/models/playground-v2-5-1024px-aesthetic": {
urlFormat: "image_generation",
supportsSize: true
},
"accounts/fireworks/models/japanese-stable-diffusion-xl": {
urlFormat: "image_generation",
supportsSize: true
},
"accounts/fireworks/models/playground-v2-1024px-aesthetic": {
urlFormat: "image_generation",
supportsSize: true
},
"accounts/fireworks/models/stable-diffusion-xl-1024-v1-0": {
urlFormat: "image_generation",
supportsSize: true
},
"accounts/fireworks/models/SSD-1B": {
urlFormat: "image_generation",
supportsSize: true
}
};
function getUrlForModel(baseUrl, modelId, hasInputImage) {
const config = modelToBackendConfig[modelId];
switch (config == null ? void 0 : config.urlFormat) {
case "image_generation":
return `${baseUrl}/image_generation/${modelId}`;
case "workflows_edit":
return `${baseUrl}/workflows/${modelId}`;
case "workflows":
default:
if (hasInputImage && (config == null ? void 0 : config.supportsEditing)) {
return `${baseUrl}/workflows/${modelId}`;
}
return `${baseUrl}/workflows/${modelId}/text_to_image`;
}
}
var FireworksImageModel = class {
constructor(modelId, config) {
this.modelId = modelId;
this.config = config;
this.specificationVersion = "v3";
this.maxImagesPerCall = 1;
}
get provider() {
return this.config.provider;
}
async doGenerate({
prompt,
n,
size,
aspectRatio,
seed,
providerOptions,
headers,
abortSignal,
files,
mask
}) {
var _a, _b, _c, _d;
const warnings = [];
const backendConfig = modelToBackendConfig[this.modelId];
if (!(backendConfig == null ? void 0 : backendConfig.supportsSize) && size != null) {
warnings.push({
type: "unsupported",
feature: "size",
details: "This model does not support the `size` option. Use `aspectRatio` instead."
});
}
if ((backendConfig == null ? void 0 : backendConfig.supportsSize) && aspectRatio != null) {
warnings.push({
type: "unsupported",
feature: "aspectRatio",
details: "This model does not support the `aspectRatio` option."
});
}
const hasInputImage = files != null && files.length > 0;
let inputImage;
if (hasInputImage) {
inputImage = convertImageModelFileToDataUri(files[0]);
if (files.length > 1) {
warnings.push({
type: "other",
message: "Fireworks only supports a single input image. Additional images are ignored."
});
}
}
if (mask != null) {
warnings.push({
type: "unsupported",
feature: "mask",
details: "Fireworks Kontext models do not support explicit masks. Use the prompt to describe the areas to edit."
});
}
const splitSize = size == null ? void 0 : size.split("x");
const currentDate = (_c = (_b = (_a = this.config._internal) == null ? void 0 : _a.currentDate) == null ? void 0 : _b.call(_a)) != null ? _c : /* @__PURE__ */ new Date();
const { value: response, responseHeaders } = await postJsonToApi({
url: getUrlForModel(this.config.baseURL, this.modelId, hasInputImage),
headers: combineHeaders(this.config.headers(), headers),
body: {
prompt,
aspect_ratio: aspectRatio,
seed,
samples: n,
...inputImage && { input_image: inputImage },
...splitSize && { width: splitSize[0], height: splitSize[1] },
...(_d = providerOptions.fireworks) != null ? _d : {}
},
failedResponseHandler: createStatusCodeErrorResponseHandler(),
successfulResponseHandler: createBinaryResponseHandler(),
abortSignal,
fetch: this.config.fetch
});
return {
images: [response],
warnings,
response: {
timestamp: currentDate,
modelId: this.modelId,
headers: responseHeaders
}
};
}
};
// src/fireworks-provider.ts
import {
OpenAICompatibleChatLanguageModel,
OpenAICompatibleCompletionLanguageModel,
OpenAICompatibleEmbeddingModel
} from "@ai-sdk/openai-compatible";
import {
loadApiKey,
withoutTrailingSlash,
withUserAgentSuffix
} from "@ai-sdk/provider-utils";
import { z } from "zod/v4";
// src/version.ts
var VERSION = true ? "2.0.13" : "0.0.0-test";
// src/fireworks-provider.ts
var fireworksErrorSchema = z.object({
error: z.string()
});
var fireworksErrorStructure = {
errorSchema: fireworksErrorSchema,
errorToMessage: (data) => data.error
};
var defaultBaseURL = "https://api.fireworks.ai/inference/v1";
function createFireworks(options = {}) {
var _a;
const baseURL = withoutTrailingSlash((_a = options.baseURL) != null ? _a : defaultBaseURL);
const getHeaders = () => withUserAgentSuffix(
{
Authorization: `Bearer ${loadApiKey({
apiKey: options.apiKey,
environmentVariableName: "FIREWORKS_API_KEY",
description: "Fireworks API key"
})}`,
...options.headers
},
`ai-sdk/fireworks/${VERSION}`
);
const getCommonModelConfig = (modelType) => ({
provider: `fireworks.${modelType}`,
url: ({ path }) => `${baseURL}${path}`,
headers: getHeaders,
fetch: options.fetch
});
const createChatModel = (modelId) => {
return new OpenAICompatibleChatLanguageModel(modelId, {
...getCommonModelConfig("chat"),
errorStructure: fireworksErrorStructure
});
};
const createCompletionModel = (modelId) => new OpenAICompatibleCompletionLanguageModel(modelId, {
...getCommonModelConfig("completion"),
errorStructure: fireworksErrorStructure
});
const createEmbeddingModel = (modelId) => new OpenAICompatibleEmbeddingModel(modelId, {
...getCommonModelConfig("embedding"),
errorStructure: fireworksErrorStructure
});
const createImageModel = (modelId) => new FireworksImageModel(modelId, {
...getCommonModelConfig("image"),
baseURL: baseURL != null ? baseURL : defaultBaseURL
});
const provider = (modelId) => createChatModel(modelId);
provider.specificationVersion = "v3";
provider.completionModel = createCompletionModel;
provider.chatModel = createChatModel;
provider.languageModel = createChatModel;
provider.embeddingModel = createEmbeddingModel;
provider.textEmbeddingModel = createEmbeddingModel;
provider.image = createImageModel;
provider.imageModel = createImageModel;
return provider;
}
var fireworks = createFireworks();
export {
FireworksImageModel,
VERSION,
createFireworks,
fireworks
};
//# sourceMappingURL=index.mjs.map