UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

200 lines (160 loc) 7.4 kB
import { AIModel, AbstractDriver, Completion, CompletionChunkObject, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, TextFallbackOptions } from "@llumiverse/core"; import { transformSSEStream } from "@llumiverse/core/async"; import { FetchClient } from "@vertesia/api-fetch-client"; import { GenerateEmbeddingPayload, GenerateEmbeddingResponse, WatsonAuthToken, WatsonxListModelResponse, WatsonxModelSpec, WatsonxTextGenerationPayload, WatsonxTextGenerationResponse } from "./interfaces.js"; interface WatsonxDriverOptions extends DriverOptions { apiKey: string; projectId: string; endpointUrl: string; } const API_VERSION = "2024-03-14" export class WatsonxDriver extends AbstractDriver<WatsonxDriverOptions, string> { static PROVIDER = "watsonx"; provider = WatsonxDriver.PROVIDER; apiKey: string; endpoint_url: string; projectId: string; authToken?: WatsonAuthToken; fetcher?: FetchClient; fetchClient: FetchClient constructor(options: WatsonxDriverOptions) { super(options); this.apiKey = options.apiKey; this.projectId = options.projectId; this.endpoint_url = options.endpointUrl; this.fetchClient = new FetchClient(this.endpoint_url).withAuthCallback(async () => this.getAuthToken().then(token => `Bearer ${token}`)); } async requestTextCompletion(prompt: string, options: ExecutionOptions): Promise<Completion> { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn("Invalid model options", { options: options.model_options }); } options.model_options = options.model_options as TextFallbackOptions | undefined; const payload: WatsonxTextGenerationPayload = { model_id: options.model, input: prompt + "\n", parameters: { max_new_tokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, top_k: options.model_options?.top_k, top_p: options.model_options?.top_p, stop_sequences: options.model_options?.stop_sequence, }, project_id: this.projectId, } const res = await this.fetchClient.post(`/ml/v1/text/generation?version=${API_VERSION}`, { payload }) as WatsonxTextGenerationResponse; const result = res.results[0]; return { result: [{ type: "text", value: result.generated_text }], token_usage: { prompt: result.input_token_count, result: result.generated_token_count, total: result.input_token_count + result.generated_token_count, }, finish_reason: watsonFinishReason(result.stop_reason), original_response: options.include_original_response ? res : undefined, } } async requestTextCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn("Invalid model options", { options: options.model_options }); } options.model_options = options.model_options as TextFallbackOptions | undefined; const payload: WatsonxTextGenerationPayload = { model_id: options.model, input: prompt + "\n", parameters: { max_new_tokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, top_k: options.model_options?.top_k, top_p: options.model_options?.top_p, stop_sequences: options.model_options?.stop_sequence, }, project_id: this.projectId, } const stream = await this.fetchClient.post(`/ml/v1/text/generation_stream?version=${API_VERSION}`, { payload: payload, reader: 'sse' }) return transformSSEStream(stream, (data: string) => { const json = JSON.parse(data) as WatsonxTextGenerationResponse; return { result: json.results[0]?.generated_text ? [{ type: "text", value: json.results[0].generated_text }] : [], finish_reason: watsonFinishReason(json.results[0]?.stop_reason), token_usage: { prompt: json.results[0].input_token_count, result: json.results[0].generated_token_count, total: json.results[0].input_token_count + json.results[0].generated_token_count, }, }; }); } async listModels(): Promise<AIModel<string>[]> { const res = await this.fetchClient.get(`/ml/v1/foundation_model_specs?version=${API_VERSION}`) .catch(err => this.logger.warn("Can't list models on Watsonx: " + err)) as WatsonxListModelResponse; const aiModels = res.resources.map((m: WatsonxModelSpec) => { return { id: m.model_id, name: m.label, description: m.short_description, provider: this.provider, } }); return aiModels; } async getAuthToken(): Promise<string> { if (this.authToken) { const now = Date.now() / 1000; if (now < this.authToken.expiration) { return this.authToken.access_token; } else { this.logger.debug("Token expired, refetching") } } const authToken = await fetch('https://iam.cloud.ibm.com/identity/token', { method: 'POST', headers: { 'Content-Type': 'application/x-www-form-urlencoded', }, body: `grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=${this.apiKey}`, }).then(response => response.json()) as WatsonAuthToken; this.authToken = authToken; return this.authToken.access_token; } async validateConnection(): Promise<boolean> { return this.listModels() .then(() => true) .catch((err) => { this.logger.warn("Failed to connect to WatsonX", { error: err }); return false }); } async generateEmbeddings(options: EmbeddingsOptions): Promise<EmbeddingsResult> { if (options.image) { throw new Error("Image embeddings not supported by Watsonx"); } if (!options.text) { throw new Error("No text provided"); } const payload: GenerateEmbeddingPayload = { inputs: [options.text], model_id: options.model ?? 'ibm/slate-125m-english-rtrvr', project_id: this.projectId } const res = await this.fetchClient.post(`/ml/v1/text/embeddings?version=${API_VERSION}`, { payload }) as GenerateEmbeddingResponse; return { values: res.results[0].embedding, model: res.model_id } } } function watsonFinishReason(reason: string | undefined) { if (!reason) return undefined; switch (reason) { case 'eos_token': return "stop"; case 'max_tokens': return "length"; default: return reason; } } /*interface ListModelsParams extends ModelSearchPayload { limit?: number; }*/