@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
275 lines • 10.2 kB
JavaScript
;
var __importDefault = (this && this.__importDefault) || function (mod) {
return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.ReplicateDriver = void 0;
const core_1 = require("@llumiverse/core");
const async_1 = require("@llumiverse/core/async");
const eventsource_1 = require("eventsource");
const replicate_1 = __importDefault(require("replicate"));
let cachedTrainableModels;
let cachedTrainableModelsTimestamp = 0;
const supportFineTunning = new Set([
"meta/llama-2-70b-chat",
"meta/llama-2-13b-chat",
"meta/llama-2-7b-chat",
"meta/llama-2-7b",
"meta/llama-2-70b",
"meta/llama-2-13b",
"mistralai/mistral-7b-v0.1"
]);
class ReplicateDriver extends core_1.AbstractDriver {
static PROVIDER = "replicate";
provider = ReplicateDriver.PROVIDER;
service;
static parseModelId(modelId) {
const [owner, modelPart] = modelId.split("/");
const i = modelPart.indexOf(':');
if (i === -1) {
throw new Error("Invalid model id. Expected format: owner/model:version");
}
return {
owner, model: modelPart.slice(0, i), version: modelPart.slice(i + 1)
};
}
constructor(options) {
super(options);
this.service = new replicate_1.default({
auth: options.apiKey,
});
}
extractDataFromResponse(response) {
const text = response.output.join("");
return {
result: text,
};
}
async requestTextCompletionStream(prompt, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const model = ReplicateDriver.parseModelId(options.model);
const predictionData = {
input: {
prompt: prompt,
max_new_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
},
version: model.version,
stream: true, //streaming described here https://replicate.com/blog/streaming
};
const prediction = await this.service.predictions.create(predictionData);
const stream = new async_1.EventStream();
const source = new eventsource_1.EventSource(prediction.urls.stream);
source.addEventListener("output", (e) => {
stream.push({ result: [{ type: "text", value: e.data }] });
});
source.addEventListener("error", (e) => {
let error;
try {
error = JSON.parse(e.data);
}
catch (error) {
error = JSON.stringify(e);
}
this.logger.error({ e, error }, "Error in SSE stream");
});
source.addEventListener("done", () => {
try {
stream.close(""); // not using e.data which is {}
}
finally {
source.close();
}
});
return stream;
}
async requestTextCompletion(prompt, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const model = ReplicateDriver.parseModelId(options.model);
const predictionData = {
input: {
prompt: prompt,
max_new_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
},
version: model.version,
//TODO stream
//stream: stream, //streaming described here https://replicate.com/blog/streaming
};
const prediction = await this.service.predictions.create(predictionData);
//TODO stream
//if we're streaming, return right away for the stream handler to handle
// if (stream) return prediction;
//not streaming, wait for the result
const res = await this.service.wait(prediction, {});
const text = res.output.join("");
return {
result: [{ type: "text", value: text }],
original_response: options.include_original_response ? res : undefined,
};
}
async startTraining(dataset, options) {
if (options.name.indexOf("/") === -1) {
throw new Error("Invalid target model name. Expected format: owner/model");
}
const { owner, model, version } = ReplicateDriver.parseModelId(options.model);
const job = await this.service.trainings.create(owner, model, version, {
destination: options.name,
input: {
train_data: await dataset.getURL(),
},
});
return jobInfo(job, options.name);
}
/**
* This method is not returning a consistent TrainingJob like the one returned by startTraining
* Instead of returning the full model name `owner/model:version` it returns only the version `version
* @param jobId
* @returns
*/
async cancelTraining(jobId) {
const job = await this.service.trainings.cancel(jobId);
return jobInfo(job);
}
/**
* This method is not returning a consistent TrainingJob like the one returned by startTraining
* Instead of returning the full model name `owner/model:version` it returns only the version `version
* @param jobId
* @returns
*/
async getTrainingJob(jobId) {
const job = await this.service.trainings.get(jobId);
return jobInfo(job);
}
// ========= management API =============
async validateConnection() {
try {
await this.service.predictions.list();
return true;
}
catch (error) {
return false;
}
}
async _listTrainableModels() {
const promises = Array.from(supportFineTunning).map(id => {
const [owner, model] = id.split('/');
return this.service.models.get(owner, model);
});
const results = await Promise.all(promises);
return results.filter(m => !!m.latest_version).map(m => {
const fullName = m.owner + '/' + m.name;
const v = m.latest_version;
return {
id: fullName + ':' + v.id,
name: fullName + "@" + v.cog_version + ":" + v.id.slice(0, 6),
provider: this.provider,
owner: m.owner,
description: m.description,
};
});
}
async listTrainableModels() {
if (!cachedTrainableModels || Date.now() > cachedTrainableModelsTimestamp + 12 * 3600 * 1000) { // 12 hours
cachedTrainableModels = await this._listTrainableModels();
cachedTrainableModelsTimestamp = Date.now();
}
return cachedTrainableModels;
}
async listModels(params = { text: '' }) {
if (!params.text) {
return this.listTrainableModels();
}
const [owner, model] = params.text.split("/");
if (!owner || !model) {
throw new Error("Invalid model name. Expected format: owner/model");
}
return this.listModelVersions(owner, model);
}
async listModelVersions(owner, model) {
const [rModel, versions] = await Promise.all([
this.service.models.get(owner, model),
this.service.models.versions.list(owner, model),
]);
if (!rModel || !versions || versions.results?.length === 0) {
throw new Error("Model not found or no versions available");
}
const models = versions.results.map((v) => {
const fullName = rModel.owner + '/' + rModel.name;
return {
id: fullName + ':' + v.id,
name: fullName + "@" + v.cog_version + ":" + v.id.slice(0, 6),
provider: this.provider,
owner: rModel.owner,
description: rModel.description,
canTrain: supportFineTunning.has(fullName),
};
});
//set latest version
//const idx = models.findIndex(m => m.id === rModel.latest_version?.id);
//models[idx].name = rModel.name + "@latest"
return models;
}
async searchModels(params) {
const res = await this.service.request("models/search", {
params: {
query: params.text,
},
});
const rModels = (await res.json()).models;
const models = rModels.map((v) => {
return {
id: v.name,
name: v.name,
provider: this.provider,
owner: v.username,
description: v.description,
has_versions: true,
};
});
return models;
}
async generateEmbeddings() {
throw new Error("Method not implemented.");
}
}
exports.ReplicateDriver = ReplicateDriver;
function jobInfo(job, modelName) {
// 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled'
const jobStatus = job.status;
let details;
let status = core_1.TrainingJobStatus.running;
if (jobStatus === 'succeeded') {
status = core_1.TrainingJobStatus.succeeded;
}
else if (jobStatus === 'failed') {
status = core_1.TrainingJobStatus.failed;
const error = job.error;
if (typeof error === 'string') {
details = error;
}
else {
details = JSON.stringify(error);
}
}
else if (jobStatus === 'canceled') {
status = core_1.TrainingJobStatus.cancelled;
}
else {
status = core_1.TrainingJobStatus.running;
details = job.status;
}
return {
id: job.id,
status,
details,
model: modelName ? modelName + ':' + job.version : job.version
};
}
//# sourceMappingURL=replicate.js.map