@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
122 lines • 5.07 kB
JavaScript
import { AbstractDriver } from "@llumiverse/core";
import { transformSSEStream } from "@llumiverse/core/async";
import { FetchClient } from "@vertesia/api-fetch-client";
export class TogetherAIDriver extends AbstractDriver {
static PROVIDER = "togetherai";
provider = TogetherAIDriver.PROVIDER;
apiKey;
fetchClient;
constructor(options) {
super(options);
this.apiKey = options.apiKey;
this.fetchClient = new FetchClient('https://api.together.xyz').withHeaders({
authorization: `Bearer ${this.apiKey}`
});
}
getResponseFormat = (options) => {
return options.result_schema ?
{
type: "json_object",
schema: options.result_schema
} : undefined;
};
async requestTextCompletion(prompt, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const stop_seq = options.model_options?.stop_sequence ?? [];
const res = await this.fetchClient.post('/v1/completions', {
payload: {
model: options.model,
prompt: prompt,
response_format: this.getResponseFormat(options),
max_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
top_p: options.model_options?.top_p,
top_k: options.model_options?.top_k,
//logprobs: options.top_logprobs, //Logprobs output currently not supported
frequency_penalty: options.model_options?.frequency_penalty,
presence_penalty: options.model_options?.presence_penalty,
stop: [
"</s>",
"[/INST]",
...stop_seq,
],
}
});
const choice = res.choices[0];
const text = choice.text ?? '';
const usage = res.usage || {};
return {
result: [{ type: "text", value: text }],
token_usage: {
prompt: usage.prompt_tokens,
result: usage.completion_tokens,
total: usage.total_tokens,
},
finish_reason: choice.finish_reason, //Uses expected "stop" , "length" format
original_response: options.include_original_response ? res : undefined,
};
}
async requestTextCompletionStream(prompt, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const stop_seq = options.model_options?.stop_sequence ?? [];
const stream = await this.fetchClient.post('/v1/completions', {
payload: {
model: options.model,
prompt: prompt,
max_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
response_format: this.getResponseFormat(options),
top_p: options.model_options?.top_p,
top_k: options.model_options?.top_k,
//logprobs: options.top_logprobs, //Logprobs output currently not supported
frequency_penalty: options.model_options?.frequency_penalty,
presence_penalty: options.model_options?.presence_penalty,
stream: true,
stop: [
"</s>",
"[/INST]",
...stop_seq,
],
},
reader: 'sse'
});
return transformSSEStream(stream, (data) => {
const json = JSON.parse(data);
return {
result: [{ type: "text", value: json.choices[0]?.text ?? '' }],
finish_reason: json.choices[0]?.finish_reason, //Uses expected "stop" , "length" format
token_usage: {
prompt: json.usage?.prompt_tokens,
result: json.usage?.completion_tokens,
total: json.usage?.prompt_tokens + json.usage?.completion_tokens,
}
};
});
}
async listModels() {
const models = await this.fetchClient.get("/models/info");
// logObject('#### LIST MODELS RESULT IS', models[0]);
const aiModels = models.map(m => {
return {
id: m.name,
name: m.display_name,
description: m.description,
provider: this.provider,
};
});
return aiModels;
}
validateConnection() {
throw new Error("Method not implemented.");
}
generateEmbeddings() {
throw new Error("Method not implemented.");
}
}
//# sourceMappingURL=index.js.map