@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
197 lines • 6.89 kB
JavaScript
import { InferenceClient, } from "@huggingface/inference";
import { AIModelStatus, AbstractDriver, } from "@llumiverse/core";
import { transformAsyncIterator } from "@llumiverse/core/async";
import { FetchClient } from "@vertesia/api-fetch-client";
export class HuggingFaceIEDriver extends AbstractDriver {
static PROVIDER = "huggingface_ie";
provider = HuggingFaceIEDriver.PROVIDER;
service;
_executor;
constructor(options) {
super(options);
if (!options.endpoint_url) {
throw new Error(`Endpoint URL is required for ${this.provider}`);
}
this.service = new FetchClient(this.options.endpoint_url);
this.service.headers["Authorization"] = `Bearer ${this.options.apiKey}`;
}
async getModelURLEndpoint(modelId) {
const res = (await this.service.get(`/${modelId}`));
return {
url: res.status.url,
status: getStatus(res),
};
}
async getExecutor(model) {
if (!this._executor) {
const endpoint = await this.getModelURLEndpoint(model);
if (!endpoint.url)
throw new Error(`Endpoint URL not found for model ${model}`);
if (endpoint.status !== AIModelStatus.Available)
throw new Error(`Endpoint ${model} is not running - current status: ${endpoint.status}`);
// Use the new InferenceClient and bind it to the endpoint URL
this._executor = new InferenceClient(this.options.apiKey).endpoint(endpoint.url);
}
return this._executor;
}
async requestTextCompletionStream(prompt, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const executor = await this.getExecutor(options.model);
const req = executor.textGenerationStream({
inputs: prompt,
parameters: {
temperature: options.model_options?.temperature,
max_new_tokens: options.model_options?.max_tokens,
},
});
return transformAsyncIterator(req, (val) => {
//special like <s> are not part of the result
if (val.token.special)
return { result: [] };
let finish_reason = val.details?.finish_reason;
if (finish_reason === "eos_token") {
finish_reason = "stop";
}
return {
result: val.token.text ? [{ type: "text", value: val.token.text }] : [],
finish_reason: finish_reason,
token_usage: {
result: val.details?.generated_tokens ?? 0,
}
};
});
}
async requestTextCompletion(prompt, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const executor = await this.getExecutor(options.model);
const res = await executor.textGeneration({
inputs: prompt,
parameters: {
temperature: options.model_options?.temperature,
max_new_tokens: options.model_options?.max_tokens,
},
});
let finish_reason = res.details?.finish_reason;
if (finish_reason === "eos_token") {
finish_reason = "stop";
}
return {
result: [{ type: "text", value: res.generated_text }],
finish_reason: finish_reason,
token_usage: {
result: res.details?.generated_tokens
},
original_response: options.include_original_response ? res : undefined,
};
}
// ============== management API ==============
async listModels() {
const res = await this.service.get("/");
const hfModels = res.items;
if (!hfModels || !hfModels.length)
return [];
const models = hfModels.map((model) => ({
id: model.name,
name: `${model.name} [${model.model.repository}:${model.model.task}]`,
provider: this.provider,
tags: [model.model.task],
status: getStatus(model),
}));
return models;
}
async validateConnection() {
try {
await this.service.get("/models");
return true;
}
catch (error) {
return false;
}
}
async generateEmbeddings() {
throw new Error("Method not implemented.");
}
}
//get status from HF status
function getStatus(hfModel) {
//[ pending, initializing, updating, updateFailed, running, paused, failed, scaledToZero ]
switch (hfModel.status.state) {
case "running":
return AIModelStatus.Available;
case "initializing":
return AIModelStatus.Pending;
case "updating":
return AIModelStatus.Pending;
case "updateFailed":
return AIModelStatus.Unavailable;
case "paused":
return AIModelStatus.Stopped;
case "failed":
return AIModelStatus.Unavailable;
case "scaledToZero":
return AIModelStatus.Available;
default:
return AIModelStatus.Unknown;
}
}
/*
Example of model returned by the API
{
"items": [
{
"accountId": "string",
"compute": {
"accelerator": "cpu",
"instanceSize": "large",
"instanceType": "c6i",
"scaling": {
"maxReplica": 8,
"minReplica": 2
}
},
"model": {
"framework": "custom",
"image": {
"huggingface": {}
},
"repository": "gpt2",
"revision": "6c0e6080953db56375760c0471a8c5f2929baf11",
"task": "text-classification"
},
"name": "my-endpoint",
"provider": {
"region": "us-east-1",
"vendor": "aws"
},
"status": {
"createdAt": "2023-10-19T05:04:17.305Z",
"createdBy": {
"id": "string",
"name": "string"
},
"message": "Endpoint is ready",
"private": {
"serviceName": "string"
},
"readyReplica": 2,
"state": "pending",
"targetReplica": 4,
"updatedAt": "2023-10-19T05:04:17.305Z",
"updatedBy": {
"id": "string",
"name": "string"
},
"url": "https://endpoint-id.region.vendor.endpoints.huggingface.cloud"
},
"type": "public"
}
]
}
*/
//# sourceMappingURL=huggingface_ie.js.map