mars-llm
Version:
Node.js client for Azure OpenAI using credentials from Azure Key Vault with streaming support
209 lines (180 loc) • 7.26 kB
JavaScript
const { SecretClient } = require('@azure/keyvault-secrets');
const { DefaultAzureCredential } = require('@azure/identity');
const { AzureOpenAI } = require('openai');
class AzureOpenAIClient {
constructor(keyVaultUrl = null) {
this.keyVaultUrl = keyVaultUrl || process.env.AZURE_KEY_VAULT_URL;
if (!this.keyVaultUrl) {
throw new Error('AZURE_KEY_VAULT_URL environment variable is required or must be provided via constructor');
}
this.secretClient = null;
this.openaiClient = null;
this.credentials = {};
}
/**
* Initialize Azure credentials and Key Vault client
*/
async initializeCredentials() {
try {
console.log('Using Default Azure Credential for Azure Key Vault');
// Use DefaultAzureCredential with excluded credential types
// This is similar to the Python approach with exclude_visual_studio_code_credential=True
const credential = new DefaultAzureCredential({
excludeVsCodeCredential: true,
excludeSharedTokenCacheCredential: true
});
this.secretClient = new SecretClient(this.keyVaultUrl, credential);
console.log(`Initialized Key Vault client for: ${this.keyVaultUrl}`);
} catch (error) {
throw new Error(`Failed to initialize Azure credentials: ${error.message}`);
}
}
/**
* Retrieve secrets from Azure Key Vault
*/
async getSecretsFromKeyVault() {
if (!this.secretClient) {
await this.initializeCredentials();
}
try {
console.log('Retrieving secrets from Key Vault...');
const [apiKeySecret, deployment, endpointSecret, apiVersionSecret] = await Promise.all([
this.secretClient.getSecret('MARS-API-KEY'),
this.secretClient.getSecret('MARS-DEPLOYMENT'),
this.secretClient.getSecret('MARS-ENDPOINT'),
this.secretClient.getSecret('MARS-API-VERSION')
]);
this.credentials = {
apiKey: apiKeySecret.value,
deployment: deployment.value,
endpoint: endpointSecret.value,
apiVersion: apiVersionSecret.value
};
console.log('Successfully retrieved all secrets from Key Vault');
console.log(`Endpoint: ${this.credentials.endpoint}`);
console.log(`Deployment: ${this.credentials.deployment}`);
console.log(`API Version: ${this.credentials.apiVersion}`);
return this.credentials;
} catch (error) {
throw new Error(`Failed to retrieve secrets from Key Vault: ${error.message}`);
}
}
/**
* Initialize Azure OpenAI client
*/
async initializeOpenAIClient() {
if (!this.credentials.apiKey) {
await this.getSecretsFromKeyVault();
}
try {
this.openaiClient = new AzureOpenAI({
apiKey: this.credentials.apiKey,
deployment: this.credentials.deployment,
endpoint: this.credentials.endpoint,
apiVersion: this.credentials.apiVersion,
});
console.log('Azure OpenAI client initialized successfully');
} catch (error) {
throw new Error(`Failed to initialize Azure OpenAI client: ${error.message}`);
}
}
/**
* Chat completion without streaming
* @param {string} prompt - The user prompt
* @param {string} text - Additional text content
* @param {Object} options - Additional options
*/
async chatCompletion(prompt, text = '', options = {}) {
if (!this.openaiClient) {
await this.initializeOpenAIClient();
}
const defaultOptions = {
model: 'gpt-4o',
maxTokens: 50,
temperature: 0.7,
...options
};
try {
const response = await this.openaiClient.chat.completions.create({
model: defaultOptions.model,
messages: [
{ role: 'user', content: `${prompt}:\n ${text}` }
],
max_tokens: defaultOptions.maxTokens,
temperature: defaultOptions.temperature
});
return {
content: response.choices[0].message.content,
usage: response.usage,
model: response.model,
finishReason: response.choices[0].finish_reason
};
} catch (error) {
throw new Error(`Chat completion failed: ${error.message}`);
}
}
/**
* Chat completion with streaming
* @param {string} prompt - The user prompt
* @param {string} text - Additional text content
* @param {Function} onChunk - Callback function for each chunk
* @param {Object} options - Additional options
*/
async chatCompletionStream(prompt, text = '', onChunk = null, options = {}) {
if (!this.openaiClient) {
await this.initializeOpenAIClient();
}
const defaultOptions = {
model: 'gpt-4o',
maxTokens: 1000,
temperature: 0.7,
...options
};
try {
const stream = await this.openaiClient.chat.completions.create({
model: defaultOptions.model,
messages: [
{ role: 'user', content: `${prompt}:\n ${text}` }
],
max_tokens: defaultOptions.maxTokens,
temperature: defaultOptions.temperature,
stream: true
});
let fullContent = '';
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta;
if (delta?.content) {
fullContent += delta.content;
// Call the callback function if provided
if (onChunk && typeof onChunk === 'function') {
onChunk(delta.content, fullContent);
}
}
}
return {
content: fullContent,
model: defaultOptions.model
};
} catch (error) {
throw new Error(`Streaming chat completion failed: ${error.message}`);
}
}
/**
* Simple chat method that combines prompt and text
* @param {string} message - The complete message to send
* @param {Object} options - Additional options
*/
async chat(message, options = {}) {
return await this.chatCompletion('', message, options);
}
/**
* Simple streaming chat method
* @param {string} message - The complete message to send
* @param {Function} onChunk - Callback function for each chunk
* @param {Object} options - Additional options
*/
async chatStream(message, onChunk = null, options = {}) {
return await this.chatCompletionStream('', message, onChunk, options);
}
}
module.exports = AzureOpenAIClient;