UNPKG

genkitx-ollama

Version:

Genkit AI framework plugin for Ollama APIs.

397 lines 10.6 kB
import { embedderRef, modelActionMetadata, z } from "genkit"; import { logger } from "genkit/logging"; import { GenerationCommonConfigDescriptions, GenerationCommonConfigSchema, getBasicUsageStats, modelRef } from "genkit/model"; import { genkitPlugin } from "genkit/plugin"; import { defineOllamaEmbedder } from "./embeddings.js"; const ANY_JSON_SCHEMA = { $schema: "http://json-schema.org/draft-07/schema#" }; const GENERIC_MODEL_INFO = { supports: { multiturn: true, media: true, tools: true, toolChoice: true, systemRole: true, constrained: "all" } }; const DEFAULT_OLLAMA_SERVER_ADDRESS = "http://localhost:11434"; async function initializer(ai, serverAddress, params) { params?.models?.map( (model) => defineOllamaModel(ai, model, serverAddress, params?.requestHeaders) ); params?.embedders?.map( (model) => defineOllamaEmbedder(ai, { name: model.name, modelName: model.name, dimensions: model.dimensions, options: params }) ); } function resolveAction(ai, actionType, actionName, serverAddress, requestHeaders) { if (actionType === "model") { defineOllamaModel( ai, { name: actionName }, serverAddress, requestHeaders ); } } async function listActions(serverAddress, requestHeaders) { const models = await listLocalModels(serverAddress, requestHeaders); return models?.filter((m) => m.model && !m.model.includes("embed")).map( (m) => modelActionMetadata({ name: `ollama/${m.model}`, info: GENERIC_MODEL_INFO }) ) || []; } function ollamaPlugin(params) { if (!params) { params = {}; } if (!params.serverAddress) { params.serverAddress = DEFAULT_OLLAMA_SERVER_ADDRESS; } const serverAddress = params.serverAddress; return genkitPlugin( "ollama", async (ai) => { await initializer(ai, serverAddress, params); }, async (ai, actionType, actionName) => { resolveAction( ai, actionType, actionName, serverAddress, params?.requestHeaders ); }, async () => await listActions(serverAddress, params?.requestHeaders) ); } async function listLocalModels(serverAddress, requestHeaders) { let res; try { res = await fetch(serverAddress + "/api/tags", { method: "GET", headers: { "Content-Type": "application/json", ...await getHeaders(serverAddress, requestHeaders) } }); } catch (e) { throw new Error(`Make sure the Ollama server is running.`, { cause: e }); } const modelResponse = JSON.parse(await res.text()); return modelResponse.models; } const OllamaConfigSchema = GenerationCommonConfigSchema.extend({ temperature: z.number().min(0).max(1).describe( GenerationCommonConfigDescriptions.temperature + " The default value is 0.8." ).optional(), topK: z.number().describe( GenerationCommonConfigDescriptions.topK + " The default value is 40." ).optional(), topP: z.number().min(0).max(1).describe( GenerationCommonConfigDescriptions.topP + " The default value is 0.9." ).optional() }); function defineOllamaModel(ai, model, serverAddress, requestHeaders) { return ai.defineModel( { name: `ollama/${model.name}`, label: `Ollama - ${model.name}`, configSchema: OllamaConfigSchema, supports: { multiturn: !model.type || model.type === "chat", systemRole: true, tools: model.supports?.tools } }, async (input, streamingCallback) => { const { topP, topK, stopSequences, maxOutputTokens, ...rest } = input.config; const options = { ...rest }; if (topP !== void 0) { options.top_p = topP; } if (topK !== void 0) { options.top_k = topK; } if (stopSequences !== void 0) { options.stop = stopSequences.join(""); } if (maxOutputTokens !== void 0) { options.num_predict = maxOutputTokens; } const type = model.type ?? "chat"; const request = toOllamaRequest( model.name, input, options, type, !!streamingCallback ); logger.debug(request, `ollama request (${type})`); const extraHeaders = await getHeaders( serverAddress, requestHeaders, model, input ); let res; try { res = await fetch( serverAddress + (type === "chat" ? "/api/chat" : "/api/generate"), { method: "POST", body: JSON.stringify(request), headers: { "Content-Type": "application/json", ...extraHeaders } } ); } catch (e) { const cause = e.cause; if (cause && cause instanceof Error && cause.message?.includes("ECONNREFUSED")) { cause.message += ". Make sure the Ollama server is running."; throw cause; } throw e; } if (!res.body) { throw new Error("Response has no body"); } let message; if (streamingCallback) { const reader = res.body.getReader(); const textDecoder = new TextDecoder(); let textResponse = ""; for await (const chunk of readChunks(reader)) { const chunkText = textDecoder.decode(chunk); const json = JSON.parse(chunkText); const message2 = parseMessage(json, type); streamingCallback({ index: 0, content: message2.content }); textResponse += message2.content[0].text; } message = { role: "model", content: [ { text: textResponse } ] }; } else { const txtBody = await res.text(); const json = JSON.parse(txtBody); logger.debug(txtBody, "ollama raw response"); message = parseMessage(json, type); } return { message, usage: getBasicUsageStats(input.messages, message), finishReason: "stop" }; } ); } function parseMessage(response, type) { if (response.error) { throw new Error(response.error); } if (type === "chat") { if (response.message.tool_calls && response.message.tool_calls.length > 0) { return { role: toGenkitRole(response.message.role), content: toGenkitToolRequest(response.message.tool_calls) }; } else { return { role: toGenkitRole(response.message.role), content: [ { text: response.message.content } ] }; } } else { return { role: "model", content: [ { text: response.response } ] }; } } async function getHeaders(serverAddress, requestHeaders, model, input) { return requestHeaders ? typeof requestHeaders === "function" ? await requestHeaders( { serverAddress, model }, input ) : requestHeaders : {}; } function toOllamaRequest(name, input, options, type, stream) { const request = { model: name, options, stream, tools: input.tools?.filter(isValidOllamaTool).map(toOllamaTool) }; if (type === "chat") { const messages = []; input.messages.forEach((m) => { let messageText = ""; const role = toOllamaRole(m.role); const images = []; const toolRequests = []; const toolResponses = []; m.content.forEach((c) => { if (c.text) { messageText += c.text; } if (c.media) { let imageUri = c.media.url; if (imageUri.startsWith("data:")) { imageUri = imageUri.substring(imageUri.indexOf(",") + 1); } images.push(imageUri); } if (c.toolRequest) { toolRequests.push(c.toolRequest); } if (c.toolResponse) { toolResponses.push(c.toolResponse); } }); toolResponses.forEach((t) => { messages.push({ role, content: typeof t.output === "string" ? t.output : JSON.stringify(t.output) }); }); messages.push({ role, content: toolRequests.length > 0 ? "" : messageText, images: images.length > 0 ? images : void 0, tool_calls: toolRequests.length > 0 ? toOllamaToolCall(toolRequests) : void 0 }); }); request.messages = messages; } else { request.prompt = getPrompt(input); request.system = getSystemMessage(input); } return request; } function toOllamaRole(role) { if (role === "model") { return "assistant"; } return role; } function toGenkitRole(role) { if (role === "assistant") { return "model"; } return role; } function toOllamaTool(tool) { return { type: "function", function: { name: tool.name, description: tool.description, parameters: tool.inputSchema ?? ANY_JSON_SCHEMA } }; } function toOllamaToolCall(toolRequests) { return toolRequests.map((t) => ({ function: { name: t.name, // This should be safe since we already filtered tools that don't accept // objects arguments: t.input } })); } function toGenkitToolRequest(tool_calls) { return tool_calls.map((t) => ({ toolRequest: { name: t.function.name, ref: t.function.index ? t.function.index.toString() : void 0, input: t.function.arguments } })); } function readChunks(reader) { return { async *[Symbol.asyncIterator]() { let readResult = await reader.read(); while (!readResult.done) { yield readResult.value; readResult = await reader.read(); } } }; } function getPrompt(input) { return input.messages.filter((m) => m.role !== "system").map((m) => m.content.map((c) => c.text).join()).join(); } function getSystemMessage(input) { return input.messages.filter((m) => m.role === "system").map((m) => m.content.map((c) => c.text).join()).join(); } function isValidOllamaTool(tool) { if (tool.inputSchema?.type !== "object") { throw new Error( `Unsupported tool: '${tool.name}'. Ollama only supports tools with object inputs` ); } return true; } const ollama = ollamaPlugin; ollama.model = (name, config) => { return modelRef({ name: `ollama/${name}`, config, configSchema: OllamaConfigSchema }); }; ollama.embedder = (name, config) => { return embedderRef({ name: `ollama/${name}`, config }); }; export { OllamaConfigSchema, ollama }; //# sourceMappingURL=index.mjs.map