@llumiverse/core
Version:
Provide an universal API to LLMs. Support for existing LLMs can be added by writing a driver.
407 lines (349 loc) • 15.8 kB
text/typescript
/**
* Classes to handle the execution of an interaction in an execution environment.
* Base abstract class is then implemented by each environment
* (eg: OpenAI, HuggingFace, etc.)
*/
import {
AIModel,
Completion,
CompletionChunkObject,
CompletionStream,
DataSource,
DriverOptions,
EmbeddingsOptions,
EmbeddingsResult,
ExecutionOptions,
ExecutionResponse,
LlumiverseErrorContext,
Logger,
ModelSearchPayload,
PromptOptions,
PromptSegment,
Providers,
TrainingJob,
TrainingOptions,
TrainingPromptOptions,
LlumiverseError
} from "@llumiverse/common";
import { DefaultCompletionStream, FallbackCompletionStream } from "./CompletionStream.js";
import { formatTextPrompt } from "./formatters/index.js";
import { validateResult } from "./validation.js";
// Helper to create logger methods that support both message-only and object-first signatures
function createConsoleLoggerMethod(consoleMethod: (...args: unknown[]) => void): Logger['info'] {
return ((objOrMsg: any, msgOrNever?: any, ...args: (string | number | boolean)[]) => {
if (typeof objOrMsg === 'string') {
// Message-only: logger.info("message", ...args)
consoleMethod(objOrMsg, msgOrNever, ...args);
} else if (msgOrNever !== undefined) {
// Object-first: logger.info({ obj }, "message", ...args)
consoleMethod(msgOrNever, objOrMsg, ...args);
} else {
// Object-only: logger.info({ obj })
consoleMethod(objOrMsg, ...args);
}
}) as Logger['info'];
}
const ConsoleLogger: Logger = {
debug: createConsoleLoggerMethod(console.debug.bind(console)),
info: createConsoleLoggerMethod(console.info.bind(console)),
warn: createConsoleLoggerMethod(console.warn.bind(console)),
error: createConsoleLoggerMethod(console.error.bind(console)),
}
const noop = () => void 0;
const NoopLogger: Logger = {
debug: noop as Logger['debug'],
info: noop as Logger['info'],
warn: noop as Logger['warn'],
error: noop as Logger['error'],
}
export function createLogger(logger: Logger | "console" | undefined) {
if (logger === "console") {
return ConsoleLogger;
} else if (logger) {
return logger;
} else {
return NoopLogger;
}
}
export interface Driver<PromptT = unknown> {
/**
*
* @param segments
* @param completion
* @param model the model to train
*/
createTrainingPrompt(options: TrainingPromptOptions): Promise<string>;
createPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<PromptT>;
execute(segments: PromptSegment[], options: ExecutionOptions): Promise<ExecutionResponse<PromptT>>;
// by default no stream is supported. we block and we return all at once
//stream(segments: PromptSegment[], options: ExecutionOptions): Promise<StreamingExecutionResponse<PromptT>>;
stream(segments: PromptSegment[], options: ExecutionOptions): Promise<CompletionStream<PromptT>>;
startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob>;
cancelTraining(jobId: string): Promise<TrainingJob>;
getTrainingJob(jobId: string): Promise<TrainingJob>;
//list models available for this environment
listModels(params?: ModelSearchPayload): Promise<AIModel[]>;
//list models that can be trained
listTrainableModels(): Promise<AIModel[]>;
//check that it is possible to connect to the environment
validateConnection(): Promise<boolean>;
//generate embeddings for a given text or image
generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult>;
/**
* Optional cleanup method called when the driver is evicted from the cache.
* Override this in driver implementations that need to release resources.
*/
destroy?(): void;
}
/**
* To be implemented by each driver
*/
export abstract class AbstractDriver<OptionsT extends DriverOptions = DriverOptions, PromptT = unknown> implements Driver<PromptT> {
options: OptionsT;
logger: Logger;
abstract provider: Providers | string; // the provider name
constructor(opts: OptionsT) {
this.options = opts;
this.logger = createLogger(opts.logger);
}
async createTrainingPrompt(options: TrainingPromptOptions): Promise<string> {
const prompt = await this.createPrompt(options.segments, { result_schema: options.schema, model: options.model })
return JSON.stringify({
prompt,
completion: typeof options.completion === 'string' ? options.completion : JSON.stringify(options.completion)
});
}
startTraining(_dataset: DataSource, _options: TrainingOptions): Promise<TrainingJob> {
throw new Error("Method not implemented.");
}
cancelTraining(_jobId: string): Promise<TrainingJob> {
throw new Error("Method not implemented.");
}
getTrainingJob(_jobId: string): Promise<TrainingJob> {
throw new Error("Method not implemented.");
}
validateResult(result: Completion, options: ExecutionOptions) {
if (!result.tool_use && !result.error && options.result_schema) {
try {
result.result = validateResult(result.result, options.result_schema);
} catch (error: any) {
const errorMessage = `[${this.provider}] [${options.model}] ${error.code ? '[' + error.code + '] ' : ''}Result validation error: ${error.message}`;
this.logger.error({ err: error, data: result.result }, errorMessage);
result.error = {
code: error.code || error.name,
message: error.message,
data: result.result,
}
}
}
}
async execute(segments: PromptSegment[], options: ExecutionOptions): Promise<ExecutionResponse<PromptT>> {
const prompt = await this.createPrompt(segments, options);
return this._execute(prompt, options).catch((error: any) => {
// Don't wrap if already a LlumiverseError
if (LlumiverseError.isLlumiverseError(error)) {
throw error;
}
throw this.formatLlumiverseError(error, {
provider: this.provider,
model: options.model,
operation: 'execute',
});
});
}
async _execute(prompt: PromptT, options: ExecutionOptions): Promise<ExecutionResponse<PromptT>> {
try {
const start = Date.now();
let result;
if (this.isImageModel(options.model)) {
this.logger.debug(
`[${this.provider}] Executing prompt on ${options.model}, image pathway.`);
result = await this.requestImageGeneration(prompt, options);
} else {
this.logger.debug(
`[${this.provider}] Executing prompt on ${options.model}, text pathway.`);
result = await this.requestTextCompletion(prompt, options);
this.validateResult(result, options);
}
const execution_time = Date.now() - start;
return { ...result, prompt, execution_time };
} catch (error) {
// Don't wrap if already a LlumiverseError
if (LlumiverseError.isLlumiverseError(error)) {
throw error;
}
// Log the original error for debugging
this.logger.error({ err: error, data: { provider: this.provider, model: options.model, operation: 'execute', prompt } }, `Error during execution in provider ${this.provider}:`);
throw this.formatLlumiverseError(error, {
provider: this.provider,
model: options.model,
operation: 'execute',
});
}
}
protected isImageModel(_model: string): boolean {
return false;
}
// by default no stream is supported. we block and we return all at once
async stream(segments: PromptSegment[], options: ExecutionOptions): Promise<CompletionStream<PromptT>> {
this.logger.debug(options, `Executing prompt with provider ${this.provider} with options: ${JSON.stringify(options)}`);
const prompt = await this.createPrompt(segments, options);
const canStream = await this.canStream(options);
if (canStream) {
return new DefaultCompletionStream(this, prompt, options);
} else if (this.isImageModel(options.model)) {
return new FallbackCompletionStream(this, prompt, options);
} else {
return new FallbackCompletionStream(this, prompt, options);
}
}
/**
* Override this method to provide a custom prompt formatter
* @param segments
* @param options
* @returns
*/
protected async formatPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<PromptT> {
return formatTextPrompt(segments, opts.result_schema) as PromptT;
}
public async createPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<PromptT> {
return await (opts.format ? opts.format(segments, opts.result_schema) : this.formatPrompt(segments, opts));
}
/**
* Must be overridden if the implementation cannot stream.
* Some implementation may be able to stream for certain models but not for others.
* You must overwrite and return false if the current model doesn't support streaming.
* The default implementation returns true, so it is assumed that the streaming can be done.
* If this method returns false then the streaming execution will fallback on a blocking execution streaming the entire response as a single event.
* @param options the execution options containing the target model name.
* @returns true if the execution can be streamed false otherwise.
*/
protected canStream(_options: ExecutionOptions) {
return Promise.resolve(true);
}
/**
* Get a list of models that can be trained.
* The default is to return an empty array
* @returns
*/
async listTrainableModels(): Promise<AIModel[]> {
return [];
}
/**
* Build the conversation context after streaming completion.
* Override this in driver implementations that support multi-turn conversations.
*
* @param prompt - The prompt that was sent (includes prior conversation context)
* @param result - The completion results from the streamed response
* @param toolUse - The tool calls from the streamed response (if any)
* @param options - The execution options
* @returns The updated conversation context, or undefined if not supported
*/
buildStreamingConversation(
_prompt: PromptT,
_result: unknown[],
_toolUse: unknown[] | undefined,
_options: ExecutionOptions
): unknown | undefined {
// Default implementation returns undefined - drivers can override
return undefined;
}
/**
* Format an error into LlumiverseError. Override in driver implementations
* to provide provider-specific error parsing.
*
* The default implementation uses common patterns:
* - Status 429, 408: retryable (rate limit, timeout)
* - Status 529: retryable (overloaded)
* - Status 5xx: retryable (server errors)
* - Status 4xx (except above): not retryable (client errors)
* - Error messages containing "rate limit", "timeout", etc.: retryable
*
* @param error - The error to format
* @param context - Context about where the error occurred
* @returns A standardized LlumiverseError
*/
public formatLlumiverseError(
error: unknown,
context: LlumiverseErrorContext
): LlumiverseError {
// Extract status code from common locations (only if numeric)
let code: number | undefined;
const rawCode = (error as any)?.status
|| (error as any)?.statusCode
|| (error as any)?.code;
if (typeof rawCode === 'number') {
code = rawCode;
}
// Extract error name if available
const errorName = (error as any)?.name;
// Extract message
const message = error instanceof Error
? error.message
: String(error);
// Determine retryability
const retryable = this.isRetryableError(code, message);
return new LlumiverseError(
`[${this.provider}] ${message}`,
retryable,
context,
error,
code,
errorName
);
}
/**
* Determine if an error is retryable based on status code and message.
* Can be overridden by drivers for provider-specific logic.
*
* @param statusCode - The HTTP status code (if available)
* @param message - The error message
* @returns True if retryable, false if not retryable, undefined if unknown
*/
protected isRetryableError(statusCode: number | undefined, message: string): boolean | undefined {
// Numeric status codes
if (statusCode !== undefined) {
if (statusCode === 429 || statusCode === 408) return true; // Rate limit, timeout
if (statusCode === 529) return true; // Overloaded
if (statusCode >= 500 && statusCode < 600) return true; // Server errors
return false; // 4xx client errors not retryable
}
// Message-based detection for non-HTTP errors
const lowerMessage = message.toLowerCase();
// Rate limit variations
if (lowerMessage.includes('rate') && lowerMessage.includes('limit')) return true;
// Timeout variations (timeout, timed out, time out)
if (lowerMessage.includes('timeout')) return true;
if (lowerMessage.includes('timed') && lowerMessage.includes('out')) return true;
if (lowerMessage.includes('time') && lowerMessage.includes('out')) return true;
// Resource exhausted variations
if (lowerMessage.includes('resource') && lowerMessage.includes('exhaust')) return true;
// Other retryable patterns
if (lowerMessage.includes('retry')) return true;
if (lowerMessage.includes('overload')) return true;
if (lowerMessage.includes('throttl')) return true;
if (lowerMessage.includes('429')) return true;
if (lowerMessage.includes('529')) return true;
// Unknown errors - let consumer decide retry strategy
return undefined;
}
abstract requestTextCompletion(prompt: PromptT, options: ExecutionOptions): Promise<Completion>;
abstract requestTextCompletionStream(prompt: PromptT, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>>;
async requestImageGeneration(_prompt: PromptT, _options: ExecutionOptions): Promise<Completion> {
throw new Error("Image generation not implemented.");
//Cannot be made abstract, as abstract methods are required in the derived class
}
//list models available for this environment
abstract listModels(params?: ModelSearchPayload): Promise<AIModel[]>;
//check that it is possible to connect to the environment
abstract validateConnection(): Promise<boolean>;
//generate embeddings for a given text
abstract generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult>;
/**
* Cleanup method called when the driver is evicted from the cache.
* Override this in driver implementations that need to release resources.
*/
destroy(): void {
// No-op by default
}
}