UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

201 lines 7.13 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.HuggingFaceIEDriver = void 0; const inference_1 = require("@huggingface/inference"); const core_1 = require("@llumiverse/core"); const async_1 = require("@llumiverse/core/async"); const api_fetch_client_1 = require("@vertesia/api-fetch-client"); class HuggingFaceIEDriver extends core_1.AbstractDriver { static PROVIDER = "huggingface_ie"; provider = HuggingFaceIEDriver.PROVIDER; service; _executor; constructor(options) { super(options); if (!options.endpoint_url) { throw new Error(`Endpoint URL is required for ${this.provider}`); } this.service = new api_fetch_client_1.FetchClient(this.options.endpoint_url); this.service.headers["Authorization"] = `Bearer ${this.options.apiKey}`; } async getModelURLEndpoint(modelId) { const res = (await this.service.get(`/${modelId}`)); return { url: res.status.url, status: getStatus(res), }; } async getExecutor(model) { if (!this._executor) { const endpoint = await this.getModelURLEndpoint(model); if (!endpoint.url) throw new Error(`Endpoint URL not found for model ${model}`); if (endpoint.status !== core_1.AIModelStatus.Available) throw new Error(`Endpoint ${model} is not running - current status: ${endpoint.status}`); // Use the new InferenceClient and bind it to the endpoint URL this._executor = new inference_1.InferenceClient(this.options.apiKey).endpoint(endpoint.url); } return this._executor; } async requestTextCompletionStream(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const executor = await this.getExecutor(options.model); const req = executor.textGenerationStream({ inputs: prompt, parameters: { temperature: options.model_options?.temperature, max_new_tokens: options.model_options?.max_tokens, }, }); return (0, async_1.transformAsyncIterator)(req, (val) => { //special like <s> are not part of the result if (val.token.special) return { result: [] }; let finish_reason = val.details?.finish_reason; if (finish_reason === "eos_token") { finish_reason = "stop"; } return { result: val.token.text ? [{ type: "text", value: val.token.text }] : [], finish_reason: finish_reason, token_usage: { result: val.details?.generated_tokens ?? 0, } }; }); } async requestTextCompletion(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const executor = await this.getExecutor(options.model); const res = await executor.textGeneration({ inputs: prompt, parameters: { temperature: options.model_options?.temperature, max_new_tokens: options.model_options?.max_tokens, }, }); let finish_reason = res.details?.finish_reason; if (finish_reason === "eos_token") { finish_reason = "stop"; } return { result: [{ type: "text", value: res.generated_text }], finish_reason: finish_reason, token_usage: { result: res.details?.generated_tokens }, original_response: options.include_original_response ? res : undefined, }; } // ============== management API ============== async listModels() { const res = await this.service.get("/"); const hfModels = res.items; if (!hfModels || !hfModels.length) return []; const models = hfModels.map((model) => ({ id: model.name, name: `${model.name} [${model.model.repository}:${model.model.task}]`, provider: this.provider, tags: [model.model.task], status: getStatus(model), })); return models; } async validateConnection() { try { await this.service.get("/models"); return true; } catch (error) { return false; } } async generateEmbeddings() { throw new Error("Method not implemented."); } } exports.HuggingFaceIEDriver = HuggingFaceIEDriver; //get status from HF status function getStatus(hfModel) { //[ pending, initializing, updating, updateFailed, running, paused, failed, scaledToZero ] switch (hfModel.status.state) { case "running": return core_1.AIModelStatus.Available; case "initializing": return core_1.AIModelStatus.Pending; case "updating": return core_1.AIModelStatus.Pending; case "updateFailed": return core_1.AIModelStatus.Unavailable; case "paused": return core_1.AIModelStatus.Stopped; case "failed": return core_1.AIModelStatus.Unavailable; case "scaledToZero": return core_1.AIModelStatus.Available; default: return core_1.AIModelStatus.Unknown; } } /* Example of model returned by the API { "items": [ { "accountId": "string", "compute": { "accelerator": "cpu", "instanceSize": "large", "instanceType": "c6i", "scaling": { "maxReplica": 8, "minReplica": 2 } }, "model": { "framework": "custom", "image": { "huggingface": {} }, "repository": "gpt2", "revision": "6c0e6080953db56375760c0471a8c5f2929baf11", "task": "text-classification" }, "name": "my-endpoint", "provider": { "region": "us-east-1", "vendor": "aws" }, "status": { "createdAt": "2023-10-19T05:04:17.305Z", "createdBy": { "id": "string", "name": "string" }, "message": "Endpoint is ready", "private": { "serviceName": "string" }, "readyReplica": 2, "state": "pending", "targetReplica": 4, "updatedAt": "2023-10-19T05:04:17.305Z", "updatedBy": { "id": "string", "name": "string" }, "url": "https://endpoint-id.region.vendor.endpoints.huggingface.cloud" }, "type": "public" } ] } */ //# sourceMappingURL=huggingface_ie.js.map