@llumiverse/drivers
Version:
LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.
92 lines (76 loc) • 3.42 kB
text/typescript
import { DefaultAzureCredential, getBearerTokenProvider } from "@azure/identity";
import { AIModel, DriverOptions, getModelCapabilities, modelModalitiesToArray, Providers } from "@llumiverse/core";
import OpenAI, { AzureOpenAI } from "openai";
import { BaseOpenAIDriver } from "./index.js";
export interface AzureOpenAIDriverOptions extends DriverOptions {
/**
* The credentials to use to access Azure OpenAI
*/
azureADTokenProvider?: any; //type with azure credentials
apiKey?: string;
endpoint?: string;
apiVersion?: string
deployment?: string;
}
export class AzureOpenAIDriver extends BaseOpenAIDriver {
service: AzureOpenAI;
readonly provider = Providers.azure_openai;
//Overload to allow independent instantiation with AzureOpenAI service
constructor(serviceOrOpts: AzureOpenAI | AzureOpenAIDriverOptions) {
if (serviceOrOpts instanceof AzureOpenAI) {
super({});
this.service = serviceOrOpts;
return;
}
const opts = serviceOrOpts ?? {};
super(opts);
if (!opts.azureADTokenProvider && !opts.apiKey) {
opts.azureADTokenProvider = this.getDefaultCognitiveServicesAuth();
}
this.service = new AzureOpenAI({
apiKey: opts.apiKey,
azureADTokenProvider: opts.azureADTokenProvider,
endpoint: opts.endpoint,
apiVersion: opts.apiVersion ?? "2024-10-21",
deployment: opts.deployment
});
}
/**
* Get default authentication for Azure Cognitive Services API
*/
getDefaultCognitiveServicesAuth() {
const scope = "https://cognitiveservices.azure.com/.default";
const azureADTokenProvider = getBearerTokenProvider(new DefaultAzureCredential(), scope);
return azureADTokenProvider;
}
async listModels(): Promise<AIModel[]> {
return this._listModels();
}
async _listModels(_filter?: (m: OpenAI.Models.Model) => boolean): Promise<AIModel[]> {
if (!this.service.deploymentName) {
throw new Error("A specific deployment is not set. Azure OpenAI cannot list deployments. Update your endpoint URL to include the deployment name, e.g., https://your-resource.openai.azure.com/openai/deployments/your-deployment/chat/completions");
}
//Do a test execution to check if the model works and to get the model ID.
let modelID = this.service.deploymentName;
try {
const testResponse = await this.service.chat.completions.create({
model: this.service.deploymentName,
messages: [{ role: "user", content: "Hi" }],
max_tokens: 1,
});
modelID = testResponse.model;
} catch (error) {
this.logger.error("Failed to test model for Azure OpenAI listing :", error);
}
const modelCapability = getModelCapabilities(modelID, "openai");
return [{
id: modelID,
name: this.service.deploymentName,
provider: this.provider,
owner: "openai",
input_modalities: modelModalitiesToArray(modelCapability.input),
output_modalities: modelModalitiesToArray(modelCapability.output),
tool_support: modelCapability.tool_support,
} satisfies AIModel<string>];
}
}