UNPKG

@genkit-ai/compat-oai

Version:

Genkit AI framework plugin for OpenAI APIs.

253 lines 6.78 kB
import { GenerationCommonConfigSchema, Message, modelRef, z } from "genkit"; const TRANSCRIPTION_MODEL_INFO = { supports: { media: true, output: ["text", "json"], multiturn: false, systemRole: false, tools: false } }; const SPEECH_MODEL_INFO = { supports: { media: false, output: ["media"], multiturn: false, systemRole: false, tools: false } }; const ChunkingStrategySchema = z.object({ type: z.string(), prefix_padding_ms: z.number().int().optional(), silence_duration_ms: z.number().int().optional(), threshold: z.number().min(0).max(1).optional() }); const TranscriptionConfigSchema = GenerationCommonConfigSchema.pick({ temperature: true }).extend({ chunking_strategy: z.union([z.literal("auto"), ChunkingStrategySchema]).optional(), include: z.array(z.any()).optional(), language: z.string().optional(), timestamp_granularities: z.array(z.enum(["word", "segment"])).optional(), response_format: z.enum(["json", "text", "srt", "verbose_json", "vtt"]).optional() // TODO stream support }); const SpeechConfigSchema = z.object({ voice: z.enum(["alloy", "echo", "fable", "onyx", "nova", "shimmer"]).default("alloy"), speed: z.number().min(0.25).max(4).optional(), response_format: z.enum(["mp3", "opus", "aac", "flac", "wav", "pcm"]).optional() }); const RESPONSE_FORMAT_MEDIA_TYPES = { mp3: "audio/mpeg", opus: "audio/opus", aac: "audio/aac", flac: "audio/flac", wav: "audio/wav", pcm: "audio/L16" }; function toTTSRequest(modelName, request, requestBuilder) { const { voice, version: modelVersion, temperature, maxOutputTokens, stopSequences, topK, topP, ...restOfConfig } = request.config ?? {}; let options = { model: modelVersion ?? modelName, input: new Message(request.messages[0]).text, voice: voice ?? "alloy" }; if (requestBuilder) { requestBuilder(request, options); } else { options = { ...options, ...restOfConfig // passthorugh rest of the config }; } for (const k in options) { if (options[k] === void 0) { delete options[k]; } } return options; } async function toGenerateResponse(response, responseFormat = "mp3") { const resultArrayBuffer = await response.arrayBuffer(); const resultBuffer = Buffer.from(new Uint8Array(resultArrayBuffer)); const mediaType = RESPONSE_FORMAT_MEDIA_TYPES[responseFormat]; return { message: { role: "model", content: [ { media: { contentType: mediaType, url: `data:${mediaType};base64,${resultBuffer.toString("base64")}` } } ] }, finishReason: "stop", raw: response }; } function defineCompatOpenAISpeechModel(params) { const { ai, name, client, modelRef: modelRef2, requestBuilder } = params; const modelName = name.substring(name.indexOf("/") + 1); return ai.defineModel( { name, apiVersion: "v2", ...modelRef2?.info, configSchema: modelRef2?.configSchema }, async (request, { abortSignal }) => { const ttsRequest = toTTSRequest(modelName, request, requestBuilder); const result = await client.audio.speech.create(ttsRequest, { signal: abortSignal }); return await toGenerateResponse(result, ttsRequest.response_format); } ); } function compatOaiSpeechModelRef(params) { const { name, info = SPEECH_MODEL_INFO, configSchema, config = void 0 } = params; return modelRef({ name, configSchema: configSchema || SpeechConfigSchema, info, config }); } function toSttRequest(modelName, request, requestBuilder) { const message = new Message(request.messages[0]); const media = message.media; if (!media?.url) { throw new Error("No media found in the request"); } const mediaBuffer = Buffer.from( media.url.slice(media.url.indexOf(",") + 1), "base64" ); const mediaFile = new File([mediaBuffer], "input", { type: media.contentType ?? media.url.slice("data:".length, media.url.indexOf(";")) }); const { temperature, version: modelVersion, maxOutputTokens, stopSequences, topK, topP, ...restOfConfig } = request.config ?? {}; let options = { model: modelVersion ?? modelName, file: mediaFile, prompt: message.text, temperature }; if (requestBuilder) { requestBuilder(request, options); } else { options = { ...options, ...restOfConfig // passthrough rest of the config }; } const outputFormat = request.output?.format; const customFormat = request.config?.response_format; if (outputFormat && customFormat) { if (outputFormat === "json" && customFormat !== "json" && customFormat !== "verbose_json") { throw new Error( `Custom response format ${customFormat} is not compatible with output format ${outputFormat}` ); } } if (outputFormat === "media") { throw new Error(`Output format ${outputFormat} is not supported.`); } options.response_format = customFormat || outputFormat || "text"; for (const k in options) { if (options[k] === void 0) { delete options[k]; } } return options; } function transcriptionToGenerateResponse(result) { return { message: { role: "model", content: [ { text: typeof result === "string" ? result : result.text } ] }, finishReason: "stop", raw: result }; } function defineCompatOpenAITranscriptionModel(params) { const { ai, name, client, modelRef: modelRef2, requestBuilder } = params; return ai.defineModel( { name, apiVersion: "v2", ...modelRef2?.info, configSchema: modelRef2?.configSchema }, async (request, { abortSignal }) => { const modelName = name.substring(name.indexOf("/") + 1); const params2 = toSttRequest(modelName, request, requestBuilder); const result = await client.audio.transcriptions.create( { ...params2, stream: false }, { signal: abortSignal } ); return transcriptionToGenerateResponse(result); } ); } function compatOaiTranscriptionModelRef(params) { const { name, info = TRANSCRIPTION_MODEL_INFO, configSchema, config = void 0 } = params; return modelRef({ name, configSchema: configSchema || TranscriptionConfigSchema, info, config }); } export { RESPONSE_FORMAT_MEDIA_TYPES, SPEECH_MODEL_INFO, SpeechConfigSchema, TRANSCRIPTION_MODEL_INFO, TranscriptionConfigSchema, compatOaiSpeechModelRef, compatOaiTranscriptionModelRef, defineCompatOpenAISpeechModel, defineCompatOpenAITranscriptionModel }; //# sourceMappingURL=audio.mjs.map