UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

275 lines 10.2 kB
"use strict"; var __importDefault = (this && this.__importDefault) || function (mod) { return (mod && mod.__esModule) ? mod : { "default": mod }; }; Object.defineProperty(exports, "__esModule", { value: true }); exports.ReplicateDriver = void 0; const core_1 = require("@llumiverse/core"); const async_1 = require("@llumiverse/core/async"); const eventsource_1 = require("eventsource"); const replicate_1 = __importDefault(require("replicate")); let cachedTrainableModels; let cachedTrainableModelsTimestamp = 0; const supportFineTunning = new Set([ "meta/llama-2-70b-chat", "meta/llama-2-13b-chat", "meta/llama-2-7b-chat", "meta/llama-2-7b", "meta/llama-2-70b", "meta/llama-2-13b", "mistralai/mistral-7b-v0.1" ]); class ReplicateDriver extends core_1.AbstractDriver { static PROVIDER = "replicate"; provider = ReplicateDriver.PROVIDER; service; static parseModelId(modelId) { const [owner, modelPart] = modelId.split("/"); const i = modelPart.indexOf(':'); if (i === -1) { throw new Error("Invalid model id. Expected format: owner/model:version"); } return { owner, model: modelPart.slice(0, i), version: modelPart.slice(i + 1) }; } constructor(options) { super(options); this.service = new replicate_1.default({ auth: options.apiKey, }); } extractDataFromResponse(response) { const text = response.output.join(""); return { result: text, }; } async requestTextCompletionStream(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const model = ReplicateDriver.parseModelId(options.model); const predictionData = { input: { prompt: prompt, max_new_tokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, }, version: model.version, stream: true, //streaming described here https://replicate.com/blog/streaming }; const prediction = await this.service.predictions.create(predictionData); const stream = new async_1.EventStream(); const source = new eventsource_1.EventSource(prediction.urls.stream); source.addEventListener("output", (e) => { stream.push({ result: [{ type: "text", value: e.data }] }); }); source.addEventListener("error", (e) => { let error; try { error = JSON.parse(e.data); } catch (error) { error = JSON.stringify(e); } this.logger.error({ e, error }, "Error in SSE stream"); }); source.addEventListener("done", () => { try { stream.close(""); // not using e.data which is {} } finally { source.close(); } }); return stream; } async requestTextCompletion(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const model = ReplicateDriver.parseModelId(options.model); const predictionData = { input: { prompt: prompt, max_new_tokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, }, version: model.version, //TODO stream //stream: stream, //streaming described here https://replicate.com/blog/streaming }; const prediction = await this.service.predictions.create(predictionData); //TODO stream //if we're streaming, return right away for the stream handler to handle // if (stream) return prediction; //not streaming, wait for the result const res = await this.service.wait(prediction, {}); const text = res.output.join(""); return { result: [{ type: "text", value: text }], original_response: options.include_original_response ? res : undefined, }; } async startTraining(dataset, options) { if (options.name.indexOf("/") === -1) { throw new Error("Invalid target model name. Expected format: owner/model"); } const { owner, model, version } = ReplicateDriver.parseModelId(options.model); const job = await this.service.trainings.create(owner, model, version, { destination: options.name, input: { train_data: await dataset.getURL(), }, }); return jobInfo(job, options.name); } /** * This method is not returning a consistent TrainingJob like the one returned by startTraining * Instead of returning the full model name `owner/model:version` it returns only the version `version * @param jobId * @returns */ async cancelTraining(jobId) { const job = await this.service.trainings.cancel(jobId); return jobInfo(job); } /** * This method is not returning a consistent TrainingJob like the one returned by startTraining * Instead of returning the full model name `owner/model:version` it returns only the version `version * @param jobId * @returns */ async getTrainingJob(jobId) { const job = await this.service.trainings.get(jobId); return jobInfo(job); } // ========= management API ============= async validateConnection() { try { await this.service.predictions.list(); return true; } catch (error) { return false; } } async _listTrainableModels() { const promises = Array.from(supportFineTunning).map(id => { const [owner, model] = id.split('/'); return this.service.models.get(owner, model); }); const results = await Promise.all(promises); return results.filter(m => !!m.latest_version).map(m => { const fullName = m.owner + '/' + m.name; const v = m.latest_version; return { id: fullName + ':' + v.id, name: fullName + "@" + v.cog_version + ":" + v.id.slice(0, 6), provider: this.provider, owner: m.owner, description: m.description, }; }); } async listTrainableModels() { if (!cachedTrainableModels || Date.now() > cachedTrainableModelsTimestamp + 12 * 3600 * 1000) { // 12 hours cachedTrainableModels = await this._listTrainableModels(); cachedTrainableModelsTimestamp = Date.now(); } return cachedTrainableModels; } async listModels(params = { text: '' }) { if (!params.text) { return this.listTrainableModels(); } const [owner, model] = params.text.split("/"); if (!owner || !model) { throw new Error("Invalid model name. Expected format: owner/model"); } return this.listModelVersions(owner, model); } async listModelVersions(owner, model) { const [rModel, versions] = await Promise.all([ this.service.models.get(owner, model), this.service.models.versions.list(owner, model), ]); if (!rModel || !versions || versions.results?.length === 0) { throw new Error("Model not found or no versions available"); } const models = versions.results.map((v) => { const fullName = rModel.owner + '/' + rModel.name; return { id: fullName + ':' + v.id, name: fullName + "@" + v.cog_version + ":" + v.id.slice(0, 6), provider: this.provider, owner: rModel.owner, description: rModel.description, canTrain: supportFineTunning.has(fullName), }; }); //set latest version //const idx = models.findIndex(m => m.id === rModel.latest_version?.id); //models[idx].name = rModel.name + "@latest" return models; } async searchModels(params) { const res = await this.service.request("models/search", { params: { query: params.text, }, }); const rModels = (await res.json()).models; const models = rModels.map((v) => { return { id: v.name, name: v.name, provider: this.provider, owner: v.username, description: v.description, has_versions: true, }; }); return models; } async generateEmbeddings() { throw new Error("Method not implemented."); } } exports.ReplicateDriver = ReplicateDriver; function jobInfo(job, modelName) { // 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled' const jobStatus = job.status; let details; let status = core_1.TrainingJobStatus.running; if (jobStatus === 'succeeded') { status = core_1.TrainingJobStatus.succeeded; } else if (jobStatus === 'failed') { status = core_1.TrainingJobStatus.failed; const error = job.error; if (typeof error === 'string') { details = error; } else { details = JSON.stringify(error); } } else if (jobStatus === 'canceled') { status = core_1.TrainingJobStatus.cancelled; } else { status = core_1.TrainingJobStatus.running; details = job.status; } return { id: job.id, status, details, model: modelName ? modelName + ':' + job.version : job.version }; } //# sourceMappingURL=replicate.js.map