@genkit-ai/compat-oai
Version:
Genkit AI framework plugin for OpenAI APIs.
200 lines • 6.1 kB
JavaScript
import {
embedderActionMetadata,
embedderRef,
modelActionMetadata
} from "genkit";
import {
defineCompatOpenAISpeechModel,
defineCompatOpenAITranscriptionModel
} from "../audio.js";
import { defineCompatOpenAIEmbedder } from "../embedder.js";
import {
defineCompatOpenAIImageModel
} from "../image.js";
import { openAICompatible } from "../index.js";
import { defineCompatOpenAIModel } from "../model.js";
import {
gptImage1RequestBuilder,
openAIImageModelRef,
SUPPORTED_IMAGE_MODELS
} from "./dalle.js";
import {
SUPPORTED_EMBEDDING_MODELS,
TextEmbeddingConfigSchema
} from "./embedder.js";
import {
openAIModelRef,
SUPPORTED_GPT_MODELS
} from "./gpt.js";
import { openAISpeechModelRef, SUPPORTED_TTS_MODELS } from "./tts.js";
import {
openAITranscriptionModelRef,
SUPPORTED_STT_MODELS
} from "./whisper.js";
const UNSUPPORTED_MODEL_MATCHERS = ["babbage", "davinci", "codex"];
const resolver = async (ai, client, actionType, actionName) => {
if (actionType === "embedder") {
defineCompatOpenAIEmbedder({ ai, name: `openai/${actionName}`, client });
} else if (actionName.includes("gpt-image-1") || actionName.includes("dall-e")) {
const modelRef = openAIImageModelRef({ name: `openai/${actionName}` });
defineCompatOpenAIImageModel({ ai, name: modelRef.name, client, modelRef });
} else if (actionName.includes("tts")) {
const modelRef = openAISpeechModelRef({ name: `openai/${actionName}` });
defineCompatOpenAISpeechModel({
ai,
name: modelRef.name,
client,
modelRef
});
} else if (actionName.includes("whisper") || actionName.includes("transcribe")) {
const modelRef = openAITranscriptionModelRef({
name: `openai/${actionName}`
});
defineCompatOpenAITranscriptionModel({
ai,
name: modelRef.name,
client,
modelRef
});
} else {
const modelRef = openAIModelRef({ name: `openai/${actionName}` });
defineCompatOpenAIModel({
ai,
name: modelRef.name,
client,
modelRef
});
}
};
function filterOpenAiModels(model2) {
return !UNSUPPORTED_MODEL_MATCHERS.some((m) => model2.id.includes(m));
}
const listActions = async (client) => {
return await client.models.list().then(
(response) => response.data.filter(filterOpenAiModels).map((model2) => {
if (model2.id.includes("embedding")) {
return embedderActionMetadata({
name: `openai/${model2.id}`,
configSchema: TextEmbeddingConfigSchema,
info: SUPPORTED_EMBEDDING_MODELS[model2.id]?.info
});
} else if (model2.id.includes("gpt-image-1") || model2.id.includes("dall-e")) {
const modelRef = SUPPORTED_IMAGE_MODELS[model2.id] ?? openAIImageModelRef({ name: `openai/${model2.id}` });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema
});
} else if (model2.id.includes("tts")) {
const modelRef = SUPPORTED_TTS_MODELS[model2.id] ?? openAISpeechModelRef({ name: `openai/${model2.id}` });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema
});
} else if (model2.id.includes("whisper") || model2.id.includes("transcribe")) {
const modelRef = SUPPORTED_STT_MODELS[model2.id] ?? openAITranscriptionModelRef({ name: `openai/${model2.id}` });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema
});
} else {
const modelRef = SUPPORTED_GPT_MODELS[model2.id] ?? openAIModelRef({ name: `openai/${model2.id}` });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema
});
}
})
);
};
function openAIPlugin(options) {
return openAICompatible({
name: "openai",
...options,
initializer: async (ai, client) => {
Object.values(SUPPORTED_GPT_MODELS).forEach(
(modelRef) => defineCompatOpenAIModel({ ai, name: modelRef.name, client, modelRef })
);
Object.values(SUPPORTED_EMBEDDING_MODELS).forEach(
(embedderRef2) => defineCompatOpenAIEmbedder({
ai,
name: embedderRef2.name,
client,
embedderRef: embedderRef2
})
);
Object.values(SUPPORTED_TTS_MODELS).forEach(
(modelRef) => defineCompatOpenAISpeechModel({
ai,
name: modelRef.name,
client,
modelRef
})
);
Object.values(SUPPORTED_STT_MODELS).forEach(
(modelRef) => defineCompatOpenAITranscriptionModel({
ai,
name: modelRef.name,
client,
modelRef
})
);
Object.values(SUPPORTED_IMAGE_MODELS).forEach(
(modelRef) => defineCompatOpenAIImageModel({
ai,
name: modelRef.name,
client,
modelRef,
requestBuilder: modelRef.name.includes("gpt-image-1") ? gptImage1RequestBuilder : void 0
})
);
},
resolver,
listActions
});
}
const model = (name, config) => {
if (name.includes("gpt-image-1") || name.includes("dall-e")) {
return openAIImageModelRef({
name: `openai/${name}`,
config
});
}
if (name.includes("tts")) {
return openAISpeechModelRef({
name: `openai/${name}`,
config
});
}
if (name.includes("whisper") || name.includes("transcribe")) {
return openAITranscriptionModelRef({
name: `openai/${name}`,
config
});
}
return openAIModelRef({
name: `openai/${name}`,
config
});
};
const embedder = (name, config) => {
return embedderRef({
name: `openai/${name}`,
config,
configSchema: TextEmbeddingConfigSchema
});
};
const openAI = Object.assign(openAIPlugin, {
model,
embedder
});
var openai_default = openAI;
export {
openai_default as default,
openAI,
openAIPlugin
};
//# sourceMappingURL=index.mjs.map