@genkit-ai/compat-oai
Version:
Genkit AI framework plugin for OpenAI APIs.
263 lines • 7.57 kB
JavaScript
import {
embedderActionMetadata,
embedderRef,
modelActionMetadata
} from "genkit";
import {
defineCompatOpenAISpeechModel,
defineCompatOpenAITranscriptionModel
} from "../audio.mjs";
import { defineCompatOpenAIEmbedder } from "../embedder.mjs";
import {
defineCompatOpenAIImageModel
} from "../image.mjs";
import { openAICompatible } from "../index.mjs";
import { defineCompatOpenAIModel } from "../model.mjs";
import {
gptImage1RequestBuilder,
openAIImageModelRef,
SUPPORTED_IMAGE_MODELS
} from "./dalle.mjs";
import {
SUPPORTED_EMBEDDING_MODELS,
TextEmbeddingConfigSchema
} from "./embedder.mjs";
import {
openAIModelRef,
SUPPORTED_GPT_MODELS
} from "./gpt.mjs";
import { openAITranscriptionModelRef, SUPPORTED_STT_MODELS } from "./stt.mjs";
import { openAISpeechModelRef, SUPPORTED_TTS_MODELS } from "./tts.mjs";
import {
defineOpenAIWhisperModel,
openAIWhisperModelRef,
SUPPORTED_WHISPER_MODELS
} from "./whisper.mjs";
const UNSUPPORTED_MODEL_MATCHERS = ["babbage", "davinci", "codex"];
function createResolver(pluginOptions) {
return async (client, actionType, actionName) => {
if (actionType === "embedder") {
return defineCompatOpenAIEmbedder({
name: actionName,
client,
pluginOptions
});
} else if (actionName.includes("gpt-image-1") || actionName.includes("dall-e")) {
const modelRef = openAIImageModelRef({ name: actionName });
return defineCompatOpenAIImageModel({
name: modelRef.name,
client,
pluginOptions,
modelRef
});
} else if (actionName.includes("tts")) {
const modelRef = openAISpeechModelRef({ name: actionName });
return defineCompatOpenAISpeechModel({
name: modelRef.name,
client,
pluginOptions,
modelRef
});
} else if (actionName.includes("whisper")) {
const modelRef = openAIWhisperModelRef({ name: actionName });
return defineOpenAIWhisperModel({
name: modelRef.name,
client,
pluginOptions,
modelRef
});
} else if (actionName.includes("transcribe")) {
const modelRef = openAITranscriptionModelRef({
name: actionName
});
return defineCompatOpenAITranscriptionModel({
name: modelRef.name,
client,
pluginOptions,
modelRef
});
} else {
const modelRef = openAIModelRef({ name: actionName });
return defineCompatOpenAIModel({
name: modelRef.name,
client,
pluginOptions,
modelRef
});
}
};
}
function filterOpenAiModels(model2) {
return !UNSUPPORTED_MODEL_MATCHERS.some((m) => model2.id.includes(m));
}
const listActions = async (client) => {
return await client.models.list().then(
(response) => response.data.filter(filterOpenAiModels).map((model2) => {
if (model2.id.includes("embedding")) {
return embedderActionMetadata({
name: model2.id,
configSchema: TextEmbeddingConfigSchema,
info: SUPPORTED_EMBEDDING_MODELS[model2.id]?.info
});
} else if (model2.id.includes("gpt-image-1") || model2.id.includes("dall-e")) {
const modelRef = SUPPORTED_IMAGE_MODELS[model2.id] ?? openAIImageModelRef({ name: model2.id });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema
});
} else if (model2.id.includes("tts")) {
const modelRef = SUPPORTED_TTS_MODELS[model2.id] ?? openAISpeechModelRef({ name: model2.id });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema
});
} else if (model2.id.includes("whisper")) {
const modelRef = SUPPORTED_WHISPER_MODELS[model2.id] ?? openAIWhisperModelRef({ name: model2.id });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema
});
} else if (model2.id.includes("transcribe")) {
const modelRef = SUPPORTED_STT_MODELS[model2.id] ?? openAITranscriptionModelRef({ name: model2.id });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema
});
} else {
const modelRef = SUPPORTED_GPT_MODELS[model2.id] ?? openAIModelRef({ name: model2.id });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema
});
}
})
);
};
function openAIPlugin(options) {
const pluginOptions = { name: "openai", ...options };
return openAICompatible({
name: "openai",
...options,
initializer: async (client) => {
const models = [];
models.push(
...Object.values(SUPPORTED_GPT_MODELS).map(
(modelRef) => defineCompatOpenAIModel({
name: modelRef.name,
client,
pluginOptions,
modelRef
})
)
);
models.push(
...Object.values(SUPPORTED_EMBEDDING_MODELS).map(
(embedderRef2) => defineCompatOpenAIEmbedder({
name: embedderRef2.name,
client,
pluginOptions,
embedderRef: embedderRef2
})
)
);
models.push(
...Object.values(SUPPORTED_TTS_MODELS).map(
(modelRef) => defineCompatOpenAISpeechModel({
name: modelRef.name,
client,
pluginOptions,
modelRef
})
)
);
models.push(
...Object.values(SUPPORTED_WHISPER_MODELS).map(
(modelRef) => defineOpenAIWhisperModel({
name: modelRef.name,
client,
pluginOptions,
modelRef
})
)
);
models.push(
...Object.values(SUPPORTED_STT_MODELS).map(
(modelRef) => defineCompatOpenAITranscriptionModel({
name: modelRef.name,
client,
pluginOptions,
modelRef
})
)
);
models.push(
...Object.values(SUPPORTED_IMAGE_MODELS).map(
(modelRef) => defineCompatOpenAIImageModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
requestBuilder: modelRef.name.includes("gpt-image-1") ? gptImage1RequestBuilder : void 0
})
)
);
return models;
},
resolver: createResolver(pluginOptions),
listActions
});
}
const model = ((name, config) => {
if (name.includes("gpt-image-1") || name.includes("dall-e")) {
return openAIImageModelRef({
name,
config
});
}
if (name.includes("tts")) {
return openAISpeechModelRef({
name,
config
});
}
if (name.includes("whisper")) {
return openAIWhisperModelRef({
name,
config
});
}
if (name.includes("transcribe")) {
return openAITranscriptionModelRef({
name,
config
});
}
return openAIModelRef({
name,
config
});
});
const embedder = ((name, config) => {
return embedderRef({
name,
config,
configSchema: TextEmbeddingConfigSchema,
namespace: "openai"
});
});
const openAI = Object.assign(openAIPlugin, {
model,
embedder
});
var openai_default = openAI;
export {
openai_default as default,
openAI,
openAIPlugin
};
//# sourceMappingURL=index.mjs.map