@genkit-ai/core
Version:
Genkit AI framework core libraries.
241 lines (218 loc) • 6.65 kB
text/typescript
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { JSONSchema7 } from 'json-schema';
import { AsyncLocalStorage } from 'node:async_hooks';
import * as z from 'zod';
import { ActionType, lookupPlugin, registerAction } from './registry.js';
import { parseSchema } from './schema.js';
import {
SPAN_TYPE_ATTR,
runInNewSpan,
setCustomMetadataAttributes,
} from './tracing.js';
export { Status, StatusCodes, StatusSchema } from './statusTypes.js';
export { JSONSchema7 };
export interface ActionMetadata<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
> {
actionType?: ActionType;
name: string;
description?: string;
inputSchema?: I;
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
metadata?: M;
}
export type Action<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
> = ((input: z.infer<I>) => Promise<z.infer<O>>) & {
__action: ActionMetadata<I, O, M>;
};
export type SideChannelData = Record<string, any>;
type ActionParams<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
> = {
name:
| string
| {
pluginId: string;
actionId: string;
};
description?: string;
inputSchema?: I;
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
metadata?: M;
use?: Middleware<z.infer<I>, z.infer<O>>[];
};
export interface Middleware<I = any, O = any> {
(req: I, next: (req?: I) => Promise<O>): Promise<O>;
}
export function actionWithMiddleware<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
>(
action: Action<I, O, M>,
middleware: Middleware<z.infer<I>, z.infer<O>>[]
): Action<I, O, M> {
const wrapped = (async (req: z.infer<I>) => {
const dispatch = async (index: number, req: z.infer<I>) => {
if (index === middleware.length) {
// end of the chain, call the original model action
return await action(req);
}
const currentMiddleware = middleware[index];
return currentMiddleware(req, async (modifiedReq) =>
dispatch(index + 1, modifiedReq || req)
);
};
return await dispatch(0, req);
}) as Action<I, O, M>;
wrapped.__action = action.__action;
return wrapped;
}
/**
* Creates an action with the provided config.
*/
export function action<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
>(
config: ActionParams<I, O, M>,
fn: (input: z.infer<I>) => Promise<z.infer<O>>
): Action<I, O> {
const actionName =
typeof config.name === 'string'
? validateActionName(config.name)
: `${validatePluginName(config.name.pluginId)}/${validateActionId(config.name.actionId)}`;
const actionFn = async (input: I) => {
input = parseSchema(input, {
schema: config.inputSchema,
jsonSchema: config.inputJsonSchema,
});
let output = await runInNewSpan(
{
metadata: {
name: actionName,
},
labels: {
[SPAN_TYPE_ATTR]: 'action',
},
},
async (metadata) => {
metadata.name = actionName;
metadata.input = input;
const output = await fn(input);
metadata.output = JSON.stringify(output);
return output;
}
);
output = parseSchema(output, {
schema: config.outputSchema,
jsonSchema: config.outputJsonSchema,
});
return output;
};
actionFn.__action = {
name: actionName,
description: config.description,
inputSchema: config.inputSchema,
inputJsonSchema: config.inputJsonSchema,
outputSchema: config.outputSchema,
outputJsonSchema: config.outputJsonSchema,
metadata: config.metadata,
} as ActionMetadata<I, O, M>;
if (config.use) {
return actionWithMiddleware(actionFn, config.use);
}
return actionFn;
}
function validateActionName(name: string) {
if (name.includes('/')) {
validatePluginName(name.split('/', 1)[0]);
validateActionId(name.substring(name.indexOf('/') + 1));
}
return name;
}
function validatePluginName(pluginId: string) {
if (!lookupPlugin(pluginId)) {
throw new Error(
`Unable to find plugin name used in the action name: ${pluginId}`
);
}
return pluginId;
}
function validateActionId(actionId: string) {
if (actionId.includes('/')) {
throw new Error(`Action name must not include slashes (/): ${actionId}`);
}
return actionId;
}
/**
* Defines an action with the given config and registers it in the registry.
*/
export function defineAction<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
>(
config: ActionParams<I, O, M> & {
actionType: ActionType;
},
fn: (input: z.infer<I>) => Promise<z.infer<O>>
): Action<I, O> {
const act = action(config, (i: I): Promise<z.infer<O>> => {
setCustomMetadataAttributes({ subtype: config.actionType });
return fn(i);
});
act.__action.actionType = config.actionType;
registerAction(config.actionType, act);
return act;
}
// Streaming callback function.
export type StreamingCallback<T> = (chunk: T) => void;
const streamingAls = new AsyncLocalStorage<StreamingCallback<any>>();
const sentinelNoopCallback = () => null;
/**
* Executes provided function with streaming callback in async local storage which can be retrieved
* using {@link getStreamingCallback}.
*/
export function runWithStreamingCallback<S, O>(
streamingCallback: StreamingCallback<S> | undefined,
fn: () => O
): O {
return streamingAls.run(streamingCallback || sentinelNoopCallback, fn);
}
/**
* Retrieves the {@link StreamingCallback} previously set by {@link runWithStreamingCallback}
*/
export function getStreamingCallback<S>(): StreamingCallback<S> | undefined {
const cb = streamingAls.getStore();
if (cb === sentinelNoopCallback) {
return undefined;
}
return cb;
}