@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
179 lines • 6.82 kB
JavaScript
import { ModelType, PromptRole } from "@llumiverse/core";
import { transformSSEStream } from "@llumiverse/core/async";
/**
* Convert a stream to a string
*/
async function streamToString(stream) {
const chunks = [];
for await (const chunk of stream) {
chunks.push(Buffer.from(chunk));
}
return Buffer.concat(chunks).toString('utf-8');
}
/**
* Update the conversation messages
* @param conversation The previous conversation context
* @param prompt The new prompt to add to the conversation
* @returns Updated conversation with combined messages
*/
function updateConversation(conversation, prompt) {
const baseMessages = conversation ? conversation.messages : [];
return {
messages: [...baseMessages, ...(prompt.messages || [])],
};
}
export class LLamaModelDefinition {
model;
constructor(modelId) {
this.model = {
id: modelId,
name: modelId,
provider: 'vertexai',
type: ModelType.Text,
can_stream: true,
};
}
// Return the appropriate region based on the Llama model
getLlamaModelRegion(modelName) {
// Llama 4 models are in us-east5, Llama 3.x models are in us-central1
if (modelName.startsWith('llama-4')) {
return 'us-east5';
}
else {
return 'us-central1';
}
}
async createPrompt(_driver, segments, options) {
const messages = [];
// Process segments and convert them to the Llama MaaS format
for (const segment of segments) {
// Convert the prompt segments to messages
const role = segment.role === PromptRole.assistant ? 'assistant' : 'user';
// Combine files and text content if needed
let messageContent = segment.content || '';
if (segment.files && segment.files.length > 0) {
for (const file of segment.files) {
if (file.mime_type?.startsWith("text/")) {
const fileStream = await file.getStream();
const fileContent = await streamToString(fileStream);
messageContent += `\n\nFile content:\n${fileContent}`;
}
}
}
messages.push({
role: role,
content: messageContent
});
}
if (options.result_schema) {
messages.push({
role: 'user',
content: "The answer must be a JSON object using the following JSON Schema:\n" + JSON.stringify(options.result_schema)
});
}
// Return the prompt in the format expected by Llama MaaS API
return {
messages: messages,
};
}
async requestTextCompletion(driver, prompt, options) {
const splits = options.model.split("/");
const modelName = splits[splits.length - 1];
let conversation = updateConversation(options.conversation, prompt);
const modelOptions = options.model_options;
const payload = {
model: `meta/${modelName}`,
messages: conversation.messages,
stream: false,
max_tokens: modelOptions?.max_tokens,
temperature: modelOptions?.temperature,
top_p: modelOptions?.top_p,
top_k: modelOptions?.top_k,
// Disable llama guard
extra_body: {
google: {
model_safety_settings: {
enabled: false,
llama_guard_settings: {}
}
}
}
};
// Make POST request to the Llama MaaS API
const region = this.getLlamaModelRegion(modelName);
const client = driver.getLLamaClient(region);
const openaiEndpoint = `endpoints/openapi/chat/completions`;
const result = await client.post(openaiEndpoint, {
payload
});
// Extract response data
const assistantMessage = result?.choices[0]?.message;
const text = assistantMessage?.content;
// Update conversation with the response
conversation = updateConversation(conversation, {
messages: [{
role: assistantMessage?.role,
content: text
}],
});
return {
result: [{ type: "text", value: text }],
token_usage: {
prompt: result.usage.prompt_tokens,
result: result.usage.completion_tokens,
total: result.usage.total_tokens
},
finish_reason: result.choices[0].finish_reason,
conversation
};
}
async requestTextCompletionStream(driver, prompt, options) {
const splits = options.model.split("/");
const modelName = splits[splits.length - 1];
const conversation = updateConversation(options.conversation, prompt);
const modelOptions = options.model_options;
const payload = {
model: `meta/${modelName}`,
messages: conversation.messages,
stream: true,
max_tokens: modelOptions?.max_tokens,
temperature: modelOptions?.temperature,
top_p: modelOptions?.top_p,
top_k: modelOptions?.top_k,
// Disable llama guard
extra_body: {
google: {
model_safety_settings: {
enabled: false,
llama_guard_settings: {}
}
}
}
};
// Make POST request to the Llama MaaS API
//TODO: Fix error handling with the fetch client, errors will return a empty response
//But not throw any error
const region = this.getLlamaModelRegion(modelName);
const client = driver.getLLamaClient(region);
const openaiEndpoint = `endpoints/openapi/chat/completions`;
const stream = await client.post(openaiEndpoint, {
payload,
reader: 'sse'
});
return transformSSEStream(stream, (data) => {
const json = JSON.parse(data);
const choice = json.choices?.[0];
const content = choice?.delta?.content ?? '';
return {
result: content ? [{ type: "text", value: content }] : [],
finish_reason: choice?.finish_reason,
token_usage: json.usage ? {
prompt: json.usage.prompt_tokens,
result: json.usage.completion_tokens,
total: json.usage.total_tokens,
} : undefined
};
});
}
}
//# sourceMappingURL=llama.js.map