UNPKG

@genkit-ai/ai

Version:

Genkit AI framework generative AI APIs.

308 lines 9.59 kB
import { defineAction, GenkitError, getStreamingCallback, runWithStreamingCallback, stripUndefinedProps } from "@genkit-ai/core"; import { logger } from "@genkit-ai/core/logging"; import { runInNewSpan, SPAN_TYPE_ATTR } from "@genkit-ai/core/tracing"; import { injectInstructions, resolveFormat, resolveInstructions } from "../formats/index.js"; import { GenerateResponse, GenerateResponseChunk, GenerationResponseError, tagAsPreamble } from "../generate.js"; import { GenerateActionOptionsSchema, GenerateResponseChunkSchema, GenerateResponseSchema, resolveModel } from "../model.js"; import { resolveTools, toToolDefinition } from "../tool.js"; import { assertValidToolNames, resolveResumeOption, resolveToolRequests } from "./resolve-tool-requests.js"; function defineGenerateAction(registry) { return defineAction( registry, { actionType: "util", name: "generate", inputSchema: GenerateActionOptionsSchema, outputSchema: GenerateResponseSchema, streamSchema: GenerateResponseChunkSchema }, async (request, { sendChunk }) => { const generateFn = () => generate(registry, { rawRequest: request, currentTurn: 0, messageIndex: 0, // Generate util action does not support middleware. Maybe when we add named/registered middleware.... middleware: [] }); return sendChunk ? runWithStreamingCallback( registry, (c) => sendChunk(c.toJSON ? c.toJSON() : c), generateFn ) : generateFn(); } ); } async function generateHelper(registry, options) { let currentTurn = options.currentTurn ?? 0; let messageIndex = options.messageIndex ?? 0; return await runInNewSpan( registry, { metadata: { name: "generate" }, labels: { [SPAN_TYPE_ATTR]: "util" } }, async (metadata) => { metadata.name = "generate"; metadata.input = options.rawRequest; const output = await generate(registry, { rawRequest: options.rawRequest, middleware: options.middleware, currentTurn, messageIndex }); metadata.output = JSON.stringify(output); return output; } ); } async function resolveParameters(registry, request) { const [model, tools, format] = await Promise.all([ resolveModel(registry, request.model, { warnDeprecated: true }).then( (r) => r.modelAction ), resolveTools(registry, request.tools), resolveFormat(registry, request.output) ]); return { model, tools, format }; } function applyFormat(rawRequest, resolvedFormat) { const outRequest = { ...rawRequest }; if (rawRequest.output?.jsonSchema && !rawRequest.output?.format) { outRequest.output = { ...rawRequest.output, format: "json" }; } const instructions = resolveInstructions( resolvedFormat, outRequest.output?.jsonSchema, outRequest?.output?.instructions ); if (resolvedFormat) { if (shouldInjectFormatInstructions(resolvedFormat.config, rawRequest?.output)) { outRequest.messages = injectInstructions( outRequest.messages, instructions ); } outRequest.output = { // use output config from the format ...resolvedFormat.config, // if anything is set explicitly, use that ...outRequest.output }; } return outRequest; } function shouldInjectFormatInstructions(formatConfig, rawRequestConfig) { return formatConfig?.defaultInstructions !== false || rawRequestConfig?.instructions; } function applyTransferPreamble(rawRequest, transferPreamble) { if (!transferPreamble) { return rawRequest; } return stripUndefinedProps({ ...rawRequest, messages: [ ...tagAsPreamble(transferPreamble.messages), ...rawRequest.messages.filter((m) => !m.metadata?.preamble) ], toolChoice: transferPreamble.toolChoice || rawRequest.toolChoice, tools: transferPreamble.tools || rawRequest.tools, config: transferPreamble.config || rawRequest.config }); } async function generate(registry, { rawRequest, middleware, currentTurn, messageIndex }) { const { model, tools, format } = await resolveParameters( registry, rawRequest ); rawRequest = applyFormat(rawRequest, format); await assertValidToolNames(tools); const { revisedRequest, interruptedResponse, toolMessage: resumedToolMessage } = await resolveResumeOption(registry, rawRequest); if (interruptedResponse) { throw new GenkitError({ status: "FAILED_PRECONDITION", message: "One or more tools triggered an interrupt during a restarted execution.", detail: { message: interruptedResponse.message } }); } rawRequest = revisedRequest; const request = await actionToGenerateRequest( rawRequest, tools, format, model ); const previousChunks = []; let chunkRole = "model"; const makeChunk = (role, chunk) => { if (role !== chunkRole && previousChunks.length) messageIndex++; chunkRole = role; const prevToSend = [...previousChunks]; previousChunks.push(chunk); return new GenerateResponseChunk(chunk, { index: messageIndex, role, previousChunks: prevToSend, parser: format?.handler(request.output?.schema).parseChunk }); }; const streamingCallback = getStreamingCallback(registry); if (resumedToolMessage && streamingCallback) { streamingCallback(makeChunk("tool", resumedToolMessage)); } const response = await runWithStreamingCallback( registry, streamingCallback && ((chunk) => streamingCallback(makeChunk("model", chunk))), async () => { const dispatch = async (index, req) => { if (!middleware || index === middleware.length) { return await model(req); } const currentMiddleware = middleware[index]; return currentMiddleware( req, async (modifiedReq) => dispatch(index + 1, modifiedReq || req) ); }; return new GenerateResponse(await dispatch(0, request), { request, parser: format?.handler(request.output?.schema).parseMessage }); } ); response.assertValid(); const generatedMessage = response.message; const toolRequests = generatedMessage.content.filter( (part) => !!part.toolRequest ); if (rawRequest.returnToolRequests || toolRequests.length === 0) { if (toolRequests.length === 0) response.assertValidSchema(request); return response.toJSON(); } const maxIterations = rawRequest.maxTurns ?? 5; if (currentTurn + 1 > maxIterations) { throw new GenerationResponseError( response, `Exceeded maximum tool call iterations (${maxIterations})`, "ABORTED", { request } ); } const { revisedModelMessage, toolMessage, transferPreamble } = await resolveToolRequests(registry, rawRequest, generatedMessage); if (revisedModelMessage) { return { ...response.toJSON(), finishReason: "interrupted", finishMessage: "One or more tool calls resulted in interrupts.", message: revisedModelMessage }; } streamingCallback?.( makeChunk("tool", { content: toolMessage.content }) ); let nextRequest = { ...rawRequest, messages: [...rawRequest.messages, generatedMessage.toJSON(), toolMessage] }; nextRequest = applyTransferPreamble(nextRequest, transferPreamble); return await generateHelper(registry, { rawRequest: nextRequest, middleware, currentTurn: currentTurn + 1, messageIndex: messageIndex + 1 }); } async function actionToGenerateRequest(options, resolvedTools, resolvedFormat, model) { const modelInfo = model.__action.metadata?.model; if ((options.tools?.length ?? 0) > 0 && modelInfo?.supports && !modelInfo?.supports?.tools) { logger.warn( `The model '${model.__action.name}' does not support tools (you set: ${options.tools?.length} tools). The model may not behave the way you expect.` ); } if (options.toolChoice && modelInfo?.supports && !modelInfo?.supports?.toolChoice) { logger.warn( `The model '${model.__action.name}' does not support the 'toolChoice' option (you set: ${options.toolChoice}). The model may not behave the way you expect.` ); } const out = { messages: options.messages, config: options.config, docs: options.docs, tools: resolvedTools?.map(toToolDefinition) || [], output: stripUndefinedProps({ constrained: options.output?.constrained, contentType: options.output?.contentType, format: options.output?.format, schema: options.output?.jsonSchema }) }; if (options.toolChoice) { out.toolChoice = options.toolChoice; } if (out.output && !out.output.schema) delete out.output.schema; return out; } function inferRoleFromParts(parts) { const uniqueRoles = /* @__PURE__ */ new Set(); for (const part of parts) { const role = getRoleFromPart(part); uniqueRoles.add(role); if (uniqueRoles.size > 1) { throw new Error("Contents contain mixed roles"); } } return Array.from(uniqueRoles)[0]; } function getRoleFromPart(part) { if (part.toolRequest !== void 0) return "model"; if (part.toolResponse !== void 0) return "tool"; if (part.text !== void 0) return "user"; if (part.media !== void 0) return "user"; if (part.data !== void 0) return "user"; throw new Error("No recognized fields in content"); } export { defineGenerateAction, generateHelper, inferRoleFromParts, shouldInjectFormatInstructions }; //# sourceMappingURL=action.mjs.map