@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
169 lines • 6.71 kB
JavaScript
import { AbstractDriver } from "@llumiverse/core";
import { transformSSEStream } from "@llumiverse/core/async";
import { getJSONSafetyNotice } from "@llumiverse/core/formatters";
import { formatOpenAILikeTextPrompt } from "../openai/openai_format.js";
import { FetchClient } from "@vertesia/api-fetch-client";
//TODO retry on 429
//const RETRY_STATUS_CODES = [429, 500, 502, 503, 504];
const ENDPOINT = 'https://api.mistral.ai';
export class MistralAIDriver extends AbstractDriver {
static PROVIDER = "mistralai";
provider = MistralAIDriver.PROVIDER;
apiKey;
client;
endpointUrl;
constructor(options) {
super(options);
this.apiKey = options.apiKey;
//this.client = new MistralClient(options.apiKey, options.endpointUrl);
this.client = new FetchClient(options.endpoint_url || ENDPOINT).withHeaders({
authorization: `Bearer ${this.apiKey}`
});
}
getResponseFormat = (_options) => {
// const responseFormatJson: ResponseFormat = {
// type: "json_object",
// } as ResponseFormat
// const responseFormatText: ResponseFormat = {
// type: "text",
// } as ResponseFormat;
// return _options.result_schema ? responseFormatJson : responseFormatText;
//TODO remove this when Mistral properly supports the parameters - it makes an error for now
// some models like mixtral mistral tiny or medium are throwing an error when using the response_format parameter
return undefined;
};
async formatPrompt(segments, opts) {
const messages = formatOpenAILikeTextPrompt(segments);
//Add JSON instruction is schema is provided
if (opts.result_schema) {
messages.push({
role: "user",
content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema)
});
}
return messages;
}
async requestTextCompletion(messages, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const res = await this.client.post('/v1/chat/completions', {
payload: _makeChatCompletionRequest({
model: options.model,
messages: messages,
maxTokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
responseFormat: this.getResponseFormat(options),
})
});
const choice = res.choices[0];
const result = choice.message.content;
return {
result: result ? [{ type: "text", value: result }] : [],
token_usage: {
prompt: res.usage.prompt_tokens,
result: res.usage.completion_tokens,
total: res.usage.total_tokens,
},
finish_reason: choice.finish_reason, //Uses expected "stop" , "length" format
original_response: options.include_original_response ? res : undefined,
};
}
async requestTextCompletionStream(messages, options) {
if (options.model_options?._option_id !== "text-fallback") {
this.logger.warn({ options: options.model_options }, "Invalid model options");
}
options.model_options = options.model_options;
const stream = await this.client.post('/v1/chat/completions', {
payload: _makeChatCompletionRequest({
model: options.model,
messages: messages,
maxTokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
topP: options.model_options?.top_p,
responseFormat: this.getResponseFormat(options),
stream: true,
stopSequences: options.model_options?.stop_sequence,
}),
reader: 'sse'
});
return transformSSEStream(stream, (data) => {
const json = JSON.parse(data);
const content = json.choices[0]?.delta.content;
return {
result: content ? [{ type: "text", value: content }] : [],
finish_reason: json.choices[0]?.finish_reason, //Uses expected "stop" , "length" format
token_usage: {
prompt: json.usage?.prompt_tokens,
result: json.usage?.completion_tokens,
total: json.usage?.total_tokens,
},
};
});
}
async listModels() {
const models = await this.client.get('v1/models');
const aiModels = models.data.map(m => {
return {
id: m.id,
name: m.id,
description: undefined,
provider: this.provider,
owner: m.owned_by,
};
});
return aiModels;
}
validateConnection() {
throw new Error("Method not implemented.");
}
async generateEmbeddings({ text, model = "mistral-embed" }) {
const r = await this.client.post('/v1/embeddings', {
payload: {
model,
input: [text],
encoding_format: "float"
},
});
return {
values: r.data[0].embedding,
model,
token_count: r.usage.total_tokens || r.usage.prompt_tokens + r.usage.completion_tokens,
};
}
}
/**
* Creates a chat completion request
* @param {*} model
* @param {*} messages
* @param {*} tools
* @param {*} temperature
* @param {*} maxTokens
* @param {*} topP
* @param {*} randomSeed
* @param {*} stream
* @param {*} safeMode deprecated use safePrompt instead
* @param {*} safePrompt
* @param {*} toolChoice
* @param {*} responseFormat
* @return {Promise<Object>}
*/
function _makeChatCompletionRequest({ model, messages, tools, temperature, maxTokens, topP, randomSeed, stream, safeMode, safePrompt, toolChoice, responseFormat, stopSequences, }) {
return {
model: model,
messages: messages,
tools: tools ?? undefined,
temperature: temperature ?? undefined,
max_tokens: maxTokens ?? undefined,
top_p: topP ?? undefined,
random_seed: randomSeed ?? undefined,
stream: stream ?? undefined,
safe_prompt: (safeMode || safePrompt) ?? undefined,
tool_choice: toolChoice ?? undefined,
response_format: responseFormat ?? undefined,
stop: stopSequences ?? undefined,
};
}
;
//# sourceMappingURL=index.js.map