UNPKG

@genkit-ai/flow

Version:

Genkit AI framework workflow APIs.

851 lines (809 loc) 24 kB
/** * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { Action, FlowError, FlowState, FlowStateSchema, FlowStateStore, Operation, StreamingCallback, defineAction, getStreamingCallback, config as globalConfig, isDevEnv, } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { SPAN_TYPE_ATTR, newTrace, setCustomMetadataAttribute, setCustomMetadataAttributes, } from '@genkit-ai/core/tracing'; import { SpanStatusCode } from '@opentelemetry/api'; import * as bodyParser from 'body-parser'; import { CorsOptions, default as cors } from 'cors'; import express from 'express'; import { performance } from 'node:perf_hooks'; import * as z from 'zod'; import { Context } from './context.js'; import { FlowExecutionError, FlowStillRunningError, InterruptError, getErrorMessage, getErrorStack, } from './errors.js'; import { FlowActionInputSchema, FlowInvokeEnvelopeMessage, FlowInvokeEnvelopeMessageSchema, Invoker, RetryConfig, Scheduler, } from './types.js'; import { generateFlowId, metadataPrefix, runWithActiveContext, } from './utils.js'; const streamDelimiter = '\n'; const CREATED_FLOWS = 'genkit__CREATED_FLOWS'; function createdFlows(): Flow<any, any, any>[] { if (global[CREATED_FLOWS] === undefined) { global[CREATED_FLOWS] = []; } return global[CREATED_FLOWS]; } /** * Step configuration for retries, etc. */ export interface RunStepConfig { name: string; retryConfig?: RetryConfig; } /** * Flow Auth policy. Consumes the authorization context of the flow and * performs checks before the flow runs. If this throws, the flow will not * be executed. */ export interface FlowAuthPolicy<I extends z.ZodTypeAny = z.ZodTypeAny> { (auth: any | undefined, input: z.infer<I>): void | Promise<void>; } /** * For express-based flows, req.auth should contain the value to bepassed into * the flow context. */ export interface __RequestWithAuth extends express.Request { auth?: unknown; } /** * Defines the flow. */ export function defineFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( config: { name: string; inputSchema?: I; outputSchema?: O; streamSchema?: S; authPolicy?: FlowAuthPolicy<I>; middleware?: express.RequestHandler[]; invoker?: Invoker<I, O, S>; experimentalDurable?: boolean; experimentalScheduler?: Scheduler<I, O, S>; }, steps: StepsFunction<I, O, S> ): Flow<I, O, S> { const f = new Flow( { name: config.name, inputSchema: config.inputSchema, outputSchema: config.outputSchema, streamSchema: config.streamSchema, experimentalDurable: !!config.experimentalDurable, stateStore: globalConfig ? () => globalConfig.getFlowStateStore() : undefined, authPolicy: config.authPolicy, middleware: config.middleware, // We always use local dispatcher in dev mode or when one is not provided. invoker: async (flow, msg, streamingCallback) => { if (!isDevEnv() && config.invoker) { return config.invoker(flow, msg, streamingCallback); } const state = await flow.runEnvelope(msg, streamingCallback); return state.operation; }, scheduler: async (flow, msg, delay = 0) => { if (!config.experimentalDurable) { throw new Error( 'This flow is not durable, cannot use scheduling features.' ); } if (!isDevEnv() && config.experimentalScheduler) { return config.experimentalScheduler(flow, msg, delay); } setTimeout(() => flow.runEnvelope(msg), delay * 1000); }, }, steps ); createdFlows().push(f); wrapAsAction(f); return f; } export interface FlowWrapper< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, > { flow: Flow<I, O, S>; } export class Flow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, > { readonly name: string; readonly inputSchema?: I; readonly outputSchema?: O; readonly streamSchema?: S; readonly stateStore?: () => Promise<FlowStateStore>; readonly invoker: Invoker<I, O, S>; readonly scheduler: Scheduler<I, O, S>; readonly experimentalDurable: boolean; readonly authPolicy?: FlowAuthPolicy<I>; readonly middleware?: express.RequestHandler[]; constructor( config: { name: string; inputSchema?: I; outputSchema?: O; streamSchema?: S; stateStore?: () => Promise<FlowStateStore>; invoker: Invoker<I, O, S>; scheduler: Scheduler<I, O, S>; experimentalDurable: boolean; authPolicy?: FlowAuthPolicy<I>; middleware?: express.RequestHandler[]; }, private steps: StepsFunction<I, O, S> ) { this.name = config.name; this.inputSchema = config.inputSchema; this.outputSchema = config.outputSchema; this.streamSchema = config.streamSchema; this.stateStore = config.stateStore; this.invoker = config.invoker; this.scheduler = config.scheduler; this.experimentalDurable = config.experimentalDurable; this.authPolicy = config.authPolicy; this.middleware = config.middleware; // Durable flows can't use an auth policy; instead they should be invoked // from a privileged context after ACL checks are performed. if (this.authPolicy && this.experimentalDurable) { throw new Error('Durable flows can not define auth policies.'); } } /** * Executes the flow with the input directly. * * This will either be called by runEnvelope when starting durable flows, * or it will be called directly when starting non-durable flows. */ async runDirectly( input: unknown, opts: { streamingCallback?: StreamingCallback<unknown>; labels?: Record<string, string>; auth?: unknown; } ): Promise<FlowState> { const flowId = generateFlowId(); const state = createNewState(flowId, this.name, input); const ctx = new Context(this, flowId, state, opts.auth); try { await this.executeSteps( ctx, this.steps, 'start', opts.streamingCallback, opts.labels ); } finally { if (isDevEnv() || this.experimentalDurable) { await ctx.saveState(); } } return state; } /** * Executes the flow with the input in the envelope format. */ async runEnvelope( req: FlowInvokeEnvelopeMessage, streamingCallback?: StreamingCallback<any>, auth?: unknown ): Promise<FlowState> { logger.debug(req, 'runEnvelope'); if (req.start) { // First time, create new state. return this.runDirectly(req.start.input, { streamingCallback, auth, labels: req.start.labels, }); } if (req.schedule) { if (!this.experimentalDurable) { throw new Error('Cannot schedule a non-durable flow'); } if (!this.stateStore) { throw new Error( 'Flow state store for durable flows must be configured' ); } // First time, create new state. const flowId = generateFlowId(); const state = createNewState(flowId, this.name, req.schedule.input); try { await (await this.stateStore()).save(flowId, state); await this.scheduler( this, { runScheduled: { flowId } } as FlowInvokeEnvelopeMessage, req.schedule.delay ); } catch (e) { state.operation.done = true; state.operation.result = { error: getErrorMessage(e), stacktrace: getErrorStack(e), }; await (await this.stateStore()).save(flowId, state); } return state; } if (req.state) { if (!this.experimentalDurable) { throw new Error('Cannot state check a non-durable flow'); } if (!this.stateStore) { throw new Error( 'Flow state store for durable flows must be configured' ); } const flowId = req.state.flowId; const state = await (await this.stateStore()).load(flowId); if (state === undefined) { throw new Error(`Unable to find flow state for ${flowId}`); } return state; } if (req.runScheduled) { if (!this.experimentalDurable) { throw new Error('Cannot run scheduled non-durable flow'); } if (!this.stateStore) { throw new Error( 'Flow state store for durable flows must be configured' ); } const flowId = req.runScheduled.flowId; const state = await (await this.stateStore()).load(flowId); if (state === undefined) { throw new Error(`Unable to find flow state for ${flowId}`); } const ctx = new Context(this, flowId, state); try { await this.executeSteps( ctx, this.steps, 'runScheduled', undefined, undefined ); } finally { await ctx.saveState(); } return state; } if (req.resume) { if (!this.experimentalDurable) { throw new Error('Cannot resume a non-durable flow'); } if (!this.stateStore) { throw new Error( 'Flow state store for durable flows must be configured' ); } const flowId = req.resume.flowId; const state = await (await this.stateStore()).load(flowId); if (state === undefined) { throw new Error(`Unable to find flow state for ${flowId}`); } if (!state.blockedOnStep) { throw new Error( "Unable to resume flow that's currently not interrupted" ); } state.eventsTriggered[state.blockedOnStep.name] = req.resume.payload; const ctx = new Context(this, flowId, state); try { await this.executeSteps( ctx, this.steps, 'resume', undefined, undefined ); } finally { await ctx.saveState(); } return state; } // TODO: add retry throw new Error( 'Unexpected envelope message case, must set one of: ' + 'start, schedule, runScheduled, resume, retry, state' ); } // TODO: refactor me... this is a mess! private async executeSteps( ctx: Context<I, O, S>, handler: StepsFunction<I, O, S>, dispatchType: string, streamingCallback: StreamingCallback<any> | undefined, labels: Record<string, string> | undefined ) { const startTimeMs = performance.now(); await runWithActiveContext(ctx, async () => { let traceContext; if (ctx.state.traceContext) { traceContext = JSON.parse(ctx.state.traceContext); } let ctxLinks = traceContext ? [{ context: traceContext }] : []; let errored = false; const output = await newTrace( { name: ctx.flow.name, labels: { [SPAN_TYPE_ATTR]: 'flow', }, links: ctxLinks, }, async (metadata, rootSpan) => { ctx.state.executions.push({ startTime: Date.now(), traceIds: [], }); setCustomMetadataAttribute( metadataPrefix(`execution`), (ctx.state.executions.length - 1).toString() ); if (labels) { Object.keys(labels).forEach((label) => { setCustomMetadataAttribute( metadataPrefix(`label:${label}`), labels[label] ); }); } setCustomMetadataAttributes({ [metadataPrefix('name')]: this.name, [metadataPrefix('id')]: ctx.flowId, }); ctx .getCurrentExecution() .traceIds.push(rootSpan.spanContext().traceId); // Save the trace in the state so that we can tie subsequent invocation together. if (!traceContext) { ctx.state.traceContext = JSON.stringify(rootSpan.spanContext()); } setCustomMetadataAttribute( metadataPrefix('dispatchType'), dispatchType ); try { const input = this.inputSchema ? this.inputSchema.parse(ctx.state.input) : ctx.state.input; metadata.input = input; const output = await handler(input, streamingCallback); metadata.output = JSON.stringify(output); setCustomMetadataAttribute(metadataPrefix('state'), 'done'); return output; } catch (e) { if (e instanceof InterruptError) { setCustomMetadataAttribute( metadataPrefix('state'), 'interrupted' ); // Log interrupted } else { metadata.state = 'error'; rootSpan.setStatus({ code: SpanStatusCode.ERROR, message: getErrorMessage(e), }); if (e instanceof Error) { rootSpan.recordException(e); } setCustomMetadataAttribute(metadataPrefix('state'), 'error'); ctx.state.operation.done = true; ctx.state.operation.result = { error: getErrorMessage(e), stacktrace: getErrorStack(e), } as FlowError; } errored = true; } } ); if (!errored) { // flow done, set response. ctx.state.operation.done = true; ctx.state.operation.result = { response: output }; } }); } private async durableExpressHandler( req: express.Request, res: express.Response ): Promise<void> { if (req.query.stream === 'true') { const respBody = { error: { status: 'INVALID_ARGUMENT', message: 'Output from durable flows cannot be streamed', }, }; res.status(400).send(respBody).end(); return; } let data = req.body; // Task queue will wrap body in a "data" object, unwrap it. if (req.body.data) { data = req.body.data; } const envMsg = FlowInvokeEnvelopeMessageSchema.parse(data); try { const state = await this.runEnvelope(envMsg); res.status(200).send(state.operation).end(); } catch (e) { // Pass errors as operations instead of a standard API error // (https://cloud.google.com/apis/design/errors#http_mapping) const respBody = { done: true, result: { error: getErrorMessage(e), stacktrace: getErrorStack(e), }, }; res .status(500) .send(respBody as Operation) .end(); } } private async nonDurableExpressHandler( req: __RequestWithAuth, res: express.Response ): Promise<void> { const { stream } = req.query; const auth = req.auth; let input = req.body.data; try { await this.authPolicy?.(auth, input); } catch (e: any) { const respBody = { error: { status: 'PERMISSION_DENIED', message: e.message || 'Permission denied to resource', }, }; res.status(403).send(respBody).end(); return; } if (stream === 'true') { res.writeHead(200, { 'Content-Type': 'text/plain', 'Transfer-Encoding': 'chunked', }); try { const state = await this.runDirectly(input, { streamingCallback: (chunk) => { res.write(JSON.stringify(chunk) + streamDelimiter); }, auth, }); res.write(JSON.stringify(state.operation)); res.end(); } catch (e) { // Errors while streaming are also passed back as operations const respBody = { done: true, result: { error: getErrorMessage(e), stacktrace: getErrorStack(e), }, }; res.write(JSON.stringify(respBody as Operation)); res.end(); } } else { try { const state = await this.runDirectly(input, { auth }); if (state.operation.result?.error) { throw new Error(state.operation.result?.error); } // Responses for non-streaming, non-durable flows are passed back // with the flow result stored in a field called "result." res .status(200) .send({ result: state.operation.result?.response, }) .end(); } catch (e) { // Errors for non-durable, non-streaming flows are passed back as // standard API errors. res .status(500) .send({ error: { status: 'INTERNAL', message: getErrorMessage(e), details: getErrorStack(e), }, }) .end(); } } } get expressHandler(): ( req: __RequestWithAuth, res: express.Response ) => Promise<void> { return this.experimentalDurable ? this.durableExpressHandler.bind(this) : this.nonDurableExpressHandler.bind(this); } } /** * Runs the flow. If the flow does not get interrupted may return a completed (done=true) operation. */ export async function runFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( flow: Flow<I, O, S> | FlowWrapper<I, O, S>, payload?: z.infer<I>, opts?: { withLocalAuthContext?: unknown } ): Promise<z.infer<O>> { if (!(flow instanceof Flow)) { flow = flow.flow; } const input = flow.inputSchema ? flow.inputSchema.parse(payload) : payload; await flow.authPolicy?.(opts?.withLocalAuthContext, payload); if (flow.middleware) { logger.warn( `Flow (${flow.name}) middleware won't run when invoked with runFlow.` ); } const state = await flow.runEnvelope( { start: { input, }, }, undefined, opts?.withLocalAuthContext ); if (!state.operation.done) { throw new FlowStillRunningError( `flow ${state.name} did not finish execution` ); } if (state.operation.result?.error) { throw new FlowExecutionError( state.operation.name, state.operation.result?.error, state.operation.result?.stacktrace ); } return state.operation.result?.response; } interface StreamingResponse< O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, > { stream(): AsyncGenerator<unknown, Operation, z.infer<S> | undefined>; output(): Promise<z.infer<O>>; } /** * Runs the flow and streams results. If the flow does not get interrupted may return a completed (done=true) operation. */ export function streamFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( flowOrFlowWrapper: Flow<I, O, S> | FlowWrapper<I, O, S>, payload?: z.infer<I>, opts?: { withLocalAuthContext?: unknown } ): StreamingResponse<O, S> { const flow = !(flowOrFlowWrapper instanceof Flow) ? flowOrFlowWrapper.flow : flowOrFlowWrapper; let chunkStreamController: ReadableStreamController<z.infer<S>>; const chunkStream = new ReadableStream<z.infer<S>>({ start(controller) { chunkStreamController = controller; }, pull() {}, cancel() {}, }); const authPromise = flow.authPolicy?.(opts?.withLocalAuthContext, payload) ?? Promise.resolve(); const operationPromise = authPromise .then(() => flow.runEnvelope( { start: { input: flow.inputSchema ? flow.inputSchema.parse(payload) : payload, }, }, (c) => { chunkStreamController.enqueue(c); }, opts?.withLocalAuthContext ) ) .then((s) => s.operation) .finally(() => { chunkStreamController.close(); }); return { output() { return operationPromise.then((op) => { if (!op.done) { throw new FlowStillRunningError( `flow ${op.name} did not finish execution` ); } if (op.result?.error) { throw new FlowExecutionError( op.name, op.result?.error, op.result?.stacktrace ); } return op.result?.response; }); }, async *stream() { const reader = chunkStream.getReader(); while (true) { const chunk = await reader.read(); if (chunk.value) { yield chunk.value; } if (chunk.done) { break; } } return await operationPromise; }, }; } function createNewState( flowId: string, name: string, input: unknown ): FlowState { return { flowId: flowId, name: name, startTime: Date.now(), input: input, cache: {}, eventsTriggered: {}, blockedOnStep: null, executions: [], operation: { name: flowId, done: false, }, }; } export type StepsFunction< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, > = ( input: z.infer<I>, streamingCallback: StreamingCallback<z.infer<S>> | undefined ) => Promise<z.infer<O>>; function wrapAsAction< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( flow: Flow<I, O, S> ): Action<typeof FlowActionInputSchema, typeof FlowStateSchema> { return defineAction( { actionType: 'flow', name: flow.name, inputSchema: FlowActionInputSchema, outputSchema: FlowStateSchema, metadata: { inputSchema: toJsonSchema({ schema: flow.inputSchema }), outputSchema: toJsonSchema({ schema: flow.outputSchema }), experimentalDurable: !!flow.experimentalDurable, requiresAuth: !!flow.authPolicy, }, }, async (envelope) => { // Only non-durable flows have an authPolicy, so envelope.start should always // be defined here. await flow.authPolicy?.( envelope.auth, envelope.start?.input as I | undefined ); setCustomMetadataAttribute(metadataPrefix('wrapperAction'), 'true'); return await flow.runEnvelope( envelope, getStreamingCallback(), envelope.auth ); } ); } export function startFlowsServer(params?: { flows?: Flow<any, any, any>[]; port?: number; cors?: CorsOptions; pathPrefix?: string; jsonParserOptions?: bodyParser.OptionsJson; }) { const port = params?.port || (process.env.PORT ? parseInt(process.env.PORT) : 0) || 3400; const pathPrefix = params?.pathPrefix ?? ''; const app = express(); app.use(bodyParser.json(params?.jsonParserOptions)); app.use(cors(params?.cors)); const flows = params?.flows || createdFlows(); logger.info(`Starting flows server on port ${port}`); flows.forEach((f) => { const flowPath = `/${pathPrefix}${f.name}`; logger.info(` - ${flowPath}`); // Add middlware f.middleware?.forEach((m) => { app.post(flowPath, m); }); app.post(flowPath, f.expressHandler); }); app.listen(port, () => { console.log(`Flows server listening on port ${port}`); }); }