UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

122 lines 5.07 kB
import { AbstractDriver } from "@llumiverse/core"; import { transformSSEStream } from "@llumiverse/core/async"; import { FetchClient } from "@vertesia/api-fetch-client"; export class TogetherAIDriver extends AbstractDriver { static PROVIDER = "togetherai"; provider = TogetherAIDriver.PROVIDER; apiKey; fetchClient; constructor(options) { super(options); this.apiKey = options.apiKey; this.fetchClient = new FetchClient('https://api.together.xyz').withHeaders({ authorization: `Bearer ${this.apiKey}` }); } getResponseFormat = (options) => { return options.result_schema ? { type: "json_object", schema: options.result_schema } : undefined; }; async requestTextCompletion(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const stop_seq = options.model_options?.stop_sequence ?? []; const res = await this.fetchClient.post('/v1/completions', { payload: { model: options.model, prompt: prompt, response_format: this.getResponseFormat(options), max_tokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, top_p: options.model_options?.top_p, top_k: options.model_options?.top_k, //logprobs: options.top_logprobs, //Logprobs output currently not supported frequency_penalty: options.model_options?.frequency_penalty, presence_penalty: options.model_options?.presence_penalty, stop: [ "</s>", "[/INST]", ...stop_seq, ], } }); const choice = res.choices[0]; const text = choice.text ?? ''; const usage = res.usage || {}; return { result: [{ type: "text", value: text }], token_usage: { prompt: usage.prompt_tokens, result: usage.completion_tokens, total: usage.total_tokens, }, finish_reason: choice.finish_reason, //Uses expected "stop" , "length" format original_response: options.include_original_response ? res : undefined, }; } async requestTextCompletionStream(prompt, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const stop_seq = options.model_options?.stop_sequence ?? []; const stream = await this.fetchClient.post('/v1/completions', { payload: { model: options.model, prompt: prompt, max_tokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, response_format: this.getResponseFormat(options), top_p: options.model_options?.top_p, top_k: options.model_options?.top_k, //logprobs: options.top_logprobs, //Logprobs output currently not supported frequency_penalty: options.model_options?.frequency_penalty, presence_penalty: options.model_options?.presence_penalty, stream: true, stop: [ "</s>", "[/INST]", ...stop_seq, ], }, reader: 'sse' }); return transformSSEStream(stream, (data) => { const json = JSON.parse(data); return { result: [{ type: "text", value: json.choices[0]?.text ?? '' }], finish_reason: json.choices[0]?.finish_reason, //Uses expected "stop" , "length" format token_usage: { prompt: json.usage?.prompt_tokens, result: json.usage?.completion_tokens, total: json.usage?.prompt_tokens + json.usage?.completion_tokens, } }; }); } async listModels() { const models = await this.fetchClient.get("/models/info"); // logObject('#### LIST MODELS RESULT IS', models[0]); const aiModels = models.map(m => { return { id: m.name, name: m.display_name, description: m.description, provider: this.provider, }; }); return aiModels; } validateConnection() { throw new Error("Method not implemented."); } generateEmbeddings() { throw new Error("Method not implemented."); } } //# sourceMappingURL=index.js.map