UNPKG

@copilotkit/runtime

Version:

<img src="https://github.com/user-attachments/assets/0a6b64d9-e193-4940-a3f6-60334ac34084" alt="banner" style="border-radius: 12px; border: 2px solid #d6d4fa;" />

204 lines (184 loc) • 6.48 kB
/** * Copilot Runtime adapter for Groq. * * ## Example * * ```ts * import { CopilotRuntime, GroqAdapter } from "@copilotkit/runtime"; * import { Groq } from "groq-sdk"; * * const groq = new Groq({ apiKey: process.env["GROQ_API_KEY"] }); * * const copilotKit = new CopilotRuntime(); * * return new GroqAdapter({ groq, model: "<model-name>" }); * ``` */ import type { Groq } from "groq-sdk"; import type { ChatCompletionMessageParam } from "groq-sdk/resources/chat"; import { CopilotServiceAdapter, CopilotRuntimeChatCompletionRequest, CopilotRuntimeChatCompletionResponse, } from "../service-adapter"; import { convertActionInputToOpenAITool, convertMessageToOpenAIMessage, limitMessagesToTokenCount, } from "../openai/utils"; import { randomUUID } from "@copilotkit/shared"; import { convertServiceAdapterError } from "../shared"; const DEFAULT_MODEL = "llama-3.3-70b-versatile"; export interface GroqAdapterParams { /** * An optional Groq instance to use. */ groq?: Groq; /** * The model to use. */ model?: string; /** * Whether to disable parallel tool calls. * You can disable parallel tool calls to force the model to execute tool calls sequentially. * This is useful if you want to execute tool calls in a specific order so that the state changes * introduced by one tool call are visible to the next tool call. (i.e. new actions or readables) * * @default false */ disableParallelToolCalls?: boolean; } export class GroqAdapter implements CopilotServiceAdapter { public model: string = DEFAULT_MODEL; public provider = "groq"; private disableParallelToolCalls: boolean = false; private _groq: Groq; public get groq(): Groq { return this._groq; } public get name() { return "GroqAdapter"; } constructor(params?: GroqAdapterParams) { if (params?.groq) { this._groq = params.groq; } // If no instance provided, we'll lazy-load in ensureGroq() if (params?.model) { this.model = params.model; } this.disableParallelToolCalls = params?.disableParallelToolCalls || false; } private ensureGroq(): Groq { if (!this._groq) { // eslint-disable-next-line @typescript-eslint/no-var-requires const { Groq } = require("groq-sdk"); this._groq = new Groq({}); } return this._groq; } async process( request: CopilotRuntimeChatCompletionRequest, ): Promise<CopilotRuntimeChatCompletionResponse> { const { threadId, model = this.model, messages, actions, eventSource, forwardedParameters, } = request; const tools = actions.map(convertActionInputToOpenAITool); let openaiMessages = messages.map((m) => convertMessageToOpenAIMessage(m, { keepSystemRole: true }), ); openaiMessages = limitMessagesToTokenCount(openaiMessages, tools, model); let toolChoice: any = forwardedParameters?.toolChoice; if (forwardedParameters?.toolChoice === "function") { toolChoice = { type: "function", function: { name: forwardedParameters.toolChoiceFunctionName }, }; } let stream; try { const groq = this.ensureGroq(); stream = await groq.chat.completions.create({ model: model, stream: true, messages: openaiMessages as unknown as ChatCompletionMessageParam[], ...(tools.length > 0 && { tools }), ...(forwardedParameters?.maxTokens && { max_tokens: forwardedParameters.maxTokens, }), ...(forwardedParameters?.stop && { stop: forwardedParameters.stop }), ...(toolChoice && { tool_choice: toolChoice }), ...(this.disableParallelToolCalls && { parallel_tool_calls: false }), ...(forwardedParameters?.temperature && { temperature: forwardedParameters.temperature }), }); } catch (error) { throw convertServiceAdapterError(error, "Groq"); } eventSource.stream(async (eventStream$) => { let mode: "function" | "message" | null = null; let currentMessageId: string; let currentToolCallId: string; try { for await (const chunk of stream) { const toolCall = chunk.choices[0].delta.tool_calls?.[0]; const content = chunk.choices[0].delta.content; // When switching from message to function or vice versa, // send the respective end event. // If toolCall?.id is defined, it means a new tool call starts. if (mode === "message" && toolCall?.id) { mode = null; eventStream$.sendTextMessageEnd({ messageId: currentMessageId }); } else if (mode === "function" && (toolCall === undefined || toolCall?.id)) { mode = null; eventStream$.sendActionExecutionEnd({ actionExecutionId: currentToolCallId }); } // If we send a new message type, send the appropriate start event. if (mode === null) { if (toolCall?.id) { mode = "function"; currentToolCallId = toolCall!.id; eventStream$.sendActionExecutionStart({ actionExecutionId: currentToolCallId, actionName: toolCall!.function!.name, parentMessageId: chunk.id, }); } else if (content) { mode = "message"; currentMessageId = chunk.id; eventStream$.sendTextMessageStart({ messageId: currentMessageId }); } } // send the content events if (mode === "message" && content) { eventStream$.sendTextMessageContent({ messageId: currentMessageId, content, }); } else if (mode === "function" && toolCall?.function?.arguments) { eventStream$.sendActionExecutionArgs({ actionExecutionId: currentToolCallId, args: toolCall.function.arguments, }); } } // send the end events if (mode === "message") { eventStream$.sendTextMessageEnd({ messageId: currentMessageId }); } else if (mode === "function") { eventStream$.sendActionExecutionEnd({ actionExecutionId: currentToolCallId }); } } catch (error) { throw convertServiceAdapterError(error, "Groq"); } eventStream$.complete(); }); return { threadId: request.threadId || randomUUID(), }; } }