@llumiverse/core
Version:
Provide an universal API to LLMs. Support for existing LLMs can be added by writing a driver.
244 lines (202 loc) • 8.85 kB
text/typescript
/**
* Classes to handle the execution of an interaction in an execution environment.
* Base abstract class is then implemented by each environment
* (eg: OpenAI, HuggingFace, etc.)
*/
import { DefaultCompletionStream, FallbackCompletionStream } from "./CompletionStream.js";
import { formatTextPrompt } from "./formatters/index.js";
import {
AIModel,
Completion,
CompletionChunkObject,
CompletionStream,
DataSource,
DriverOptions,
EmbeddingsOptions,
EmbeddingsResult,
ExecutionOptions,
ExecutionResponse,
Logger,
Modalities,
ModelSearchPayload,
PromptOptions,
PromptSegment,
TrainingJob,
TrainingOptions,
TrainingPromptOptions
} from "@llumiverse/common";
import { validateResult } from "./validation.js";
const ConsoleLogger: Logger = {
debug: console.debug,
info: console.info,
warn: console.warn,
error: console.error,
}
const noop = () => void 0;
const NoopLogger: Logger = {
debug: noop,
info: noop,
warn: noop,
error: noop,
}
export function createLogger(logger: Logger | "console" | undefined) {
if (logger === "console") {
return ConsoleLogger;
} else if (logger) {
return logger;
} else {
return NoopLogger;
}
}
export interface Driver<PromptT = unknown> {
/**
*
* @param segments
* @param completion
* @param model the model to train
*/
createTrainingPrompt(options: TrainingPromptOptions): Promise<string>;
createPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<PromptT>;
execute(segments: PromptSegment[], options: ExecutionOptions): Promise<ExecutionResponse<PromptT>>;
// by default no stream is supported. we block and we return all at once
//stream(segments: PromptSegment[], options: ExecutionOptions): Promise<StreamingExecutionResponse<PromptT>>;
stream(segments: PromptSegment[], options: ExecutionOptions): Promise<CompletionStream<PromptT>>;
startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob>;
cancelTraining(jobId: string): Promise<TrainingJob>;
getTrainingJob(jobId: string): Promise<TrainingJob>;
//list models available for this environment
listModels(params?: ModelSearchPayload): Promise<AIModel[]>;
//list models that can be trained
listTrainableModels(): Promise<AIModel[]>;
//check that it is possible to connect to the environment
validateConnection(): Promise<boolean>;
//generate embeddings for a given text or image
generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult>;
}
/**
* To be implemented by each driver
*/
export abstract class AbstractDriver<OptionsT extends DriverOptions = DriverOptions, PromptT = unknown> implements Driver<PromptT> {
options: OptionsT;
logger: Logger;
abstract provider: string; // the provider name
constructor(opts: OptionsT) {
this.options = opts;
this.logger = createLogger(opts.logger);
}
async createTrainingPrompt(options: TrainingPromptOptions): Promise<string> {
const prompt = await this.createPrompt(options.segments, { result_schema: options.schema, model: options.model })
return JSON.stringify({
prompt,
completion: typeof options.completion === 'string' ? options.completion : JSON.stringify(options.completion)
});
}
startTraining(_dataset: DataSource, _options: TrainingOptions): Promise<TrainingJob> {
throw new Error("Method not implemented.");
}
cancelTraining(_jobId: string): Promise<TrainingJob> {
throw new Error("Method not implemented.");
}
getTrainingJob(_jobId: string): Promise<TrainingJob> {
throw new Error("Method not implemented.");
}
validateResult(result: Completion, options: ExecutionOptions) {
if (!result.tool_use && !result.error && options.result_schema) {
try {
result.result = validateResult(result.result, options.result_schema);
} catch (error: any) {
this.logger?.error({ err: error, data: result.result }, `[${this.provider}] [${options.model}] ${error.code ? '[' + error.code + '] ' : ''}Result validation error: ${error.message}`);
result.error = {
code: error.code || error.name,
message: error.message,
data: result.result,
}
}
}
}
async execute(segments: PromptSegment[], options: ExecutionOptions): Promise<ExecutionResponse<PromptT>> {
const prompt = await this.createPrompt(segments, options);
return this._execute(prompt, options).catch((error: any) => {
(error as any).prompt = prompt;
throw error;
});
}
async _execute(prompt: PromptT, options: ExecutionOptions): Promise<ExecutionResponse<PromptT>> {
this.logger.debug(
`[${this.provider}] Executing prompt on ${options.model}`);
try {
const start = Date.now();
let result;
switch (options.output_modality) {
case Modalities.text:
result = await this.requestTextCompletion(prompt, options);
this.validateResult(result, options);
break;
case Modalities.image:
result = await this.requestImageGeneration(prompt, options);
break;
default:
throw new Error(`Unsupported modality: ${options['output_modality'] ?? "No modality specified"}`);
}
const execution_time = Date.now() - start;
return { ...result, prompt, execution_time };
} catch (error) {
(error as any).prompt = prompt;
throw error;
}
}
// by default no stream is supported. we block and we return all at once
async stream(segments: PromptSegment[], options: ExecutionOptions): Promise<CompletionStream<PromptT>> {
const prompt = await this.createPrompt(segments, options);
const canStream = await this.canStream(options);
if (options.output_modality === Modalities.text && canStream) {
return new DefaultCompletionStream(this, prompt, options);
} else {
return new FallbackCompletionStream(this, prompt, options);
}
}
/**
* Override this method to provide a custom prompt formatter
* @param segments
* @param options
* @returns
*/
protected async formatPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<PromptT> {
return formatTextPrompt(segments, opts.result_schema) as PromptT;
}
public async createPrompt(segments: PromptSegment[], opts: PromptOptions): Promise<PromptT> {
return await (opts.format ? opts.format(segments, opts.result_schema) : this.formatPrompt(segments, opts));
}
/**
* Must be overridden if the implementation cannot stream.
* Some implementation may be able to stream for certain models but not for others.
* You must overwrite and return false if the current model doesn't support streaming.
* The default implementation returns true, so it is assumed that the streaming can be done.
* If this method returns false then the streaming execution will fallback on a blocking execution streaming the entire response as a single event.
* @param options the execution options containing the target model name.
* @returns true if the execution can be streamed false otherwise.
*/
protected canStream(_options: ExecutionOptions) {
return Promise.resolve(true);
}
/**
* Get a list of models that can be trained.
* The default is to return an empty array
* @returns
*/
async listTrainableModels(): Promise<AIModel[]> {
return [];
}
abstract requestTextCompletion(prompt: PromptT, options: ExecutionOptions): Promise<Completion>;
abstract requestTextCompletionStream(prompt: PromptT, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>>;
async requestImageGeneration(_prompt: PromptT, _options: ExecutionOptions): Promise<Completion> {
throw new Error("Image generation not implemented.");
//Cannot be made abstract, as abstract methods are required in the derived class
}
//list models available for this environment
abstract listModels(params?: ModelSearchPayload): Promise<AIModel[]>;
//check that it is possible to connect to the environment
abstract validateConnection(): Promise<boolean>;
//generate embeddings for a given text
abstract generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult>;
}