@genkit-ai/flow
Version:
Genkit AI framework workflow APIs.
265 lines (228 loc) • 7.56 kB
text/typescript
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { FlowState, __hardResetConfigForTesting } from '@genkit-ai/core';
import { __hardResetRegistryForTesting } from '@genkit-ai/core/registry';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { z } from 'zod';
import { defineFlow, runFlow, streamFlow } from '../src/flow.js';
import { getFlowAuth } from '../src/utils.js';
import { configureInMemoryStateStore } from './testUtil.js';
function createTestFlow() {
return defineFlow(
{
name: 'testFlow',
inputSchema: z.string(),
outputSchema: z.string(),
},
async (input) => {
return `bar ${input}`;
}
);
}
function createTestFlowWithAuth() {
return defineFlow(
{
name: 'testFlowWithAuth',
inputSchema: z.string(),
outputSchema: z.string(),
authPolicy: async (auth) => {
if (auth != 'open sesame') {
throw new Error('forty thieves!');
}
},
},
async (input) => {
return `foo ${input}, auth ${JSON.stringify(getFlowAuth())}`;
}
);
}
function createTestStreamingFlow() {
return defineFlow(
{
name: 'testFlow',
inputSchema: z.number(),
outputSchema: z.string(),
streamSchema: z.object({ count: z.number() }),
},
async (input, streamingCallback) => {
if (streamingCallback) {
for (let i = 0; i < input; i++) {
streamingCallback({ count: i });
}
}
return `bar ${input} ${!!streamingCallback}`;
}
);
}
describe('flow', () => {
beforeEach(__hardResetRegistryForTesting);
beforeEach(() => {
__hardResetConfigForTesting();
delete process.env.GENKIT_ENV;
});
describe('runFlow', () => {
it('should run the flow', async () => {
configureInMemoryStateStore('prod');
const testFlow = createTestFlow();
const result = await runFlow(testFlow, 'foo');
assert.equal(result, 'bar foo');
});
it('should rethrow the error', async () => {
configureInMemoryStateStore('prod');
const testFlow = defineFlow(
{
name: 'throwing',
inputSchema: z.string(),
outputSchema: z.string(),
},
async (input) => {
throw new Error(`bad happened: ${input}`);
}
);
await assert.rejects(async () => await runFlow(testFlow, 'foo'), {
name: 'Error',
message: 'bad happened: foo',
});
});
it('should validate input', async () => {
configureInMemoryStateStore('prod');
const testFlow = defineFlow(
{
name: 'validating',
inputSchema: z.object({ foo: z.string(), bar: z.number() }),
outputSchema: z.string(),
},
async (input) => {
return `ok ${input}`;
}
);
await assert.rejects(
async () => await runFlow(testFlow, { foo: 'foo', bar: 'bar' } as any),
(err: Error) => {
assert.strictEqual(err.name, 'ZodError');
assert.equal(
err.message.includes('Expected number, received string'),
true
);
return true;
}
);
});
it('should pass auth context all the way', async () => {
configureInMemoryStateStore('prod');
const testFlow = createTestFlowWithAuth();
const result = await runFlow(testFlow, 'bar', {
withLocalAuthContext: 'open sesame',
});
assert.equal(result, 'foo bar, auth "open sesame"');
});
it('should fail auth', async () => {
configureInMemoryStateStore('prod');
const testFlow = createTestFlowWithAuth();
await assert.rejects(
() =>
runFlow(testFlow, 'bar', {
withLocalAuthContext: 'yolo',
}),
/forty thieves/
);
});
});
describe('streamFlow', () => {
it('should run the flow', async () => {
configureInMemoryStateStore('prod');
const testFlow = createTestStreamingFlow();
const response = streamFlow(testFlow, 3);
const gotChunks: any[] = [];
for await (const chunk of response.stream()) {
gotChunks.push(chunk);
}
assert.equal(await response.output(), 'bar 3 true');
assert.deepEqual(gotChunks, [{ count: 0 }, { count: 1 }, { count: 2 }]);
});
it('should rethrow the error', async () => {
configureInMemoryStateStore('prod');
const testFlow = defineFlow(
{
name: 'throwing',
inputSchema: z.string(),
},
async (input) => {
throw new Error(`bad happened: ${input}`);
}
);
const response = streamFlow(testFlow, 'foo');
await assert.rejects(async () => await response.output(), {
name: 'Error',
message: 'bad happened: foo',
});
});
it('should pass auth context all the way', async () => {
configureInMemoryStateStore('prod');
const testFlow = createTestFlowWithAuth();
const result = await streamFlow(testFlow, 'bar', {
withLocalAuthContext: 'open sesame',
});
assert.equal(await result.output(), 'foo bar, auth "open sesame"');
});
it('should fail auth', async () => {
configureInMemoryStateStore('prod');
const testFlow = createTestFlowWithAuth();
const response = streamFlow(testFlow, 'bar', {
withLocalAuthContext: 'yolo',
});
await assert.rejects(() => response.output(), /forty thieves/);
});
});
describe('stateStore', () => {
describe('dev', () => {
beforeEach(() => {
process.env.GENKIT_ENV = 'dev';
});
it('should persist state in dev', async () => {
const stateStore = configureInMemoryStateStore('dev');
const testFlow = createTestFlow();
const result = await runFlow(testFlow, 'foo');
assert.equal(result, 'bar foo');
assert.equal(Object.keys(stateStore.state).length, 1);
// do some asserting on the state... TODO: make this better.
const state = JSON.parse(
Object.values(stateStore.state)[0]
) as FlowState;
assert.equal(state.executions.length, 1);
assert.equal(state.operation.done, true);
assert.deepEqual(state.operation.result, {
response: 'bar foo',
});
assert.equal(state.blockedOnStep, null);
assert.deepEqual(state.eventsTriggered, {});
assert.deepEqual(state.cache, {});
assert.equal(state.input, 'foo');
assert.equal(state.name, 'testFlow');
});
});
describe('prod', () => {
it('should not persist the state for non-durable flow', async () => {
const stateStore = configureInMemoryStateStore('prod');
const testFlow = createTestFlow();
const result = await runFlow(testFlow, 'foo');
assert.equal(result, 'bar foo');
assert.equal(Object.keys(stateStore.state).length, 0);
});
});
});
});