@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
291 lines (266 loc) • 8.71 kB
text/typescript
import {
HfInference,
HfInferenceEndpoint,
TextGenerationStreamOutput
} from "@huggingface/inference";
import {
AIModel,
AIModelStatus,
AbstractDriver,
CompletionChunkObject,
DriverOptions,
EmbeddingsResult,
ExecutionOptions,
TextFallbackOptions,
} from "@llumiverse/core";
import { transformAsyncIterator } from "@llumiverse/core/async";
import { FetchClient } from "@vertesia/api-fetch-client";
export interface HuggingFaceIEDriverOptions extends DriverOptions {
apiKey: string;
endpoint_url: string;
}
export class HuggingFaceIEDriver extends AbstractDriver<HuggingFaceIEDriverOptions, string> {
static PROVIDER = "huggingface_ie";
provider = HuggingFaceIEDriver.PROVIDER;
service: FetchClient;
_executor?: HfInferenceEndpoint;
constructor(
options: HuggingFaceIEDriverOptions
) {
super(options);
if (!options.endpoint_url) {
throw new Error(`Endpoint URL is required for ${this.provider}`);
}
this.service = new FetchClient(this.options.endpoint_url);
this.service.headers["Authorization"] = `Bearer ${this.options.apiKey}`;
}
async getModelURLEndpoint(
modelId: string
): Promise<{ url: string; status: string; }> {
const res = (await this.service.get(`/${modelId}`)) as HuggingFaceIEModel;
return {
url: res.status.url,
status: getStatus(res),
};
}
async getExecutor(model: string) {
if (!this._executor) {
const endpoint = await this.getModelURLEndpoint(model);
if (!endpoint.url)
throw new Error(
`Endpoint URL not found for model ${model}`
);
if (endpoint.status !== AIModelStatus.Available)
throw new Error(
`Endpoint ${model} is not running - current status: ${endpoint.status}`
);
this._executor = new HfInference(this.options.apiKey).endpoint(
endpoint.url
);
}
return this._executor;
}
async requestTextCompletionStream(prompt: string, options: ExecutionOptions) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn("Invalid model options", { options: options.model_options });
}
options.model_options = options.model_options as TextFallbackOptions;
const executor = await this.getExecutor(options.model);
const req = executor.textGenerationStream({
inputs: prompt,
parameters: {
temperature: options.model_options?.temperature,
max_new_tokens: options.model_options?.max_tokens,
},
});
return transformAsyncIterator(req, (val: TextGenerationStreamOutput) => {
//special like <s> are not part of the result
if (val.token.special) return { result: [] };
let finish_reason = val.details?.finish_reason as string;
if (finish_reason === "eos_token") {
finish_reason = "stop";
}
return {
result: val.token.text ? [{ type: "text" as const, value: val.token.text }] : [],
finish_reason: finish_reason,
token_usage: {
result: val.details?.generated_tokens ?? 0,
}
} as CompletionChunkObject;
});
}
async requestTextCompletion(prompt: string, options: ExecutionOptions) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn("Invalid model options", { options: options.model_options });
}
options.model_options = options.model_options as TextFallbackOptions;
const executor = await this.getExecutor(options.model);
const res = await executor.textGeneration({
inputs: prompt,
parameters: {
temperature: options.model_options?.temperature,
max_new_tokens: options.model_options?.max_tokens,
},
});
let finish_reason = res.details?.finish_reason as string;
if (finish_reason === "eos_token") {
finish_reason = "stop";
}
return {
result: [{ type: "text" as const, value: res.generated_text }],
finish_reason: finish_reason,
token_usage: {
result: res.details?.generated_tokens
},
original_response: options.include_original_response ? res : undefined,
};
}
// ============== management API ==============
async listModels(): Promise<AIModel[]> {
const res = await this.service.get("/");
const hfModels = res.items as HuggingFaceIEModel[];
if (!hfModels || !hfModels.length) return [];
const models: AIModel[] = hfModels.map((model: HuggingFaceIEModel) => ({
id: model.name,
name: `${model.name} [${model.model.repository}:${model.model.task}]`,
provider: this.provider,
tags: [model.model.task],
status: getStatus(model),
}));
return models;
}
async validateConnection(): Promise<boolean> {
try {
await this.service.get("/models");
return true;
} catch (error) {
return false;
}
}
async generateEmbeddings(): Promise<EmbeddingsResult> {
throw new Error("Method not implemented.");
}
}
//get status from HF status
function getStatus(hfModel: HuggingFaceIEModel): AIModelStatus {
//[ pending, initializing, updating, updateFailed, running, paused, failed, scaledToZero ]
switch (hfModel.status.state) {
case "running":
return AIModelStatus.Available;
case "initializing":
return AIModelStatus.Pending;
case "updating":
return AIModelStatus.Pending;
case "updateFailed":
return AIModelStatus.Unavailable;
case "paused":
return AIModelStatus.Stopped;
case "failed":
return AIModelStatus.Unavailable;
case "scaledToZero":
return AIModelStatus.Available;
default:
return AIModelStatus.Unknown;
}
}
interface HuggingFaceIEModel {
accountId: string;
compute: {
accelerator: string;
instanceSize: string;
instanceType: string;
scaling: {
maxReplica: number;
minReplica: number;
};
};
model: {
framework: string;
image: {
huggingface: {};
};
repository: string;
revision: string;
task: string;
};
name: string;
provider: {
region: string;
vendor: string;
};
status: {
createdAt: string;
createdBy: {
id: string;
name: string;
};
message: string;
private: {
serviceName: string;
};
readyReplica: number;
state: string;
targetReplica: number;
updatedAt: string;
updatedBy: {
id: string;
name: string;
};
url: string;
};
type: string;
}
/*
Example of model returned by the API
{
"items": [
{
"accountId": "string",
"compute": {
"accelerator": "cpu",
"instanceSize": "large",
"instanceType": "c6i",
"scaling": {
"maxReplica": 8,
"minReplica": 2
}
},
"model": {
"framework": "custom",
"image": {
"huggingface": {}
},
"repository": "gpt2",
"revision": "6c0e6080953db56375760c0471a8c5f2929baf11",
"task": "text-classification"
},
"name": "my-endpoint",
"provider": {
"region": "us-east-1",
"vendor": "aws"
},
"status": {
"createdAt": "2023-10-19T05:04:17.305Z",
"createdBy": {
"id": "string",
"name": "string"
},
"message": "Endpoint is ready",
"private": {
"serviceName": "string"
},
"readyReplica": 2,
"state": "pending",
"targetReplica": 4,
"updatedAt": "2023-10-19T05:04:17.305Z",
"updatedBy": {
"id": "string",
"name": "string"
},
"url": "https://endpoint-id.region.vendor.endpoints.huggingface.cloud"
},
"type": "public"
}
]
}
*/