@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
321 lines (281 loc) • 11 kB
text/typescript
import {
AIModel,
AbstractDriver,
Completion,
CompletionChunkObject,
DataSource,
DriverOptions,
EmbeddingsResult,
ExecutionOptions,
ModelSearchPayload,
TextFallbackOptions,
TrainingJob,
TrainingJobStatus,
TrainingOptions,
} from "@llumiverse/core";
import { EventStream } from "@llumiverse/core/async";
import { EventSource } from "eventsource";
import Replicate, { Prediction } from "replicate";
let cachedTrainableModels: AIModel[] | undefined;
let cachedTrainableModelsTimestamp: number = 0;
const supportFineTunning = new Set([
"meta/llama-2-70b-chat",
"meta/llama-2-13b-chat",
"meta/llama-2-7b-chat",
"meta/llama-2-7b",
"meta/llama-2-70b",
"meta/llama-2-13b",
"mistralai/mistral-7b-v0.1"
]);
export interface ReplicateDriverOptions extends DriverOptions {
apiKey: string;
}
export class ReplicateDriver extends AbstractDriver<DriverOptions, string> {
static PROVIDER = "replicate";
provider = ReplicateDriver.PROVIDER;
service: Replicate;
static parseModelId(modelId: string) {
const [owner, modelPart] = modelId.split("/");
const i = modelPart.indexOf(':');
if (i === -1) {
throw new Error("Invalid model id. Expected format: owner/model:version");
}
return {
owner, model: modelPart.slice(0, i), version: modelPart.slice(i + 1)
}
}
constructor(options: ReplicateDriverOptions) {
super(options);
this.service = new Replicate({
auth: options.apiKey,
});
}
extractDataFromResponse(response: Prediction): Completion {
const text = response.output.join("");
return {
result: text,
};
}
async requestTextCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn("Invalid model options", { options: options.model_options });
}
options.model_options = options.model_options as TextFallbackOptions;
const model = ReplicateDriver.parseModelId(options.model);
const predictionData = {
input: {
prompt: prompt,
max_new_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
},
version: model.version,
stream: true, //streaming described here https://replicate.com/blog/streaming
};
const prediction =
await this.service.predictions.create(predictionData);
const stream = new EventStream<CompletionChunkObject>();
const source = new EventSource(prediction.urls.stream!);
source.addEventListener("output", (e: any) => {
stream.push({result: [{ type: "text", value: e.data }] });
});
source.addEventListener("error", (e: any) => {
let error: any;
try {
error = JSON.parse(e.data);
} catch (error) {
error = JSON.stringify(e);
}
this.logger?.error("Error in SSE stream", { e, error });
});
source.addEventListener("done", () => {
try {
stream.close(""); // not using e.data which is {}
} finally {
source.close();
}
});
return stream;
}
async requestTextCompletion(prompt: string, options: ExecutionOptions) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn("Invalid model options", { options: options.model_options });
}
options.model_options = options.model_options as TextFallbackOptions;
const model = ReplicateDriver.parseModelId(options.model);
const predictionData = {
input: {
prompt: prompt,
max_new_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
},
version: model.version,
//TODO stream
//stream: stream, //streaming described here https://replicate.com/blog/streaming
};
const prediction =
await this.service.predictions.create(predictionData);
//TODO stream
//if we're streaming, return right away for the stream handler to handle
// if (stream) return prediction;
//not streaming, wait for the result
const res = await this.service.wait(prediction, {});
const text: string = res.output.join("");
return {
result: [{ type: "text" as const, value: text }],
original_response: options.include_original_response ? res : undefined,
};
}
async startTraining(dataset: DataSource, options: TrainingOptions): Promise<TrainingJob> {
if (options.name.indexOf("/") === -1) {
throw new Error("Invalid target model name. Expected format: owner/model");
}
const { owner, model, version } = ReplicateDriver.parseModelId(options.model);
const job = await this.service.trainings.create(owner, model, version, {
destination: options.name as any,
input: {
train_data: await dataset.getURL(),
},
})
return jobInfo(job, options.name);
}
/**
* This method is not returning a consistent TrainingJob like the one returned by startTraining
* Instead of returning the full model name `owner/model:version` it returns only the version `version
* @param jobId
* @returns
*/
async cancelTraining(jobId: string): Promise<TrainingJob> {
const job = await this.service.trainings.cancel(jobId);
return jobInfo(job);
}
/**
* This method is not returning a consistent TrainingJob like the one returned by startTraining
* Instead of returning the full model name `owner/model:version` it returns only the version `version
* @param jobId
* @returns
*/
async getTrainingJob(jobId: string): Promise<TrainingJob> {
const job = await this.service.trainings.get(jobId);
return jobInfo(job);
}
// ========= management API =============
async validateConnection(): Promise<boolean> {
try {
await this.service.predictions.list();
return true;
} catch (error) {
return false;
}
}
async _listTrainableModels(): Promise<AIModel[]> {
const promises = Array.from(supportFineTunning).map(id => {
const [owner, model] = id.split('/');
return this.service.models.get(owner, model)
});
const results = await Promise.all(promises);
return results.filter(m => !!m.latest_version).map(m => {
const fullName = m.owner + '/' + m.name;
const v = m.latest_version!;
return {
id: fullName + ':' + v.id,
name:
fullName + "@" + v.cog_version + ":" + v.id.slice(0, 6),
provider: this.provider,
owner: m.owner,
description: m.description,
} as AIModel;
});
}
async listTrainableModels(): Promise<AIModel[]> {
if (!cachedTrainableModels || Date.now() > cachedTrainableModelsTimestamp + 12 * 3600 * 1000) { // 12 hours
cachedTrainableModels = await this._listTrainableModels();
cachedTrainableModelsTimestamp = Date.now();
}
return cachedTrainableModels;
}
async listModels(params: ModelSearchPayload = { text: '' }): Promise<AIModel[]> {
if (!params.text) {
return this.listTrainableModels();
}
const [owner, model] = params.text.split("/");
if (!owner || !model) {
throw new Error("Invalid model name. Expected format: owner/model");
}
return this.listModelVersions(owner, model);
}
async listModelVersions(owner: string, model: string): Promise<AIModel[]> {
const [rModel, versions] = await Promise.all([
this.service.models.get(owner, model),
this.service.models.versions.list(owner, model),
]);
if (!rModel || !versions || versions.length === 0) {
throw new Error("Model not found or no versions available");
}
const models: AIModel[] = (versions as any).results.map((v: any) => {
const fullName = rModel.owner + '/' + rModel.name;
return {
id: fullName + ':' + v.id,
name:
fullName + "@" + v.cog_version + ":" + v.id.slice(0, 6),
provider: this.provider,
owner: rModel.owner,
description: rModel.description,
canTrain: supportFineTunning.has(fullName),
} as AIModel;
});
//set latest version
//const idx = models.findIndex(m => m.id === rModel.latest_version?.id);
//models[idx].name = rModel.name + "@latest"
return models;
}
async searchModels(params: ModelSearchPayload): Promise<AIModel[]> {
const res = await this.service.request("models/search", {
params: {
query: params.text,
},
});
const rModels = ((await res.json()) as any).models;
const models: AIModel[] = rModels.map((v: any) => {
return {
id: v.name,
name: v.name,
provider: this.provider,
owner: v.username,
description: v.description,
has_versions: true,
};
});
return models;
}
async generateEmbeddings(): Promise<EmbeddingsResult> {
throw new Error("Method not implemented.");
}
}
function jobInfo(job: Prediction, modelName?: string): TrainingJob {
// 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled'
const jobStatus = job.status;
let details: string | undefined;
let status = TrainingJobStatus.running;
if (jobStatus === 'succeeded') {
status = TrainingJobStatus.succeeded;
} else if (jobStatus === 'failed') {
status = TrainingJobStatus.failed;
const error = job.error as any;
if (typeof error === 'string') {
details = error;
} else {
details = JSON.stringify(error);
}
} else if (jobStatus === 'canceled') {
status = TrainingJobStatus.cancelled;
} else {
status = TrainingJobStatus.running;
details = job.status;
}
return {
id: job.id,
status,
details,
model: modelName ? modelName + ':' + job.version : job.version
} as TrainingJob;
}