UNPKG

@genkit-ai/compat-oai

Version:

Genkit AI framework plugin for OpenAI APIs.

263 lines 7.57 kB
import { embedderActionMetadata, embedderRef, modelActionMetadata } from "genkit"; import { defineCompatOpenAISpeechModel, defineCompatOpenAITranscriptionModel } from "../audio.mjs"; import { defineCompatOpenAIEmbedder } from "../embedder.mjs"; import { defineCompatOpenAIImageModel } from "../image.mjs"; import { openAICompatible } from "../index.mjs"; import { defineCompatOpenAIModel } from "../model.mjs"; import { gptImage1RequestBuilder, openAIImageModelRef, SUPPORTED_IMAGE_MODELS } from "./dalle.mjs"; import { SUPPORTED_EMBEDDING_MODELS, TextEmbeddingConfigSchema } from "./embedder.mjs"; import { openAIModelRef, SUPPORTED_GPT_MODELS } from "./gpt.mjs"; import { openAITranscriptionModelRef, SUPPORTED_STT_MODELS } from "./stt.mjs"; import { openAISpeechModelRef, SUPPORTED_TTS_MODELS } from "./tts.mjs"; import { defineOpenAIWhisperModel, openAIWhisperModelRef, SUPPORTED_WHISPER_MODELS } from "./whisper.mjs"; const UNSUPPORTED_MODEL_MATCHERS = ["babbage", "davinci", "codex"]; function createResolver(pluginOptions) { return async (client, actionType, actionName) => { if (actionType === "embedder") { return defineCompatOpenAIEmbedder({ name: actionName, client, pluginOptions }); } else if (actionName.includes("gpt-image-1") || actionName.includes("dall-e")) { const modelRef = openAIImageModelRef({ name: actionName }); return defineCompatOpenAIImageModel({ name: modelRef.name, client, pluginOptions, modelRef }); } else if (actionName.includes("tts")) { const modelRef = openAISpeechModelRef({ name: actionName }); return defineCompatOpenAISpeechModel({ name: modelRef.name, client, pluginOptions, modelRef }); } else if (actionName.includes("whisper")) { const modelRef = openAIWhisperModelRef({ name: actionName }); return defineOpenAIWhisperModel({ name: modelRef.name, client, pluginOptions, modelRef }); } else if (actionName.includes("transcribe")) { const modelRef = openAITranscriptionModelRef({ name: actionName }); return defineCompatOpenAITranscriptionModel({ name: modelRef.name, client, pluginOptions, modelRef }); } else { const modelRef = openAIModelRef({ name: actionName }); return defineCompatOpenAIModel({ name: modelRef.name, client, pluginOptions, modelRef }); } }; } function filterOpenAiModels(model2) { return !UNSUPPORTED_MODEL_MATCHERS.some((m) => model2.id.includes(m)); } const listActions = async (client) => { return await client.models.list().then( (response) => response.data.filter(filterOpenAiModels).map((model2) => { if (model2.id.includes("embedding")) { return embedderActionMetadata({ name: model2.id, configSchema: TextEmbeddingConfigSchema, info: SUPPORTED_EMBEDDING_MODELS[model2.id]?.info }); } else if (model2.id.includes("gpt-image-1") || model2.id.includes("dall-e")) { const modelRef = SUPPORTED_IMAGE_MODELS[model2.id] ?? openAIImageModelRef({ name: model2.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, configSchema: modelRef.configSchema }); } else if (model2.id.includes("tts")) { const modelRef = SUPPORTED_TTS_MODELS[model2.id] ?? openAISpeechModelRef({ name: model2.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, configSchema: modelRef.configSchema }); } else if (model2.id.includes("whisper")) { const modelRef = SUPPORTED_WHISPER_MODELS[model2.id] ?? openAIWhisperModelRef({ name: model2.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, configSchema: modelRef.configSchema }); } else if (model2.id.includes("transcribe")) { const modelRef = SUPPORTED_STT_MODELS[model2.id] ?? openAITranscriptionModelRef({ name: model2.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, configSchema: modelRef.configSchema }); } else { const modelRef = SUPPORTED_GPT_MODELS[model2.id] ?? openAIModelRef({ name: model2.id }); return modelActionMetadata({ name: modelRef.name, info: modelRef.info, configSchema: modelRef.configSchema }); } }) ); }; function openAIPlugin(options) { const pluginOptions = { name: "openai", ...options }; return openAICompatible({ name: "openai", ...options, initializer: async (client) => { const models = []; models.push( ...Object.values(SUPPORTED_GPT_MODELS).map( (modelRef) => defineCompatOpenAIModel({ name: modelRef.name, client, pluginOptions, modelRef }) ) ); models.push( ...Object.values(SUPPORTED_EMBEDDING_MODELS).map( (embedderRef2) => defineCompatOpenAIEmbedder({ name: embedderRef2.name, client, pluginOptions, embedderRef: embedderRef2 }) ) ); models.push( ...Object.values(SUPPORTED_TTS_MODELS).map( (modelRef) => defineCompatOpenAISpeechModel({ name: modelRef.name, client, pluginOptions, modelRef }) ) ); models.push( ...Object.values(SUPPORTED_WHISPER_MODELS).map( (modelRef) => defineOpenAIWhisperModel({ name: modelRef.name, client, pluginOptions, modelRef }) ) ); models.push( ...Object.values(SUPPORTED_STT_MODELS).map( (modelRef) => defineCompatOpenAITranscriptionModel({ name: modelRef.name, client, pluginOptions, modelRef }) ) ); models.push( ...Object.values(SUPPORTED_IMAGE_MODELS).map( (modelRef) => defineCompatOpenAIImageModel({ name: modelRef.name, client, pluginOptions, modelRef, requestBuilder: modelRef.name.includes("gpt-image-1") ? gptImage1RequestBuilder : void 0 }) ) ); return models; }, resolver: createResolver(pluginOptions), listActions }); } const model = ((name, config) => { if (name.includes("gpt-image-1") || name.includes("dall-e")) { return openAIImageModelRef({ name, config }); } if (name.includes("tts")) { return openAISpeechModelRef({ name, config }); } if (name.includes("whisper")) { return openAIWhisperModelRef({ name, config }); } if (name.includes("transcribe")) { return openAITranscriptionModelRef({ name, config }); } return openAIModelRef({ name, config }); }); const embedder = ((name, config) => { return embedderRef({ name, config, configSchema: TextEmbeddingConfigSchema, namespace: "openai" }); }); const openAI = Object.assign(openAIPlugin, { model, embedder }); var openai_default = openAI; export { openai_default as default, openAI, openAIPlugin }; //# sourceMappingURL=index.mjs.map