@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
464 lines • 18.4 kB
JavaScript
import { AbstractDriver, ModelType, TrainingJobStatus, getModelCapabilities, modelModalitiesToArray, supportsToolUse, } from "@llumiverse/core";
import { asyncMap } from "@llumiverse/core/async";
import { formatOpenAILikeMultimodalPrompt } from "./openai_format.js";
// Helper function to convert string to CompletionResult[]
function textToCompletionResult(text) {
return text ? [{ type: "text", value: text }] : [];
}
//TODO: Do we need a list?, replace with if statements and modernize?
const supportFineTunning = new Set([
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
"babbage-002",
"davinci-002",
"gpt-4-0613"
]);
export class BaseOpenAIDriver extends AbstractDriver {
constructor(opts) {
super(opts);
this.formatPrompt = formatOpenAILikeMultimodalPrompt;
//TODO: better type, we send back OpenAI.Chat.Completions.ChatCompletionMessageParam[] but just not compatible with Function call that we don't use here
}
extractDataFromResponse(_options, result) {
const tokenInfo = {
prompt: result.usage?.prompt_tokens,
result: result.usage?.completion_tokens,
total: result.usage?.total_tokens,
};
const choice = result.choices[0];
const tools = collectTools(choice.message.tool_calls);
const data = choice.message.content ?? undefined;
if (!data && !tools) {
this.logger.error({ result }, "[OpenAI] Response is not valid");
throw new Error("Response is not valid: no data");
}
return {
result: textToCompletionResult(data || ''),
token_usage: tokenInfo,
finish_reason: openAiFinishReason(choice.finish_reason),
tool_use: tools,
};
}
async requestTextCompletionStream(prompt, options) {
if (options.model_options?._option_id !== "openai-text" && options.model_options?._option_id !== "openai-thinking") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
const toolDefs = getToolDefinitions(options.tools);
const useTools = toolDefs ? supportsToolUse(options.model, "openai", true) : false;
const mapFn = (chunk) => {
let result = undefined;
if (useTools && this.provider !== "xai" && options.result_schema) {
result = chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? "";
}
else {
result = chunk.choices[0]?.delta.content ?? "";
}
return {
result: textToCompletionResult(result),
finish_reason: openAiFinishReason(chunk.choices[0]?.finish_reason ?? undefined), //Uses expected "stop" , "length" format
token_usage: {
prompt: chunk.usage?.prompt_tokens,
result: chunk.usage?.completion_tokens,
total: (chunk.usage?.prompt_tokens ?? 0) + (chunk.usage?.completion_tokens ?? 0),
}
};
};
convertRoles(prompt, options.model);
const model_options = options.model_options;
insert_image_detail(prompt, model_options?.image_detail ?? "auto");
let parsedSchema = undefined;
let strictMode = false;
if (options.result_schema && supportsSchema(options.model)) {
try {
parsedSchema = openAISchemaFormat(options.result_schema);
strictMode = true;
}
catch (e) {
parsedSchema = limitedSchemaFormat(options.result_schema);
strictMode = false;
}
}
const stream = await this.service.chat.completions.create({
stream: true,
stream_options: { include_usage: true },
model: options.model,
messages: prompt,
reasoning_effort: model_options?.reasoning_effort,
temperature: model_options?.temperature,
top_p: model_options?.top_p,
presence_penalty: model_options?.presence_penalty,
frequency_penalty: model_options?.frequency_penalty,
n: 1,
max_completion_tokens: model_options?.max_tokens, //TODO: use max_tokens for older models, currently relying on OpenAI to handle it
tools: useTools ? toolDefs : undefined,
stop: model_options?.stop_sequence,
response_format: parsedSchema ? {
type: "json_schema",
json_schema: {
name: "format_output",
schema: parsedSchema,
strict: strictMode,
}
} : undefined,
});
return asyncMap(stream, mapFn);
}
async requestTextCompletion(prompt, options) {
if (options.model_options?._option_id !== "openai-text" && options.model_options?._option_id !== "openai-thinking") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
convertRoles(prompt, options.model);
const model_options = options.model_options;
insert_image_detail(prompt, model_options?.image_detail ?? "auto");
const toolDefs = getToolDefinitions(options.tools);
const useTools = toolDefs ? supportsToolUse(options.model, "openai") : false;
let conversation = updateConversation(options.conversation, prompt);
let parsedSchema = undefined;
let strictMode = false;
if (options.result_schema && supportsSchema(options.model)) {
try {
parsedSchema = openAISchemaFormat(options.result_schema);
strictMode = true;
}
catch (e) {
parsedSchema = limitedSchemaFormat(options.result_schema);
strictMode = false;
}
}
const res = await this.service.chat.completions.create({
stream: false,
model: options.model,
messages: conversation,
reasoning_effort: model_options?.reasoning_effort,
temperature: model_options?.temperature,
top_p: model_options?.top_p,
presence_penalty: model_options?.presence_penalty,
frequency_penalty: model_options?.frequency_penalty,
n: 1,
max_completion_tokens: model_options?.max_tokens, //TODO: use max_tokens for older models, currently relying on OpenAI to handle it
tools: useTools ? toolDefs : undefined,
stop: model_options?.stop_sequence,
response_format: parsedSchema ? {
type: "json_schema",
json_schema: {
name: "format_output",
schema: parsedSchema,
strict: strictMode,
}
} : undefined,
});
const completion = this.extractDataFromResponse(options, res);
if (options.include_original_response) {
completion.original_response = res;
}
conversation = updateConversation(conversation, createPromptFromResponse(res.choices[0].message));
completion.conversation = conversation;
return completion;
}
canStream(_options) {
if (_options.model.includes("o1")
&& !(_options.model.includes("mini") || _options.model.includes("preview"))) {
//o1 full does not support streaming
//TODO: Update when OpenAI adds support for streaming, last check 16/02/2025
return Promise.resolve(false);
}
return Promise.resolve(true);
}
createTrainingPrompt(options) {
if (options.model.includes("gpt")) {
return super.createTrainingPrompt(options);
}
else {
// babbage, davinci not yet implemented
throw new Error("Unsupported model for training: " + options.model);
}
}
async startTraining(dataset, options) {
const url = await dataset.getURL();
const file = await this.service.files.create({
file: await fetch(url),
purpose: "fine-tune",
});
const job = await this.service.fineTuning.jobs.create({
training_file: file.id,
model: options.model,
hyperparameters: options.params
});
return jobInfo(job);
}
async cancelTraining(jobId) {
const job = await this.service.fineTuning.jobs.cancel(jobId);
return jobInfo(job);
}
async getTrainingJob(jobId) {
const job = await this.service.fineTuning.jobs.retrieve(jobId);
return jobInfo(job);
}
// ========= management API =============
async validateConnection() {
try {
await this.service.models.list();
return true;
}
catch (error) {
return false;
}
}
listTrainableModels() {
return this._listModels((m) => supportFineTunning.has(m.id));
}
async listModels() {
return this._listModels();
}
async _listModels(filter) {
let result = (await this.service.models.list()).data;
//Some of these use the completions API instead of the chat completions API.
//Others are for non-text input modalities. Therefore common to both.
const wordBlacklist = ["embed", "whisper", "transcribe", "audio", "moderation", "tts",
"realtime", "dall-e", "babbage", "davinci", "codex", "o1-pro", "computer-use", "sora"];
//OpenAI has very little information, filtering based on name.
result = result.filter((m) => {
return !wordBlacklist.some((word) => m.id.includes(word));
});
const models = filter ? result.filter(filter) : result;
const aiModels = models.map((m) => {
const modelCapability = getModelCapabilities(m.id, "openai");
let owner = m.owned_by;
if (owner == "system") {
owner = "openai";
}
return {
id: m.id,
name: m.id,
provider: this.provider,
owner: owner,
type: m.object === "model" ? ModelType.Text : ModelType.Unknown,
can_stream: true,
is_multimodal: m.id.includes("gpt-4"),
input_modalities: modelModalitiesToArray(modelCapability.input),
output_modalities: modelModalitiesToArray(modelCapability.output),
tool_support: modelCapability.tool_support,
};
}).sort((a, b) => a.id.localeCompare(b.id));
return aiModels;
}
async generateEmbeddings({ text, image, model = "text-embedding-3-small" }) {
if (image) {
throw new Error("Image embeddings not supported by OpenAI");
}
if (!text) {
throw new Error("No text provided");
}
const res = await this.service.embeddings.create({
input: text,
model: model,
});
const embeddings = res.data[0].embedding;
if (!embeddings || embeddings.length === 0) {
throw new Error("No embedding found");
}
return { values: embeddings, model };
}
}
function jobInfo(job) {
//validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.
const jobStatus = job.status;
let status = TrainingJobStatus.running;
let details;
if (jobStatus === 'succeeded') {
status = TrainingJobStatus.succeeded;
}
else if (jobStatus === 'failed') {
status = TrainingJobStatus.failed;
details = job.error ? `${job.error.code} - ${job.error.message} ${job.error.param ? " [" + job.error.param + "]" : ""}` : "error";
}
else if (jobStatus === 'cancelled') {
status = TrainingJobStatus.cancelled;
}
else {
status = TrainingJobStatus.running;
details = jobStatus;
}
return {
id: job.id,
model: job.fine_tuned_model || undefined,
status,
details
};
}
function insert_image_detail(messages, detail_level) {
if (detail_level == "auto" || detail_level == "low" || detail_level == "high") {
for (const message of messages) {
if (message.role !== 'assistant' && message.content) {
for (const part of message.content) {
if (typeof part === "string") {
continue;
}
if (part.type === 'image_url') {
part.image_url = { ...part.image_url, detail: detail_level };
}
}
}
}
}
return messages;
}
function convertRoles(messages, model) {
//New openai models use developer role instead of system
if (model.includes("o1") || model.includes("o3")) {
if (model.includes("o1-mini") || model.includes("o1-preview")) {
//o1-mini and o1-preview support neither system nor developer
for (const message of messages) {
if (message.role === 'system') {
message.role = 'user';
}
}
}
else {
//Models newer than o1 use developer role
for (const message of messages) {
if (message.role === 'system') {
message.role = 'developer';
}
}
}
}
return messages;
}
//Structured output support is typically aligned with tool use support
//Not true for realtime models, which do not support structured output, but do support tool use.
function supportsSchema(model) {
const realtimeModel = model.includes("realtime");
if (realtimeModel) {
return false;
}
return supportsToolUse(model, "openai");
}
function getToolDefinitions(tools) {
return tools ? tools.map(getToolDefinition) : undefined;
}
function getToolDefinition(toolDef) {
let parsedSchema = undefined;
let strictMode = false;
if (toolDef.input_schema) {
try {
parsedSchema = openAISchemaFormat(toolDef.input_schema);
strictMode = true;
}
catch (e) {
parsedSchema = limitedSchemaFormat(toolDef.input_schema);
strictMode = false;
}
}
return {
type: "function",
function: {
name: toolDef.name,
description: toolDef.description,
parameters: parsedSchema,
strict: strictMode,
},
};
}
function openAiFinishReason(finish_reason) {
if (finish_reason === "tool_calls") {
return "tool_use";
}
return finish_reason;
}
function updateConversation(conversation, message) {
if (!message) {
return conversation;
}
if (!conversation) {
return message;
}
return [...conversation, ...message];
}
export function collectTools(toolCalls) {
if (!toolCalls) {
return undefined;
}
const tools = [];
for (const call of toolCalls) {
tools.push({
id: call.id,
tool_name: call.function.name,
tool_input: JSON.parse(call.function.arguments),
});
}
return tools.length > 0 ? tools : undefined;
}
function createPromptFromResponse(response) {
const messages = [];
if (response) {
messages.push({
role: response.role,
content: [{
type: "text",
text: response.content ?? ""
}],
tool_calls: response.tool_calls,
});
}
return messages;
}
//For strict mode false
function limitedSchemaFormat(schema) {
const formattedSchema = { ...schema };
// Defaults not supported
delete formattedSchema.default;
if (formattedSchema?.properties) {
// Process each property recursively
for (const propName of Object.keys(formattedSchema.properties)) {
const property = formattedSchema.properties[propName];
// Recursively process properties
formattedSchema.properties[propName] = limitedSchemaFormat(property);
// Process arrays with items of type object
if (property?.type === 'array' && property.items && property.items?.type === 'object') {
formattedSchema.properties[propName] = {
...property,
items: limitedSchemaFormat(property.items),
};
}
}
}
return formattedSchema;
}
//For strict mode true
function openAISchemaFormat(schema, nesting = 0) {
if (nesting > 5) {
throw new Error("OpenAI schema nesting too deep");
}
const formattedSchema = { ...schema };
// Defaults not supported
delete formattedSchema.default;
// Additional properties not supported, required to be set.
if (formattedSchema?.type === "object") {
formattedSchema.additionalProperties = false;
}
if (formattedSchema?.properties) {
// Set all properties as required
formattedSchema.required = Object.keys(formattedSchema.properties);
// Process each property recursively
for (const propName of Object.keys(formattedSchema.properties)) {
const property = formattedSchema.properties[propName];
// Recursively process properties
formattedSchema.properties[propName] = openAISchemaFormat(property, nesting + 1);
// Process arrays with items of type object
if (property?.type === 'array' && property.items && property.items?.type === 'object') {
formattedSchema.properties[propName] = {
...property,
items: openAISchemaFormat(property.items, nesting + 1),
};
}
}
}
if (formattedSchema?.type === 'object' && (!formattedSchema?.properties || Object.keys(formattedSchema?.properties ?? {}).length == 0)) {
//If no properties are defined, then additionalProperties: true was set or the object would be empty.
//OpenAI does not support this on structured output/ strict mode.
throw new Error("OpenAI does not support empty objects or objects with additionalProperties set to true");
}
return formattedSchema;
}
//# sourceMappingURL=index.js.map