@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
388 lines • 17 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.AzureFoundryDriver = void 0;
exports.parseAzureFoundryModelId = parseAzureFoundryModelId;
exports.isCompositeModelId = isCompositeModelId;
const identity_1 = require("@azure/identity");
const core_1 = require("@llumiverse/core");
const ai_projects_1 = require("@azure/ai-projects");
const ai_inference_1 = require("@azure-rest/ai-inference");
const azure_openai_js_1 = require("../openai/azure_openai.js");
const core_sse_1 = require("@azure/core-sse");
const openai_format_js_1 = require("../openai/openai_format.js");
class AzureFoundryDriver extends core_1.AbstractDriver {
service;
provider = core_1.Providers.azure_foundry;
OPENAI_API_VERSION = "2025-01-01-preview";
INFERENCE_API_VERSION = "2024-05-01-preview";
constructor(opts) {
super(opts);
this.formatPrompt = openai_format_js_1.formatOpenAILikeMultimodalPrompt;
if (!opts.endpoint) {
throw new Error("Azure AI Foundry endpoint is required");
}
try {
if (!opts.azureADTokenProvider) {
// Using Microsoft Entra ID (Azure AD) for authentication
opts.azureADTokenProvider = new identity_1.DefaultAzureCredential();
}
}
catch (error) {
this.logger.error({ error }, "Failed to initialize Azure AD token provider:");
throw new Error("Failed to initialize Azure AD token provider");
}
// Initialize AI Projects client which provides access to inference operations
this.service = new ai_projects_1.AIProjectClient(opts.endpoint, opts.azureADTokenProvider);
if (opts.apiVersion) {
this.OPENAI_API_VERSION = opts.apiVersion;
this.INFERENCE_API_VERSION = opts.apiVersion;
this.logger.info(`[Azure Foundry] Overriding default API version, using API version: ${opts.apiVersion}`);
}
}
/**
* Get default authentication for Azure AI Foundry API
*/
getDefaultAIFoundryAuth() {
const scope = "https://ai.azure.com/.default";
const azureADTokenProvider = (0, identity_1.getBearerTokenProvider)(new identity_1.DefaultAzureCredential(), scope);
return azureADTokenProvider;
}
async isOpenAIDeployment(model) {
const { deploymentName } = parseAzureFoundryModelId(model);
let deployment = undefined;
// First, verify the deployment exists
try {
deployment = await this.service.deployments.get(deploymentName);
this.logger.debug(`[Azure Foundry] Deployment ${deploymentName} found`);
}
catch (deploymentError) {
this.logger.error({ deploymentError }, `[Azure Foundry] Deployment ${deploymentName} not found:`);
}
return deployment.modelPublisher == "OpenAI";
}
canStream(_options) {
return Promise.resolve(true);
}
async requestTextCompletion(prompt, options) {
const { deploymentName } = parseAzureFoundryModelId(options.model);
const model_options = options.model_options;
const isOpenAI = await this.isOpenAIDeployment(options.model);
let response;
if (isOpenAI) {
// Use the Azure OpenAI client for OpenAI models
const azureOpenAI = await this.service.inference.azureOpenAI({ apiVersion: this.OPENAI_API_VERSION });
const subDriver = new azure_openai_js_1.AzureOpenAIDriver(azureOpenAI);
// Use deployment name for API calls
const modifiedOptions = { ...options, model: deploymentName };
const response = await subDriver.requestTextCompletion(prompt, modifiedOptions);
return response;
}
else {
// Use the chat completions client from the inference operations
const chatClient = this.service.inference.chatCompletions({ apiVersion: this.INFERENCE_API_VERSION });
response = await chatClient.post({
body: {
messages: prompt,
max_tokens: model_options?.max_tokens,
model: deploymentName,
stream: true,
temperature: model_options?.temperature,
top_p: model_options?.top_p,
frequency_penalty: model_options?.frequency_penalty,
presence_penalty: model_options?.presence_penalty,
stop: model_options?.stop_sequence,
}
});
if (response.status !== "200") {
this.logger.error({ response }, `[Azure Foundry] Chat completion request failed:`);
throw new Error(`Chat completion request failed with status ${response.status}: ${response.body}`);
}
return this.extractDataFromResponse(response.body);
}
}
async requestTextCompletionStream(prompt, options) {
const { deploymentName } = parseAzureFoundryModelId(options.model);
const model_options = options.model_options;
const isOpenAI = await this.isOpenAIDeployment(options.model);
if (isOpenAI) {
const azureOpenAI = await this.service.inference.azureOpenAI({ apiVersion: this.OPENAI_API_VERSION });
const subDriver = new azure_openai_js_1.AzureOpenAIDriver(azureOpenAI);
const modifiedOptions = { ...options, model: deploymentName };
const stream = await subDriver.requestTextCompletionStream(prompt, modifiedOptions);
return stream;
}
else {
const chatClient = this.service.inference.chatCompletions({ apiVersion: this.INFERENCE_API_VERSION });
const response = await chatClient.post({
body: {
messages: prompt,
max_tokens: model_options?.max_tokens,
model: deploymentName,
stream: true,
temperature: model_options?.temperature,
top_p: model_options?.top_p,
frequency_penalty: model_options?.frequency_penalty,
presence_penalty: model_options?.presence_penalty,
stop: model_options?.stop_sequence,
}
}).asNodeStream();
// We type assert from NodeJS.ReadableStream to NodeJSReadableStream
// The Azure Examples, expect a .destroy() method on the stream
const stream = response.body;
if (!stream) {
throw new Error("The response stream is undefined");
}
if (response.status !== "200") {
stream.destroy();
throw new Error(`Failed to get chat completions, http operation failed with ${response.status} code`);
}
const sseStream = (0, core_sse_1.createSseStream)(stream);
return this.processStreamResponse(sseStream);
}
}
async *processStreamResponse(sseStream) {
try {
for await (const event of sseStream) {
if (event.data === "[DONE]") {
break;
}
try {
const data = JSON.parse(event.data);
if (!data) {
this.logger.warn(`[Azure Foundry] Received empty data in streaming response`);
continue;
}
const choice = data.choices?.[0];
if (!choice) {
continue;
}
const chunk = {
result: choice.delta?.content || "",
finish_reason: this.convertFinishReason(choice.finish_reason),
token_usage: {
prompt: data.usage?.prompt_tokens,
result: data.usage?.completion_tokens,
total: data.usage?.total_tokens,
},
};
yield chunk;
}
catch (parseError) {
this.logger.warn({ parseError }, `[Azure Foundry] Failed to parse streaming response:`);
continue;
}
}
}
catch (error) {
this.logger.error({ error }, `[Azure Foundry] Streaming error:`);
throw error;
}
}
extractDataFromResponse(result) {
const tokenInfo = {
prompt: result.usage?.prompt_tokens,
result: result.usage?.completion_tokens,
total: result.usage?.total_tokens,
};
const choice = result.choices?.[0];
if (!choice) {
this.logger.error({ result }, "[Azure Foundry] No choices in response");
throw new Error("No choices in response");
}
const data = choice.message?.content;
const toolCalls = choice.message?.tool_calls;
if (!data && !toolCalls) {
this.logger.error({ result }, "[Azure Foundry] Response is not valid");
throw new Error("Response is not valid: no content or tool calls");
}
const completion = {
result: data ? [{ type: "text", value: data }] : [],
token_usage: tokenInfo,
finish_reason: this.convertFinishReason(choice.finish_reason),
};
if (toolCalls && toolCalls.length > 0) {
completion.tool_use = toolCalls.map((call) => ({
id: call.id,
tool_name: call.function?.name,
tool_input: call.function?.arguments ? JSON.parse(call.function.arguments) : {}
}));
}
return completion;
}
convertFinishReason(reason) {
if (!reason)
return undefined;
// Map Azure AI finish reasons to standard format
switch (reason) {
case 'stop': return 'stop';
case 'length': return 'length';
case 'tool_calls': return 'tool_use';
default: return reason;
}
}
async validateConnection() {
try {
// Test the AI Projects client by listing deployments
const deploymentsIterable = this.service.deployments.list();
let hasDeployments = false;
for await (const deployment of deploymentsIterable) {
hasDeployments = true;
this.logger.debug(`[Azure Foundry] Found deployment: ${deployment.name} (${deployment.type})`);
break; // Just check if we can get at least one deployment
}
if (!hasDeployments) {
this.logger.warn("[Azure Foundry] No deployments found in the project");
}
return true;
}
catch (error) {
this.logger.error({ error }, "Azure Foundry connection validation failed:");
return false;
}
}
async generateEmbeddings(options) {
if (!options.model) {
throw new Error("Default embedding model selection not supported for Azure Foundry. Please specify a model.");
}
if (options.text) {
return this.generateTextEmbeddings(options);
}
else if (options.image) {
return this.generateImageEmbeddings(options);
}
else {
throw new Error("No text or images provided for embeddings");
}
}
async generateTextEmbeddings(options) {
if (!options.text) {
throw new Error("No text provided for text embeddings");
}
const { deploymentName } = parseAzureFoundryModelId(options.model || "");
let response;
try {
// Use the embeddings client from the inference operations
const embeddingsClient = this.service.inference.embeddings({ apiVersion: this.INFERENCE_API_VERSION });
response = await embeddingsClient.post({
body: {
input: Array.isArray(options.text) ? options.text : [options.text],
model: deploymentName
}
});
}
catch (error) {
this.logger.error({ error }, "Azure Foundry text embeddings error:");
throw error;
}
if ((0, ai_inference_1.isUnexpected)(response)) {
throw new Error(`Text embeddings request failed: ${response.status} ${response.body?.error?.message || 'Unknown error'}`);
}
const embeddings = response.body.data?.[0]?.embedding;
if (!embeddings || !Array.isArray(embeddings) || embeddings.length === 0) {
throw new Error("No valid embedding array found in response");
}
return {
values: embeddings,
model: options.model ?? ""
};
}
async generateImageEmbeddings(options) {
if (!options.image) {
throw new Error("No images provided for image embeddings");
}
const { deploymentName } = parseAzureFoundryModelId(options.model || "");
let response;
try {
// Use the embeddings client from the inference operations
const embeddingsClient = this.service.inference.embeddings({ apiVersion: this.INFERENCE_API_VERSION });
response = await embeddingsClient.post({
body: {
input: Array.isArray(options.image) ? options.image : [options.image],
model: deploymentName
}
});
}
catch (error) {
this.logger.error({ error }, "Azure Foundry image embeddings error:");
throw error;
}
if ((0, ai_inference_1.isUnexpected)(response)) {
throw new Error(`Image embeddings request failed: ${response.status} ${response.body?.error?.message || 'Unknown error'}`);
}
const embeddings = response.body.data?.[0]?.embedding;
if (!embeddings || !Array.isArray(embeddings) || embeddings.length === 0) {
throw new Error("No valid embedding array found in response");
}
return {
values: embeddings,
model: options.model ?? ""
};
}
async listModels() {
const filter = (m) => {
// Only include models that support chat completions
return !!m.capabilities.chat_completion;
};
return this._listModels(filter);
}
async _listModels(filter) {
let deploymentsIterable;
try {
// List all deployments in the Azure AI Foundry project
deploymentsIterable = this.service.deployments.list();
}
catch (error) {
this.logger.error({ error }, "Failed to list deployments:");
throw new Error("Failed to list deployments in Azure AI Foundry project");
}
const deployments = [];
for await (const page of deploymentsIterable.byPage()) {
for (const deployment of page) {
deployments.push(deployment);
}
}
let modelDeployments = deployments.filter((d) => {
return d.type === "ModelDeployment";
});
if (filter) {
modelDeployments = modelDeployments.filter(filter);
}
const aiModels = modelDeployments.map((model) => {
// Create composite ID: deployment_name::base_model
const compositeId = `${model.name}::${model.modelName}`;
const modelCapability = (0, core_1.getModelCapabilities)(model.modelName, core_1.Providers.azure_foundry);
return {
id: compositeId,
name: model.name,
description: `${model.modelName} - ${model.modelVersion}`,
version: model.modelVersion,
provider: this.provider,
owner: model.modelPublisher,
input_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.input),
output_modalities: (0, core_1.modelModalitiesToArray)(modelCapability.output),
tool_support: modelCapability.tool_support,
};
}).sort((modelA, modelB) => modelA.id.localeCompare(modelB.id));
return aiModels;
}
}
exports.AzureFoundryDriver = AzureFoundryDriver;
// Helper functions to parse the composite ID
function parseAzureFoundryModelId(compositeId) {
const parts = compositeId.split('::');
if (parts.length === 2) {
return {
deploymentName: parts[0],
baseModel: parts[1]
};
}
// Backwards compatibility: if no delimiter found, treat as deployment name
return {
deploymentName: compositeId,
baseModel: compositeId
};
}
function isCompositeModelId(modelId) {
return modelId.includes('::');
}
//# sourceMappingURL=azure_foundry.js.map