UNPKG

@llumiverse/drivers

Version:

LLM driver implementations. Currently supported are: openai, huggingface, bedrock, replicate.

68 lines 2.96 kB
import { DefaultAzureCredential, getBearerTokenProvider } from "@azure/identity"; import { getModelCapabilities, modelModalitiesToArray, Providers } from "@llumiverse/core"; import { AzureOpenAI } from "openai"; import { BaseOpenAIDriver } from "./index.js"; export class AzureOpenAIDriver extends BaseOpenAIDriver { service; provider = Providers.azure_openai; //Overload to allow independent instantiation with AzureOpenAI service constructor(serviceOrOpts) { if (serviceOrOpts instanceof AzureOpenAI) { super({}); this.service = serviceOrOpts; return; } const opts = serviceOrOpts ?? {}; super(opts); if (!opts.azureADTokenProvider && !opts.apiKey) { opts.azureADTokenProvider = this.getDefaultCognitiveServicesAuth(); } this.service = new AzureOpenAI({ apiKey: opts.apiKey, azureADTokenProvider: opts.azureADTokenProvider, endpoint: opts.endpoint, apiVersion: opts.apiVersion ?? "2024-10-21", deployment: opts.deployment }); } /** * Get default authentication for Azure Cognitive Services API */ getDefaultCognitiveServicesAuth() { const scope = "https://cognitiveservices.azure.com/.default"; const azureADTokenProvider = getBearerTokenProvider(new DefaultAzureCredential(), scope); return azureADTokenProvider; } async listModels() { return this._listModels(); } async _listModels(_filter) { if (!this.service.deploymentName) { throw new Error("A specific deployment is not set. Azure OpenAI cannot list deployments. Update your endpoint URL to include the deployment name, e.g., https://your-resource.openai.azure.com/openai/deployments/your-deployment/chat/completions"); } //Do a test execution to check if the model works and to get the model ID. let modelID = this.service.deploymentName; try { const testResponse = await this.service.chat.completions.create({ model: this.service.deploymentName, messages: [{ role: "user", content: "Hi" }], max_tokens: 1, }); modelID = testResponse.model; } catch (error) { this.logger.error({ error }, "Failed to test model for Azure OpenAI listing :"); } const modelCapability = getModelCapabilities(modelID, "openai"); return [{ id: modelID, name: this.service.deploymentName, provider: this.provider, owner: "openai", input_modalities: modelModalitiesToArray(modelCapability.input), output_modalities: modelModalitiesToArray(modelCapability.output), tool_support: modelCapability.tool_support, }]; } } //# sourceMappingURL=azure_openai.js.map