UNPKG

@genkit-ai/ai

Version:

Genkit AI framework generative AI APIs.

508 lines 15.6 kB
import { GenkitError, defineAction, stripUndefinedProps } from "@genkit-ai/core"; import { logger } from "@genkit-ai/core/logging"; import { Registry } from "@genkit-ai/core/registry"; import { SPAN_TYPE_ATTR, runInNewSpan } from "@genkit-ai/core/tracing"; import { injectInstructions, resolveFormat, resolveInstructions } from "../formats/index.js"; import { GenerateResponse, GenerationResponseError, maybeRegisterDynamicMiddlewareTools, normalizeMiddleware, tagAsPreamble } from "../generate.js"; import { GenerateResponseChunk } from "../generate/chunk.js"; import { GenerateActionOptionsSchema, GenerateResponseChunkSchema, GenerateResponseSchema, resolveModel } from "../model.js"; import { findMatchingResource, resolveResources } from "../resource.js"; import { resolveTools, toToolDefinition } from "../tool.js"; import { resolveMiddleware } from "./middleware.js"; import { assertValidToolNames, resolveResumeOption, resolveToolRequests } from "./resolve-tool-requests.js"; function defineGenerateAction(registry) { return defineAction( registry, { actionType: "util", name: "generate", inputSchema: GenerateActionOptionsSchema, outputSchema: GenerateResponseSchema, streamSchema: GenerateResponseChunkSchema }, async (request, { streamingRequested, sendChunk, context }) => { let childRegistry = Registry.withParent(registry); const middlewareRefs = await normalizeMiddleware( childRegistry, request.use ); request.use = middlewareRefs; const resolvedMiddleware = await resolveMiddleware( childRegistry, request.use ); maybeRegisterDynamicMiddlewareTools(childRegistry, resolvedMiddleware); const generateFn = (sendChunk2) => generateActionImpl(childRegistry, { rawRequest: request, currentTurn: 0, messageIndex: 0, middleware: resolvedMiddleware, streamingCallback: sendChunk2, context }); return streamingRequested ? generateFn( (c) => sendChunk(c.toJSON ? c.toJSON() : c) ) : generateFn(); } ); } async function generateHelper(registry, options) { const currentTurn = options.currentTurn ?? 0; const messageIndex = options.messageIndex ?? 0; return await runInNewSpan( { metadata: { name: options.rawRequest.stepName || "generate" }, labels: { [SPAN_TYPE_ATTR]: "util" } }, async (metadata) => { metadata.name = options.rawRequest.stepName || "generate"; metadata.input = options.rawRequest; const output = await generateActionImpl(registry, { rawRequest: options.rawRequest, middleware: options.middleware, currentTurn, messageIndex, abortSignal: options.abortSignal, streamingCallback: options.streamingCallback, context: options.context }); metadata.output = JSON.stringify(output); return output; } ); } async function resolveParameters(registry, request) { const [model, tools, resources, format] = await Promise.all([ resolveModel(registry, request.model, { warnDeprecated: true }).then( (r) => r.modelAction ), resolveTools(registry, request.tools), resolveResources(registry, request.resources), resolveFormat(registry, request.output) ]); return { model, tools, resources, format }; } function applyFormat(rawRequest, resolvedFormat) { const outRequest = { ...rawRequest }; if (rawRequest.output?.jsonSchema && !rawRequest.output?.format) { outRequest.output = { ...rawRequest.output, format: "json" }; } const instructions = resolveInstructions( resolvedFormat, outRequest.output?.jsonSchema, outRequest?.output?.instructions ); if (resolvedFormat) { if (shouldInjectFormatInstructions(resolvedFormat.config, rawRequest?.output)) { outRequest.messages = injectInstructions( outRequest.messages, instructions ); } outRequest.output = { // use output config from the format ...resolvedFormat.config, // if anything is set explicitly, use that ...outRequest.output }; } return outRequest; } function shouldInjectFormatInstructions(formatConfig, rawRequestConfig) { return formatConfig?.defaultInstructions !== false || rawRequestConfig?.instructions; } function applyTransferPreamble(rawRequest, transferPreamble) { if (!transferPreamble) { return rawRequest; } if (transferPreamble?.model) { rawRequest.model = transferPreamble.model; } return stripUndefinedProps({ ...rawRequest, messages: [ ...tagAsPreamble(transferPreamble.messages), ...rawRequest.messages.filter((m) => !m.metadata?.preamble) ], toolChoice: transferPreamble.toolChoice || rawRequest.toolChoice, tools: transferPreamble.tools || rawRequest.tools, config: transferPreamble.config || rawRequest.config }); } async function generateActionImpl(registry, args) { const { rawRequest, middleware, currentTurn, messageIndex, abortSignal, streamingCallback, context } = args; const format = await resolveFormat(registry, rawRequest.output); const sharedPreviousChunks = []; const parser = format?.handler(rawRequest.output?.jsonSchema).parseChunk; if (middleware && middleware.length > 0) { const dispatchGenerate = async (index, request, currentTurn2, messageIndex2, ctx) => { if (index === middleware.length) { return generateActionTurn(registry, { rawRequest: request, middleware, currentTurn: currentTurn2, messageIndex: messageIndex2, abortSignal: ctx.abortSignal, streamingCallback: ctx.onChunk, context: ctx.context, sharedPreviousChunks }); } const currentMiddleware = middleware[index]; if (currentMiddleware.generate) { const wrappedOnChunk = ctx.onChunk ? (c) => { if (c instanceof GenerateResponseChunk) { ctx.onChunk(c); } else { const chunk = new GenerateResponseChunk(c, { index: c.index !== void 0 ? c.index : messageIndex2, role: c.role !== void 0 ? c.role : "model", previousChunks: [...sharedPreviousChunks], parser }); sharedPreviousChunks.push(c); ctx.onChunk(chunk); } } : void 0; return currentMiddleware.generate( { request, currentTurn: currentTurn2, messageIndex: messageIndex2 }, { ...ctx, onChunk: wrappedOnChunk }, async (modifiedEnvelope, opts) => dispatchGenerate( index + 1, modifiedEnvelope?.request || request, modifiedEnvelope?.currentTurn !== void 0 ? modifiedEnvelope.currentTurn : currentTurn2, modifiedEnvelope?.messageIndex !== void 0 ? modifiedEnvelope.messageIndex : messageIndex2, opts || ctx ) ); } else { return dispatchGenerate( index + 1, request, currentTurn2, messageIndex2, ctx ); } }; return dispatchGenerate(0, rawRequest, currentTurn, messageIndex, { abortSignal, onChunk: streamingCallback, context }); } else { return generateActionTurn(registry, { ...args, sharedPreviousChunks }); } } async function generateActionTurn(registry, { rawRequest, middleware, currentTurn, messageIndex, abortSignal, streamingCallback, context, sharedPreviousChunks }) { const { model, tools, resources, format } = await resolveParameters( registry, rawRequest ); if (middleware) { tools.push(...middleware.flatMap((m) => m.tools || [])); } rawRequest = applyFormat(rawRequest, format); rawRequest = await applyResources(registry, rawRequest, resources); await assertValidToolNames(tools); const { revisedRequest, interruptedResponse, toolMessage: resumedToolMessage } = await resolveResumeOption(registry, rawRequest, tools, middleware || []); if (revisedRequest && revisedRequest !== rawRequest) { if (interruptedResponse) { throw new GenkitError({ status: "FAILED_PRECONDITION", message: "One or more tools triggered an interrupt during a restarted execution.", detail: { message: interruptedResponse.message } }); } if (resumedToolMessage && streamingCallback) { streamingCallback( new GenerateResponseChunk( { role: "tool", content: resumedToolMessage.content }, { index: messageIndex, role: "tool", previousChunks: [], parser: format?.handler(rawRequest.output?.jsonSchema).parseChunk } ) ); } return await generateHelper(registry, { rawRequest: revisedRequest, middleware, currentTurn, messageIndex: messageIndex + (resumedToolMessage ? 1 : 0), abortSignal, streamingCallback, context }); } rawRequest = revisedRequest; const request = await actionToGenerateRequest( rawRequest, tools, format, model ); let chunkRole = "model"; const makeChunk = (role, chunk) => { if (role !== chunkRole && sharedPreviousChunks.length) messageIndex++; chunkRole = role; const prevToSend = [...sharedPreviousChunks]; sharedPreviousChunks.push(chunk); return new GenerateResponseChunk(chunk, { index: messageIndex, role, previousChunks: prevToSend, parser: format?.handler(request.output?.schema).parseChunk }); }; var response; const sendChunk = streamingCallback && ((chunk) => streamingCallback(makeChunk("model", chunk))); const dispatchModel = async (index, req, actionOpts) => { if (!middleware || index === middleware.length) { return await model(req, actionOpts); } const currentMiddleware = middleware[index]; if (currentMiddleware.model) { return currentMiddleware.model( req, actionOpts, async (modifiedReq, opts) => dispatchModel(index + 1, modifiedReq || req, opts || actionOpts) ); } else { return dispatchModel(index + 1, req, actionOpts); } }; const modelResponse = await dispatchModel(0, request, { abortSignal, context, onChunk: sendChunk }); if (model.__action.actionType === "background-model") { response = new GenerateResponse( { operation: modelResponse }, { request, parser: format?.handler(request.output?.schema).parseMessage } ); } else { response = new GenerateResponse(modelResponse, { request, parser: format?.handler(request.output?.schema).parseMessage }); } if (model.__action.actionType === "background-model") { return response.toJSON(); } response.assertValid(); const generatedMessage = response.message; const toolRequests = generatedMessage.content.filter( (part) => !!part.toolRequest ); if (rawRequest.returnToolRequests || toolRequests.length === 0) { if (toolRequests.length === 0) response.assertValidSchema(request); return response.toJSON(); } const maxIterations = rawRequest.maxTurns ?? 5; if (currentTurn + 1 > maxIterations) { throw new GenerationResponseError( response, `Exceeded maximum tool call iterations (${maxIterations})`, "ABORTED", { request } ); } const { revisedModelMessage, toolMessage, transferPreamble } = await resolveToolRequests( rawRequest, generatedMessage, tools, middleware || [] ); if (revisedModelMessage) { return { ...response.toJSON(), finishReason: "interrupted", finishMessage: "One or more tool calls resulted in interrupts.", message: revisedModelMessage }; } if (toolMessage) { streamingCallback?.( makeChunk("tool", { content: toolMessage.content }) ); } const messages = [...rawRequest.messages, generatedMessage.toJSON()]; if (toolMessage) { messages.push(toolMessage); } let nextRequest = { ...rawRequest, messages }; nextRequest = applyTransferPreamble(nextRequest, transferPreamble); return await generateHelper(registry, { rawRequest: nextRequest, middleware, currentTurn: currentTurn + 1, messageIndex: messageIndex + 1, streamingCallback, abortSignal }); } async function actionToGenerateRequest(options, resolvedTools, resolvedFormat, model) { const modelInfo = model.__action.metadata?.model; if ((options.tools?.length ?? 0) > 0 && modelInfo?.supports && !modelInfo?.supports?.tools) { logger.warn( `The model '${model.__action.name}' does not support tools (you set: ${options.tools?.length} tools). The model may not behave the way you expect.` ); } if (options.toolChoice && modelInfo?.supports && !modelInfo?.supports?.toolChoice) { logger.warn( `The model '${model.__action.name}' does not support the 'toolChoice' option (you set: ${options.toolChoice}). The model may not behave the way you expect.` ); } const out = { messages: options.messages, config: options.config, docs: options.docs, tools: resolvedTools?.map(toToolDefinition) || [], output: stripUndefinedProps({ constrained: options.output?.constrained, contentType: options.output?.contentType, format: options.output?.format, schema: options.output?.jsonSchema }) }; if (options.toolChoice) { out.toolChoice = options.toolChoice; } if (out.output && !out.output.schema) delete out.output.schema; return out; } function inferRoleFromParts(parts) { const uniqueRoles = /* @__PURE__ */ new Set(); for (const part of parts) { const role = getRoleFromPart(part); uniqueRoles.add(role); if (uniqueRoles.size > 1) { throw new Error("Contents contain mixed roles"); } } return Array.from(uniqueRoles)[0]; } function getRoleFromPart(part) { if (part.toolRequest !== void 0) return "model"; if (part.toolResponse !== void 0) return "tool"; if (part.text !== void 0) return "user"; if (part.media !== void 0) return "user"; if (part.data !== void 0) return "user"; throw new Error("No recognized fields in content"); } async function applyResources(registry, rawRequest, resources) { if (!rawRequest.messages.find((m) => !!m.content.find((c) => c.resource))) { return rawRequest; } const updatedMessages = []; for (const m of rawRequest.messages) { if (!m.content.find((c) => c.resource)) { updatedMessages.push(m); continue; } const updatedContent = []; for (const p of m.content) { if (!p.resource) { updatedContent.push(p); continue; } const resource = await findMatchingResource( registry, resources, p.resource ); if (!resource) { throw new GenkitError({ status: "NOT_FOUND", message: `failed to find matching resource for ${p.resource.uri}` }); } const resourceParts = await resource(p.resource); updatedContent.push(...resourceParts.content); } updatedMessages.push({ ...m, content: updatedContent }); } return { ...rawRequest, messages: updatedMessages }; } export { defineGenerateAction, generateHelper, inferRoleFromParts, shouldInjectFormatInstructions }; //# sourceMappingURL=action.mjs.map