UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

157 lines 6.25 kB
import { AbstractDriver } from "@llumiverse/core"; import { transformSSEStream } from "@llumiverse/core/async"; import { FetchClient } from "@vertesia/api-fetch-client"; const API_VERSION = "2024-03-14"; export class WatsonxDriver extends AbstractDriver { static PROVIDER = "watsonx"; provider = WatsonxDriver.PROVIDER; apiKey; endpoint_url; projectId; authToken; fetcher; fetchClient; constructor(options) { super(options); this.apiKey = options.apiKey; this.projectId = options.projectId; this.endpoint_url = options.endpointUrl; this.fetchClient = new FetchClient(this.endpoint_url).withAuthCallback(async () => this.getAuthToken().then(token => `Bearer ${token}`)); } async requestTextCompletion(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const payload = { model_id: options.model, input: prompt + "\n", parameters: { max_new_tokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, top_k: options.model_options?.top_k, top_p: options.model_options?.top_p, stop_sequences: options.model_options?.stop_sequence, }, project_id: this.projectId, }; const res = await this.fetchClient.post(`/ml/v1/text/generation?version=${API_VERSION}`, { payload }); const result = res.results[0]; return { result: [{ type: "text", value: result.generated_text }], token_usage: { prompt: result.input_token_count, result: result.generated_token_count, total: result.input_token_count + result.generated_token_count, }, finish_reason: watsonFinishReason(result.stop_reason), original_response: options.include_original_response ? res : undefined, }; } async requestTextCompletionStream(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const payload = { model_id: options.model, input: prompt + "\n", parameters: { max_new_tokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, top_k: options.model_options?.top_k, top_p: options.model_options?.top_p, stop_sequences: options.model_options?.stop_sequence, }, project_id: this.projectId, }; const stream = await this.fetchClient.post(`/ml/v1/text/generation_stream?version=${API_VERSION}`, { payload: payload, reader: 'sse' }); return transformSSEStream(stream, (data) => { const json = JSON.parse(data); return { result: json.results[0]?.generated_text ? [{ type: "text", value: json.results[0].generated_text }] : [], finish_reason: watsonFinishReason(json.results[0]?.stop_reason), token_usage: { prompt: json.results[0].input_token_count, result: json.results[0].generated_token_count, total: json.results[0].input_token_count + json.results[0].generated_token_count, }, }; }); } async listModels() { const res = await this.fetchClient.get(`/ml/v1/foundation_model_specs?version=${API_VERSION}`) .catch(err => this.logger.warn("Can't list models on Watsonx: " + err)); const aiModels = res.resources.map((m) => { return { id: m.model_id, name: m.label, description: m.short_description, provider: this.provider, }; }); return aiModels; } async getAuthToken() { if (this.authToken) { const now = Date.now() / 1000; if (now < this.authToken.expiration) { return this.authToken.access_token; } else { this.logger.debug("Token expired, refetching"); } } const authToken = await fetch('https://iam.cloud.ibm.com/identity/token', { method: 'POST', headers: { 'Content-Type': 'application/x-www-form-urlencoded', }, body: `grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=${this.apiKey}`, }).then(response => response.json()); this.authToken = authToken; return this.authToken.access_token; } async validateConnection() { return this.listModels() .then(() => true) .catch((err) => { this.logger.warn({ error: err }, "Failed to connect to WatsonX"); return false; }); } async generateEmbeddings(options) { if (options.image) { throw new Error("Image embeddings not supported by Watsonx"); } if (!options.text) { throw new Error("No text provided"); } const payload = { inputs: [options.text], model_id: options.model ?? 'ibm/slate-125m-english-rtrvr', project_id: this.projectId }; const res = await this.fetchClient.post(`/ml/v1/text/embeddings?version=${API_VERSION}`, { payload }); return { values: res.results[0].embedding, model: res.model_id }; } } function watsonFinishReason(reason) { if (!reason) return undefined; switch (reason) { case 'eos_token': return "stop"; case 'max_tokens': return "length"; default: return reason; } } /*interface ListModelsParams extends ModelSearchPayload { limit?: number; }*/ //# sourceMappingURL=index.js.map