UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

169 lines 6.71 kB
import { AbstractDriver } from "@llumiverse/core"; import { transformSSEStream } from "@llumiverse/core/async"; import { getJSONSafetyNotice } from "@llumiverse/core/formatters"; import { formatOpenAILikeTextPrompt } from "../openai/openai_format.js"; import { FetchClient } from "@vertesia/api-fetch-client"; //TODO retry on 429 //const RETRY_STATUS_CODES = [429, 500, 502, 503, 504]; const ENDPOINT = 'https://api.mistral.ai'; export class MistralAIDriver extends AbstractDriver { static PROVIDER = "mistralai"; provider = MistralAIDriver.PROVIDER; apiKey; client; endpointUrl; constructor(options) { super(options); this.apiKey = options.apiKey; //this.client = new MistralClient(options.apiKey, options.endpointUrl); this.client = new FetchClient(options.endpoint_url || ENDPOINT).withHeaders({ authorization: `Bearer ${this.apiKey}` }); } getResponseFormat = (_options) => { // const responseFormatJson: ResponseFormat = { // type: "json_object", // } as ResponseFormat // const responseFormatText: ResponseFormat = { // type: "text", // } as ResponseFormat; // return _options.result_schema ? responseFormatJson : responseFormatText; //TODO remove this when Mistral properly supports the parameters - it makes an error for now // some models like mixtral mistral tiny or medium are throwing an error when using the response_format parameter return undefined; }; async formatPrompt(segments, opts) { const messages = formatOpenAILikeTextPrompt(segments); //Add JSON instruction is schema is provided if (opts.result_schema) { messages.push({ role: "user", content: "IMPORTANT: " + getJSONSafetyNotice(opts.result_schema) }); } return messages; } async requestTextCompletion(messages, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const res = await this.client.post('/v1/chat/completions', { payload: _makeChatCompletionRequest({ model: options.model, messages: messages, maxTokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, responseFormat: this.getResponseFormat(options), }) }); const choice = res.choices[0]; const result = choice.message.content; return { result: result ? [{ type: "text", value: result }] : [], token_usage: { prompt: res.usage.prompt_tokens, result: res.usage.completion_tokens, total: res.usage.total_tokens, }, finish_reason: choice.finish_reason, //Uses expected "stop" , "length" format original_response: options.include_original_response ? res : undefined, }; } async requestTextCompletionStream(messages, options) { if (options.model_options?._option_id !== "text-fallback") { this.logger.warn({ options: options.model_options }, "Invalid model options"); } options.model_options = options.model_options; const stream = await this.client.post('/v1/chat/completions', { payload: _makeChatCompletionRequest({ model: options.model, messages: messages, maxTokens: options.model_options?.max_tokens, temperature: options.model_options?.temperature, topP: options.model_options?.top_p, responseFormat: this.getResponseFormat(options), stream: true, stopSequences: options.model_options?.stop_sequence, }), reader: 'sse' }); return transformSSEStream(stream, (data) => { const json = JSON.parse(data); const content = json.choices[0]?.delta.content; return { result: content ? [{ type: "text", value: content }] : [], finish_reason: json.choices[0]?.finish_reason, //Uses expected "stop" , "length" format token_usage: { prompt: json.usage?.prompt_tokens, result: json.usage?.completion_tokens, total: json.usage?.total_tokens, }, }; }); } async listModels() { const models = await this.client.get('v1/models'); const aiModels = models.data.map(m => { return { id: m.id, name: m.id, description: undefined, provider: this.provider, owner: m.owned_by, }; }); return aiModels; } validateConnection() { throw new Error("Method not implemented."); } async generateEmbeddings({ text, model = "mistral-embed" }) { const r = await this.client.post('/v1/embeddings', { payload: { model, input: [text], encoding_format: "float" }, }); return { values: r.data[0].embedding, model, token_count: r.usage.total_tokens || r.usage.prompt_tokens + r.usage.completion_tokens, }; } } /** * Creates a chat completion request * @param {*} model * @param {*} messages * @param {*} tools * @param {*} temperature * @param {*} maxTokens * @param {*} topP * @param {*} randomSeed * @param {*} stream * @param {*} safeMode deprecated use safePrompt instead * @param {*} safePrompt * @param {*} toolChoice * @param {*} responseFormat * @return {Promise<Object>} */ function _makeChatCompletionRequest({ model, messages, tools, temperature, maxTokens, topP, randomSeed, stream, safeMode, safePrompt, toolChoice, responseFormat, stopSequences, }) { return { model: model, messages: messages, tools: tools ?? undefined, temperature: temperature ?? undefined, max_tokens: maxTokens ?? undefined, top_p: topP ?? undefined, random_seed: randomSeed ?? undefined, stream: stream ?? undefined, safe_prompt: (safeMode || safePrompt) ?? undefined, tool_choice: toolChoice ?? undefined, response_format: responseFormat ?? undefined, stop: stopSequences ?? undefined, }; } ; //# sourceMappingURL=index.js.map