notdiamond
Version:
TS/JS client for the NotDiamond API
584 lines (573 loc) • 20.1 kB
JavaScript
;
const dotenv = require('dotenv');
const openai = require('@langchain/openai');
const messages = require('@langchain/core/messages');
const anthropic = require('@langchain/anthropic');
const googleGenai = require('@langchain/google-genai');
const mistralai = require('@langchain/mistralai');
const chat_models = require('@langchain/core/language_models/chat_models');
const axios = require('axios');
const cohere = require('@langchain/cohere');
const togetherai = require('@langchain/community/chat_models/togetherai');
function _interopDefaultCompat (e) { return e && typeof e === 'object' && 'default' in e ? e.default : e; }
function _interopNamespaceCompat(e) {
if (e && typeof e === 'object' && 'default' in e) return e;
const n = Object.create(null);
if (e) {
for (const k in e) {
n[k] = e[k];
}
}
n.default = e;
return n;
}
const dotenv__namespace = /*#__PURE__*/_interopNamespaceCompat(dotenv);
const axios__default = /*#__PURE__*/_interopDefaultCompat(axios);
const version = "1.1.1";
const packageJson = {
version: version};
class ChatPerplexity extends chat_models.BaseChatModel {
_generate(messages, options, runManager) {
throw new Error(
"Method not implemented." + JSON.stringify(messages) + JSON.stringify(options) + JSON.stringify(runManager)
);
}
apiKey;
model;
constructor({ apiKey, model }) {
super({});
this.apiKey = apiKey;
this.model = model;
}
_llmType() {
return "perplexity";
}
/**
* Invokes the Perplexity model.
* @param messages The messages to send to the model.
* @returns The results of the model.
*/
async invoke(messages$1) {
try {
const { data } = await axios__default.post(
"https://api.perplexity.ai/chat/completions",
{
model: this.model,
messages: messages$1.map((m) => ({
role: m._getType() === "human" ? "user" : m._getType(),
content: m.content
}))
},
{
headers: {
Authorization: `Bearer ${this.apiKey}`
}
}
);
return new messages.AIMessage(data.choices[0].message.content);
} catch (error) {
if (axios__default.isAxiosError(error) && error.response) {
throw new Error(`Perplexity API error: ${error.response.statusText}`);
}
throw error;
}
}
}
const SupportedProvider = {
OPENAI: "openai",
ANTHROPIC: "anthropic",
GOOGLE: "google",
MISTRAL: "mistral",
PERPLEXITY: "perplexity",
COHERE: "cohere",
TOGETHERAI: "togetherai"
};
const SupportedModel = {
GPT_3_5_TURBO: "gpt-3.5-turbo",
GPT_3_5_TURBO_0125: "gpt-3.5-turbo-0125",
GPT_4: "gpt-4",
GPT_4_0613: "gpt-4-0613",
GPT_4_1106_PREVIEW: "gpt-4-1106-preview",
GPT_4_TURBO: "gpt-4-turbo",
GPT_4_TURBO_PREVIEW: "gpt-4-turbo-preview",
GPT_4_TURBO_2024_04_09: "gpt-4-turbo-2024-04-09",
GPT_4O_2024_05_13: "gpt-4o-2024-05-13",
GPT_4O_2024_08_06: "gpt-4o-2024-08-06",
GPT_4O: "gpt-4o",
GPT_4O_MINI_2024_07_18: "gpt-4o-mini-2024-07-18",
GPT_4O_MINI: "gpt-4o-mini",
GPT_4_0125_PREVIEW: "gpt-4-0125-preview",
GPT_4_5_PREVIEW: "gpt-4.5-preview",
GPT_4_5_PREVIEW_2025_02_27: "gpt-4.5-preview-2025-02-27",
CHATGPT_4O_LATEST: "chatgpt-4o-latest",
O1_PREVIEW: "o1-preview",
O1_PREVIEW_2024_09_12: "o1-preview-2024-09-12",
O1_MINI: "o1-mini",
O1_MINI_2024_09_12: "o1-mini-2024-09-12",
CLAUDE_2_1: "claude-2.1",
CLAUDE_3_OPUS_20240229: "claude-3-opus-20240229",
CLAUDE_3_SONNET_20240229: "claude-3-sonnet-20240229",
CLAUDE_3_5_SONNET_20240620: "claude-3-5-sonnet-20240620",
CLAUDE_3_5_SONNET_20241022: "claude-3-5-sonnet-20241022",
CLAUDE_3_5_SONNET_LATEST: "claude-3-5-sonnet-latest",
CLAUDE_3_HAIKU_20240307: "claude-3-haiku-20240307",
CLAUDE_3_5_HAIKU_20241022: "claude-3-5-haiku-20241022",
CLAUDE_3_7_SONNET_LATEST: "claude-3-7-sonnet-latest",
CLAUDE_3_7_SONNET_20250219: "claude-3-7-sonnet-20250219",
GEMINI_PRO: "gemini-pro",
GEMINI_1_PRO_LATEST: "gemini-1.0-pro-latest",
GEMINI_15_PRO_LATEST: "gemini-1.5-pro-latest",
GEMINI_15_PRO_EXP_0801: "gemini-1.5-pro-exp-0801",
GEMINI_15_FLASH_LATEST: "gemini-1.5-flash-latest",
GEMINI_2_0_FLASH: "gemini-2.0-flash",
GEMINI_2_0_FLASH_001: "gemini-2.0-flash-001",
COMMAND_R: "command-r",
COMMAND_R_PLUS: "command-r-plus",
MISTRAL_LARGE_LATEST: "mistral-large-latest",
MISTRAL_LARGE_2407: "mistral-large-2407",
MISTRAL_LARGE_2402: "mistral-large-2402",
MISTRAL_MEDIUM_LATEST: "mistral-medium-latest",
MISTRAL_SMALL_LATEST: "mistral-small-latest",
CODESTRAL_LATEST: "codestral-latest",
OPEN_MISTRAL_7B: "open-mistral-7b",
OPEN_MIXTRAL_8X7B: "open-mixtral-8x7b",
OPEN_MIXTRAL_8X22B: "open-mixtral-8x22b",
MISTRAL_7B_INSTRUCT_V0_2: "Mistral-7B-Instruct-v0.2",
MIXTRAL_8X7B_INSTRUCT_V0_1: "Mixtral-8x7B-Instruct-v0.1",
MIXTRAL_8X22B_INSTRUCT_V0_1: "Mixtral-8x22B-Instruct-v0.1",
LLAMA_3_70B_CHAT_HF: "Llama-3-70b-chat-hf",
LLAMA_3_8B_CHAT_HF: "Llama-3-8b-chat-hf",
QWEN2_72B_INSTRUCT: "Qwen2-72B-Instruct",
LLAMA_3_1_8B_INSTRUCT_TURBO: "Meta-Llama-3.1-8B-Instruct-Turbo",
LLAMA_3_1_70B_INSTRUCT_TURBO: "Meta-Llama-3.1-70B-Instruct-Turbo",
LLAMA_3_1_405B_INSTRUCT_TURBO: "Meta-Llama-3.1-405B-Instruct-Turbo",
PERPLEXITY_SONAR: "sonar",
OPEN_MISTRAL_NEMO: "open-mistral-nemo",
DEEPSEEK_R1: "DeepSeek-R1"
};
({
[SupportedProvider.OPENAI]: [
SupportedModel.GPT_3_5_TURBO,
SupportedModel.GPT_3_5_TURBO_0125,
SupportedModel.GPT_4,
SupportedModel.GPT_4_0613,
SupportedModel.GPT_4_1106_PREVIEW,
SupportedModel.GPT_4_TURBO,
SupportedModel.GPT_4_TURBO_PREVIEW,
SupportedModel.GPT_4_TURBO_2024_04_09,
SupportedModel.GPT_4O_2024_05_13,
SupportedModel.GPT_4O_2024_08_06,
SupportedModel.GPT_4O,
SupportedModel.GPT_4O_MINI_2024_07_18,
SupportedModel.GPT_4O_MINI,
SupportedModel.GPT_4_0125_PREVIEW,
SupportedModel.O1_PREVIEW,
SupportedModel.O1_PREVIEW_2024_09_12,
SupportedModel.O1_MINI,
SupportedModel.O1_MINI_2024_09_12,
SupportedModel.CHATGPT_4O_LATEST,
SupportedModel.GPT_4_5_PREVIEW,
SupportedModel.GPT_4_5_PREVIEW_2025_02_27
],
[SupportedProvider.ANTHROPIC]: [
SupportedModel.CLAUDE_2_1,
SupportedModel.CLAUDE_3_OPUS_20240229,
SupportedModel.CLAUDE_3_SONNET_20240229,
SupportedModel.CLAUDE_3_5_SONNET_20240620,
SupportedModel.CLAUDE_3_5_SONNET_20241022,
SupportedModel.CLAUDE_3_5_SONNET_LATEST,
SupportedModel.CLAUDE_3_HAIKU_20240307,
SupportedModel.CLAUDE_3_5_HAIKU_20241022,
SupportedModel.CLAUDE_3_7_SONNET_LATEST,
SupportedModel.CLAUDE_3_7_SONNET_20250219
],
[SupportedProvider.GOOGLE]: [
SupportedModel.GEMINI_PRO,
SupportedModel.GEMINI_1_PRO_LATEST,
SupportedModel.GEMINI_15_PRO_LATEST,
SupportedModel.GEMINI_15_PRO_EXP_0801,
SupportedModel.GEMINI_15_FLASH_LATEST,
SupportedModel.GEMINI_2_0_FLASH,
SupportedModel.GEMINI_2_0_FLASH_001
],
[SupportedProvider.MISTRAL]: [
SupportedModel.MISTRAL_LARGE_LATEST,
SupportedModel.MISTRAL_LARGE_2407,
SupportedModel.MISTRAL_LARGE_2402,
SupportedModel.MISTRAL_MEDIUM_LATEST,
SupportedModel.MISTRAL_SMALL_LATEST,
SupportedModel.CODESTRAL_LATEST,
SupportedModel.OPEN_MISTRAL_7B,
SupportedModel.OPEN_MIXTRAL_8X7B,
SupportedModel.OPEN_MIXTRAL_8X22B,
SupportedModel.OPEN_MISTRAL_NEMO
],
[SupportedProvider.PERPLEXITY]: [
SupportedModel.PERPLEXITY_SONAR
],
[SupportedProvider.COHERE]: [
SupportedModel.COMMAND_R,
SupportedModel.COMMAND_R_PLUS
],
[SupportedProvider.TOGETHERAI]: [
SupportedModel.MISTRAL_7B_INSTRUCT_V0_2,
SupportedModel.MIXTRAL_8X7B_INSTRUCT_V0_1,
SupportedModel.MIXTRAL_8X22B_INSTRUCT_V0_1,
SupportedModel.LLAMA_3_70B_CHAT_HF,
SupportedModel.LLAMA_3_8B_CHAT_HF,
SupportedModel.QWEN2_72B_INSTRUCT,
SupportedModel.LLAMA_3_1_8B_INSTRUCT_TURBO,
SupportedModel.LLAMA_3_1_70B_INSTRUCT_TURBO,
SupportedModel.LLAMA_3_1_405B_INSTRUCT_TURBO,
SupportedModel.DEEPSEEK_R1
]
});
function getLangChainModel(provider, llmKeys, responseModel) {
const { OPENAI, ANTHROPIC, GOOGLE, MISTRAL, PERPLEXITY, COHERE, TOGETHERAI } = SupportedProvider;
switch (provider.provider) {
case OPENAI:
if (responseModel) {
return new openai.ChatOpenAI({
modelName: provider.model,
apiKey: llmKeys.openai || process.env.OPENAI_API_KEY
}).withStructuredOutput(responseModel);
}
return new openai.ChatOpenAI({
modelName: provider.model,
apiKey: llmKeys.openai || process.env.OPENAI_API_KEY
});
case ANTHROPIC:
if (responseModel) {
return new anthropic.ChatAnthropic({
modelName: provider.model,
anthropicApiKey: llmKeys.anthropic || process.env.ANTHROPIC_API_KEY
}).withStructuredOutput(responseModel);
}
return new anthropic.ChatAnthropic({
modelName: provider.model,
anthropicApiKey: llmKeys.anthropic || process.env.ANTHROPIC_API_KEY
});
case GOOGLE:
if (responseModel) {
return new googleGenai.ChatGoogleGenerativeAI({
modelName: provider.model,
apiKey: llmKeys.google || process.env.GOOGLE_API_KEY
}).withStructuredOutput(responseModel);
}
return new googleGenai.ChatGoogleGenerativeAI({
modelName: provider.model,
apiKey: llmKeys.google || process.env.GOOGLE_API_KEY
});
case MISTRAL:
if (responseModel) {
return new mistralai.ChatMistralAI({
modelName: provider.model,
apiKey: llmKeys.mistral || process.env.MISTRAL_API_KEY
}).withStructuredOutput(responseModel);
}
return new mistralai.ChatMistralAI({
modelName: provider.model,
apiKey: llmKeys.mistral || process.env.MISTRAL_API_KEY
});
case PERPLEXITY:
if (responseModel) {
return new ChatPerplexity({
apiKey: llmKeys.perplexity || process.env.PPLX_API_KEY || "",
model: provider.model
}).withStructuredOutput(responseModel);
}
return new ChatPerplexity({
apiKey: llmKeys.perplexity || process.env.PPLX_API_KEY || "",
model: provider.model
});
case COHERE:
if (responseModel) {
return new cohere.ChatCohere({
apiKey: process.env.COHERE_API_KEY || llmKeys.cohere,
model: provider.model
}).withStructuredOutput(responseModel);
}
return new cohere.ChatCohere({
apiKey: process.env.COHERE_API_KEY || llmKeys.cohere,
model: provider.model
});
case TOGETHERAI:
if (responseModel) {
return new togetherai.ChatTogetherAI({
apiKey: process.env.TOGETHERAI_API_KEY || llmKeys.togetherai,
model: getTogetheraiModel(provider.model)
}).withStructuredOutput(responseModel);
}
return new togetherai.ChatTogetherAI({
apiKey: process.env.TOGETHERAI_API_KEY || llmKeys.togetherai,
model: getTogetheraiModel(provider.model)
});
default:
throw new Error(`Unsupported provider: ${provider.provider}`);
}
}
const getTogetheraiModel = (model) => {
if (model === SupportedModel.MISTRAL_7B_INSTRUCT_V0_2 || model === SupportedModel.MIXTRAL_8X7B_INSTRUCT_V0_1 || model === SupportedModel.MIXTRAL_8X22B_INSTRUCT_V0_1) {
return `mistralai/${model}`;
}
if (model === SupportedModel.LLAMA_3_70B_CHAT_HF || model === SupportedModel.LLAMA_3_8B_CHAT_HF || model === SupportedModel.LLAMA_3_1_8B_INSTRUCT_TURBO || model === SupportedModel.LLAMA_3_1_70B_INSTRUCT_TURBO || model === SupportedModel.LLAMA_3_1_405B_INSTRUCT_TURBO) {
return `meta-llama/${model}`;
}
if (model === SupportedModel.QWEN2_72B_INSTRUCT) {
return `Qwen/${model}`;
}
return model;
};
async function callLLM(provider, options, llmKeys, runtimeArgs) {
const model = getLangChainModel(provider, llmKeys, options.responseModel);
const langChainMessages = extendProviderSystemPrompt(
options.messages.map(convertToLangChainMessage),
options,
provider
);
const response = await model.invoke(langChainMessages, runtimeArgs);
return extractContent(response);
}
function extendProviderSystemPrompt(messages$1, options, provider) {
const matchingProvider = options.llmProviders.find(
(p) => p.provider === provider.provider && p.model === provider.model
);
if (matchingProvider && matchingProvider.systemPrompt) {
messages$1.unshift(new messages.SystemMessage(matchingProvider.systemPrompt));
}
return messages$1;
}
function convertToLangChainMessage(msg) {
switch (msg.role) {
case "user":
return new messages.HumanMessage(msg.content);
case "assistant":
return new messages.AIMessage(msg.content);
case "system":
return new messages.SystemMessage(msg.content);
default:
return new messages.HumanMessage(msg.content);
}
}
async function* callLLMStream(provider, options, llmKeys, runtimeArgs) {
const model = getLangChainModel(provider, llmKeys, options.responseModel);
const langChainMessages = extendProviderSystemPrompt(
options.messages.map(convertToLangChainMessage),
options,
provider
);
const stream = await model.stream(langChainMessages, runtimeArgs);
for await (const chunk of stream) {
yield extractContent(chunk);
}
}
function extractContent(response) {
if ("content" in response) {
return typeof response.content === "string" ? response.content : JSON.stringify(response.content);
}
return typeof response === "string" ? response : JSON.stringify(response);
}
const SDK_VERSION = packageJson.version;
dotenv__namespace.config();
const DEFAULT_TIMEOUT = 5;
const BASE_URL = "https://api.notdiamond.ai";
class NotDiamond {
apiKey;
apiUrl;
modelSelectUrl;
feedbackUrl;
createUrl;
llmKeys;
constructor(options = {}) {
this.apiKey = options.apiKey || process.env.NOTDIAMOND_API_KEY || "";
this.apiUrl = options.apiUrl || process.env.NOTDIAMOND_API_URL || BASE_URL;
this.llmKeys = options.llmKeys || {};
this.modelSelectUrl = `${this.apiUrl}/v2/modelRouter/modelSelect`;
this.feedbackUrl = `${this.apiUrl}/v2/report/metrics/feedback`;
this.createUrl = `${this.apiUrl}/v2/preferences/userPreferenceCreate`;
}
getAuthHeader() {
return `Bearer ${this.apiKey}`;
}
async postRequest(url, body) {
try {
const response = await axios__default.post(url, body, {
headers: {
Authorization: this.getAuthHeader(),
Accept: "application/json",
"Content-Type": "application/json",
"User-Agent": `TS-SDK/${SDK_VERSION}`
}
});
return response.data;
} catch (error) {
if (axios__default.isAxiosError(error) && error.response) {
return { detail: "An error occurred." };
}
console.error("error", error);
return { detail: "An unexpected error occurred." };
}
}
/**
* Selects the best model for the given messages.
* @param options The options for the model.
* @returns The results of the model.
*/
async modelSelect(options) {
const requestBody = {
messages: options.messages,
llm_providers: options.llmProviders.map((provider) => ({
provider: provider.provider,
model: provider.model,
...provider.contextLength !== void 0 && {
context_length: provider.contextLength
},
...provider.customInputPrice !== void 0 && {
input_price: provider.customInputPrice
},
...provider.inputPrice !== void 0 && {
input_price: provider.inputPrice
},
...provider.customOutputPrice !== void 0 && {
output_price: provider.customOutputPrice
},
...provider.outputPrice !== void 0 && {
output_price: provider.outputPrice
},
...provider.customLatency !== void 0 && {
latency: provider.customLatency
},
...provider.latency !== void 0 && { latency: provider.latency },
...provider.isCustom !== void 0 && {
is_custom: provider.isCustom
}
})),
...options.tradeoff && {
tradeoff: options.tradeoff
},
...options.maxModelDepth && {
max_model_depth: options.maxModelDepth
},
...options.tools && { tools: options.tools },
...options.hashContent !== void 0 && {
hash_content: options.hashContent
},
...options.preferenceId && { preference_id: options.preferenceId },
...options.timeout ? { timeout: options.timeout } : {
timeout: DEFAULT_TIMEOUT
},
...options.default && { default: options.default },
...options.previousSession && {
previous_session: options.previousSession
},
...options.responseModel && {
response_model: options.responseModel
}
};
return this.postRequest(
this.modelSelectUrl,
requestBody
);
}
/**
* Sends feedback to the NotDiamond API.
* @param options The options for the feedback.
* @returns The results of the feedback.
*/
async feedback(options) {
return this.postRequest(this.feedbackUrl, {
session_id: options.sessionId,
feedback: options.feedback,
provider: options.provider
});
}
/**
* Creates a preference id.
* @returns The preference id.
*/
async createPreferenceId() {
const response = await this.postRequest(
this.createUrl,
{}
);
if ("preference_id" in response) {
return response.preference_id;
}
throw new Error("Invalid response: preference_id not found");
}
/**
*
* @param options The options for the model.
* @returns A promise that resolves to the results of the model.
*/
async acreate(options, runtimeArgs = {}) {
const selectedModel = await this.modelSelect(options);
const { providers } = selectedModel;
const content = await callLLM(
providers[0],
options,
this.llmKeys,
runtimeArgs
);
return { content, providers };
}
/**
*
* @param options The options for the model.
* @param callback Optional callback function to handle the result.
* @returns A promise that resolves to the results of the model or a callback function
*/
create(options, runtimeArgs = {}, callback) {
const promise = this.acreate(options, runtimeArgs);
if (callback) {
promise.then((result) => callback(null, result)).catch((error) => callback(error));
} else {
return promise;
}
}
/**
* Streams the results of the model asynchronously.
* @param options The options for the model.
* @returns A promise that resolves to an object containing the provider and an AsyncIterable of strings.
*/
async astream(options, runtimeArgs = {}) {
const selectedModel = await this.modelSelect(options);
const { providers } = selectedModel;
const stream = await Promise.resolve(
callLLMStream(
providers?.[0] || options.default,
options,
this.llmKeys,
runtimeArgs
)
);
return {
provider: providers?.[0] || options.default,
stream
};
}
/**
* Streams the results of the model.
* @param options The options for the model.
* @param callback Optional callback function to handle each chunk of the stream.
* @returns A promise that resolves to an object containing the provider and an AsyncIterable of strings or a callback function
*/
stream(options, runtimeArgs = {}, callback) {
if (!options.llmProviders || options.llmProviders.length === 0) {
throw new Error("No LLM providers specified");
}
const promise = this.astream(options, runtimeArgs);
if (callback) {
promise.then(async ({ provider, stream }) => {
for await (const chunk of stream) {
callback(null, { provider, chunk });
}
}).catch((error) => callback(error));
} else {
return promise;
}
}
}
exports.NotDiamond = NotDiamond;
exports.SupportedModel = SupportedModel;
exports.SupportedProvider = SupportedProvider;