@genkit-ai/compat-oai
Version:
Genkit AI framework plugin for OpenAI APIs.
143 lines • 3.83 kB
JavaScript
import { GenerationCommonConfigSchema, Message, modelRef, z } from "genkit";
import { model } from "genkit/plugin";
import { maybeCreateRequestScopedOpenAIClient, toModelName } from "./utils.mjs";
const TRANSLATION_MODEL_INFO = {
supports: {
media: true,
output: ["text", "json"],
multiturn: false,
systemRole: false,
tools: false
}
};
const TranslationConfigSchema = GenerationCommonConfigSchema.pick({
temperature: true
}).extend({
response_format: z.enum(["json", "text", "srt", "verbose_json", "vtt"]).optional()
});
function toTranslationRequest(modelName, request, requestBuilder) {
const message = new Message(request.messages[0]);
const media = message.media;
if (!media?.url) {
throw new Error("No media found in the request");
}
const mediaBuffer = Buffer.from(
media.url.slice(media.url.indexOf(",") + 1),
"base64"
);
const mediaFile = new File([mediaBuffer], "input", {
type: media.contentType ?? media.url.slice("data:".length, media.url.indexOf(";"))
});
const {
temperature,
version: modelVersion,
maxOutputTokens,
stopSequences,
topK,
topP,
...restOfConfig
} = request.config ?? {};
let options = {
model: modelVersion ?? modelName,
file: mediaFile,
prompt: message.text,
temperature
};
if (requestBuilder) {
requestBuilder(request, options);
} else {
options = {
...options,
...restOfConfig
// passthrough rest of the config
};
}
const outputFormat = request.output?.format;
const customFormat = request.config?.response_format;
if (outputFormat && customFormat) {
if (outputFormat === "json" && customFormat !== "json" && customFormat !== "verbose_json") {
throw new Error(
`Custom response format ${customFormat} is not compatible with output format ${outputFormat}`
);
}
}
if (outputFormat === "media") {
throw new Error(`Output format ${outputFormat} is not supported.`);
}
options.response_format = customFormat || outputFormat || "text";
for (const k in options) {
if (options[k] === void 0) {
delete options[k];
}
}
return options;
}
function translationToGenerateResponse(result) {
return {
message: {
role: "model",
content: [
{
text: typeof result === "string" ? result : result.text
}
]
},
finishReason: "stop",
raw: result
};
}
function defineCompatOpenAITranslationModel(params) {
const {
name,
client: defaultClient,
pluginOptions,
modelRef: modelRef2,
requestBuilder
} = params;
const modelName = toModelName(name, pluginOptions?.name);
const actionName = `${pluginOptions?.name ?? "compat-oai"}/${modelName}`;
return model(
{
name: actionName,
...modelRef2?.info,
configSchema: modelRef2?.configSchema
},
async (request, { abortSignal }) => {
const params2 = toTranslationRequest(modelName, request, requestBuilder);
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);
const result = await client.audio.translations.create(params2, {
signal: abortSignal
});
return translationToGenerateResponse(result);
}
);
}
function compatOaiTranslationModelRef(params) {
const {
name,
info = TRANSLATION_MODEL_INFO,
configSchema,
config = void 0,
namespace
} = params;
return modelRef({
name,
configSchema: configSchema || TranslationConfigSchema,
info,
config,
namespace
});
}
export {
TRANSLATION_MODEL_INFO,
TranslationConfigSchema,
compatOaiTranslationModelRef,
defineCompatOpenAITranslationModel,
toTranslationRequest,
translationToGenerateResponse
};
//# sourceMappingURL=translate.mjs.map