@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
157 lines • 6.25 kB
JavaScript
import { AbstractDriver } from "@llumiverse/core";
import { transformSSEStream } from "@llumiverse/core/async";
import { FetchClient } from "@vertesia/api-fetch-client";
const API_VERSION = "2024-03-14";
export class WatsonxDriver extends AbstractDriver {
static PROVIDER = "watsonx";
provider = WatsonxDriver.PROVIDER;
apiKey;
endpoint_url;
projectId;
authToken;
fetcher;
fetchClient;
constructor(options) {
super(options);
this.apiKey = options.apiKey;
this.projectId = options.projectId;
this.endpoint_url = options.endpointUrl;
this.fetchClient = new FetchClient(this.endpoint_url).withAuthCallback(async () => this.getAuthToken().then(token => `Bearer ${token}`));
}
async requestTextCompletion(prompt, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const payload = {
model_id: options.model,
input: prompt + "\n",
parameters: {
max_new_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
top_k: options.model_options?.top_k,
top_p: options.model_options?.top_p,
stop_sequences: options.model_options?.stop_sequence,
},
project_id: this.projectId,
};
const res = await this.fetchClient.post(`/ml/v1/text/generation?version=${API_VERSION}`, { payload });
const result = res.results[0];
return {
result: [{ type: "text", value: result.generated_text }],
token_usage: {
prompt: result.input_token_count,
result: result.generated_token_count,
total: result.input_token_count + result.generated_token_count,
},
finish_reason: watsonFinishReason(result.stop_reason),
original_response: options.include_original_response ? res : undefined,
};
}
async requestTextCompletionStream(prompt, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const payload = {
model_id: options.model,
input: prompt + "\n",
parameters: {
max_new_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
top_k: options.model_options?.top_k,
top_p: options.model_options?.top_p,
stop_sequences: options.model_options?.stop_sequence,
},
project_id: this.projectId,
};
const stream = await this.fetchClient.post(`/ml/v1/text/generation_stream?version=${API_VERSION}`, {
payload: payload,
reader: 'sse'
});
return transformSSEStream(stream, (data) => {
const json = JSON.parse(data);
return {
result: json.results[0]?.generated_text ? [{ type: "text", value: json.results[0].generated_text }] : [],
finish_reason: watsonFinishReason(json.results[0]?.stop_reason),
token_usage: {
prompt: json.results[0].input_token_count,
result: json.results[0].generated_token_count,
total: json.results[0].input_token_count + json.results[0].generated_token_count,
},
};
});
}
async listModels() {
const res = await this.fetchClient.get(`/ml/v1/foundation_model_specs?version=${API_VERSION}`)
.catch(err => this.logger.warn("Can't list models on Watsonx: " + err));
const aiModels = res.resources.map((m) => {
return {
id: m.model_id,
name: m.label,
description: m.short_description,
provider: this.provider,
};
});
return aiModels;
}
async getAuthToken() {
if (this.authToken) {
const now = Date.now() / 1000;
if (now < this.authToken.expiration) {
return this.authToken.access_token;
}
else {
this.logger.debug("Token expired, refetching");
}
}
const authToken = await fetch('https://iam.cloud.ibm.com/identity/token', {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: `grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=${this.apiKey}`,
}).then(response => response.json());
this.authToken = authToken;
return this.authToken.access_token;
}
async validateConnection() {
return this.listModels()
.then(() => true)
.catch((err) => {
this.logger.warn({ error: err }, "Failed to connect to WatsonX");
return false;
});
}
async generateEmbeddings(options) {
if (options.image) {
throw new Error("Image embeddings not supported by Watsonx");
}
if (!options.text) {
throw new Error("No text provided");
}
const payload = {
inputs: [options.text],
model_id: options.model ?? 'ibm/slate-125m-english-rtrvr',
project_id: this.projectId
};
const res = await this.fetchClient.post(`/ml/v1/text/embeddings?version=${API_VERSION}`, { payload });
return {
values: res.results[0].embedding,
model: res.model_id
};
}
}
function watsonFinishReason(reason) {
if (!reason)
return undefined;
switch (reason) {
case 'eos_token': return "stop";
case 'max_tokens': return "length";
default: return reason;
}
}
/*interface ListModelsParams extends ModelSearchPayload {
limit?: number;
}*/
//# sourceMappingURL=index.js.map