UNPKG

@azure/msal-common

Version:
524 lines (476 loc) 19.1 kB
/* * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ import { ClientConfiguration, isOidcProtocolMode, } from "../config/ClientConfiguration.js"; import { BaseClient } from "./BaseClient.js"; import { CommonRefreshTokenRequest } from "../request/CommonRefreshTokenRequest.js"; import { Authority } from "../authority/Authority.js"; import { ServerAuthorizationTokenResponse } from "../response/ServerAuthorizationTokenResponse.js"; import * as RequestParameterBuilder from "../request/RequestParameterBuilder.js"; import * as UrlUtils from "../utils/UrlUtils.js"; import { GrantType, AuthenticationScheme, Errors, HeaderNames, } from "../utils/Constants.js"; import * as AADServerParamKeys from "../constants/AADServerParamKeys.js"; import { ResponseHandler } from "../response/ResponseHandler.js"; import { AuthenticationResult } from "../response/AuthenticationResult.js"; import { PopTokenGenerator } from "../crypto/PopTokenGenerator.js"; import { StringUtils } from "../utils/StringUtils.js"; import { NetworkResponse } from "../network/NetworkResponse.js"; import { CommonSilentFlowRequest } from "../request/CommonSilentFlowRequest.js"; import { createClientConfigurationError, ClientConfigurationErrorCodes, } from "../error/ClientConfigurationError.js"; import { createClientAuthError, ClientAuthErrorCodes, } from "../error/ClientAuthError.js"; import { ServerError } from "../error/ServerError.js"; import * as TimeUtils from "../utils/TimeUtils.js"; import { UrlString } from "../url/UrlString.js"; import { CcsCredentialType } from "../account/CcsCredential.js"; import { buildClientInfoFromHomeAccountId } from "../account/ClientInfo.js"; import { InteractionRequiredAuthError, InteractionRequiredAuthErrorCodes, createInteractionRequiredAuthError, } from "../error/InteractionRequiredAuthError.js"; import { PerformanceEvents } from "../telemetry/performance/PerformanceEvent.js"; import { IPerformanceClient } from "../telemetry/performance/IPerformanceClient.js"; import { invoke, invokeAsync } from "../utils/FunctionWrappers.js"; import { generateCredentialKey } from "../cache/utils/CacheHelpers.js"; import { ClientAssertion } from "../account/ClientCredentials.js"; import { getClientAssertion } from "../utils/ClientAssertionUtils.js"; import { getRequestThumbprint } from "../network/RequestThumbprint.js"; const DEFAULT_REFRESH_TOKEN_EXPIRATION_OFFSET_SECONDS = 300; // 5 Minutes /** * OAuth2.0 refresh token client * @internal */ export class RefreshTokenClient extends BaseClient { constructor( configuration: ClientConfiguration, performanceClient?: IPerformanceClient ) { super(configuration, performanceClient); } public async acquireToken( request: CommonRefreshTokenRequest ): Promise<AuthenticationResult> { this.performanceClient?.addQueueMeasurement( PerformanceEvents.RefreshTokenClientAcquireToken, request.correlationId ); const reqTimestamp = TimeUtils.nowSeconds(); const response = await invokeAsync( this.executeTokenRequest.bind(this), PerformanceEvents.RefreshTokenClientExecuteTokenRequest, this.logger, this.performanceClient, request.correlationId )(request, this.authority); // Retrieve requestId from response headers const requestId = response.headers?.[HeaderNames.X_MS_REQUEST_ID]; const responseHandler = new ResponseHandler( this.config.authOptions.clientId, this.cacheManager, this.cryptoUtils, this.logger, this.config.serializableCache, this.config.persistencePlugin ); responseHandler.validateTokenResponse(response.body); return invokeAsync( responseHandler.handleServerTokenResponse.bind(responseHandler), PerformanceEvents.HandleServerTokenResponse, this.logger, this.performanceClient, request.correlationId )( response.body, this.authority, reqTimestamp, request, undefined, undefined, true, request.forceCache, requestId ); } /** * Gets cached refresh token and attaches to request, then calls acquireToken API * @param request */ public async acquireTokenByRefreshToken( request: CommonSilentFlowRequest ): Promise<AuthenticationResult> { // Cannot renew token if no request object is given. if (!request) { throw createClientConfigurationError( ClientConfigurationErrorCodes.tokenRequestEmpty ); } this.performanceClient?.addQueueMeasurement( PerformanceEvents.RefreshTokenClientAcquireTokenByRefreshToken, request.correlationId ); // We currently do not support silent flow for account === null use cases; This will be revisited for confidential flow usecases if (!request.account) { throw createClientAuthError( ClientAuthErrorCodes.noAccountInSilentRequest ); } // try checking if FOCI is enabled for the given application const isFOCI = this.cacheManager.isAppMetadataFOCI( request.account.environment ); // if the app is part of the family, retrive a Family refresh token if present and make a refreshTokenRequest if (isFOCI) { try { return await invokeAsync( this.acquireTokenWithCachedRefreshToken.bind(this), PerformanceEvents.RefreshTokenClientAcquireTokenWithCachedRefreshToken, this.logger, this.performanceClient, request.correlationId )(request, true); } catch (e) { const noFamilyRTInCache = e instanceof InteractionRequiredAuthError && e.errorCode === InteractionRequiredAuthErrorCodes.noTokensFound; const clientMismatchErrorWithFamilyRT = e instanceof ServerError && e.errorCode === Errors.INVALID_GRANT_ERROR && e.subError === Errors.CLIENT_MISMATCH_ERROR; // if family Refresh Token (FRT) cache acquisition fails or if client_mismatch error is seen with FRT, reattempt with application Refresh Token (ART) if (noFamilyRTInCache || clientMismatchErrorWithFamilyRT) { return invokeAsync( this.acquireTokenWithCachedRefreshToken.bind(this), PerformanceEvents.RefreshTokenClientAcquireTokenWithCachedRefreshToken, this.logger, this.performanceClient, request.correlationId )(request, false); // throw in all other cases } else { throw e; } } } // fall back to application refresh token acquisition return invokeAsync( this.acquireTokenWithCachedRefreshToken.bind(this), PerformanceEvents.RefreshTokenClientAcquireTokenWithCachedRefreshToken, this.logger, this.performanceClient, request.correlationId )(request, false); } /** * makes a network call to acquire tokens by exchanging RefreshToken available in userCache; throws if refresh token is not cached * @param request */ private async acquireTokenWithCachedRefreshToken( request: CommonSilentFlowRequest, foci: boolean ) { this.performanceClient?.addQueueMeasurement( PerformanceEvents.RefreshTokenClientAcquireTokenWithCachedRefreshToken, request.correlationId ); // fetches family RT or application RT based on FOCI value const refreshToken = invoke( this.cacheManager.getRefreshToken.bind(this.cacheManager), PerformanceEvents.CacheManagerGetRefreshToken, this.logger, this.performanceClient, request.correlationId )( request.account, foci, undefined, this.performanceClient, request.correlationId ); if (!refreshToken) { throw createInteractionRequiredAuthError( InteractionRequiredAuthErrorCodes.noTokensFound ); } if ( refreshToken.expiresOn && TimeUtils.isTokenExpired( refreshToken.expiresOn, request.refreshTokenExpirationOffsetSeconds || DEFAULT_REFRESH_TOKEN_EXPIRATION_OFFSET_SECONDS ) ) { this.performanceClient?.addFields( { rtExpiresOnMs: Number(refreshToken.expiresOn) }, request.correlationId ); throw createInteractionRequiredAuthError( InteractionRequiredAuthErrorCodes.refreshTokenExpired ); } // attach cached RT size to the current measurement const refreshTokenRequest: CommonRefreshTokenRequest = { ...request, refreshToken: refreshToken.secret, authenticationScheme: request.authenticationScheme || AuthenticationScheme.BEARER, ccsCredential: { credential: request.account.homeAccountId, type: CcsCredentialType.HOME_ACCOUNT_ID, }, }; try { return await invokeAsync( this.acquireToken.bind(this), PerformanceEvents.RefreshTokenClientAcquireToken, this.logger, this.performanceClient, request.correlationId )(refreshTokenRequest); } catch (e) { if (e instanceof InteractionRequiredAuthError) { this.performanceClient?.addFields( { rtExpiresOnMs: Number(refreshToken.expiresOn) }, request.correlationId ); if (e.subError === InteractionRequiredAuthErrorCodes.badToken) { // Remove bad refresh token from cache this.logger.verbose( "acquireTokenWithRefreshToken: bad refresh token, removing from cache" ); const badRefreshTokenKey = generateCredentialKey(refreshToken); this.cacheManager.removeRefreshToken(badRefreshTokenKey); } } throw e; } } /** * Constructs the network message and makes a NW call to the underlying secure token service * @param request * @param authority */ private async executeTokenRequest( request: CommonRefreshTokenRequest, authority: Authority ): Promise<NetworkResponse<ServerAuthorizationTokenResponse>> { this.performanceClient?.addQueueMeasurement( PerformanceEvents.RefreshTokenClientExecuteTokenRequest, request.correlationId ); const queryParametersString = this.createTokenQueryParameters(request); const endpoint = UrlString.appendQueryString( authority.tokenEndpoint, queryParametersString ); const requestBody = await invokeAsync( this.createTokenRequestBody.bind(this), PerformanceEvents.RefreshTokenClientCreateTokenRequestBody, this.logger, this.performanceClient, request.correlationId )(request); const headers: Record<string, string> = this.createTokenRequestHeaders( request.ccsCredential ); const thumbprint = getRequestThumbprint( this.config.authOptions.clientId, request ); return invokeAsync( this.executePostToTokenEndpoint.bind(this), PerformanceEvents.RefreshTokenClientExecutePostToTokenEndpoint, this.logger, this.performanceClient, request.correlationId )( endpoint, requestBody, headers, thumbprint, request.correlationId, PerformanceEvents.RefreshTokenClientExecutePostToTokenEndpoint ); } /** * Helper function to create the token request body * @param request */ private async createTokenRequestBody( request: CommonRefreshTokenRequest ): Promise<string> { this.performanceClient?.addQueueMeasurement( PerformanceEvents.RefreshTokenClientCreateTokenRequestBody, request.correlationId ); const parameters = new Map<string, string>(); RequestParameterBuilder.addClientId( parameters, request.embeddedClientId || request.tokenBodyParameters?.[AADServerParamKeys.CLIENT_ID] || this.config.authOptions.clientId ); if (request.redirectUri) { RequestParameterBuilder.addRedirectUri( parameters, request.redirectUri ); } RequestParameterBuilder.addScopes( parameters, request.scopes, true, this.config.authOptions.authority.options.OIDCOptions?.defaultScopes ); RequestParameterBuilder.addGrantType( parameters, GrantType.REFRESH_TOKEN_GRANT ); RequestParameterBuilder.addClientInfo(parameters); RequestParameterBuilder.addLibraryInfo( parameters, this.config.libraryInfo ); RequestParameterBuilder.addApplicationTelemetry( parameters, this.config.telemetry.application ); RequestParameterBuilder.addThrottling(parameters); if (this.serverTelemetryManager && !isOidcProtocolMode(this.config)) { RequestParameterBuilder.addServerTelemetry( parameters, this.serverTelemetryManager ); } RequestParameterBuilder.addRefreshToken( parameters, request.refreshToken ); if (this.config.clientCredentials.clientSecret) { RequestParameterBuilder.addClientSecret( parameters, this.config.clientCredentials.clientSecret ); } if (this.config.clientCredentials.clientAssertion) { const clientAssertion: ClientAssertion = this.config.clientCredentials.clientAssertion; RequestParameterBuilder.addClientAssertion( parameters, await getClientAssertion( clientAssertion.assertion, this.config.authOptions.clientId, request.resourceRequestUri ) ); RequestParameterBuilder.addClientAssertionType( parameters, clientAssertion.assertionType ); } if (request.authenticationScheme === AuthenticationScheme.POP) { const popTokenGenerator = new PopTokenGenerator( this.cryptoUtils, this.performanceClient ); let reqCnfData; if (!request.popKid) { const generatedReqCnfData = await invokeAsync( popTokenGenerator.generateCnf.bind(popTokenGenerator), PerformanceEvents.PopTokenGenerateCnf, this.logger, this.performanceClient, request.correlationId )(request, this.logger); reqCnfData = generatedReqCnfData.reqCnfString; } else { reqCnfData = this.cryptoUtils.encodeKid(request.popKid); } // SPA PoP requires full Base64Url encoded req_cnf string (unhashed) RequestParameterBuilder.addPopToken(parameters, reqCnfData); } else if (request.authenticationScheme === AuthenticationScheme.SSH) { if (request.sshJwk) { RequestParameterBuilder.addSshJwk(parameters, request.sshJwk); } else { throw createClientConfigurationError( ClientConfigurationErrorCodes.missingSshJwk ); } } if ( !StringUtils.isEmptyObj(request.claims) || (this.config.authOptions.clientCapabilities && this.config.authOptions.clientCapabilities.length > 0) ) { RequestParameterBuilder.addClaims( parameters, request.claims, this.config.authOptions.clientCapabilities ); } if ( this.config.systemOptions.preventCorsPreflight && request.ccsCredential ) { switch (request.ccsCredential.type) { case CcsCredentialType.HOME_ACCOUNT_ID: try { const clientInfo = buildClientInfoFromHomeAccountId( request.ccsCredential.credential ); RequestParameterBuilder.addCcsOid( parameters, clientInfo ); } catch (e) { this.logger.verbose( "Could not parse home account ID for CCS Header: " + e ); } break; case CcsCredentialType.UPN: RequestParameterBuilder.addCcsUpn( parameters, request.ccsCredential.credential ); break; } } if (request.embeddedClientId) { RequestParameterBuilder.addBrokerParameters( parameters, this.config.authOptions.clientId, this.config.authOptions.redirectUri ); } if (request.tokenBodyParameters) { RequestParameterBuilder.addExtraQueryParameters( parameters, request.tokenBodyParameters ); } RequestParameterBuilder.instrumentBrokerParams( parameters, request.correlationId, this.performanceClient ); return UrlUtils.mapToQueryString(parameters); } }