rocket-ai
Version:
Simple AI Client that lets you access different LLMs in a unified way.
1,004 lines (973 loc) • 37.4 kB
JavaScript
"use strict";
var __create = Object.create;
var __defProp = Object.defineProperty;
var __defProps = Object.defineProperties;
var __getOwnPropDesc = Object.getOwnPropertyDescriptor;
var __getOwnPropDescs = Object.getOwnPropertyDescriptors;
var __getOwnPropNames = Object.getOwnPropertyNames;
var __getOwnPropSymbols = Object.getOwnPropertySymbols;
var __getProtoOf = Object.getPrototypeOf;
var __hasOwnProp = Object.prototype.hasOwnProperty;
var __propIsEnum = Object.prototype.propertyIsEnumerable;
var __knownSymbol = (name, symbol) => (symbol = Symbol[name]) ? symbol : Symbol.for("Symbol." + name);
var __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value;
var __spreadValues = (a, b) => {
for (var prop in b || (b = {}))
if (__hasOwnProp.call(b, prop))
__defNormalProp(a, prop, b[prop]);
if (__getOwnPropSymbols)
for (var prop of __getOwnPropSymbols(b)) {
if (__propIsEnum.call(b, prop))
__defNormalProp(a, prop, b[prop]);
}
return a;
};
var __spreadProps = (a, b) => __defProps(a, __getOwnPropDescs(b));
var __export = (target, all) => {
for (var name in all)
__defProp(target, name, { get: all[name], enumerable: true });
};
var __copyProps = (to, from, except, desc) => {
if (from && typeof from === "object" || typeof from === "function") {
for (let key of __getOwnPropNames(from))
if (!__hasOwnProp.call(to, key) && key !== except)
__defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable });
}
return to;
};
var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__getProtoOf(mod)) : {}, __copyProps(
// If the importer is in node compatibility mode or this is not an ESM
// file that has been converted to a CommonJS file using a Babel-
// compatible transform (i.e. "__esModule" has not been set), then set
// "default" to the CommonJS "module.exports" for node compatibility.
isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target,
mod
));
var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: true }), mod);
var __async = (__this, __arguments, generator) => {
return new Promise((resolve, reject) => {
var fulfilled = (value) => {
try {
step(generator.next(value));
} catch (e) {
reject(e);
}
};
var rejected = (value) => {
try {
step(generator.throw(value));
} catch (e) {
reject(e);
}
};
var step = (x) => x.done ? resolve(x.value) : Promise.resolve(x.value).then(fulfilled, rejected);
step((generator = generator.apply(__this, __arguments)).next());
});
};
var __await = function(promise, isYieldStar) {
this[0] = promise;
this[1] = isYieldStar;
};
var __asyncGenerator = (__this, __arguments, generator) => {
var resume = (k, v, yes, no) => {
try {
var x = generator[k](v), isAwait = (v = x.value) instanceof __await, done = x.done;
Promise.resolve(isAwait ? v[0] : v).then((y) => isAwait ? resume(k === "return" ? k : "next", v[1] ? { done: y.done, value: y.value } : y, yes, no) : yes({ value: y, done })).catch((e) => resume("throw", e, yes, no));
} catch (e) {
no(e);
}
}, method = (k) => it[k] = (x) => new Promise((yes, no) => resume(k, x, yes, no)), it = {};
return generator = generator.apply(__this, __arguments), it[__knownSymbol("asyncIterator")] = () => it, method("next"), method("throw"), method("return"), it;
};
var __forAwait = (obj, it, method) => (it = obj[__knownSymbol("asyncIterator")]) ? it.call(obj) : (obj = obj[__knownSymbol("iterator")](), it = {}, method = (key, fn) => (fn = obj[key]) && (it[key] = (arg) => new Promise((yes, no, done) => (arg = fn.call(obj, arg), done = arg.done, Promise.resolve(arg.value).then((value) => yes({ value, done }), no)))), method("next"), method("return"), it);
// src/index.ts
var index_exports = {};
__export(index_exports, {
Agent: () => Agent,
AiClient: () => AiClient,
AiModelType: () => AiModelType,
BasicPromptTemplate: () => BasicPromptTemplate,
StructuredOutputTemplate: () => StructuredOutputTemplate,
Tool: () => Tool,
ToolRegistry: () => ToolRegistry,
tool: () => tool
});
module.exports = __toCommonJS(index_exports);
// src/client/clients/open-ai.ts
var import_dotenv = __toESM(require("dotenv"));
var import_openai = __toESM(require("openai"));
import_dotenv.default.config();
var OpenAiClient = class {
constructor(client = new import_openai.default(), apiKey, baseUrl) {
this.client = client;
this.apiKey = apiKey != null ? apiKey : process.env.OPENAI_API_KEY;
this.baseUrl = baseUrl != null ? baseUrl : "https://api.openai.com/v1";
if (!this.apiKey) {
throw new Error("OpenAI API key is required. Set it in config or via OPENAI_API_KEY environment variable.");
}
}
invoke(model, messages, systemPrompt, temperature = 0.7) {
return __async(this, null, function* () {
var _a, _b, _c, _d, _e, _f, _g, _h, _i;
messages.unshift({
role: "system",
content: systemPrompt
});
const response = yield this.client.chat.completions.create({
model,
messages,
temperature
});
return {
content: (_c = (_b = (_a = response.choices[0]) == null ? void 0 : _a.message) == null ? void 0 : _b.content) != null ? _c : "",
usage: {
inputTokens: (_e = (_d = response.usage) == null ? void 0 : _d.prompt_tokens) != null ? _e : 0,
outputTokens: (_g = (_f = response.usage) == null ? void 0 : _f.completion_tokens) != null ? _g : 0,
totalTokens: (_i = (_h = response.usage) == null ? void 0 : _h.total_tokens) != null ? _i : 0
}
};
});
}
stream(model, messages, systemPrompt, temperature = 0.7) {
return __asyncGenerator(this, null, function* () {
var _a, _b, _c, _d, _e, _f, _g, _h, _i;
messages.unshift({
role: "system",
content: systemPrompt
});
const stream = yield new __await(this.client.chat.completions.create({
model,
messages,
temperature,
stream: true,
stream_options: {
include_usage: true
}
}));
try {
for (var iter = __forAwait(stream), more, temp, error; more = !(temp = yield new __await(iter.next())).done; more = false) {
const chunk = temp.value;
yield {
content: (_c = (_b = (_a = chunk.choices[0]) == null ? void 0 : _a.delta) == null ? void 0 : _b.content) != null ? _c : "",
usage: {
inputTokens: (_e = (_d = chunk.usage) == null ? void 0 : _d.prompt_tokens) != null ? _e : 0,
outputTokens: (_g = (_f = chunk.usage) == null ? void 0 : _f.completion_tokens) != null ? _g : 0,
totalTokens: (_i = (_h = chunk.usage) == null ? void 0 : _h.total_tokens) != null ? _i : 0
}
};
}
} catch (temp) {
error = [temp];
} finally {
try {
more && (temp = iter.return) && (yield new __await(temp.call(iter)));
} finally {
if (error)
throw error[0];
}
}
});
}
generateImage(model, messages, size, n = 1) {
return __async(this, null, function* () {
var _a, _b, _c, _d;
const response = yield this.client.images.generate({
model: "dall-e-3",
prompt: messages,
n,
size
});
return {
url: (_b = (_a = response.data[0]) == null ? void 0 : _a.url) != null ? _b : "",
revisedPrompt: (_d = (_c = response.data[0]) == null ? void 0 : _c.revised_prompt) != null ? _d : ""
};
});
}
generateSpeech(model, messages, voice) {
return __async(this, null, function* () {
const voiceResponse = yield this.client.audio.speech.create({
model: "tts-1",
input: messages,
voice
});
const buffer = Buffer.from(yield voiceResponse.arrayBuffer());
return buffer.toString("base64");
});
}
};
// src/client/clients/anthropic.ts
var import_sdk = __toESM(require("@anthropic-ai/sdk"));
var import_dotenv2 = __toESM(require("dotenv"));
import_dotenv2.default.config();
var AnthropicAiClient = class {
/**
* Constructs an AnthropicAiClient instance.
* @param {Anthropic} client - An instance of the Anthropic client.
* @param {string} [apiKey] - The API key for authenticating with the Anthropic API.
* @param {string} [baseUrl] - The base URL for the Anthropic API.
* @throws Will throw an error if the API key is not provided.
*/
constructor(client = new import_sdk.default(), apiKey, baseUrl) {
this.client = client;
this.apiKey = apiKey != null ? apiKey : process.env.ANTHROPIC_API_KEY;
this.baseUrl = baseUrl != null ? baseUrl : "https://api.anthropic.com/v1";
if (!this.apiKey) {
throw new Error("Anthropic API key is required. Set it in config or via ANTHROPIC_API_KEY environment variable.");
}
}
/**
* Invokes the AI model with the provided messages and system prompt.
* @param {string} model - The model to use for the invocation.
* @param {AiMessage[]} messages - An array of messages to send to the model.
* @param {string} systemPrompt - The system prompt to use for the invocation.
* @returns {Promise<AiMessageResponse>} - The response from the AI model.
*/
invoke(model, messages, systemPrompt) {
return __async(this, null, function* () {
var _a, _b, _c, _d, _e, _f;
const response = yield this.client.messages.create({
model,
max_tokens: 4096,
stream: false,
messages,
system: systemPrompt
});
return {
content: response.content[0].type === "text" ? response.content[0].text : "",
usage: {
inputTokens: (_b = (_a = response.usage) == null ? void 0 : _a.input_tokens) != null ? _b : 0,
outputTokens: (_d = (_c = response.usage) == null ? void 0 : _c.output_tokens) != null ? _d : 0,
totalTokens: ((_e = response.usage) == null ? void 0 : _e.input_tokens) + ((_f = response.usage) == null ? void 0 : _f.output_tokens)
}
};
});
}
/**
* Streams the AI model responses for the provided messages and system prompt.
* @param {string} model - The model to use for the streaming.
* @param {AiMessage[]} messages - An array of messages to send to the model.
* @param {string} systemPrompt - The system prompt to use for the streaming.
* @returns {AsyncGenerator<AiMessageResponse, void, unknown>} - An async generator yielding AI model responses.
*/
stream(model, messages, systemPrompt) {
return __asyncGenerator(this, null, function* () {
var _a, _b, _c, _d, _e, _f, _g, _h, _i, _j, _k, _l, _m, _n;
const stream = this.client.messages.stream({
model,
messages,
max_tokens: 4096,
system: systemPrompt
});
try {
for (var iter = __forAwait(stream), more, temp, error; more = !(temp = yield new __await(iter.next())).done; more = false) {
const chunk = temp.value;
if (chunk.type === "content_block_delta") {
yield {
// @ts-expect-error - TS doesn't know about the content_block_delta type
content: (_a = chunk.delta.text) != null ? _a : "",
usage: {
// @ts-expect-error - TS doesn't know about the content_block_delta type
inputTokens: (_c = (_b = chunk.usage) == null ? void 0 : _b.prompt_tokens) != null ? _c : 0,
// @ts-expect-error - TS doesn't know about the content_block_delta type
outputTokens: (_e = (_d = chunk.usage) == null ? void 0 : _d.completion_tokens) != null ? _e : 0,
// @ts-expect-error - TS doesn't know about the content_block_delta type
totalTokens: (_g = (_f = chunk.usage) == null ? void 0 : _f.total_tokens) != null ? _g : 0
}
};
}
if (chunk.type === "message_delta") {
yield {
// @ts-expect-error - TS doesn't know about the content_block_delta type
content: (_h = chunk.delta.text) != null ? _h : "",
usage: {
// @ts-expect-error - TS doesn't know about the content_block_delta type
inputTokens: (_j = (_i = chunk.usage) == null ? void 0 : _i.prompt_tokens) != null ? _j : 0,
outputTokens: (_l = (_k = chunk.usage) == null ? void 0 : _k.output_tokens) != null ? _l : 0,
// @ts-expect-error - TS doesn't know about the content_block_delta type
totalTokens: (_n = (_m = chunk.usage) == null ? void 0 : _m.total_tokens) != null ? _n : 0
}
};
}
}
} catch (temp) {
error = [temp];
} finally {
try {
more && (temp = iter.return) && (yield new __await(temp.call(iter)));
} finally {
if (error)
throw error[0];
}
}
});
}
};
// src/client/clients/google.ts
var import_dotenv3 = __toESM(require("dotenv"));
var import_generative_ai = require("@google/generative-ai");
import_dotenv3.default.config();
var GoogleAiClient = class {
constructor(apiKey) {
this.apiKey = apiKey != null ? apiKey : process.env.GOOGLE_API_KEY;
this.client = new import_generative_ai.GoogleGenerativeAI(this.apiKey);
if (!this.apiKey) {
throw new Error("Gemini API key is required. Set it in config or via GEMINI_API_KEY environment variable.");
}
}
invoke(model, messages) {
return __async(this, null, function* () {
var _a, _b, _c, _d, _e, _f;
const chatModel = this.client.getGenerativeModel({
model
});
const mappedMessages = messages.map((message) => {
if (message.role === "user") {
return {
role: "user",
parts: [{ text: message.content }]
};
} else {
return {
role: "model",
parts: [{ text: message.content }]
};
}
});
const chat = chatModel.startChat({
history: mappedMessages,
generationConfig: {
maxOutputTokens: 4096,
temperature: 0.7
}
});
const result = yield chat.sendMessage(messages[messages.length - 1].content);
const response = result.response;
return {
content: response.text(),
usage: {
inputTokens: (_b = (_a = response.usageMetadata) == null ? void 0 : _a.promptTokenCount) != null ? _b : 0,
outputTokens: (_d = (_c = response.usageMetadata) == null ? void 0 : _c.candidatesTokenCount) != null ? _d : 0,
totalTokens: (_f = (_e = response.usageMetadata) == null ? void 0 : _e.totalTokenCount) != null ? _f : 0
}
};
});
}
stream(model, messages, systemPrompt) {
return __asyncGenerator(this, null, function* () {
var _a, _b, _c, _d, _e, _f;
const chatModel = this.client.getGenerativeModel({
systemInstruction: systemPrompt,
model
});
const mappedMessages = messages.map((message) => {
if (message.role === "user") {
return {
role: "user",
parts: [{ text: message.content }]
};
} else {
return {
role: "model",
parts: [{ text: message.content }]
};
}
});
const chat = chatModel.startChat({
history: mappedMessages,
generationConfig: {
maxOutputTokens: 4096,
temperature: 0.7
}
});
const stream = yield new __await(chat.sendMessageStream(messages[messages.length - 1].content));
try {
for (var iter = __forAwait(stream.stream), more, temp, error; more = !(temp = yield new __await(iter.next())).done; more = false) {
const chunk = temp.value;
yield {
content: chunk.text(),
usage: {
inputTokens: (_b = (_a = chunk.usageMetadata) == null ? void 0 : _a.promptTokenCount) != null ? _b : 0,
outputTokens: (_d = (_c = chunk.usageMetadata) == null ? void 0 : _c.candidatesTokenCount) != null ? _d : 0,
totalTokens: (_f = (_e = chunk.usageMetadata) == null ? void 0 : _e.totalTokenCount) != null ? _f : 0
}
};
}
} catch (temp) {
error = [temp];
} finally {
try {
more && (temp = iter.return) && (yield new __await(temp.call(iter)));
} finally {
if (error)
throw error[0];
}
}
});
}
};
// src/client/clients/fireworks-ai.ts
var import_dotenv4 = __toESM(require("dotenv"));
var import_openai2 = __toESM(require("openai"));
// src/types/ai-model-type.ts
var AiModelType = /* @__PURE__ */ ((AiModelType2) => {
AiModelType2["Gpt4o"] = "gpt-4o";
AiModelType2["Gpt4oMini"] = "gpt-4o-mini";
AiModelType2["DallE3"] = "dall-e3";
AiModelType2["GptTextToSpeech"] = "tts-1";
AiModelType2["O1Preview"] = "o1-preview";
AiModelType2["O1Mini"] = "o1-mini";
AiModelType2["Claude37SonnetLatest"] = "claude-3-7-sonnet-latest";
AiModelType2["Claude35HaikuLatest"] = "claude-3-5-haiku-latest";
AiModelType2["Gemini20FlashLatest"] = "gemini-2.0-flash-latest";
AiModelType2["Gemini15FlashLatest"] = "gemini-1.5-flash-latest";
AiModelType2["Gemini15ProLatest"] = "gemini-1.5-pro-latest";
AiModelType2["Llama33"] = "fw-llama-3-3";
AiModelType2["DeepSeekV3"] = "fw-deepseek-v3";
AiModelType2["DeepSeekR1"] = "fw-deepseek-r1";
return AiModelType2;
})(AiModelType || {});
// src/client/libs/get-fw-model.ts
var getFwModel = (model) => {
if (model === "fw-llama-3-3" /* Llama33 */) {
return "accounts/fireworks/models/llama-v3p3-70b-instruct";
}
if (model === "fw-deepseek-v3" /* DeepSeekV3 */) {
return "accounts/fireworks/models/deepseek-v3";
}
if (model === "fw-deepseek-r1" /* DeepSeekR1 */) {
return "accounts/fireworks/models/deepseek-r1";
}
return "accounts/fireworks/models/llama-v3p3-70b-instruct";
};
// src/client/clients/fireworks-ai.ts
import_dotenv4.default.config();
var FireworksAiClient = class {
constructor(client = new import_openai2.default(), apiKey, baseUrl) {
this.client = client;
this.apiKey = apiKey != null ? apiKey : process.env.FIREWORKS_API_KEY;
this.baseUrl = baseUrl != null ? baseUrl : "https://api.fireworks.ai/inference/v1";
if (!this.apiKey) {
throw new Error("FireworksAI API key is required. Set it in config or via FIREWORKS_API_KEY environment variable.");
}
}
invoke(model, messages, systemPrompt) {
return __async(this, null, function* () {
var _a, _b, _c, _d, _e, _f, _g, _h, _i;
messages.unshift({
role: "system",
content: systemPrompt
});
const response = yield this.client.chat.completions.create({
model: getFwModel(model),
messages
});
return {
content: (_c = (_b = (_a = response.choices[0]) == null ? void 0 : _a.message) == null ? void 0 : _b.content) != null ? _c : "",
usage: {
inputTokens: (_e = (_d = response.usage) == null ? void 0 : _d.prompt_tokens) != null ? _e : 0,
outputTokens: (_g = (_f = response.usage) == null ? void 0 : _f.completion_tokens) != null ? _g : 0,
totalTokens: (_i = (_h = response.usage) == null ? void 0 : _h.total_tokens) != null ? _i : 0
}
};
});
}
stream(model, messages, systemPrompt) {
return __asyncGenerator(this, null, function* () {
var _a, _b, _c, _d, _e, _f, _g, _h, _i;
messages.unshift({
role: "system",
content: systemPrompt
});
const stream = yield new __await(this.client.chat.completions.create({
model: getFwModel(model),
messages,
stream: true,
stream_options: {
include_usage: true
}
}));
try {
for (var iter = __forAwait(stream), more, temp, error; more = !(temp = yield new __await(iter.next())).done; more = false) {
const chunk = temp.value;
yield {
content: (_c = (_b = (_a = chunk.choices[0]) == null ? void 0 : _a.delta) == null ? void 0 : _b.content) != null ? _c : "",
usage: {
inputTokens: (_e = (_d = chunk.usage) == null ? void 0 : _d.prompt_tokens) != null ? _e : 0,
outputTokens: (_g = (_f = chunk.usage) == null ? void 0 : _f.completion_tokens) != null ? _g : 0,
totalTokens: (_i = (_h = chunk.usage) == null ? void 0 : _h.total_tokens) != null ? _i : 0
}
};
}
} catch (temp) {
error = [temp];
} finally {
try {
more && (temp = iter.return) && (yield new __await(temp.call(iter)));
} finally {
if (error)
throw error[0];
}
}
});
}
};
// src/client/prompts/structured-output.ts
var import_zod_to_json_schema = __toESM(require("zod-to-json-schema"));
var StructuredOutputTemplate = class {
static getFormatInstructions(schema) {
return `
You must format your output as a JSON value that adheres to a given "JSON Schema" instance.
"JSON Schema" is a declarative language that allows you to annotate and validate JSON documents.
For example, the example "JSON Schema" instance {{"properties": {{"foo": {{"description": "a list of test words", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}}}
would match an object with one required property, "foo". The "type" property specifies "foo" must be an "array", and the "description" property semantically describes it as "a list of test words". The items within "foo" must be strings.
Thus, the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of this example "JSON Schema". The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Your output will be parsed and type-checked according to the provided schema instance, so make sure all fields in your output match the schema exactly and there are no trailing commas!
Here is the JSON Schema instance your output must adhere to. Do not include the enclosing markdown codeblock:
\`\`\`json
${JSON.stringify((0, import_zod_to_json_schema.default)(schema))}
\`\`\`
`;
}
};
// src/client/ai-client.ts
var AiClient = class {
/**
* Constructs an instance of AiClient.
* @param clients - Optional custom clients for AI providers.
* @param modelProviderMap - Optional custom mapping of models to providers.
*/
constructor(clients, modelProviderMap) {
this.clients = clients;
this.modelProviderMap = modelProviderMap;
this.options = {};
this.clients = {
openai: new OpenAiClient(),
anthropic: new AnthropicAiClient(),
gemini: new GoogleAiClient(),
fireworks: new FireworksAiClient()
};
this.modelProviderMap = {
["gpt-4o" /* Gpt4o */]: "openai",
["gpt-4o-mini" /* Gpt4oMini */]: "openai",
["tts-1" /* GptTextToSpeech */]: "openai",
["dall-e3" /* DallE3 */]: "openai",
["o1-preview" /* O1Preview */]: "openai",
["o1-mini" /* O1Mini */]: "openai",
["claude-3-7-sonnet-latest" /* Claude37SonnetLatest */]: "anthropic",
["claude-3-5-haiku-latest" /* Claude35HaikuLatest */]: "anthropic",
["gemini-2.0-flash-latest" /* Gemini20FlashLatest */]: "gemini",
["gemini-1.5-flash-latest" /* Gemini15FlashLatest */]: "gemini",
["gemini-1.5-pro-latest" /* Gemini15ProLatest */]: "gemini",
["fw-llama-3-3" /* Llama33 */]: "fireworks",
["fw-deepseek-v3" /* DeepSeekV3 */]: "fireworks",
["fw-deepseek-r1" /* DeepSeekR1 */]: "fireworks"
};
}
/**
* Gets the provider name by model type.
* @param model - The model type.
* @returns The provider name.
* @throws Will throw an error if no provider is found for the model.
*/
_getProviderByModel(model) {
const provider = this.modelProviderMap[model];
if (!provider) {
throw new Error(`No provider found for model: ${model}`);
}
return provider;
}
/**
* Sets the structured output schema for the client.
* @param schema - The Zod schema for structured output.
* @returns The AiClient instance.
*/
withStructuredOutput(schema) {
this.structuredOutputSchema = schema;
return this;
}
/**
* Sets options for invoking or streaming.
* @param options - The options to set.
* @returns The AiClient instance.
*/
setOptions(options) {
this.options = __spreadValues(__spreadValues({}, this.options), options);
return this;
}
/**
* Invokes an AI model with the given options.
* @param options - The options for invoking the model.
* @returns A promise that resolves to the AI message response.
*/
invoke(options) {
return __async(this, null, function* () {
const fullOptions = this.prepareOptions(options);
return this._invokeOrStream(fullOptions, "invoke");
});
}
/**
* Streams responses from an AI model with the given options.
* @param options - The options for streaming responses.
* @returns A promise that resolves to an async generator of AI message responses.
*/
stream(options) {
return __async(this, null, function* () {
const fullOptions = this.prepareOptions(options);
return this._invokeOrStream(fullOptions, "stream");
});
}
/**
* Prepares options by adding the structured output schema to the system prompt if available.
* @param options - The options to prepare.
* @returns The prepared options.
*/
prepareOptions(options) {
let { systemPrompt } = options;
if (this.structuredOutputSchema) {
systemPrompt += StructuredOutputTemplate.getFormatInstructions(this.structuredOutputSchema);
}
return __spreadProps(__spreadValues({}, options), { systemPrompt });
}
_invokeOrStream(options, method) {
return __async(this, null, function* () {
const { model, messages, systemPrompt } = options;
const provider = this._getProviderByModel(model);
const client = this.clients[provider];
if (!client) {
throw new Error(`Provider ${provider} is not configured.`);
}
try {
const res = yield client[method](model, messages, systemPrompt);
this.structuredOutputSchema = void 0;
return res;
} catch (e) {
throw e;
}
});
}
/**
* Generates an image using the specified options.
* @param options - The options for generating the image.
* @returns A promise that resolves to the AI message response image.
* @throws Will throw an error if the provider is not configured.
*/
generateImage(options) {
return __async(this, null, function* () {
const { model, prompt, size, n } = options;
const provider = this._getProviderByModel(model);
const client = this.clients[provider];
if (!client) {
throw new Error(`Provider ${provider} is not configured.`);
}
try {
return yield client.generateImage(model, prompt, size, n);
} catch (e) {
throw e;
}
});
}
/**
* Generates speech using the specified model, messages, and voice.
* @param model - The model type.
* @param messages - The messages to convert to speech.
* @param voice - The voice to use for speech generation.
* @returns A promise that resolves to the generated speech as a string.
* @throws Will throw an error if the provider is not configured.
*/
generateSpeech(model, messages, voice) {
return __async(this, null, function* () {
const provider = this._getProviderByModel(model);
const client = this.clients[provider];
if (!client) {
throw new Error(`Provider ${provider} is not configured.`);
}
try {
return yield client.generateSpeech(model, messages, voice);
} catch (e) {
throw e;
}
});
}
};
// src/client/agents/agent.ts
var import_zod = require("zod");
// src/client/agents/tools/tool-registry.ts
var ToolRegistry = class {
constructor() {
this.tools = {};
}
registerTools(tools) {
tools.forEach((tool2) => {
this.tools[tool2.name] = tool2;
console.log(`Registered tool: ${tool2.name}`);
});
}
callTool(toolName, ...args) {
return __async(this, null, function* () {
const tool2 = this.tools[toolName];
if (tool2) {
console.log(`Calling tool: ${tool2.name}`);
return yield tool2.execute(...args);
} else {
throw new Error(`Tool '${toolName}' not found.`);
}
});
}
getToolInfo(toolName) {
const tool2 = this.tools[toolName];
if (!tool2) {
throw new Error(`Tool '${toolName}' not found.`);
}
return { name: tool2.name, description: tool2.description, queryFormat: tool2.getQueryFormat() };
}
getAllTools() {
const toolInfos = {};
for (const toolName in this.tools) {
const tool2 = this.tools[toolName];
toolInfos[toolName] = {
name: tool2.name,
description: tool2.description,
queryFormat: tool2.getQueryFormat()
};
}
return toolInfos;
}
};
// src/client/agents/tools/tool.ts
var import_zod_openapi = require("@anatine/zod-openapi");
var Tool = class {
constructor(func, options) {
this.name = options.name;
this.description = options.description;
this.queryFormat = options.queryFormat;
this.func = func;
}
execute(...args) {
return __async(this, null, function* () {
const validation = this.queryFormat.safeParse(args[0]);
if (!validation.success) throw new Error("Invalid argument: " + validation.error.message);
return yield this.func(...args);
});
}
getQueryFormat() {
return JSON.stringify((0, import_zod_openapi.generateSchema)(this.queryFormat));
}
};
// src/client/agents/tools/tool-helper.ts
function tool(func, options) {
return new Tool(func, options);
}
// src/client/libs/usage-counter.ts
var UsageCounter = class {
/**
* Constructs a UsageCounter instance and initializes token counts to zero.
*/
constructor() {
this.inputTokens = 0;
this.outputTokens = 0;
this.totalTokens = 0;
}
/**
* Adds the usage tokens to the current counts.
* @param {Usage} usage - The usage object containing input, output, and total tokens.
*/
addUsageTokens(usage) {
this.inputTokens += usage.inputTokens;
this.outputTokens += usage.outputTokens;
this.totalTokens += usage.totalTokens;
}
/**
* Retrieves the current usage counts.
* @returns {Object} An object containing the counts of input, output, and total tokens.
*/
getUsage() {
return {
inputTokens: this.inputTokens,
outputTokens: this.outputTokens,
totalTokens: this.totalTokens
};
}
};
// src/client/prompts/react-template.ts
var ReActTemplate = class {
/**
* Constructs a ReActTemplate instance.
* @param {string} name - The name of the agent.
*/
constructor(name) {
this.originalRequestTemplate = "## Original Request\n\n";
this.toolTemplate = "## Tools\n\n";
this.instructionTemplate = "## Instructions\n\n";
this.iterationTemplate = "## Iterations\n\n";
this.template = "# ReAct Agent\n\n";
}
/**
* Adds the default instructions to the instruction template.
*/
useDefaultInstructions() {
this.addInstruction(
"You run in a loop of Thought, Action, PAUSE, Observation.\nAt the end of the loop you output an Answer\nStrictly follow the provided response format.\nUse Thought to describe your thoughts about the question you have been asked.\nUse Action to run one of the actions available to you\nObservation will be the result of running those actions.\n"
);
}
/**
* Adds the original request to the original request template.
* @param {string} request - The original request to add.
*/
addOriginalRequest(request) {
this.originalRequestTemplate += `${request}
`;
}
/**
* Adds an instruction to the instruction template.
* @param {string} instruction - The instruction to add.
*/
addInstruction(instruction) {
this.instructionTemplate += `${instruction}
`;
}
/**
* Adds a tool to the tool template.
* @param {string} tool - The tool to add.
*/
addTool(tool2) {
this.toolTemplate += `${tool2}
`;
}
/**
* Adds an iteration to the iteration template.
* @param {number} iteration - The iteration number.
* @param {string} content - The content of the iteration.
*/
addIteration(iteration, content) {
this.iterationTemplate += `### Iteration ${iteration}
${content}
`;
}
/**
* Retrieves the complete template including all sections.
* @returns {string} The complete template.
*/
getTemplate() {
return this.template + this.originalRequestTemplate + this.instructionTemplate + this.toolTemplate + this.iterationTemplate;
}
/**
* Retrieves the template with instructions, original request, and tools.
* @returns {string} The template with instructions.
*/
getInstructions() {
return this.template + this.originalRequestTemplate + this.instructionTemplate + this.toolTemplate;
}
};
// src/client/agents/agent.ts
var Agent = class {
/**
* Constructs an Agent instance.
* @param {string} name - The name of the agent.
* @param {string} [systemPrompt] - An optional system prompt to initialize the agent with.
*/
constructor(name, systemPrompt) {
this.client = new AiClient();
this.toolRegistry = new ToolRegistry();
this.promptTemplate = new ReActTemplate(name);
systemPrompt ? this.promptTemplate.addInstruction(systemPrompt) : this.promptTemplate.useDefaultInstructions();
this.usageCounter = new UsageCounter();
}
/**
* Registers a list of tools with the agent.
* @param {Array<any>} tools - An array of tools to register.
*/
registerTools(tools) {
this.toolRegistry.registerTools(tools);
}
/**
* Executes a task based on the provided input.
* @param {string} input - The input string for the task.
* @returns {Promise<AiMessageResponse>} - The final response from the agent.
*/
executeTask(input) {
return __async(this, null, function* () {
this.promptTemplate.addOriginalRequest(input);
const toolInfos = [];
const schemas = {};
for (const toolName in this.toolRegistry.getAllTools()) {
const toolInfo = this.toolRegistry.getToolInfo(toolName);
if (toolInfo) {
toolInfos.push(toolInfo);
schemas[toolName] = toolInfo.queryFormat;
}
}
this.promptTemplate.addTool(`You have these tools at your disposal: ${JSON.stringify(toolInfos, null, 2)}`);
let nextPrompt = input;
let finalResponse = null;
const maxIterations = 10;
let count = 0;
while (count < maxIterations) {
count++;
const prompt = this.promptTemplate.getInstructions();
const reActSchema = import_zod.z.object({
thought: import_zod.z.string(),
action: import_zod.z.object({
tool: import_zod.z.string().refine((toolName) => toolName in schemas, { message: "Invalid tool" }),
input: import_zod.z.any()
}).optional(),
answer: import_zod.z.string()
});
const response = yield this.client.withStructuredOutput(
reActSchema
).invoke({
model: "gpt-4o-mini",
messages: [{ role: "user", content: nextPrompt }],
systemPrompt: prompt,
temperature: 0
});
this.usageCounter.addUsageTokens(response.usage);
const parsedResponse = reActSchema.safeParse(JSON.parse(response.content));
if (parsedResponse.success && parsedResponse.data.answer) {
finalResponse = parsedResponse.data.answer;
break;
}
if (parsedResponse.success && parsedResponse.data.action) {
const action = parsedResponse.data.action;
const selectedSchema = schemas[action.tool];
const toolResponse = yield this.toolRegistry.callTool(action.tool, action.input);
nextPrompt = `Observation: ${toolResponse}`;
this.promptTemplate.addIteration(count, nextPrompt);
} else {
break;
}
}
return {
content: finalResponse ? finalResponse : "No valid answer could be provided.",
usage: this.usageCounter.getUsage()
};
});
}
};
// src/client/prompts/templates.ts
var BasicPromptTemplate = class {
constructor(name) {
this.template = "";
this.template += `# ${name}
`;
}
addTemplateSection(name, content) {
this.template += `## ${name}
${content}
`;
}
getTemplate() {
return this.template;
}
};
// Annotate the CommonJS export names for ESM import in node:
0 && (module.exports = {
Agent,
AiClient,
AiModelType,
BasicPromptTemplate,
StructuredOutputTemplate,
Tool,
ToolRegistry,
tool
});
//# sourceMappingURL=index.js.map