UNPKG

notdiamond

Version:
563 lines (556 loc) 19.4 kB
import * as dotenv from 'dotenv'; import { ChatOpenAI } from '@langchain/openai'; import { AIMessage, SystemMessage, HumanMessage } from '@langchain/core/messages'; import { ChatAnthropic } from '@langchain/anthropic'; import { ChatGoogleGenerativeAI } from '@langchain/google-genai'; import { ChatMistralAI } from '@langchain/mistralai'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import axios from 'axios'; import { ChatCohere } from '@langchain/cohere'; import { ChatTogetherAI } from '@langchain/community/chat_models/togetherai'; const version = "1.1.1"; const packageJson = { version: version}; class ChatPerplexity extends BaseChatModel { _generate(messages, options, runManager) { throw new Error( "Method not implemented." + JSON.stringify(messages) + JSON.stringify(options) + JSON.stringify(runManager) ); } apiKey; model; constructor({ apiKey, model }) { super({}); this.apiKey = apiKey; this.model = model; } _llmType() { return "perplexity"; } /** * Invokes the Perplexity model. * @param messages The messages to send to the model. * @returns The results of the model. */ async invoke(messages) { try { const { data } = await axios.post( "https://api.perplexity.ai/chat/completions", { model: this.model, messages: messages.map((m) => ({ role: m._getType() === "human" ? "user" : m._getType(), content: m.content })) }, { headers: { Authorization: `Bearer ${this.apiKey}` } } ); return new AIMessage(data.choices[0].message.content); } catch (error) { if (axios.isAxiosError(error) && error.response) { throw new Error(`Perplexity API error: ${error.response.statusText}`); } throw error; } } } const SupportedProvider = { OPENAI: "openai", ANTHROPIC: "anthropic", GOOGLE: "google", MISTRAL: "mistral", PERPLEXITY: "perplexity", COHERE: "cohere", TOGETHERAI: "togetherai" }; const SupportedModel = { GPT_3_5_TURBO: "gpt-3.5-turbo", GPT_3_5_TURBO_0125: "gpt-3.5-turbo-0125", GPT_4: "gpt-4", GPT_4_0613: "gpt-4-0613", GPT_4_1106_PREVIEW: "gpt-4-1106-preview", GPT_4_TURBO: "gpt-4-turbo", GPT_4_TURBO_PREVIEW: "gpt-4-turbo-preview", GPT_4_TURBO_2024_04_09: "gpt-4-turbo-2024-04-09", GPT_4O_2024_05_13: "gpt-4o-2024-05-13", GPT_4O_2024_08_06: "gpt-4o-2024-08-06", GPT_4O: "gpt-4o", GPT_4O_MINI_2024_07_18: "gpt-4o-mini-2024-07-18", GPT_4O_MINI: "gpt-4o-mini", GPT_4_0125_PREVIEW: "gpt-4-0125-preview", GPT_4_5_PREVIEW: "gpt-4.5-preview", GPT_4_5_PREVIEW_2025_02_27: "gpt-4.5-preview-2025-02-27", CHATGPT_4O_LATEST: "chatgpt-4o-latest", O1_PREVIEW: "o1-preview", O1_PREVIEW_2024_09_12: "o1-preview-2024-09-12", O1_MINI: "o1-mini", O1_MINI_2024_09_12: "o1-mini-2024-09-12", CLAUDE_2_1: "claude-2.1", CLAUDE_3_OPUS_20240229: "claude-3-opus-20240229", CLAUDE_3_SONNET_20240229: "claude-3-sonnet-20240229", CLAUDE_3_5_SONNET_20240620: "claude-3-5-sonnet-20240620", CLAUDE_3_5_SONNET_20241022: "claude-3-5-sonnet-20241022", CLAUDE_3_5_SONNET_LATEST: "claude-3-5-sonnet-latest", CLAUDE_3_HAIKU_20240307: "claude-3-haiku-20240307", CLAUDE_3_5_HAIKU_20241022: "claude-3-5-haiku-20241022", CLAUDE_3_7_SONNET_LATEST: "claude-3-7-sonnet-latest", CLAUDE_3_7_SONNET_20250219: "claude-3-7-sonnet-20250219", GEMINI_PRO: "gemini-pro", GEMINI_1_PRO_LATEST: "gemini-1.0-pro-latest", GEMINI_15_PRO_LATEST: "gemini-1.5-pro-latest", GEMINI_15_PRO_EXP_0801: "gemini-1.5-pro-exp-0801", GEMINI_15_FLASH_LATEST: "gemini-1.5-flash-latest", GEMINI_2_0_FLASH: "gemini-2.0-flash", GEMINI_2_0_FLASH_001: "gemini-2.0-flash-001", COMMAND_R: "command-r", COMMAND_R_PLUS: "command-r-plus", MISTRAL_LARGE_LATEST: "mistral-large-latest", MISTRAL_LARGE_2407: "mistral-large-2407", MISTRAL_LARGE_2402: "mistral-large-2402", MISTRAL_MEDIUM_LATEST: "mistral-medium-latest", MISTRAL_SMALL_LATEST: "mistral-small-latest", CODESTRAL_LATEST: "codestral-latest", OPEN_MISTRAL_7B: "open-mistral-7b", OPEN_MIXTRAL_8X7B: "open-mixtral-8x7b", OPEN_MIXTRAL_8X22B: "open-mixtral-8x22b", MISTRAL_7B_INSTRUCT_V0_2: "Mistral-7B-Instruct-v0.2", MIXTRAL_8X7B_INSTRUCT_V0_1: "Mixtral-8x7B-Instruct-v0.1", MIXTRAL_8X22B_INSTRUCT_V0_1: "Mixtral-8x22B-Instruct-v0.1", LLAMA_3_70B_CHAT_HF: "Llama-3-70b-chat-hf", LLAMA_3_8B_CHAT_HF: "Llama-3-8b-chat-hf", QWEN2_72B_INSTRUCT: "Qwen2-72B-Instruct", LLAMA_3_1_8B_INSTRUCT_TURBO: "Meta-Llama-3.1-8B-Instruct-Turbo", LLAMA_3_1_70B_INSTRUCT_TURBO: "Meta-Llama-3.1-70B-Instruct-Turbo", LLAMA_3_1_405B_INSTRUCT_TURBO: "Meta-Llama-3.1-405B-Instruct-Turbo", PERPLEXITY_SONAR: "sonar", OPEN_MISTRAL_NEMO: "open-mistral-nemo", DEEPSEEK_R1: "DeepSeek-R1" }; ({ [SupportedProvider.OPENAI]: [ SupportedModel.GPT_3_5_TURBO, SupportedModel.GPT_3_5_TURBO_0125, SupportedModel.GPT_4, SupportedModel.GPT_4_0613, SupportedModel.GPT_4_1106_PREVIEW, SupportedModel.GPT_4_TURBO, SupportedModel.GPT_4_TURBO_PREVIEW, SupportedModel.GPT_4_TURBO_2024_04_09, SupportedModel.GPT_4O_2024_05_13, SupportedModel.GPT_4O_2024_08_06, SupportedModel.GPT_4O, SupportedModel.GPT_4O_MINI_2024_07_18, SupportedModel.GPT_4O_MINI, SupportedModel.GPT_4_0125_PREVIEW, SupportedModel.O1_PREVIEW, SupportedModel.O1_PREVIEW_2024_09_12, SupportedModel.O1_MINI, SupportedModel.O1_MINI_2024_09_12, SupportedModel.CHATGPT_4O_LATEST, SupportedModel.GPT_4_5_PREVIEW, SupportedModel.GPT_4_5_PREVIEW_2025_02_27 ], [SupportedProvider.ANTHROPIC]: [ SupportedModel.CLAUDE_2_1, SupportedModel.CLAUDE_3_OPUS_20240229, SupportedModel.CLAUDE_3_SONNET_20240229, SupportedModel.CLAUDE_3_5_SONNET_20240620, SupportedModel.CLAUDE_3_5_SONNET_20241022, SupportedModel.CLAUDE_3_5_SONNET_LATEST, SupportedModel.CLAUDE_3_HAIKU_20240307, SupportedModel.CLAUDE_3_5_HAIKU_20241022, SupportedModel.CLAUDE_3_7_SONNET_LATEST, SupportedModel.CLAUDE_3_7_SONNET_20250219 ], [SupportedProvider.GOOGLE]: [ SupportedModel.GEMINI_PRO, SupportedModel.GEMINI_1_PRO_LATEST, SupportedModel.GEMINI_15_PRO_LATEST, SupportedModel.GEMINI_15_PRO_EXP_0801, SupportedModel.GEMINI_15_FLASH_LATEST, SupportedModel.GEMINI_2_0_FLASH, SupportedModel.GEMINI_2_0_FLASH_001 ], [SupportedProvider.MISTRAL]: [ SupportedModel.MISTRAL_LARGE_LATEST, SupportedModel.MISTRAL_LARGE_2407, SupportedModel.MISTRAL_LARGE_2402, SupportedModel.MISTRAL_MEDIUM_LATEST, SupportedModel.MISTRAL_SMALL_LATEST, SupportedModel.CODESTRAL_LATEST, SupportedModel.OPEN_MISTRAL_7B, SupportedModel.OPEN_MIXTRAL_8X7B, SupportedModel.OPEN_MIXTRAL_8X22B, SupportedModel.OPEN_MISTRAL_NEMO ], [SupportedProvider.PERPLEXITY]: [ SupportedModel.PERPLEXITY_SONAR ], [SupportedProvider.COHERE]: [ SupportedModel.COMMAND_R, SupportedModel.COMMAND_R_PLUS ], [SupportedProvider.TOGETHERAI]: [ SupportedModel.MISTRAL_7B_INSTRUCT_V0_2, SupportedModel.MIXTRAL_8X7B_INSTRUCT_V0_1, SupportedModel.MIXTRAL_8X22B_INSTRUCT_V0_1, SupportedModel.LLAMA_3_70B_CHAT_HF, SupportedModel.LLAMA_3_8B_CHAT_HF, SupportedModel.QWEN2_72B_INSTRUCT, SupportedModel.LLAMA_3_1_8B_INSTRUCT_TURBO, SupportedModel.LLAMA_3_1_70B_INSTRUCT_TURBO, SupportedModel.LLAMA_3_1_405B_INSTRUCT_TURBO, SupportedModel.DEEPSEEK_R1 ] }); function getLangChainModel(provider, llmKeys, responseModel) { const { OPENAI, ANTHROPIC, GOOGLE, MISTRAL, PERPLEXITY, COHERE, TOGETHERAI } = SupportedProvider; switch (provider.provider) { case OPENAI: if (responseModel) { return new ChatOpenAI({ modelName: provider.model, apiKey: llmKeys.openai || process.env.OPENAI_API_KEY }).withStructuredOutput(responseModel); } return new ChatOpenAI({ modelName: provider.model, apiKey: llmKeys.openai || process.env.OPENAI_API_KEY }); case ANTHROPIC: if (responseModel) { return new ChatAnthropic({ modelName: provider.model, anthropicApiKey: llmKeys.anthropic || process.env.ANTHROPIC_API_KEY }).withStructuredOutput(responseModel); } return new ChatAnthropic({ modelName: provider.model, anthropicApiKey: llmKeys.anthropic || process.env.ANTHROPIC_API_KEY }); case GOOGLE: if (responseModel) { return new ChatGoogleGenerativeAI({ modelName: provider.model, apiKey: llmKeys.google || process.env.GOOGLE_API_KEY }).withStructuredOutput(responseModel); } return new ChatGoogleGenerativeAI({ modelName: provider.model, apiKey: llmKeys.google || process.env.GOOGLE_API_KEY }); case MISTRAL: if (responseModel) { return new ChatMistralAI({ modelName: provider.model, apiKey: llmKeys.mistral || process.env.MISTRAL_API_KEY }).withStructuredOutput(responseModel); } return new ChatMistralAI({ modelName: provider.model, apiKey: llmKeys.mistral || process.env.MISTRAL_API_KEY }); case PERPLEXITY: if (responseModel) { return new ChatPerplexity({ apiKey: llmKeys.perplexity || process.env.PPLX_API_KEY || "", model: provider.model }).withStructuredOutput(responseModel); } return new ChatPerplexity({ apiKey: llmKeys.perplexity || process.env.PPLX_API_KEY || "", model: provider.model }); case COHERE: if (responseModel) { return new ChatCohere({ apiKey: process.env.COHERE_API_KEY || llmKeys.cohere, model: provider.model }).withStructuredOutput(responseModel); } return new ChatCohere({ apiKey: process.env.COHERE_API_KEY || llmKeys.cohere, model: provider.model }); case TOGETHERAI: if (responseModel) { return new ChatTogetherAI({ apiKey: process.env.TOGETHERAI_API_KEY || llmKeys.togetherai, model: getTogetheraiModel(provider.model) }).withStructuredOutput(responseModel); } return new ChatTogetherAI({ apiKey: process.env.TOGETHERAI_API_KEY || llmKeys.togetherai, model: getTogetheraiModel(provider.model) }); default: throw new Error(`Unsupported provider: ${provider.provider}`); } } const getTogetheraiModel = (model) => { if (model === SupportedModel.MISTRAL_7B_INSTRUCT_V0_2 || model === SupportedModel.MIXTRAL_8X7B_INSTRUCT_V0_1 || model === SupportedModel.MIXTRAL_8X22B_INSTRUCT_V0_1) { return `mistralai/${model}`; } if (model === SupportedModel.LLAMA_3_70B_CHAT_HF || model === SupportedModel.LLAMA_3_8B_CHAT_HF || model === SupportedModel.LLAMA_3_1_8B_INSTRUCT_TURBO || model === SupportedModel.LLAMA_3_1_70B_INSTRUCT_TURBO || model === SupportedModel.LLAMA_3_1_405B_INSTRUCT_TURBO) { return `meta-llama/${model}`; } if (model === SupportedModel.QWEN2_72B_INSTRUCT) { return `Qwen/${model}`; } return model; }; async function callLLM(provider, options, llmKeys, runtimeArgs) { const model = getLangChainModel(provider, llmKeys, options.responseModel); const langChainMessages = extendProviderSystemPrompt( options.messages.map(convertToLangChainMessage), options, provider ); const response = await model.invoke(langChainMessages, runtimeArgs); return extractContent(response); } function extendProviderSystemPrompt(messages, options, provider) { const matchingProvider = options.llmProviders.find( (p) => p.provider === provider.provider && p.model === provider.model ); if (matchingProvider && matchingProvider.systemPrompt) { messages.unshift(new SystemMessage(matchingProvider.systemPrompt)); } return messages; } function convertToLangChainMessage(msg) { switch (msg.role) { case "user": return new HumanMessage(msg.content); case "assistant": return new AIMessage(msg.content); case "system": return new SystemMessage(msg.content); default: return new HumanMessage(msg.content); } } async function* callLLMStream(provider, options, llmKeys, runtimeArgs) { const model = getLangChainModel(provider, llmKeys, options.responseModel); const langChainMessages = extendProviderSystemPrompt( options.messages.map(convertToLangChainMessage), options, provider ); const stream = await model.stream(langChainMessages, runtimeArgs); for await (const chunk of stream) { yield extractContent(chunk); } } function extractContent(response) { if ("content" in response) { return typeof response.content === "string" ? response.content : JSON.stringify(response.content); } return typeof response === "string" ? response : JSON.stringify(response); } const SDK_VERSION = packageJson.version; dotenv.config(); const DEFAULT_TIMEOUT = 5; const BASE_URL = "https://api.notdiamond.ai"; class NotDiamond { apiKey; apiUrl; modelSelectUrl; feedbackUrl; createUrl; llmKeys; constructor(options = {}) { this.apiKey = options.apiKey || process.env.NOTDIAMOND_API_KEY || ""; this.apiUrl = options.apiUrl || process.env.NOTDIAMOND_API_URL || BASE_URL; this.llmKeys = options.llmKeys || {}; this.modelSelectUrl = `${this.apiUrl}/v2/modelRouter/modelSelect`; this.feedbackUrl = `${this.apiUrl}/v2/report/metrics/feedback`; this.createUrl = `${this.apiUrl}/v2/preferences/userPreferenceCreate`; } getAuthHeader() { return `Bearer ${this.apiKey}`; } async postRequest(url, body) { try { const response = await axios.post(url, body, { headers: { Authorization: this.getAuthHeader(), Accept: "application/json", "Content-Type": "application/json", "User-Agent": `TS-SDK/${SDK_VERSION}` } }); return response.data; } catch (error) { if (axios.isAxiosError(error) && error.response) { return { detail: "An error occurred." }; } console.error("error", error); return { detail: "An unexpected error occurred." }; } } /** * Selects the best model for the given messages. * @param options The options for the model. * @returns The results of the model. */ async modelSelect(options) { const requestBody = { messages: options.messages, llm_providers: options.llmProviders.map((provider) => ({ provider: provider.provider, model: provider.model, ...provider.contextLength !== void 0 && { context_length: provider.contextLength }, ...provider.customInputPrice !== void 0 && { input_price: provider.customInputPrice }, ...provider.inputPrice !== void 0 && { input_price: provider.inputPrice }, ...provider.customOutputPrice !== void 0 && { output_price: provider.customOutputPrice }, ...provider.outputPrice !== void 0 && { output_price: provider.outputPrice }, ...provider.customLatency !== void 0 && { latency: provider.customLatency }, ...provider.latency !== void 0 && { latency: provider.latency }, ...provider.isCustom !== void 0 && { is_custom: provider.isCustom } })), ...options.tradeoff && { tradeoff: options.tradeoff }, ...options.maxModelDepth && { max_model_depth: options.maxModelDepth }, ...options.tools && { tools: options.tools }, ...options.hashContent !== void 0 && { hash_content: options.hashContent }, ...options.preferenceId && { preference_id: options.preferenceId }, ...options.timeout ? { timeout: options.timeout } : { timeout: DEFAULT_TIMEOUT }, ...options.default && { default: options.default }, ...options.previousSession && { previous_session: options.previousSession }, ...options.responseModel && { response_model: options.responseModel } }; return this.postRequest( this.modelSelectUrl, requestBody ); } /** * Sends feedback to the NotDiamond API. * @param options The options for the feedback. * @returns The results of the feedback. */ async feedback(options) { return this.postRequest(this.feedbackUrl, { session_id: options.sessionId, feedback: options.feedback, provider: options.provider }); } /** * Creates a preference id. * @returns The preference id. */ async createPreferenceId() { const response = await this.postRequest( this.createUrl, {} ); if ("preference_id" in response) { return response.preference_id; } throw new Error("Invalid response: preference_id not found"); } /** * * @param options The options for the model. * @returns A promise that resolves to the results of the model. */ async acreate(options, runtimeArgs = {}) { const selectedModel = await this.modelSelect(options); const { providers } = selectedModel; const content = await callLLM( providers[0], options, this.llmKeys, runtimeArgs ); return { content, providers }; } /** * * @param options The options for the model. * @param callback Optional callback function to handle the result. * @returns A promise that resolves to the results of the model or a callback function */ create(options, runtimeArgs = {}, callback) { const promise = this.acreate(options, runtimeArgs); if (callback) { promise.then((result) => callback(null, result)).catch((error) => callback(error)); } else { return promise; } } /** * Streams the results of the model asynchronously. * @param options The options for the model. * @returns A promise that resolves to an object containing the provider and an AsyncIterable of strings. */ async astream(options, runtimeArgs = {}) { const selectedModel = await this.modelSelect(options); const { providers } = selectedModel; const stream = await Promise.resolve( callLLMStream( providers?.[0] || options.default, options, this.llmKeys, runtimeArgs ) ); return { provider: providers?.[0] || options.default, stream }; } /** * Streams the results of the model. * @param options The options for the model. * @param callback Optional callback function to handle each chunk of the stream. * @returns A promise that resolves to an object containing the provider and an AsyncIterable of strings or a callback function */ stream(options, runtimeArgs = {}, callback) { if (!options.llmProviders || options.llmProviders.length === 0) { throw new Error("No LLM providers specified"); } const promise = this.astream(options, runtimeArgs); if (callback) { promise.then(async ({ provider, stream }) => { for await (const chunk of stream) { callback(null, { provider, chunk }); } }).catch((error) => callback(error)); } else { return promise; } } } export { NotDiamond, SupportedModel, SupportedProvider };