UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

197 lines 6.89 kB
import { InferenceClient, } from "@huggingface/inference"; import { AIModelStatus, AbstractDriver, } from "@llumiverse/core"; import { transformAsyncIterator } from "@llumiverse/core/async"; import { FetchClient } from "@vertesia/api-fetch-client"; export class HuggingFaceIEDriver extends AbstractDriver { static PROVIDER = "huggingface_ie"; provider = HuggingFaceIEDriver.PROVIDER; service; _executor; constructor(options) { super(options); if (!options.endpoint_url) { throw new Error(`Endpoint URL is required for ${this.provider}`); } this.service = new FetchClient(this.options.endpoint_url); this.service.headers["Authorization"] = `Bearer ${this.options.apiKey}`; } async getModelURLEndpoint(modelId) { const res = (await this.service.get(`/${modelId}`)); return { url: res.status.url, status: getStatus(res), }; } async getExecutor(model) { if (!this._executor) { const endpoint = await this.getModelURLEndpoint(model); if (!endpoint.url) throw new Error(`Endpoint URL not found for model ${model}`); if (endpoint.status !== AIModelStatus.Available) throw new Error(`Endpoint ${model} is not running - current status: ${endpoint.status}`); // Use the new InferenceClient and bind it to the endpoint URL this._executor = new InferenceClient(this.options.apiKey).endpoint(endpoint.url); } return this._executor; } async requestTextCompletionStream(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const executor = await this.getExecutor(options.model); const req = executor.textGenerationStream({ inputs: prompt, parameters: { temperature: options.model_options?.temperature, max_new_tokens: options.model_options?.max_tokens, }, }); return transformAsyncIterator(req, (val) => { //special like <s> are not part of the result if (val.token.special) return { result: [] }; let finish_reason = val.details?.finish_reason; if (finish_reason === "eos_token") { finish_reason = "stop"; } return { result: val.token.text ? [{ type: "text", value: val.token.text }] : [], finish_reason: finish_reason, token_usage: { result: val.details?.generated_tokens ?? 0, } }; }); } async requestTextCompletion(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const executor = await this.getExecutor(options.model); const res = await executor.textGeneration({ inputs: prompt, parameters: { temperature: options.model_options?.temperature, max_new_tokens: options.model_options?.max_tokens, }, }); let finish_reason = res.details?.finish_reason; if (finish_reason === "eos_token") { finish_reason = "stop"; } return { result: [{ type: "text", value: res.generated_text }], finish_reason: finish_reason, token_usage: { result: res.details?.generated_tokens }, original_response: options.include_original_response ? res : undefined, }; } // ============== management API ============== async listModels() { const res = await this.service.get("/"); const hfModels = res.items; if (!hfModels || !hfModels.length) return []; const models = hfModels.map((model) => ({ id: model.name, name: `${model.name} [${model.model.repository}:${model.model.task}]`, provider: this.provider, tags: [model.model.task], status: getStatus(model), })); return models; } async validateConnection() { try { await this.service.get("/models"); return true; } catch (error) { return false; } } async generateEmbeddings() { throw new Error("Method not implemented."); } } //get status from HF status function getStatus(hfModel) { //[ pending, initializing, updating, updateFailed, running, paused, failed, scaledToZero ] switch (hfModel.status.state) { case "running": return AIModelStatus.Available; case "initializing": return AIModelStatus.Pending; case "updating": return AIModelStatus.Pending; case "updateFailed": return AIModelStatus.Unavailable; case "paused": return AIModelStatus.Stopped; case "failed": return AIModelStatus.Unavailable; case "scaledToZero": return AIModelStatus.Available; default: return AIModelStatus.Unknown; } } /* Example of model returned by the API { "items": [ { "accountId": "string", "compute": { "accelerator": "cpu", "instanceSize": "large", "instanceType": "c6i", "scaling": { "maxReplica": 8, "minReplica": 2 } }, "model": { "framework": "custom", "image": { "huggingface": {} }, "repository": "gpt2", "revision": "6c0e6080953db56375760c0471a8c5f2929baf11", "task": "text-classification" }, "name": "my-endpoint", "provider": { "region": "us-east-1", "vendor": "aws" }, "status": { "createdAt": "2023-10-19T05:04:17.305Z", "createdBy": { "id": "string", "name": "string" }, "message": "Endpoint is ready", "private": { "serviceName": "string" }, "readyReplica": 2, "state": "pending", "targetReplica": 4, "updatedAt": "2023-10-19T05:04:17.305Z", "updatedBy": { "id": "string", "name": "string" }, "url": "https://endpoint-id.region.vendor.endpoints.huggingface.cloud" }, "type": "public" } ] } */ //# sourceMappingURL=huggingface_ie.js.map