UNPKG

ag-ui-cloudflare

Version:

Native AG-UI protocol implementation for Cloudflare Workers AI - Enable CopilotKit with edge AI at 93% lower cost

656 lines (645 loc) 19.4 kB
// src/stream-parser.ts import { createParser } from "eventsource-parser"; var CloudflareStreamParser = class { constructor() { this.buffer = ""; this.parser = createParser(this.onParse.bind(this)); } onParse(event) { if (event.type === "event") { try { const data = JSON.parse(event.data); return data; } catch (error) { console.error("Failed to parse stream chunk:", error); } } } async *parseStream(stream) { const reader = stream.getReader(); const decoder = new TextDecoder(); try { while (true) { const { done, value } = await reader.read(); if (done) break; const chunk = decoder.decode(value, { stream: true }); const lines = chunk.split("\n"); for (const line of lines) { if (line.startsWith("data: ")) { const data = line.slice(6).trim(); if (data === "[DONE]") { yield { done: true }; return; } if (data) { try { const parsed = JSON.parse(data); if (parsed.choices?.[0]?.delta) { const delta = parsed.choices[0].delta; if (delta.content) { yield { response: delta.content, done: false }; } if (delta.tool_calls) { yield { tool_calls: delta.tool_calls, done: false }; } } else if (parsed.response) { yield { response: parsed.response, done: false }; } if (parsed.choices?.[0]?.finish_reason) { yield { done: true, usage: parsed.usage }; } } catch (error) { this.buffer += data; try { const parsed = JSON.parse(this.buffer); if (parsed.choices?.[0]?.delta?.content) { yield { response: parsed.choices[0].delta.content, done: false }; } this.buffer = ""; } catch { } } } } } } } finally { reader.releaseLock(); } } parseSSE(text) { const chunks = []; const lines = text.split("\n"); for (const line of lines) { if (line.startsWith("data: ")) { const data = line.slice(6); if (data !== "[DONE]") { try { chunks.push(JSON.parse(data)); } catch (error) { console.error("Failed to parse SSE chunk:", error); } } } } return chunks; } }; // src/client.ts var CloudflareAIClient = class { constructor(config) { this.config = config; if (config.gatewayId) { this.baseURL = config.baseURL || `https://gateway.ai.cloudflare.com/v1/${config.accountId}/${config.gatewayId}/workers-ai/v1`; } else { this.baseURL = config.baseURL || `https://api.cloudflare.com/client/v4/accounts/${config.accountId}/ai/v1`; } this.headers = { "Authorization": `Bearer ${config.apiToken}`, "Content-Type": "application/json" }; } async complete(options) { const response = await this.makeRequest(options); if (!response.ok) { const error = await response.text(); throw new Error(`Cloudflare AI error: ${error}`); } const data = await response.json(); return { role: "assistant", content: data.result?.response || data.response || "", tool_calls: data.result?.tool_calls || data.tool_calls }; } async *streamComplete(options) { const streamOptions = { ...options, stream: true }; const response = await this.makeRequest(streamOptions); if (!response.ok) { const error = await response.text(); throw new Error(`Cloudflare AI error: ${error}`); } const contentType = response.headers.get("content-type"); if (contentType?.includes("text/event-stream")) { if (!response.body) { throw new Error("No response body from Cloudflare AI"); } const parser = new CloudflareStreamParser(); yield* parser.parseStream(response.body); } else { const data = await response.json(); if (data.choices) { for (const choice of data.choices) { if (choice.message?.content) { yield { response: choice.message.content, done: false }; } } yield { done: true, usage: data.usage }; } else if (data.result) { const content = data.result.response || ""; if (content) { yield { response: content, done: false }; } yield { done: true, usage: data.result.usage || { prompt_tokens: 0, completion_tokens: content.length / 4, // Estimate total_tokens: content.length / 4 } }; } else { console.warn("Unexpected response format:", data); yield { done: true }; } } } async makeRequest(options) { const model = options.model || this.config.model || "@cf/meta/llama-3.1-8b-instruct"; const endpoint = `${this.baseURL}/chat/completions`; const body = { model, messages: options.messages, temperature: options.temperature, max_tokens: options.max_tokens, top_p: options.top_p, frequency_penalty: options.frequency_penalty, presence_penalty: options.presence_penalty, stream: options.stream || false, tools: options.tools, tool_choice: options.tool_choice }; Object.keys(body).forEach((key) => { if (body[key] === void 0) { delete body[key]; } }); return fetch(endpoint, { method: "POST", headers: this.headers, body: JSON.stringify(body) }); } async listModels() { const response = await fetch(`${this.baseURL}/models`, { headers: this.headers }); if (!response.ok) { throw new Error(`Failed to list models: ${response.statusText}`); } const data = await response.json(); return data.result?.models || data.models || []; } getModelCapabilities(model) { const capabilities = { "@cf/meta/llama-3.3-70b-instruct": { streaming: true, functionCalling: true, maxTokens: 4096, contextWindow: 128e3 }, "@cf/meta/llama-3.1-70b-instruct": { streaming: true, functionCalling: false, maxTokens: 4096, contextWindow: 128e3 }, "@cf/meta/llama-3.1-8b-instruct": { streaming: true, functionCalling: false, maxTokens: 2048, contextWindow: 128e3 }, "@cf/mistral/mistral-7b-instruct-v0.2": { streaming: true, functionCalling: false, maxTokens: 2048, contextWindow: 32768 } }; return capabilities[model] || { streaming: true, functionCalling: false, maxTokens: 2048, contextWindow: 4096 }; } }; // src/events.ts var EventType = /* @__PURE__ */ ((EventType3) => { EventType3["TEXT_MESSAGE_START"] = "TEXT_MESSAGE_START"; EventType3["TEXT_MESSAGE_CONTENT"] = "TEXT_MESSAGE_CONTENT"; EventType3["TEXT_MESSAGE_END"] = "TEXT_MESSAGE_END"; EventType3["TOOL_CALL_START"] = "TOOL_CALL_START"; EventType3["TOOL_CALL_ARGS"] = "TOOL_CALL_ARGS"; EventType3["TOOL_CALL_END"] = "TOOL_CALL_END"; EventType3["TOOL_CALL_RESULT"] = "TOOL_CALL_RESULT"; EventType3["RUN_STARTED"] = "RUN_STARTED"; EventType3["RUN_FINISHED"] = "RUN_FINISHED"; EventType3["RUN_ERROR"] = "RUN_ERROR"; EventType3["STEP_STARTED"] = "STEP_STARTED"; EventType3["STEP_FINISHED"] = "STEP_FINISHED"; EventType3["STATE_SYNC"] = "STATE_SYNC"; EventType3["METADATA"] = "METADATA"; EventType3["PROGRESS"] = "PROGRESS"; EventType3["CUSTOM"] = "CUSTOM"; return EventType3; })(EventType || {}); var CloudflareAGUIEvents = class { static runStarted(runId, metadata) { return { type: "RUN_STARTED" /* RUN_STARTED */, runId, timestamp: Date.now(), metadata }; } static runFinished(runId, metadata) { return { type: "RUN_FINISHED" /* RUN_FINISHED */, runId, timestamp: Date.now(), metadata }; } static textMessageStart(runId, role) { return { type: "TEXT_MESSAGE_START" /* TEXT_MESSAGE_START */, runId, timestamp: Date.now(), data: { role } }; } static textMessageContent(runId, delta) { return { type: "TEXT_MESSAGE_CONTENT" /* TEXT_MESSAGE_CONTENT */, runId, timestamp: Date.now(), data: { delta } }; } static textMessageEnd(runId) { return { type: "TEXT_MESSAGE_END" /* TEXT_MESSAGE_END */, runId, timestamp: Date.now() }; } static toolCallStart(runId, toolCallId, toolName) { return { type: "TOOL_CALL_START" /* TOOL_CALL_START */, runId, timestamp: Date.now(), data: { toolCallId, toolName } }; } static toolCallArgs(runId, toolCallId, args) { return { type: "TOOL_CALL_ARGS" /* TOOL_CALL_ARGS */, runId, timestamp: Date.now(), data: { toolCallId, args } }; } static toolCallEnd(runId, toolCallId) { return { type: "TOOL_CALL_END" /* TOOL_CALL_END */, runId, timestamp: Date.now(), data: { toolCallId } }; } static toolCallResult(runId, toolCallId, result) { return { type: "TOOL_CALL_RESULT" /* TOOL_CALL_RESULT */, runId, timestamp: Date.now(), data: { toolCallId, result } }; } static error(runId, error) { return { type: "RUN_ERROR" /* RUN_ERROR */, runId, timestamp: Date.now(), data: { message: error.message, stack: error.stack, name: error.name } }; } static stepStarted(runId, stepName) { return { type: "STEP_STARTED" /* STEP_STARTED */, runId, timestamp: Date.now(), data: { stepName } }; } static stepFinished(runId, stepName) { return { type: "STEP_FINISHED" /* STEP_FINISHED */, runId, timestamp: Date.now(), data: { stepName } }; } static stateSync(runId, state) { return { type: "STATE_SYNC" /* STATE_SYNC */, runId, timestamp: Date.now(), data: { state } }; } static metadata(runId, metadata) { return { type: "METADATA" /* METADATA */, runId, timestamp: Date.now(), data: metadata }; } static progress(runId, progress, message) { return { type: "PROGRESS" /* PROGRESS */, runId, timestamp: Date.now(), data: { progress, message } }; } static custom(runId, name, value) { return { type: "CUSTOM" /* CUSTOM */, runId, timestamp: Date.now(), data: { name, value } }; } }; // src/adapter.ts var CloudflareAGUIAdapter = class _CloudflareAGUIAdapter { constructor(options) { this.runCounter = 0; this.options = options; this.client = new CloudflareAIClient(options); } async *execute(messages, context) { const runId = this.generateRunId(); try { yield CloudflareAGUIEvents.runStarted(runId, { model: this.options.model, messageCount: messages.length, ...context }); const allMessages = this.options.systemPrompt ? [{ role: "system", content: this.options.systemPrompt }, ...messages] : messages; const completionOptions = { messages: allMessages, model: this.options.model, tools: this.options.tools, stream: this.options.streamingEnabled !== false }; if (this.options.streamingEnabled !== false) { yield* this.handleStreaming(runId, completionOptions); } else { yield* this.handleNonStreaming(runId, completionOptions); } yield CloudflareAGUIEvents.runFinished(runId); } catch (error) { yield CloudflareAGUIEvents.error(runId, error); throw error; } } async *handleStreaming(runId, options) { let messageStarted = false; let toolCallsInProgress = /* @__PURE__ */ new Map(); let accumulatedContent = ""; yield CloudflareAGUIEvents.textMessageStart(runId, "assistant"); messageStarted = true; for await (const chunk of this.client.streamComplete(options)) { if (chunk.response) { if (!messageStarted) { yield CloudflareAGUIEvents.textMessageStart(runId, "assistant"); messageStarted = true; } yield CloudflareAGUIEvents.textMessageContent(runId, chunk.response); accumulatedContent += chunk.response; } if (chunk.tool_calls) { for (const toolCall of chunk.tool_calls) { if (!toolCallsInProgress.has(toolCall.id)) { yield CloudflareAGUIEvents.toolCallStart( runId, toolCall.id, toolCall.function.name ); toolCallsInProgress.set(toolCall.id, ""); } yield CloudflareAGUIEvents.toolCallArgs( runId, toolCall.id, toolCall.function.arguments ); } } if (chunk.done) { if (messageStarted) { yield CloudflareAGUIEvents.textMessageEnd(runId); } for (const [toolCallId] of toolCallsInProgress) { yield CloudflareAGUIEvents.toolCallEnd(runId, toolCallId); } if (chunk.usage) { yield CloudflareAGUIEvents.metadata(runId, { usage: chunk.usage, model: options.model }); } } } } async *handleNonStreaming(runId, options) { const response = await this.client.complete(options); if (response.content) { yield CloudflareAGUIEvents.textMessageStart(runId, "assistant"); yield CloudflareAGUIEvents.textMessageContent(runId, response.content); yield CloudflareAGUIEvents.textMessageEnd(runId); } if (response.tool_calls) { for (const toolCall of response.tool_calls) { yield CloudflareAGUIEvents.toolCallStart(runId, toolCall.id, toolCall.function.name); yield CloudflareAGUIEvents.toolCallArgs(runId, toolCall.id, toolCall.function.arguments); yield CloudflareAGUIEvents.toolCallEnd(runId, toolCall.id); } } } async *executeWithTools(messages, tools, context) { const updatedOptions = { ...this.options, tools }; const adapter = new _CloudflareAGUIAdapter(updatedOptions); yield* adapter.execute(messages, context); } async *progressiveGeneration(prompt, stages) { const runId = this.generateRunId(); yield CloudflareAGUIEvents.runStarted(runId, { stages: stages.length }); let allContent = ""; const totalStages = stages.length; for (let i = 0; i < stages.length; i++) { const stage = stages[i]; yield CloudflareAGUIEvents.progress( runId, (i + 1) / totalStages * 100, `Processing: ${stage.name}` ); const stagePrompt = i === 0 ? `${prompt} ${stage.instruction}` : `${prompt} Previous research/content: ${allContent} Now, ${stage.instruction}`; const messages = [ { role: "user", content: stagePrompt } ]; const completionOptions = { messages, model: this.options.model, stream: true }; yield CloudflareAGUIEvents.textMessageStart(runId, "assistant"); let stageContent = ""; for await (const chunk of this.client.streamComplete(completionOptions)) { if (chunk.response) { const event = CloudflareAGUIEvents.textMessageContent(runId, chunk.response); yield event; stageContent += chunk.response; } if (chunk.done && chunk.usage) { yield CloudflareAGUIEvents.metadata(runId, { stage: stage.name, usage: chunk.usage }); } } allContent += ` ## ${stage.name} ${stageContent}`; yield CloudflareAGUIEvents.textMessageEnd(runId); } yield CloudflareAGUIEvents.runFinished(runId); } setModel(model) { this.options.model = model; } getCapabilities() { return this.client.getModelCapabilities( this.options.model || "@cf/meta/llama-3.1-8b-instruct" ); } async listAvailableModels() { return this.client.listModels(); } generateRunId() { return `cf-run-${Date.now()}-${++this.runCounter}`; } }; // src/providers.ts var CloudflareProviders = class _CloudflareProviders { static llama3_8b(config) { return new CloudflareAGUIAdapter({ ...config, model: config.model || "@cf/meta/llama-3.1-8b-instruct" }); } static llama3_70b(config) { return new CloudflareAGUIAdapter({ ...config, model: config.model || "@cf/meta/llama-3.1-70b-instruct" }); } static llama3_3_70b(config) { return new CloudflareAGUIAdapter({ ...config, model: config.model || "@cf/meta/llama-3.3-70b-instruct" }); } static mistral7b(config) { return new CloudflareAGUIAdapter({ ...config, model: config.model || "@cf/mistral/mistral-7b-instruct-v0.2" }); } static gemma7b(config) { return new CloudflareAGUIAdapter({ ...config, model: config.model || "@cf/google/gemma-7b-it" }); } static qwen14b(config) { return new CloudflareAGUIAdapter({ ...config, model: config.model || "@cf/qwen/qwen1.5-14b-chat-awq" }); } static phi2(config) { return new CloudflareAGUIAdapter({ ...config, model: config.model || "@cf/microsoft/phi-2" }); } static deepseekMath(config) { return new CloudflareAGUIAdapter({ ...config, model: config.model || "@cf/deepseek-ai/deepseek-math-7b-instruct" }); } static deepseekCoder(config) { return new CloudflareAGUIAdapter({ ...config, model: config.model || "@cf/thebloke/deepseek-coder-6.7b-instruct-awq" }); } static auto(config) { const needsFunctionCalling = config.tools && config.tools.length > 0; if (needsFunctionCalling) { return _CloudflareProviders.llama3_3_70b(config); } return _CloudflareProviders.llama3_8b(config); } static createWithGateway(accountId, apiToken, gatewayId, model) { return new CloudflareAGUIAdapter({ accountId, apiToken, gatewayId, model: model || "@cf/meta/llama-3.1-8b-instruct" }); } }; // src/index.ts var CLOUDFLARE_MODELS = { LLAMA_3_1_8B: "@cf/meta/llama-3.1-8b-instruct", LLAMA_3_1_70B: "@cf/meta/llama-3.1-70b-instruct", LLAMA_3_3_70B: "@cf/meta/llama-3.3-70b-instruct", LLAMA_2_7B: "@cf/meta/llama-2-7b-chat-int8", MISTRAL_7B: "@cf/mistral/mistral-7b-instruct-v0.2", GEMMA_7B: "@cf/google/gemma-7b-it", QWEN_14B: "@cf/qwen/qwen1.5-14b-chat-awq", PHI_2: "@cf/microsoft/phi-2", DEEPSEEK_MATH_7B: "@cf/deepseek-ai/deepseek-math-7b-instruct", DEEPSEEK_CODER_6B: "@cf/thebloke/deepseek-coder-6.7b-instruct-awq" }; function createCloudflareAdapter(config) { return new CloudflareAGUIAdapter(config); } export { CLOUDFLARE_MODELS, CloudflareAGUIAdapter, CloudflareAGUIEvents, CloudflareAIClient, CloudflareProviders, CloudflareStreamParser, EventType, createCloudflareAdapter };