@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
578 lines (495 loc) • 21.2 kB
text/typescript
import {
AIModel,
AbstractDriver,
Completion,
CompletionChunkObject,
CompletionResult,
DataSource,
DriverOptions,
EmbeddingsOptions,
EmbeddingsResult,
ExecutionOptions,
ExecutionTokenUsage,
JSONSchema,
ModelType,
Providers,
ToolDefinition,
ToolUse,
TrainingJob,
TrainingJobStatus,
TrainingOptions,
TrainingPromptOptions,
getModelCapabilities,
modelModalitiesToArray,
supportsToolUse,
} from "@llumiverse/core";
import { asyncMap } from "@llumiverse/core/async";
import { formatOpenAILikeMultimodalPrompt } from "./openai_format.js";
import OpenAI, { AzureOpenAI } from "openai";
import { ChatCompletionMessageParam } from "openai/resources/chat/completions";
import { Stream } from "openai/streaming";
// Helper function to convert string to CompletionResult[]
function textToCompletionResult(text: string): CompletionResult[] {
return text ? [{ type: "text", value: text }] : [];
}
//TODO: Do we need a list?, replace with if statements and modernize?
const supportFineTunning = new Set([
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
"babbage-002",
"davinci-002",
"gpt-4-0613"
]);
export interface BaseOpenAIDriverOptions extends DriverOptions {
}
export abstract class BaseOpenAIDriver extends AbstractDriver<
BaseOpenAIDriverOptions,
ChatCompletionMessageParam[]
> {
//abstract provider: "azure_openai" | "openai" | "xai" | "azure_foundry";
abstract provider: Providers.openai | Providers.azure_openai | "xai" | Providers.azure_foundry;
abstract service: OpenAI | AzureOpenAI;
constructor(opts: BaseOpenAIDriverOptions) {
super(opts);
this.formatPrompt = formatOpenAILikeMultimodalPrompt
//TODO: better type, we send back OpenAI.Chat.Completions.ChatCompletionMessageParam[] but just not compatible with Function call that we don't use here
}
extractDataFromResponse(
_options: ExecutionOptions,
result: OpenAI.Chat.Completions.ChatCompletion
): Completion {
const tokenInfo: ExecutionTokenUsage = {
prompt: result.usage?.prompt_tokens,
result: result.usage?.completion_tokens,
total: result.usage?.total_tokens,
};
const choice = result.choices[0];
const tools = collectTools(choice.message.tool_calls);
const data = choice.message.content ?? undefined;
if (!data && !tools) {
this.logger?.error("[OpenAI] Response is not valid", result);
throw new Error("Response is not valid: no data");
}
return {
result: textToCompletionResult(data || ''),
token_usage: tokenInfo,
finish_reason: openAiFinishReason(choice.finish_reason),
tool_use: tools,
};
}
async requestTextCompletionStream(prompt: ChatCompletionMessageParam[], options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
if (options.model_options?._option_id !== "openai-text" && options.model_options?._option_id !== "openai-thinking") {
this.logger.warn("Invalid model options", { options: options.model_options });
}
const toolDefs = getToolDefinitions(options.tools);
const useTools: boolean = toolDefs ? supportsToolUse(options.model, "openai", true) : false;
const mapFn = (chunk: OpenAI.Chat.Completions.ChatCompletionChunk) => {
let result = undefined
if (useTools && this.provider !== "xai" && options.result_schema) {
result = chunk.choices[0]?.delta?.tool_calls?.[0].function?.arguments ?? "";
} else {
result = chunk.choices[0]?.delta.content ?? "";
}
return {
result: textToCompletionResult(result),
finish_reason: openAiFinishReason(chunk.choices[0]?.finish_reason ?? undefined), //Uses expected "stop" , "length" format
token_usage: {
prompt: chunk.usage?.prompt_tokens,
result: chunk.usage?.completion_tokens,
total: (chunk.usage?.prompt_tokens ?? 0) + (chunk.usage?.completion_tokens ?? 0),
}
} satisfies CompletionChunkObject;
};
convertRoles(prompt, options.model);
const model_options = options.model_options as any;
insert_image_detail(prompt, model_options?.image_detail ?? "auto");
let parsedSchema: JSONSchema | undefined = undefined;
let strictMode = false;
if (options.result_schema && supportsSchema(options.model)) {
try {
parsedSchema = openAISchemaFormat(options.result_schema);
strictMode = true;
}
catch (e) {
parsedSchema = limitedSchemaFormat(options.result_schema);
strictMode = false;
}
}
const stream = await this.service.chat.completions.create({
stream: true,
stream_options: { include_usage: true },
model: options.model,
messages: prompt,
reasoning_effort: model_options?.reasoning_effort,
temperature: model_options?.temperature,
top_p: model_options?.top_p,
presence_penalty: model_options?.presence_penalty,
frequency_penalty: model_options?.frequency_penalty,
n: 1,
max_completion_tokens: model_options?.max_tokens, //TODO: use max_tokens for older models, currently relying on OpenAI to handle it
tools: useTools ? toolDefs : undefined,
stop: model_options?.stop_sequence,
response_format: parsedSchema ? {
type: "json_schema",
json_schema: {
name: "format_output",
schema: parsedSchema,
strict: strictMode,
}
} : undefined,
} satisfies OpenAI.Chat.ChatCompletionCreateParamsStreaming
) satisfies Stream<OpenAI.Chat.Completions.ChatCompletionChunk>;
return asyncMap(stream, mapFn);
}
async requestTextCompletion(prompt: ChatCompletionMessageParam[], options: ExecutionOptions): Promise<Completion> {
if (options.model_options?._option_id !== "openai-text" && options.model_options?._option_id !== "openai-thinking") {
this.logger.warn("Invalid model options", { options: options.model_options });
}
convertRoles(prompt, options.model);
const model_options = options.model_options as any;
insert_image_detail(prompt, model_options?.image_detail ?? "auto");
const toolDefs = getToolDefinitions(options.tools);
const useTools: boolean = toolDefs ? supportsToolUse(options.model, "openai") : false;
let conversation = updateConversation(options.conversation as ChatCompletionMessageParam[], prompt);
let parsedSchema: JSONSchema | undefined = undefined;
let strictMode = false;
if (options.result_schema && supportsSchema(options.model)) {
try {
parsedSchema = openAISchemaFormat(options.result_schema);
strictMode = true;
}
catch (e) {
parsedSchema = limitedSchemaFormat(options.result_schema);
strictMode = false;
}
}
const res = await this.service.chat.completions.create({
stream: false,
model: options.model,
messages: conversation,
reasoning_effort: model_options?.reasoning_effort,
temperature: model_options?.temperature,
top_p: model_options?.top_p,
presence_penalty: model_options?.presence_penalty,
frequency_penalty: model_options?.frequency_penalty,
n: 1,
max_completion_tokens: model_options?.max_tokens, //TODO: use max_tokens for older models, currently relying on OpenAI to handle it
tools: useTools ? toolDefs : undefined,
stop: model_options?.stop_sequence,
response_format: parsedSchema ? {
type: "json_schema",
json_schema: {
name: "format_output",
schema: parsedSchema,
strict: strictMode,
}
} : undefined,
});
const completion = this.extractDataFromResponse(options, res);
if (options.include_original_response) {
completion.original_response = res;
}
conversation = updateConversation(conversation, createPromptFromResponse(res.choices[0].message));
completion.conversation = conversation;
return completion;
}
protected canStream(_options: ExecutionOptions): Promise<boolean> {
if (_options.model.includes("o1")
&& !(_options.model.includes("mini") || _options.model.includes("preview"))) {
//o1 full does not support streaming
//TODO: Update when OpenAI adds support for streaming, last check 16/02/2025
return Promise.resolve(false);
}
return Promise.resolve(true);
}
createTrainingPrompt(options: TrainingPromptOptions): Promise<string> {
if (options.model.includes("gpt")) {
return super.createTrainingPrompt(options);
} else {
// babbage, davinci not yet implemented
throw new Error("Unsupported model for training: " + options.model);
}
}
async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> {
const url = await dataset.getURL();
const file = await this.service.files.create({
file: await fetch(url),
purpose: "fine-tune",
});
const job = await this.service.fineTuning.jobs.create({
training_file: file.id,
model: options.model,
hyperparameters: options.params
})
return jobInfo(job);
}
async cancelTraining(jobId: string): Promise<TrainingJob> {
const job = await this.service.fineTuning.jobs.cancel(jobId);
return jobInfo(job);
}
async getTrainingJob(jobId: string): Promise<TrainingJob> {
const job = await this.service.fineTuning.jobs.retrieve(jobId);
return jobInfo(job);
}
// ========= management API =============
async validateConnection(): Promise<boolean> {
try {
await this.service.models.list();
return true;
} catch (error) {
return false;
}
}
listTrainableModels(): Promise<AIModel<string>[]> {
return this._listModels((m) => supportFineTunning.has(m.id));
}
async listModels(): Promise<AIModel[]> {
return this._listModels();
}
async _listModels(filter?: (m: OpenAI.Models.Model) => boolean): Promise<AIModel[]> {
let result = (await this.service.models.list()).data;
//Some of these use the completions API instead of the chat completions API.
//Others are for non-text input modalities. Therefore common to both.
const wordBlacklist = ["embed", "whisper", "transcribe", "audio", "moderation", "tts",
"realtime", "dall-e", "babbage", "davinci", "codex", "o1-pro", "computer-use", "sora"];
//OpenAI has very little information, filtering based on name.
result = result.filter((m) => {
return !wordBlacklist.some((word) => m.id.includes(word));
});
const models = filter ? result.filter(filter) : result;
const aiModels = models.map((m) => {
const modelCapability = getModelCapabilities(m.id, "openai");
let owner = m.owned_by;
if (owner == "system") {
owner = "openai";
}
return {
id: m.id,
name: m.id,
provider: this.provider,
owner: owner,
type: m.object === "model" ? ModelType.Text : ModelType.Unknown,
can_stream: true,
is_multimodal: m.id.includes("gpt-4"),
input_modalities: modelModalitiesToArray(modelCapability.input),
output_modalities: modelModalitiesToArray(modelCapability.output),
tool_support: modelCapability.tool_support,
} satisfies AIModel<string>;
}).sort((a, b) => a.id.localeCompare(b.id));
return aiModels;
}
async generateEmbeddings({ text, image, model = "text-embedding-3-small" }: EmbeddingsOptions): Promise<EmbeddingsResult> {
if (image) {
throw new Error("Image embeddings not supported by OpenAI");
}
if (!text) {
throw new Error("No text provided");
}
const res = await this.service.embeddings.create({
input: text,
model: model,
});
const embeddings = res.data[0].embedding;
if (!embeddings || embeddings.length === 0) {
throw new Error("No embedding found");
}
return { values: embeddings, model } satisfies EmbeddingsResult;
}
}
function jobInfo(job: OpenAI.FineTuning.Jobs.FineTuningJob): TrainingJob {
//validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.
const jobStatus = job.status;
let status = TrainingJobStatus.running;
let details: string | undefined;
if (jobStatus === 'succeeded') {
status = TrainingJobStatus.succeeded;
} else if (jobStatus === 'failed') {
status = TrainingJobStatus.failed;
details = job.error ? `${job.error.code} - ${job.error.message} ${job.error.param ? " [" + job.error.param + "]" : ""}` : "error";
} else if (jobStatus === 'cancelled') {
status = TrainingJobStatus.cancelled;
} else {
status = TrainingJobStatus.running;
details = jobStatus;
}
return {
id: job.id,
model: job.fine_tuned_model || undefined,
status,
details
}
}
function insert_image_detail(messages: ChatCompletionMessageParam[], detail_level: string): ChatCompletionMessageParam[] {
if (detail_level == "auto" || detail_level == "low" || detail_level == "high") {
for (const message of messages) {
if (message.role !== 'assistant' && message.content) {
for (const part of message.content) {
if (typeof part === "string") {
continue;
}
if (part.type === 'image_url') {
part.image_url = { ...part.image_url, detail: detail_level };
}
}
}
}
}
return messages;
}
function convertRoles(messages: ChatCompletionMessageParam[], model: string): ChatCompletionMessageParam[] {
//New openai models use developer role instead of system
if (model.includes("o1") || model.includes("o3")) {
if (model.includes("o1-mini") || model.includes("o1-preview")) {
//o1-mini and o1-preview support neither system nor developer
for (const message of messages) {
if (message.role === 'system') {
(message.role as any) = 'user';
}
}
} else {
//Models newer than o1 use developer role
for (const message of messages) {
if (message.role === 'system') {
(message.role as any) = 'developer';
}
}
}
}
return messages
}
//Structured output support is typically aligned with tool use support
//Not true for realtime models, which do not support structured output, but do support tool use.
function supportsSchema(model: string): boolean {
const realtimeModel = model.includes("realtime");
if (realtimeModel) {
return false;
}
return supportsToolUse(model, "openai");
}
function getToolDefinitions(tools: ToolDefinition[] | undefined | null): OpenAI.ChatCompletionTool[] | undefined {
return tools ? tools.map(getToolDefinition) : undefined;
}
function getToolDefinition(toolDef: ToolDefinition): OpenAI.ChatCompletionTool {
let parsedSchema: JSONSchema | undefined = undefined;
let strictMode = false;
if (toolDef.input_schema) {
try {
parsedSchema = openAISchemaFormat(toolDef.input_schema as JSONSchema);
strictMode = true;
}
catch (e) {
parsedSchema = limitedSchemaFormat(toolDef.input_schema as JSONSchema);
strictMode = false;
}
}
return {
type: "function",
function: {
name: toolDef.name,
description: toolDef.description,
parameters: parsedSchema,
strict: strictMode,
},
} satisfies OpenAI.ChatCompletionTool;
}
function openAiFinishReason(finish_reason?: string): string | undefined {
if (finish_reason === "tool_calls") {
return "tool_use";
}
return finish_reason;
}
function updateConversation(conversation: ChatCompletionMessageParam[], message: ChatCompletionMessageParam[]): ChatCompletionMessageParam[] {
if (!message) {
return conversation;
}
if (!conversation) {
return message;
}
return [...conversation, ...message];
}
export function collectTools(toolCalls?: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[]): ToolUse[] | undefined {
if (!toolCalls) {
return undefined;
}
const tools: ToolUse[] = [];
for (const call of toolCalls) {
tools.push({
id: call.id,
tool_name: call.function.name,
tool_input: JSON.parse(call.function.arguments),
});
}
return tools.length > 0 ? tools : undefined;
}
function createPromptFromResponse(response: OpenAI.Chat.Completions.ChatCompletionMessage): ChatCompletionMessageParam[] {
const messages: ChatCompletionMessageParam[] = [];
if (response) {
messages.push({
role: response.role,
content: [{
type: "text",
text: response.content ?? ""
}],
tool_calls: response.tool_calls,
});
}
return messages;
}
//For strict mode false
function limitedSchemaFormat(schema: JSONSchema): JSONSchema {
const formattedSchema = { ...schema };
// Defaults not supported
delete formattedSchema.default;
if (formattedSchema?.properties) {
// Process each property recursively
for (const propName of Object.keys(formattedSchema.properties)) {
const property = formattedSchema.properties[propName];
// Recursively process properties
formattedSchema.properties[propName] = limitedSchemaFormat(property);
// Process arrays with items of type object
if (property?.type === 'array' && property.items && property.items?.type === 'object') {
formattedSchema.properties[propName] = {
...property,
items: limitedSchemaFormat(property.items),
};
}
}
}
return formattedSchema;
}
//For strict mode true
function openAISchemaFormat(schema: JSONSchema, nesting: number = 0): JSONSchema {
if (nesting > 5) {
throw new Error("OpenAI schema nesting too deep");
}
const formattedSchema = { ...schema };
// Defaults not supported
delete formattedSchema.default;
// Additional properties not supported, required to be set.
if (formattedSchema?.type === "object") {
formattedSchema.additionalProperties = false;
}
if (formattedSchema?.properties) {
// Set all properties as required
formattedSchema.required = Object.keys(formattedSchema.properties);
// Process each property recursively
for (const propName of Object.keys(formattedSchema.properties)) {
const property = formattedSchema.properties[propName];
// Recursively process properties
formattedSchema.properties[propName] = openAISchemaFormat(property, nesting + 1);
// Process arrays with items of type object
if (property?.type === 'array' && property.items && property.items?.type === 'object') {
formattedSchema.properties[propName] = {
...property,
items: openAISchemaFormat(property.items, nesting + 1),
};
}
}
}
if (formattedSchema?.type === 'object' && (!formattedSchema?.properties || Object.keys(formattedSchema?.properties ?? {}).length == 0)) {
//If no properties are defined, then additionalProperties: true was set or the object would be empty.
//OpenAI does not support this on structured output/ strict mode.
throw new Error("OpenAI does not support empty objects or objects with additionalProperties set to true");
}
return formattedSchema
}