assistan-ts
Version:
A typesafe and code-first library to define and run OpenAI assistants
226 lines (188 loc) • 4.94 kB
text/typescript
import assert from "assert";
import { expect, test } from "bun:test";
import { RequiredActionFunctionToolCall } from "openai/resources/beta/threads/runs/runs";
import { Type } from "..";
import { AssistantVisibleError, filter, toolbox, join } from "../toolbox";
const sumDef = {
sum: {
parameters: Type.Object({
a: Type.Number(),
b: Type.Number(Type.Number()),
}),
},
};
const noopDef = {
noop: {
parameters: Type.Null(),
},
};
const dateDef = {
dateFn: {
parameters: Type.Object({
d: Type.String({
format: "date",
}),
}),
},
};
function toolCall(tool: string, args: string): RequiredActionFunctionToolCall {
return {
id: "123",
type: "function",
function: {
arguments: args,
name: tool,
},
};
}
const sumToolCall = (args: string) => toolCall("sum", args);
const onePlusTwo = sumToolCall(`{"a": 1, "b": 2}`);
test("valid", async () => {
const tb = toolbox(sumDef, {
sum: async ({ a, b }) => {
return a + b;
},
});
const result = await tb.handleAction(onePlusTwo);
expect(result.output).toBe("3");
});
test("no parameters", async () => {
const tb = toolbox(noopDef, {
noop: async () => {
return "success";
},
});
const result = await tb.handleAction(toolCall("noop", "{}"));
expect(result.output).toBe("success");
});
/* error handling */
test("missing args", async () => {
const tb = toolbox(sumDef, {
sum: async ({ a, b }) => {
return a + b;
},
});
const output = await tb.handleAction(sumToolCall(`{"a": 2}`));
expect(output.output).toContain("arguments/b");
});
test("invalid arg type", async () => {
const tb = toolbox(sumDef, {
sum: async ({ a, b }) => {
return a + b;
},
});
const output = await tb.handleAction(sumToolCall(`{"a": 2, "b": "foo"}`));
expect(output.output).toContain("arguments/b");
});
test("valid format type", async () => {
const tb = toolbox(dateDef, {
dateFn: async ({ d }) => {
return "success";
},
});
const output = await tb.handleAction(
toolCall("dateFn", `{"d": "2022-03-14"}`)
);
expect(output.output).toBe("success");
});
test("invalid format type", async () => {
const tb = toolbox(dateDef, {
dateFn: async ({ d }) => {
return d;
},
});
const output = await tb.handleAction(toolCall("dateFn", `{"d": "22/03/14"}`));
expect(output.output).toContain("arguments/d");
});
test("tool throws fatal error", async () => {
const tb = toolbox(sumDef, {
sum: async ({ a, b }) => {
throw new Error("foo");
},
});
try {
await tb.handleAction(onePlusTwo);
assert.fail("should have thrown");
} catch (e: any) {
expect(e.message).toContain("foo");
}
});
test("tool throws visible error", async () => {
const tb = toolbox(sumDef, {
sum: async ({ a, b }) => {
throw new AssistantVisibleError("foo");
},
});
const output = await tb.handleAction(onePlusTwo);
expect(output.output).toBe("Error calling sum: foo");
});
/* overrides */
test("override formatOutput", async () => {
const tb = toolbox(
sumDef,
{
sum: async ({ a, b }) => a + b,
},
{
formatOutput: (output, ctx) => {
return `${ctx.action.function.name} == ${output}`;
},
}
);
const output = await tb.handleAction(onePlusTwo);
expect(output.output).toBe("sum == 3");
});
test("override formatToolError", async () => {
const tb = toolbox(
sumDef,
{
sum: async ({ a, b }) => {
throw new Error("foo");
},
},
{
formatToolError: (error: any, ctx) => {
return `Error calling ${ctx.action.function.name}: ${error.message}`;
},
}
);
const output = await tb.handleAction(onePlusTwo);
expect(output.output).toBe("Error calling sum: foo");
});
test("join toolboxes", async () => {
const one = toolbox(sumDef, {
sum: async ({ a, b }) => {
return a + b;
},
});
const two = toolbox(noopDef, {
noop: async () => {
return "success";
},
});
const tb = join(one, two);
const output = await tb.handleAction(onePlusTwo);
expect(output.output).toBe("3");
const output2 = await tb.handleAction(toolCall("noop", "{}"));
expect(output2.output).toBe("success");
});
test("filter", async () => {
const tb = toolbox(
{ ...noopDef, ...sumDef },
{
noop: async () => {
return "success";
},
sum: async ({ a, b }) => {
return a + b;
},
}
);
const filteredTb = filter(tb, (key) => key === "sum");
expect(Object.keys(filteredTb.toolDefs).length).toBe(1);
expect(Object.keys(filteredTb.toolsFn).length).toBe(1);
const output = await filteredTb.handleAction(onePlusTwo);
expect(output.output).toBe("3");
const output2 = await filteredTb.handleAction(toolCall("noop", "{}"));
expect(output2.output).toContain("Error calling noop");
});