UNPKG

mars-llm

Version:

Node.js client for Azure OpenAI using credentials from Azure Key Vault with streaming support

209 lines (180 loc) 7.26 kB
const { SecretClient } = require('@azure/keyvault-secrets'); const { DefaultAzureCredential } = require('@azure/identity'); const { AzureOpenAI } = require('openai'); class AzureOpenAIClient { constructor(keyVaultUrl = null) { this.keyVaultUrl = keyVaultUrl || process.env.AZURE_KEY_VAULT_URL; if (!this.keyVaultUrl) { throw new Error('AZURE_KEY_VAULT_URL environment variable is required or must be provided via constructor'); } this.secretClient = null; this.openaiClient = null; this.credentials = {}; } /** * Initialize Azure credentials and Key Vault client */ async initializeCredentials() { try { console.log('Using Default Azure Credential for Azure Key Vault'); // Use DefaultAzureCredential with excluded credential types // This is similar to the Python approach with exclude_visual_studio_code_credential=True const credential = new DefaultAzureCredential({ excludeVsCodeCredential: true, excludeSharedTokenCacheCredential: true }); this.secretClient = new SecretClient(this.keyVaultUrl, credential); console.log(`Initialized Key Vault client for: ${this.keyVaultUrl}`); } catch (error) { throw new Error(`Failed to initialize Azure credentials: ${error.message}`); } } /** * Retrieve secrets from Azure Key Vault */ async getSecretsFromKeyVault() { if (!this.secretClient) { await this.initializeCredentials(); } try { console.log('Retrieving secrets from Key Vault...'); const [apiKeySecret, deployment, endpointSecret, apiVersionSecret] = await Promise.all([ this.secretClient.getSecret('MARS-API-KEY'), this.secretClient.getSecret('MARS-DEPLOYMENT'), this.secretClient.getSecret('MARS-ENDPOINT'), this.secretClient.getSecret('MARS-API-VERSION') ]); this.credentials = { apiKey: apiKeySecret.value, deployment: deployment.value, endpoint: endpointSecret.value, apiVersion: apiVersionSecret.value }; console.log('Successfully retrieved all secrets from Key Vault'); console.log(`Endpoint: ${this.credentials.endpoint}`); console.log(`Deployment: ${this.credentials.deployment}`); console.log(`API Version: ${this.credentials.apiVersion}`); return this.credentials; } catch (error) { throw new Error(`Failed to retrieve secrets from Key Vault: ${error.message}`); } } /** * Initialize Azure OpenAI client */ async initializeOpenAIClient() { if (!this.credentials.apiKey) { await this.getSecretsFromKeyVault(); } try { this.openaiClient = new AzureOpenAI({ apiKey: this.credentials.apiKey, deployment: this.credentials.deployment, endpoint: this.credentials.endpoint, apiVersion: this.credentials.apiVersion, }); console.log('Azure OpenAI client initialized successfully'); } catch (error) { throw new Error(`Failed to initialize Azure OpenAI client: ${error.message}`); } } /** * Chat completion without streaming * @param {string} prompt - The user prompt * @param {string} text - Additional text content * @param {Object} options - Additional options */ async chatCompletion(prompt, text = '', options = {}) { if (!this.openaiClient) { await this.initializeOpenAIClient(); } const defaultOptions = { model: 'gpt-4o', maxTokens: 50, temperature: 0.7, ...options }; try { const response = await this.openaiClient.chat.completions.create({ model: defaultOptions.model, messages: [ { role: 'user', content: `${prompt}:\n ${text}` } ], max_tokens: defaultOptions.maxTokens, temperature: defaultOptions.temperature }); return { content: response.choices[0].message.content, usage: response.usage, model: response.model, finishReason: response.choices[0].finish_reason }; } catch (error) { throw new Error(`Chat completion failed: ${error.message}`); } } /** * Chat completion with streaming * @param {string} prompt - The user prompt * @param {string} text - Additional text content * @param {Function} onChunk - Callback function for each chunk * @param {Object} options - Additional options */ async chatCompletionStream(prompt, text = '', onChunk = null, options = {}) { if (!this.openaiClient) { await this.initializeOpenAIClient(); } const defaultOptions = { model: 'gpt-4o', maxTokens: 1000, temperature: 0.7, ...options }; try { const stream = await this.openaiClient.chat.completions.create({ model: defaultOptions.model, messages: [ { role: 'user', content: `${prompt}:\n ${text}` } ], max_tokens: defaultOptions.maxTokens, temperature: defaultOptions.temperature, stream: true }); let fullContent = ''; for await (const chunk of stream) { const delta = chunk.choices[0]?.delta; if (delta?.content) { fullContent += delta.content; // Call the callback function if provided if (onChunk && typeof onChunk === 'function') { onChunk(delta.content, fullContent); } } } return { content: fullContent, model: defaultOptions.model }; } catch (error) { throw new Error(`Streaming chat completion failed: ${error.message}`); } } /** * Simple chat method that combines prompt and text * @param {string} message - The complete message to send * @param {Object} options - Additional options */ async chat(message, options = {}) { return await this.chatCompletion('', message, options); } /** * Simple streaming chat method * @param {string} message - The complete message to send * @param {Function} onChunk - Callback function for each chunk * @param {Object} options - Additional options */ async chatStream(message, onChunk = null, options = {}) { return await this.chatCompletionStream('', message, onChunk, options); } } module.exports = AzureOpenAIClient;