@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
377 lines (325 loc) • 15.4 kB
text/typescript
import { AbstractDriver, AIModel, Completion, CompletionChunkObject, DriverOptions, EmbeddingsOptions, EmbeddingsResult, ExecutionOptions, PromptSegment, TextFallbackOptions, ToolDefinition, ToolUse } from "@llumiverse/core";
import { transformAsyncIterator } from "@llumiverse/core/async";
import Groq from "groq-sdk";
import type { ChatCompletionMessageParam, ChatCompletionTool } from "groq-sdk/resources/chat/completions";
import type { FunctionParameters } from "groq-sdk/resources/shared";
import type OpenAI from "openai";
import { formatOpenAILikeMultimodalPrompt } from "../openai/openai_format.js";
type ResponseInputItem = OpenAI.Responses.ResponseInputItem;
type EasyInputMessage = OpenAI.Responses.EasyInputMessage;
interface GroqDriverOptions extends DriverOptions {
apiKey: string;
endpoint_url?: string;
}
export class GroqDriver extends AbstractDriver<GroqDriverOptions, ChatCompletionMessageParam[]> {
static PROVIDER = "groq";
provider = GroqDriver.PROVIDER;
apiKey: string;
client: Groq;
endpointUrl?: string;
constructor(options: GroqDriverOptions) {
super(options);
this.apiKey = options.apiKey;
this.client = new Groq({
apiKey: options.apiKey,
baseURL: options.endpoint_url
});
}
// protected canStream(options: ExecutionOptions): Promise<boolean> {
// if (options.result_schema) {
// // not yet streaming json responses
// return Promise.resolve(false);
// } else {
// return Promise.resolve(true);
// }
// }
getResponseFormat(_options: ExecutionOptions): undefined {
//TODO: when forcing json_object type the streaming is not supported.
// either implement canStream as above or comment the code below:
// const responseFormatJson: Groq.Chat.Completions.CompletionCreateParams.ResponseFormat = {
// type: "json_object",
// }
// return _options.result_schema ? responseFormatJson : undefined;
return undefined;
}
protected async formatPrompt(segments: PromptSegment[], opts: ExecutionOptions): Promise<ChatCompletionMessageParam[]> {
// Use OpenAI's multimodal formatter as base then convert to Groq types
const responseItems = await formatOpenAILikeMultimodalPrompt(segments, {
...opts,
multimodal: true,
});
// Convert ResponseInputItem[] to Groq ChatCompletionMessageParam[]
return convertResponseItemsToGroqMessages(responseItems);
}
private getToolDefinitions(tools: ToolDefinition[] | undefined): ChatCompletionTool[] | undefined {
if (!tools || tools.length === 0) {
return undefined;
}
return tools.map(tool => ({
type: 'function' as const,
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema satisfies FunctionParameters,
}
}));
}
private extractToolUse(message: any): ToolUse[] | undefined {
if (!message.tool_calls || message.tool_calls.length === 0) {
return undefined;
}
return message.tool_calls.map((toolCall: any) => ({
id: toolCall.id,
tool_name: toolCall.function.name,
tool_input: JSON.parse(toolCall.function.arguments || '{}'),
}));
}
private sanitizeMessagesForGroq(messages: ChatCompletionMessageParam[]): ChatCompletionMessageParam[] {
return messages.map(message => {
// Remove any reasoning field from message objects
const { reasoning, ...sanitizedMessage } = message as any;
// If message has content array, filter out reasoning content types
if (Array.isArray(sanitizedMessage.content)) {
sanitizedMessage.content = sanitizedMessage.content.filter((part: any) => {
// Filter out any reasoning-related content parts
return part.type !== 'reasoning' && !('reasoning' in part);
});
}
return sanitizedMessage as ChatCompletionMessageParam;
});
}
async requestTextCompletion(messages: ChatCompletionMessageParam[], options: ExecutionOptions): Promise<Completion> {
if (options.model_options?._option_id !== undefined &&
options.model_options?._option_id !== "text-fallback" &&
options.model_options?._option_id !== "groq-deepseek-thinking") {
this.logger.debug({ options: options.model_options }, "Unexpected option id");
}
options.model_options = options.model_options as TextFallbackOptions;
// Update conversation with current messages
let conversation = updateConversation(options.conversation as ChatCompletionMessageParam[], messages);
// Filter out any reasoning content that Groq doesn't support
conversation = this.sanitizeMessagesForGroq(conversation);
const tools = this.getToolDefinitions(options.tools);
const res = await this.client.chat.completions.create({
model: options.model,
messages: conversation,
max_completion_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
top_p: options.model_options?.top_p,
//top_logprobs: options.top_logprobs, //Logprobs output currently not supported
//logprobs: options.top_logprobs ? true : false,
presence_penalty: options.model_options?.presence_penalty,
frequency_penalty: options.model_options?.frequency_penalty,
response_format: this.getResponseFormat(options),
tools: tools,
});
const choice = res.choices[0];
const result = choice.message.content;
// Extract tool use from the response
const tool_use = this.extractToolUse(choice.message);
// Update conversation with the response
conversation = updateConversation(conversation, [choice.message]);
let finish_reason = choice.finish_reason;
if (tool_use && tool_use.length > 0) {
finish_reason = "tool_calls";
}
return {
result: result ? [{ type: "text", value: result }] : [],
token_usage: {
prompt: res.usage?.prompt_tokens,
result: res.usage?.completion_tokens,
total: res.usage?.total_tokens,
},
finish_reason: finish_reason,
original_response: options.include_original_response ? res : undefined,
conversation,
tool_use,
};
}
async requestTextCompletionStream(messages: ChatCompletionMessageParam[], options: ExecutionOptions): Promise<AsyncIterable<CompletionChunkObject>> {
if (options.model_options?._option_id !== undefined && options.model_options?._option_id !== "text-fallback") {
this.logger.debug({ options: options.model_options }, "Unexpected option id");
}
options.model_options = options.model_options as TextFallbackOptions;
// Update conversation with current messages
let conversation = updateConversation(options.conversation as ChatCompletionMessageParam[], messages);
// Filter out any reasoning content that Groq doesn't support
conversation = this.sanitizeMessagesForGroq(conversation);
const tools = this.getToolDefinitions(options.tools);
const res = await this.client.chat.completions.create({
model: options.model,
messages: conversation,
max_completion_tokens: options.model_options?.max_tokens,
temperature: options.model_options?.temperature,
top_p: options.model_options?.top_p,
//top_logprobs: options.top_logprobs, //Logprobs output currently not supported
//logprobs: options.top_logprobs ? true : false,
presence_penalty: options.model_options?.presence_penalty,
frequency_penalty: options.model_options?.frequency_penalty,
stream: true,
tools: tools,
});
return transformAsyncIterator(res, (chunk) => {
const choice = chunk.choices[0];
let finish_reason = choice.finish_reason;
// Check for tool calls in the delta
if (choice.delta.tool_calls && choice.delta.tool_calls.length > 0) {
finish_reason = "tool_calls";
}
return {
result: choice.delta.content ? [{ type: "text", value: choice.delta.content }] : [],
finish_reason: finish_reason ?? undefined,
token_usage: {
prompt: chunk.x_groq?.usage?.prompt_tokens,
result: chunk.x_groq?.usage?.completion_tokens,
total: chunk.x_groq?.usage?.total_tokens,
},
} satisfies CompletionChunkObject;
});
}
async listModels(): Promise<AIModel<string>[]> {
const models = await this.client.models.list();
if (!models.data) {
throw new Error("No models found");
}
const aiModels = models.data?.map(m => {
if (!m.id) {
throw new Error("Model id is missing");
}
return {
id: m.id,
name: m.id,
description: undefined,
provider: this.provider,
owner: m.owned_by || '',
}
});
return aiModels;
}
validateConnection(): Promise<boolean> {
throw new Error("Method not implemented.");
}
async generateEmbeddings({ }: EmbeddingsOptions): Promise<EmbeddingsResult> {
throw new Error("Method not implemented.");
}
}
/**
* Update the conversation messages by combining existing conversation with new messages
* @param conversation Existing conversation messages
* @param messages New messages to add
* @returns Combined conversation
*/
function updateConversation(
conversation: ChatCompletionMessageParam[] | undefined,
messages: ChatCompletionMessageParam[]
): ChatCompletionMessageParam[] {
return (conversation || []).concat(messages);
}
/**
* Convert ResponseInputItem[] to Groq ChatCompletionMessageParam[]
*/
function convertResponseItemsToGroqMessages(items: ResponseInputItem[]): ChatCompletionMessageParam[] {
const messages: ChatCompletionMessageParam[] = [];
for (const item of items) {
// Handle EasyInputMessage (has role and content)
if ('role' in item && 'content' in item) {
const msg = item as EasyInputMessage;
const role = msg.role;
// Handle system/developer messages
if (role === 'system' || role === 'developer') {
let content: string;
if (typeof msg.content === 'string') {
content = msg.content;
} else if (Array.isArray(msg.content)) {
content = msg.content
.filter((part): part is OpenAI.Responses.ResponseInputText => part.type === 'input_text')
.map(part => part.text)
.join('\n');
} else {
content = '';
}
messages.push({ role: 'system', content });
continue;
}
// Handle user messages
if (role === 'user') {
let content: string | Array<{ type: 'text', text: string } | { type: 'image_url', image_url: { url: string, detail?: 'auto' | 'low' | 'high' } }>;
if (typeof msg.content === 'string') {
content = msg.content;
} else if (Array.isArray(msg.content)) {
const parts: Array<{ type: 'text', text: string } | { type: 'image_url', image_url: { url: string, detail?: 'auto' | 'low' | 'high' } }> = [];
for (const part of msg.content) {
if (part.type === 'input_text') {
parts.push({ type: 'text', text: part.text });
} else if (part.type === 'input_image') {
const imgPart = part as OpenAI.Responses.ResponseInputImage;
if (imgPart.image_url) {
const image_url: { url: string; detail?: 'auto' | 'low' | 'high' } = {
url: imgPart.image_url
};
if (imgPart.detail) {
image_url.detail = imgPart.detail as 'auto' | 'low' | 'high';
}
parts.push({
type: 'image_url',
image_url
});
}
}
}
content = parts.length > 0 ? parts : '';
} else {
content = '';
}
messages.push({ role: 'user', content });
continue;
}
// Handle assistant messages
if (role === 'assistant') {
let content: string | null;
if (typeof msg.content === 'string') {
content = msg.content;
} else if (Array.isArray(msg.content)) {
content = msg.content
.filter((part): part is OpenAI.Responses.ResponseInputText => part.type === 'input_text')
.map(part => part.text)
.join('\n') || null;
} else {
content = null;
}
messages.push({ role: 'assistant', content });
continue;
}
}
// Handle function_call_output (tool response)
if ('type' in item && item.type === 'function_call_output') {
const output = item as OpenAI.Responses.ResponseInputItem.FunctionCallOutput;
messages.push({
role: 'tool',
tool_call_id: output.call_id,
content: typeof output.output === 'string' ? output.output : JSON.stringify(output.output),
});
continue;
}
// Handle function_call (assistant tool call)
if ('type' in item && item.type === 'function_call') {
const call = item as OpenAI.Responses.ResponseFunctionToolCall;
// Groq expects tool_calls in assistant message, but we handle them separately
// This is a simplification - in practice tool_calls come from model responses
messages.push({
role: 'assistant',
content: null,
tool_calls: [{
id: call.call_id,
type: 'function',
function: {
name: call.name,
arguments: call.arguments,
}
}]
});
continue;
}
}
return messages;
}