@gensx/vercel-ai
Version:
Vercel AI SDK for GenSX.
177 lines (173 loc) • 6.75 kB
JavaScript
/**
* Check out the docs at https://www.gensx.com/docs
* Find us on Github https://github.com/gensx-inc/gensx
* Find us on Discord https://discord.gg/F5BSU8Kc
*/
import { Component, wrap } from '@gensx/core';
import * as ai from 'ai';
export { asToolSet } from './tools.js';
/* eslint-disable @typescript-eslint/no-unsafe-return */
/* eslint-disable @typescript-eslint/no-unsafe-member-access */
// Helper function to wrap tools in GSX components
function wrapTools(tools) {
if (!tools)
return undefined;
return Object.entries(tools).reduce((acc, [name, tool]) => {
if (!tool.execute)
return acc;
const wrappedTool = {
...tool,
execute: (args, options) => {
const ToolComponent = Component(`tool.${name}`, async (toolArgs) => {
if (!tool.execute)
throw new Error(`Tool ${name} has no execute function`);
return await tool.execute(toolArgs, options);
});
return ToolComponent(args);
},
};
return {
...acc,
[name]: wrappedTool,
};
}, {});
}
// Export the original functions with proper typing
const streamText = Component("StreamText", new Proxy(ai.streamText, {
apply: (target, thisArg, args) => {
const [first, ...rest] = args;
const wrappedTools = wrapTools(first.tools);
return Reflect.apply(target, thisArg, [
{
...first,
model: wrapVercelAIModel(first.model),
tools: wrappedTools,
},
...rest,
]);
},
}), {
__streamingResultKey: "textStream",
});
const streamObject = Component("StreamObject", new Proxy(ai.streamObject, {
apply: (target, thisArg, args) => {
const [first, ...rest] = args;
return Reflect.apply(target, thisArg, [
{
...first,
model: wrapVercelAIModel(first.model),
},
...rest,
]);
},
}));
const generateObject = Component("GenerateObject", new Proxy(ai.generateObject, {
apply: (target, thisArg, args) => {
const [first, ...rest] = args;
return Reflect.apply(target, thisArg, [
{
...first,
model: wrapVercelAIModel(first.model),
},
...rest,
]);
},
}));
const generateText = Component("GenerateText", new Proxy(ai.generateText, {
apply: (target, thisArg, args) => {
const [first, ...rest] = args;
const wrappedTools = wrapTools(first.tools);
return Reflect.apply(target, thisArg, [
{
...first,
model: wrapVercelAIModel(first.model),
tools: wrappedTools,
},
...rest,
]);
},
}));
const embed = Component("embed", ai.embed);
const embedMany = Component("embedMany", ai.embedMany);
const generateImage = Component("generateImage", ai.experimental_generateImage);
const wrapVercelAIModel = (languageModel, componentOpts) => {
assertIsLanguageModel(languageModel);
const componentName = componentOpts?.name ?? languageModel.provider;
return new Proxy(languageModel, {
get(target, propKey, receiver) {
const originalValue = Reflect.get(target, propKey, receiver);
if (typeof originalValue === "function") {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
if (originalValue.__gensxComponent) {
return originalValue;
}
// let aggregator: ((chunks: any[]) => unknown) | undefined;
let __streamingResultKey;
if (propKey === "doStream") {
__streamingResultKey = "stream";
// aggregator =
// componentOpts?.aggregator ??
// ((
// chunks: {
// type: "text-delta" | "tool-call" | "finish" | "something-else";
// textDelta: string;
// usage: unknown;
// finishReason: unknown;
// }[],
// ) => {
// return chunks.reduce((aggregated, chunk) => {
// console.log("aggregating chunk", chunk);
// if (chunk.type === "text-delta") {
// return {
// ...aggregated,
// };
// } else if (chunk.type === "tool-call") {
// return {
// ...aggregated,
// ...chunk,
// };
// } else if (chunk.type === "finish") {
// return {
// ...aggregated,
// usage: chunk.usage,
// finishReason: chunk.finishReason,
// };
// } else {
// return aggregated;
// }
// }, {});
// });
}
return Component(componentName, originalValue.bind(target), {
...componentOpts,
// aggregator,
__streamingResultKey,
idPropsKeys: ["inputFormat", "prompt", "responseFormat", "seed"],
});
}
else if (
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
originalValue != null &&
!Array.isArray(originalValue) &&
!(originalValue instanceof Date) &&
typeof originalValue === "object") {
return wrap(originalValue, {
prefix: [componentName, propKey.toString()].join("."),
});
}
else {
return originalValue;
}
},
});
};
function assertIsLanguageModel(languageModel) {
if (!("doStream" in languageModel) ||
typeof languageModel.doStream !== "function" ||
!("doGenerate" in languageModel) ||
typeof languageModel.doGenerate !== "function") {
throw new Error(`Invalid model. Is this a LanguageModelV2 instance?`);
}
}
export { embed, embedMany, generateImage, generateObject, generateText, streamObject, streamText, wrapVercelAIModel };
//# sourceMappingURL=index.js.map