UNPKG

@ai-sdk/openai

Version:

The **[OpenAI provider](https://ai-sdk.dev/providers/ai-sdk-providers/openai)** for the [AI SDK](https://ai-sdk.dev/docs) contains language model support for the OpenAI chat and completion APIs and embedding model support for the OpenAI embeddings API.

704 lines (627 loc) 22.3 kB
import { InvalidResponseDataError, LanguageModelV3, LanguageModelV3CallOptions, LanguageModelV3Content, LanguageModelV3FinishReason, LanguageModelV3GenerateResult, LanguageModelV3StreamPart, LanguageModelV3StreamResult, SharedV3ProviderMetadata, SharedV3Warning, } from '@ai-sdk/provider'; import { FetchFunction, ParseResult, combineHeaders, createEventSourceResponseHandler, createJsonResponseHandler, generateId, isParsableJson, parseProviderOptions, postJsonToApi, } from '@ai-sdk/provider-utils'; import { openaiFailedResponseHandler } from '../openai-error'; import { getOpenAILanguageModelCapabilities } from '../openai-language-model-capabilities'; import { OpenAIChatUsage, convertOpenAIChatUsage, } from './convert-openai-chat-usage'; import { convertToOpenAIChatMessages } from './convert-to-openai-chat-messages'; import { getResponseMetadata } from './get-response-metadata'; import { mapOpenAIFinishReason } from './map-openai-finish-reason'; import { OpenAIChatChunk, openaiChatChunkSchema, openaiChatResponseSchema, } from './openai-chat-api'; import { OpenAIChatModelId, openaiLanguageModelChatOptions, } from './openai-chat-options'; import { prepareChatTools } from './openai-chat-prepare-tools'; type OpenAIChatConfig = { provider: string; headers: () => Record<string, string | undefined>; url: (options: { modelId: string; path: string }) => string; fetch?: FetchFunction; }; export class OpenAIChatLanguageModel implements LanguageModelV3 { readonly specificationVersion = 'v3'; readonly modelId: OpenAIChatModelId; readonly supportedUrls = { 'image/*': [/^https?:\/\/.*$/], }; private readonly config: OpenAIChatConfig; constructor(modelId: OpenAIChatModelId, config: OpenAIChatConfig) { this.modelId = modelId; this.config = config; } get provider(): string { return this.config.provider; } private async getArgs({ prompt, maxOutputTokens, temperature, topP, topK, frequencyPenalty, presencePenalty, stopSequences, responseFormat, seed, tools, toolChoice, providerOptions, }: LanguageModelV3CallOptions) { const warnings: SharedV3Warning[] = []; // Parse provider options const openaiOptions = (await parseProviderOptions({ provider: 'openai', providerOptions, schema: openaiLanguageModelChatOptions, })) ?? {}; const modelCapabilities = getOpenAILanguageModelCapabilities(this.modelId); const isReasoningModel = openaiOptions.forceReasoning ?? modelCapabilities.isReasoningModel; if (topK != null) { warnings.push({ type: 'unsupported', feature: 'topK' }); } const { messages, warnings: messageWarnings } = convertToOpenAIChatMessages( { prompt, systemMessageMode: openaiOptions.systemMessageMode ?? (isReasoningModel ? 'developer' : modelCapabilities.systemMessageMode), }, ); warnings.push(...messageWarnings); const strictJsonSchema = openaiOptions.strictJsonSchema ?? true; const baseArgs = { // model id: model: this.modelId, // model specific settings: logit_bias: openaiOptions.logitBias, logprobs: openaiOptions.logprobs === true || typeof openaiOptions.logprobs === 'number' ? true : undefined, top_logprobs: typeof openaiOptions.logprobs === 'number' ? openaiOptions.logprobs : typeof openaiOptions.logprobs === 'boolean' ? openaiOptions.logprobs ? 0 : undefined : undefined, user: openaiOptions.user, parallel_tool_calls: openaiOptions.parallelToolCalls, // standardized settings: max_tokens: maxOutputTokens, temperature, top_p: topP, frequency_penalty: frequencyPenalty, presence_penalty: presencePenalty, response_format: responseFormat?.type === 'json' ? responseFormat.schema != null ? { type: 'json_schema', json_schema: { schema: responseFormat.schema, strict: strictJsonSchema, name: responseFormat.name ?? 'response', description: responseFormat.description, }, } : { type: 'json_object' } : undefined, stop: stopSequences, seed, verbosity: openaiOptions.textVerbosity, // openai specific settings: // TODO AI SDK 6: remove, we auto-map maxOutputTokens now max_completion_tokens: openaiOptions.maxCompletionTokens, store: openaiOptions.store, metadata: openaiOptions.metadata, prediction: openaiOptions.prediction, reasoning_effort: openaiOptions.reasoningEffort, service_tier: openaiOptions.serviceTier, prompt_cache_key: openaiOptions.promptCacheKey, prompt_cache_retention: openaiOptions.promptCacheRetention, safety_identifier: openaiOptions.safetyIdentifier, // messages: messages, }; // remove unsupported settings for reasoning models // see https://platform.openai.com/docs/guides/reasoning#limitations if (isReasoningModel) { // when reasoning effort is none, gpt-5.1 models allow temperature, topP, logprobs // https://platform.openai.com/docs/guides/latest-model#gpt-5-1-parameter-compatibility if ( openaiOptions.reasoningEffort !== 'none' || !modelCapabilities.supportsNonReasoningParameters ) { if (baseArgs.temperature != null) { baseArgs.temperature = undefined; warnings.push({ type: 'unsupported', feature: 'temperature', details: 'temperature is not supported for reasoning models', }); } if (baseArgs.top_p != null) { baseArgs.top_p = undefined; warnings.push({ type: 'unsupported', feature: 'topP', details: 'topP is not supported for reasoning models', }); } if (baseArgs.logprobs != null) { baseArgs.logprobs = undefined; warnings.push({ type: 'other', message: 'logprobs is not supported for reasoning models', }); } } if (baseArgs.frequency_penalty != null) { baseArgs.frequency_penalty = undefined; warnings.push({ type: 'unsupported', feature: 'frequencyPenalty', details: 'frequencyPenalty is not supported for reasoning models', }); } if (baseArgs.presence_penalty != null) { baseArgs.presence_penalty = undefined; warnings.push({ type: 'unsupported', feature: 'presencePenalty', details: 'presencePenalty is not supported for reasoning models', }); } if (baseArgs.logit_bias != null) { baseArgs.logit_bias = undefined; warnings.push({ type: 'other', message: 'logitBias is not supported for reasoning models', }); } if (baseArgs.top_logprobs != null) { baseArgs.top_logprobs = undefined; warnings.push({ type: 'other', message: 'topLogprobs is not supported for reasoning models', }); } // reasoning models use max_completion_tokens instead of max_tokens: if (baseArgs.max_tokens != null) { if (baseArgs.max_completion_tokens == null) { baseArgs.max_completion_tokens = baseArgs.max_tokens; } baseArgs.max_tokens = undefined; } } else if ( this.modelId.startsWith('gpt-4o-search-preview') || this.modelId.startsWith('gpt-4o-mini-search-preview') ) { if (baseArgs.temperature != null) { baseArgs.temperature = undefined; warnings.push({ type: 'unsupported', feature: 'temperature', details: 'temperature is not supported for the search preview models and has been removed.', }); } } // Validate flex processing support if ( openaiOptions.serviceTier === 'flex' && !modelCapabilities.supportsFlexProcessing ) { warnings.push({ type: 'unsupported', feature: 'serviceTier', details: 'flex processing is only available for o3, o4-mini, and gpt-5 models', }); baseArgs.service_tier = undefined; } // Validate priority processing support if ( openaiOptions.serviceTier === 'priority' && !modelCapabilities.supportsPriorityProcessing ) { warnings.push({ type: 'unsupported', feature: 'serviceTier', details: 'priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported', }); baseArgs.service_tier = undefined; } const { tools: openaiTools, toolChoice: openaiToolChoice, toolWarnings, } = prepareChatTools({ tools, toolChoice, }); return { args: { ...baseArgs, tools: openaiTools, tool_choice: openaiToolChoice, }, warnings: [...warnings, ...toolWarnings], }; } async doGenerate( options: LanguageModelV3CallOptions, ): Promise<LanguageModelV3GenerateResult> { const { args: body, warnings } = await this.getArgs(options); const { responseHeaders, value: response, rawValue: rawResponse, } = await postJsonToApi({ url: this.config.url({ path: '/chat/completions', modelId: this.modelId, }), headers: combineHeaders(this.config.headers(), options.headers), body, failedResponseHandler: openaiFailedResponseHandler, successfulResponseHandler: createJsonResponseHandler( openaiChatResponseSchema, ), abortSignal: options.abortSignal, fetch: this.config.fetch, }); const choice = response.choices[0]; const content: Array<LanguageModelV3Content> = []; // text content: const text = choice.message.content; if (text != null && text.length > 0) { content.push({ type: 'text', text }); } // tool calls: for (const toolCall of choice.message.tool_calls ?? []) { content.push({ type: 'tool-call' as const, toolCallId: toolCall.id ?? generateId(), toolName: toolCall.function.name, input: toolCall.function.arguments!, }); } // annotations/citations: for (const annotation of choice.message.annotations ?? []) { content.push({ type: 'source', sourceType: 'url', id: generateId(), url: annotation.url_citation.url, title: annotation.url_citation.title, }); } // provider metadata: const completionTokenDetails = response.usage?.completion_tokens_details; const promptTokenDetails = response.usage?.prompt_tokens_details; const providerMetadata: SharedV3ProviderMetadata = { openai: {} }; if (completionTokenDetails?.accepted_prediction_tokens != null) { providerMetadata.openai.acceptedPredictionTokens = completionTokenDetails?.accepted_prediction_tokens; } if (completionTokenDetails?.rejected_prediction_tokens != null) { providerMetadata.openai.rejectedPredictionTokens = completionTokenDetails?.rejected_prediction_tokens; } if (choice.logprobs?.content != null) { providerMetadata.openai.logprobs = choice.logprobs.content; } return { content, finishReason: { unified: mapOpenAIFinishReason(choice.finish_reason), raw: choice.finish_reason ?? undefined, }, usage: convertOpenAIChatUsage(response.usage), request: { body }, response: { ...getResponseMetadata(response), headers: responseHeaders, body: rawResponse, }, warnings, providerMetadata, }; } async doStream( options: LanguageModelV3CallOptions, ): Promise<LanguageModelV3StreamResult> { const { args, warnings } = await this.getArgs(options); const body = { ...args, stream: true, stream_options: { include_usage: true, }, }; const { responseHeaders, value: response } = await postJsonToApi({ url: this.config.url({ path: '/chat/completions', modelId: this.modelId, }), headers: combineHeaders(this.config.headers(), options.headers), body, failedResponseHandler: openaiFailedResponseHandler, successfulResponseHandler: createEventSourceResponseHandler( openaiChatChunkSchema, ), abortSignal: options.abortSignal, fetch: this.config.fetch, }); const toolCalls: Array<{ id: string; type: 'function'; function: { name: string; arguments: string; }; hasFinished: boolean; }> = []; let finishReason: LanguageModelV3FinishReason = { unified: 'other', raw: undefined, }; let usage: OpenAIChatUsage | undefined = undefined; let metadataExtracted = false; let isActiveText = false; const providerMetadata: SharedV3ProviderMetadata = { openai: {} }; return { stream: response.pipeThrough( new TransformStream< ParseResult<OpenAIChatChunk>, LanguageModelV3StreamPart >({ start(controller) { controller.enqueue({ type: 'stream-start', warnings }); }, transform(chunk, controller) { if (options.includeRawChunks) { controller.enqueue({ type: 'raw', rawValue: chunk.rawValue }); } // handle failed chunk parsing / validation: if (!chunk.success) { finishReason = { unified: 'error', raw: undefined }; controller.enqueue({ type: 'error', error: chunk.error }); return; } const value = chunk.value; // handle error chunks: if ('error' in value) { finishReason = { unified: 'error', raw: undefined }; controller.enqueue({ type: 'error', error: value.error }); return; } // extract and emit response metadata once. Usually it comes in the first chunk. // Azure may prepend a chunk with a `"prompt_filter_results"` key which does not contain other metadata, // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/concepts/content-filter-annotations?tabs=powershell if (!metadataExtracted) { const metadata = getResponseMetadata(value); if (Object.values(metadata).some(Boolean)) { metadataExtracted = true; controller.enqueue({ type: 'response-metadata', ...getResponseMetadata(value), }); } } if (value.usage != null) { usage = value.usage; if ( value.usage.completion_tokens_details ?.accepted_prediction_tokens != null ) { providerMetadata.openai.acceptedPredictionTokens = value.usage.completion_tokens_details?.accepted_prediction_tokens; } if ( value.usage.completion_tokens_details ?.rejected_prediction_tokens != null ) { providerMetadata.openai.rejectedPredictionTokens = value.usage.completion_tokens_details?.rejected_prediction_tokens; } } const choice = value.choices[0]; if (choice?.finish_reason != null) { finishReason = { unified: mapOpenAIFinishReason(choice.finish_reason), raw: choice.finish_reason, }; } if (choice?.logprobs?.content != null) { providerMetadata.openai.logprobs = choice.logprobs.content; } if (choice?.delta == null) { return; } const delta = choice.delta; if (delta.content != null) { if (!isActiveText) { controller.enqueue({ type: 'text-start', id: '0' }); isActiveText = true; } controller.enqueue({ type: 'text-delta', id: '0', delta: delta.content, }); } if (delta.tool_calls != null) { for (const toolCallDelta of delta.tool_calls) { const index = toolCallDelta.index; // Tool call start. OpenAI returns all information except the arguments in the first chunk. if (toolCalls[index] == null) { if ( toolCallDelta.type != null && toolCallDelta.type !== 'function' ) { throw new InvalidResponseDataError({ data: toolCallDelta, message: `Expected 'function' type.`, }); } if (toolCallDelta.id == null) { throw new InvalidResponseDataError({ data: toolCallDelta, message: `Expected 'id' to be a string.`, }); } if (toolCallDelta.function?.name == null) { throw new InvalidResponseDataError({ data: toolCallDelta, message: `Expected 'function.name' to be a string.`, }); } controller.enqueue({ type: 'tool-input-start', id: toolCallDelta.id, toolName: toolCallDelta.function.name, }); toolCalls[index] = { id: toolCallDelta.id, type: 'function', function: { name: toolCallDelta.function.name, arguments: toolCallDelta.function.arguments ?? '', }, hasFinished: false, }; const toolCall = toolCalls[index]; if ( toolCall.function?.name != null && toolCall.function?.arguments != null ) { // send delta if the argument text has already started: if (toolCall.function.arguments.length > 0) { controller.enqueue({ type: 'tool-input-delta', id: toolCall.id, delta: toolCall.function.arguments, }); } // check if tool call is complete // (some providers send the full tool call in one chunk): if (isParsableJson(toolCall.function.arguments)) { controller.enqueue({ type: 'tool-input-end', id: toolCall.id, }); controller.enqueue({ type: 'tool-call', toolCallId: toolCall.id ?? generateId(), toolName: toolCall.function.name, input: toolCall.function.arguments, }); toolCall.hasFinished = true; } } continue; } // existing tool call, merge if not finished const toolCall = toolCalls[index]; if (toolCall.hasFinished) { continue; } if (toolCallDelta.function?.arguments != null) { toolCall.function!.arguments += toolCallDelta.function?.arguments ?? ''; } // send delta controller.enqueue({ type: 'tool-input-delta', id: toolCall.id, delta: toolCallDelta.function.arguments ?? '', }); // check if tool call is complete if ( toolCall.function?.name != null && toolCall.function?.arguments != null && isParsableJson(toolCall.function.arguments) ) { controller.enqueue({ type: 'tool-input-end', id: toolCall.id, }); controller.enqueue({ type: 'tool-call', toolCallId: toolCall.id ?? generateId(), toolName: toolCall.function.name, input: toolCall.function.arguments, }); toolCall.hasFinished = true; } } } // annotations/citations: if (delta.annotations != null) { for (const annotation of delta.annotations) { controller.enqueue({ type: 'source', sourceType: 'url', id: generateId(), url: annotation.url_citation.url, title: annotation.url_citation.title, }); } } }, flush(controller) { if (isActiveText) { controller.enqueue({ type: 'text-end', id: '0' }); } controller.enqueue({ type: 'finish', finishReason, usage: convertOpenAIChatUsage(usage), ...(providerMetadata != null ? { providerMetadata } : {}), }); }, }), ), request: { body }, response: { headers: responseHeaders }, }; } }