UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

92 lines (76 loc) 3.42 kB
import { DefaultAzureCredential, getBearerTokenProvider } from "@azure/identity"; import { AIModel, DriverOptions, getModelCapabilities, modelModalitiesToArray, Providers } from "@llumiverse/core"; import OpenAI, { AzureOpenAI } from "openai"; import { BaseOpenAIDriver } from "./index.js"; export interface AzureOpenAIDriverOptions extends DriverOptions { /** * The credentials to use to access Azure OpenAI */ azureADTokenProvider?: any; //type with azure credentials apiKey?: string; endpoint?: string; apiVersion?: string deployment?: string; } export class AzureOpenAIDriver extends BaseOpenAIDriver { service: AzureOpenAI; readonly provider = Providers.azure_openai; //Overload to allow independent instantiation with AzureOpenAI service constructor(serviceOrOpts: AzureOpenAI | AzureOpenAIDriverOptions) { if (serviceOrOpts instanceof AzureOpenAI) { super({}); this.service = serviceOrOpts; return; } const opts = serviceOrOpts ?? {}; super(opts); if (!opts.azureADTokenProvider && !opts.apiKey) { opts.azureADTokenProvider = this.getDefaultCognitiveServicesAuth(); } this.service = new AzureOpenAI({ apiKey: opts.apiKey, azureADTokenProvider: opts.azureADTokenProvider, endpoint: opts.endpoint, apiVersion: opts.apiVersion ?? "2024-10-21", deployment: opts.deployment }); } /** * Get default authentication for Azure Cognitive Services API */ getDefaultCognitiveServicesAuth() { const scope = "https://cognitiveservices.azure.com/.default"; const azureADTokenProvider = getBearerTokenProvider(new DefaultAzureCredential(), scope); return azureADTokenProvider; } async listModels(): Promise<AIModel[]> { return this._listModels(); } async _listModels(_filter?: (m: OpenAI.Models.Model) => boolean): Promise<AIModel[]> { if (!this.service.deploymentName) { throw new Error("A specific deployment is not set. Azure OpenAI cannot list deployments. Update your endpoint URL to include the deployment name, e.g., https://your-resource.openai.azure.com/openai/deployments/your-deployment/chat/completions"); } //Do a test execution to check if the model works and to get the model ID. let modelID = this.service.deploymentName; try { const testResponse = await this.service.chat.completions.create({ model: this.service.deploymentName, messages: [{ role: "user", content: "Hi" }], max_tokens: 1, }); modelID = testResponse.model; } catch (error) { this.logger.error("Failed to test model for Azure OpenAI listing :", error); } const modelCapability = getModelCapabilities(modelID, "openai"); return [{ id: modelID, name: this.service.deploymentName, provider: this.provider, owner: "openai", input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, } satisfies AIModel<string>]; } }