UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

464 lines 18.4 kB
import { AbstractDriver, ModelType, TrainingJobStatus, getModelCapabilities, modelModalitiesToArray, supportsToolUse, } from "@llumiverse/core"; import { asyncMap } from "@llumiverse/core/async"; import { formatOpenAILikeMultimodalPrompt } from "./openai_format.js"; // Helper function to convert string to CompletionResult[] function textToCompletionResult(text) { return text ? [{ type: "text", value: text }] : []; } //TODO: Do we need a list?, replace with if statements and modernize? const supportFineTunning = new Set([ "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0613", "babbage-002", "davinci-002", "gpt-4-0613" ]); export class BaseOpenAIDriver extends AbstractDriver { constructor(opts) { super(opts); this.formatPrompt = formatOpenAILikeMultimodalPrompt; //TODO: better type, we send back OpenAI.Chat.Completions.ChatCompletionMessageParam[] but just not compatible with Function call that we don't use here } extractDataFromResponse(_options, result) { const tokenInfo = { prompt: result.usage?.prompt_tokens, result: result.usage?.completion_tokens, total: result.usage?.total_tokens, }; const choice = result.choices[0]; const tools = collectTools(choice.message.tool_calls); const data = choice.message.content ?? undefined; if (!data && !tools) { this.logger.error({ result }, "[OpenAI] Response is not valid"); throw new Error("Response is not valid: no data"); } return { result: textToCompletionResult(data || ''), token_usage: tokenInfo, finish_reason: openAiFinishReason(choice.finish_reason), tool_use: tools, }; } async requestTextCompletionStream(prompt, options) { if (options.model_options?._option_id !== "openai-text" && options.model_options?._option_id !== "openai-thinking") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } const toolDefs = getToolDefinitions(options.tools); const useTools = toolDefs ? supportsToolUse(options.model, "openai", true) : false; const mapFn = (chunk) => { let result = undefined; if (useTools && this.provider !== "xai" && options.result_schema) { result = chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? ""; } else { result = chunk.choices[0]?.delta.content ?? ""; } return { result: textToCompletionResult(result), finish_reason: openAiFinishReason(chunk.choices[0]?.finish_reason ?? undefined), //Uses expected "stop" , "length" format token_usage: { prompt: chunk.usage?.prompt_tokens, result: chunk.usage?.completion_tokens, total: (chunk.usage?.prompt_tokens ?? 0) + (chunk.usage?.completion_tokens ?? 0), } }; }; convertRoles(prompt, options.model); const model_options = options.model_options; insert_image_detail(prompt, model_options?.image_detail ?? "auto"); let parsedSchema = undefined; let strictMode = false; if (options.result_schema && supportsSchema(options.model)) { try { parsedSchema = openAISchemaFormat(options.result_schema); strictMode = true; } catch (e) { parsedSchema = limitedSchemaFormat(options.result_schema); strictMode = false; } } const stream = await this.service.chat.completions.create({ stream: true, stream_options: { include_usage: true }, model: options.model, messages: prompt, reasoning_effort: model_options?.reasoning_effort, temperature: model_options?.temperature, top_p: model_options?.top_p, presence_penalty: model_options?.presence_penalty, frequency_penalty: model_options?.frequency_penalty, n: 1, max_completion_tokens: model_options?.max_tokens, //TODO: use max_tokens for older models, currently relying on OpenAI to handle it tools: useTools ? toolDefs : undefined, stop: model_options?.stop_sequence, response_format: parsedSchema ? { type: "json_schema", json_schema: { name: "format_output", schema: parsedSchema, strict: strictMode, } } : undefined, }); return asyncMap(stream, mapFn); } async requestTextCompletion(prompt, options) { if (options.model_options?._option_id !== "openai-text" && options.model_options?._option_id !== "openai-thinking") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } convertRoles(prompt, options.model); const model_options = options.model_options; insert_image_detail(prompt, model_options?.image_detail ?? "auto"); const toolDefs = getToolDefinitions(options.tools); const useTools = toolDefs ? supportsToolUse(options.model, "openai") : false; let conversation = updateConversation(options.conversation, prompt); let parsedSchema = undefined; let strictMode = false; if (options.result_schema && supportsSchema(options.model)) { try { parsedSchema = openAISchemaFormat(options.result_schema); strictMode = true; } catch (e) { parsedSchema = limitedSchemaFormat(options.result_schema); strictMode = false; } } const res = await this.service.chat.completions.create({ stream: false, model: options.model, messages: conversation, reasoning_effort: model_options?.reasoning_effort, temperature: model_options?.temperature, top_p: model_options?.top_p, presence_penalty: model_options?.presence_penalty, frequency_penalty: model_options?.frequency_penalty, n: 1, max_completion_tokens: model_options?.max_tokens, //TODO: use max_tokens for older models, currently relying on OpenAI to handle it tools: useTools ? toolDefs : undefined, stop: model_options?.stop_sequence, response_format: parsedSchema ? { type: "json_schema", json_schema: { name: "format_output", schema: parsedSchema, strict: strictMode, } } : undefined, }); const completion = this.extractDataFromResponse(options, res); if (options.include_original_response) { completion.original_response = res; } conversation = updateConversation(conversation, createPromptFromResponse(res.choices[0].message)); completion.conversation = conversation; return completion; } canStream(_options) { if (_options.model.includes("o1") && !(_options.model.includes("mini") || _options.model.includes("preview"))) { //o1 full does not support streaming //TODO: Update when OpenAI adds support for streaming, last check 16/02/2025 return Promise.resolve(false); } return Promise.resolve(true); } createTrainingPrompt(options) { if (options.model.includes("gpt")) { return super.createTrainingPrompt(options); } else { // babbage, davinci not yet implemented throw new Error("Unsupported model for training: " + options.model); } } async startTraining(dataset, options) { const url = await dataset.getURL(); const file = await this.service.files.create({ file: await fetch(url), purpose: "fine-tune", }); const job = await this.service.fineTuning.jobs.create({ training_file: file.id, model: options.model, hyperparameters: options.params }); return jobInfo(job); } async cancelTraining(jobId) { const job = await this.service.fineTuning.jobs.cancel(jobId); return jobInfo(job); } async getTrainingJob(jobId) { const job = await this.service.fineTuning.jobs.retrieve(jobId); return jobInfo(job); } // ========= management API ============= async validateConnection() { try { await this.service.models.list(); return true; } catch (error) { return false; } } listTrainableModels() { return this._listModels((m) => supportFineTunning.has(m.id)); } async listModels() { return this._listModels(); } async _listModels(filter) { let result = (await this.service.models.list()).data; //Some of these use the completions API instead of the chat completions API. //Others are for non-text input modalities. Therefore common to both. const wordBlacklist = ["embed", "whisper", "transcribe", "audio", "moderation", "tts", "realtime", "dall-e", "babbage", "davinci", "codex", "o1-pro", "computer-use", "sora"]; //OpenAI has very little information, filtering based on name. result = result.filter((m) => { return !wordBlacklist.some((word) => m.id.includes(word)); }); const models = filter ? result.filter(filter) : result; const aiModels = models.map((m) => { const modelCapability = getModelCapabilities(m.id, "openai"); let owner = m.owned_by; if (owner == "system") { owner = "openai"; } return { id: m.id, name: m.id, provider: this.provider, owner: owner, type: m.object === "model" ? ModelType.Text : ModelType.Unknown, can_stream: true, is_multimodal: m.id.includes("gpt-4"), input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, }; }).sort((a, b) => a.id.localeCompare(b.id)); return aiModels; } async generateEmbeddings({ text, image, model = "text-embedding-3-small" }) { if (image) { throw new Error("Image embeddings not supported by OpenAI"); } if (!text) { throw new Error("No text provided"); } const res = await this.service.embeddings.create({ input: text, model: model, }); const embeddings = res.data[0].embedding; if (!embeddings || embeddings.length === 0) { throw new Error("No embedding found"); } return { values: embeddings, model }; } } function jobInfo(job) { //validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`. const jobStatus = job.status; let status = TrainingJobStatus.running; let details; if (jobStatus === 'succeeded') { status = TrainingJobStatus.succeeded; } else if (jobStatus === 'failed') { status = TrainingJobStatus.failed; details = job.error ? `${job.error.code} - ${job.error.message} ${job.error.param ? " [" + job.error.param + "]" : ""}` : "error"; } else if (jobStatus === 'cancelled') { status = TrainingJobStatus.cancelled; } else { status = TrainingJobStatus.running; details = jobStatus; } return { id: job.id, model: job.fine_tuned_model || undefined, status, details }; } function insert_image_detail(messages, detail_level) { if (detail_level == "auto" || detail_level == "low" || detail_level == "high") { for (const message of messages) { if (message.role !== 'assistant' && message.content) { for (const part of message.content) { if (typeof part === "string") { continue; } if (part.type === 'image_url') { part.image_url = { ...part.image_url, detail: detail_level }; } } } } } return messages; } function convertRoles(messages, model) { //New openai models use developer role instead of system if (model.includes("o1") || model.includes("o3")) { if (model.includes("o1-mini") || model.includes("o1-preview")) { //o1-mini and o1-preview support neither system nor developer for (const message of messages) { if (message.role === 'system') { message.role = 'user'; } } } else { //Models newer than o1 use developer role for (const message of messages) { if (message.role === 'system') { message.role = 'developer'; } } } } return messages; } //Structured output support is typically aligned with tool use support //Not true for realtime models, which do not support structured output, but do support tool use. function supportsSchema(model) { const realtimeModel = model.includes("realtime"); if (realtimeModel) { return false; } return supportsToolUse(model, "openai"); } function getToolDefinitions(tools) { return tools ? tools.map(getToolDefinition) : undefined; } function getToolDefinition(toolDef) { let parsedSchema = undefined; let strictMode = false; if (toolDef.input_schema) { try { parsedSchema = openAISchemaFormat(toolDef.input_schema); strictMode = true; } catch (e) { parsedSchema = limitedSchemaFormat(toolDef.input_schema); strictMode = false; } } return { type: "function", function: { name: toolDef.name, description: toolDef.description, parameters: parsedSchema, strict: strictMode, }, }; } function openAiFinishReason(finish_reason) { if (finish_reason === "tool_calls") { return "tool_use"; } return finish_reason; } function updateConversation(conversation, message) { if (!message) { return conversation; } if (!conversation) { return message; } return [...conversation, ...message]; } export function collectTools(toolCalls) { if (!toolCalls) { return undefined; } const tools = []; for (const call of toolCalls) { tools.push({ id: call.id, tool_name: call.function.name, tool_input: JSON.parse(call.function.arguments), }); } return tools.length > 0 ? tools : undefined; } function createPromptFromResponse(response) { const messages = []; if (response) { messages.push({ role: response.role, content: [{ type: "text", text: response.content ?? "" }], tool_calls: response.tool_calls, }); } return messages; } //For strict mode false function limitedSchemaFormat(schema) { const formattedSchema = { ...schema }; // Defaults not supported delete formattedSchema.default; if (formattedSchema?.properties) { // Process each property recursively for (const propName of Object.keys(formattedSchema.properties)) { const property = formattedSchema.properties[propName]; // Recursively process properties formattedSchema.properties[propName] = limitedSchemaFormat(property); // Process arrays with items of type object if (property?.type === 'array' && property.items && property.items?.type === 'object') { formattedSchema.properties[propName] = { ...property, items: limitedSchemaFormat(property.items), }; } } } return formattedSchema; } //For strict mode true function openAISchemaFormat(schema, nesting = 0) { if (nesting > 5) { throw new Error("OpenAI schema nesting too deep"); } const formattedSchema = { ...schema }; // Defaults not supported delete formattedSchema.default; // Additional properties not supported, required to be set. if (formattedSchema?.type === "object") { formattedSchema.additionalProperties = false; } if (formattedSchema?.properties) { // Set all properties as required formattedSchema.required = Object.keys(formattedSchema.properties); // Process each property recursively for (const propName of Object.keys(formattedSchema.properties)) { const property = formattedSchema.properties[propName]; // Recursively process properties formattedSchema.properties[propName] = openAISchemaFormat(property, nesting + 1); // Process arrays with items of type object if (property?.type === 'array' && property.items && property.items?.type === 'object') { formattedSchema.properties[propName] = { ...property, items: openAISchemaFormat(property.items, nesting + 1), }; } } } if (formattedSchema?.type === 'object' && (!formattedSchema?.properties || Object.keys(formattedSchema?.properties ?? {}).length == 0)) { //If no properties are defined, then additionalProperties: true was set or the object would be empty. //OpenAI does not support this on structured output/ strict mode. throw new Error("OpenAI does not support empty objects or objects with additionalProperties set to true"); } return formattedSchema; } //# sourceMappingURL=index.js.map