@azure/msal-common
Version: 
Microsoft Authentication Library for js
524 lines (476 loc) • 19.1 kB
text/typescript
/*
 * Copyright (c) Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License.
 */
import {
    ClientConfiguration,
    isOidcProtocolMode,
} from "../config/ClientConfiguration.js";
import { BaseClient } from "./BaseClient.js";
import { CommonRefreshTokenRequest } from "../request/CommonRefreshTokenRequest.js";
import { Authority } from "../authority/Authority.js";
import { ServerAuthorizationTokenResponse } from "../response/ServerAuthorizationTokenResponse.js";
import * as RequestParameterBuilder from "../request/RequestParameterBuilder.js";
import * as UrlUtils from "../utils/UrlUtils.js";
import {
    GrantType,
    AuthenticationScheme,
    Errors,
    HeaderNames,
} from "../utils/Constants.js";
import * as AADServerParamKeys from "../constants/AADServerParamKeys.js";
import { ResponseHandler } from "../response/ResponseHandler.js";
import { AuthenticationResult } from "../response/AuthenticationResult.js";
import { PopTokenGenerator } from "../crypto/PopTokenGenerator.js";
import { StringUtils } from "../utils/StringUtils.js";
import { NetworkResponse } from "../network/NetworkResponse.js";
import { CommonSilentFlowRequest } from "../request/CommonSilentFlowRequest.js";
import {
    createClientConfigurationError,
    ClientConfigurationErrorCodes,
} from "../error/ClientConfigurationError.js";
import {
    createClientAuthError,
    ClientAuthErrorCodes,
} from "../error/ClientAuthError.js";
import { ServerError } from "../error/ServerError.js";
import * as TimeUtils from "../utils/TimeUtils.js";
import { UrlString } from "../url/UrlString.js";
import { CcsCredentialType } from "../account/CcsCredential.js";
import { buildClientInfoFromHomeAccountId } from "../account/ClientInfo.js";
import {
    InteractionRequiredAuthError,
    InteractionRequiredAuthErrorCodes,
    createInteractionRequiredAuthError,
} from "../error/InteractionRequiredAuthError.js";
import { PerformanceEvents } from "../telemetry/performance/PerformanceEvent.js";
import { IPerformanceClient } from "../telemetry/performance/IPerformanceClient.js";
import { invoke, invokeAsync } from "../utils/FunctionWrappers.js";
import { generateCredentialKey } from "../cache/utils/CacheHelpers.js";
import { ClientAssertion } from "../account/ClientCredentials.js";
import { getClientAssertion } from "../utils/ClientAssertionUtils.js";
import { getRequestThumbprint } from "../network/RequestThumbprint.js";
const DEFAULT_REFRESH_TOKEN_EXPIRATION_OFFSET_SECONDS = 300; // 5 Minutes
/**
 * OAuth2.0 refresh token client
 * @internal
 */
export class RefreshTokenClient extends BaseClient {
    constructor(
        configuration: ClientConfiguration,
        performanceClient?: IPerformanceClient
    ) {
        super(configuration, performanceClient);
    }
    public async acquireToken(
        request: CommonRefreshTokenRequest
    ): Promise<AuthenticationResult> {
        this.performanceClient?.addQueueMeasurement(
            PerformanceEvents.RefreshTokenClientAcquireToken,
            request.correlationId
        );
        const reqTimestamp = TimeUtils.nowSeconds();
        const response = await invokeAsync(
            this.executeTokenRequest.bind(this),
            PerformanceEvents.RefreshTokenClientExecuteTokenRequest,
            this.logger,
            this.performanceClient,
            request.correlationId
        )(request, this.authority);
        // Retrieve requestId from response headers
        const requestId = response.headers?.[HeaderNames.X_MS_REQUEST_ID];
        const responseHandler = new ResponseHandler(
            this.config.authOptions.clientId,
            this.cacheManager,
            this.cryptoUtils,
            this.logger,
            this.config.serializableCache,
            this.config.persistencePlugin
        );
        responseHandler.validateTokenResponse(response.body);
        return invokeAsync(
            responseHandler.handleServerTokenResponse.bind(responseHandler),
            PerformanceEvents.HandleServerTokenResponse,
            this.logger,
            this.performanceClient,
            request.correlationId
        )(
            response.body,
            this.authority,
            reqTimestamp,
            request,
            undefined,
            undefined,
            true,
            request.forceCache,
            requestId
        );
    }
    /**
     * Gets cached refresh token and attaches to request, then calls acquireToken API
     * @param request
     */
    public async acquireTokenByRefreshToken(
        request: CommonSilentFlowRequest
    ): Promise<AuthenticationResult> {
        // Cannot renew token if no request object is given.
        if (!request) {
            throw createClientConfigurationError(
                ClientConfigurationErrorCodes.tokenRequestEmpty
            );
        }
        this.performanceClient?.addQueueMeasurement(
            PerformanceEvents.RefreshTokenClientAcquireTokenByRefreshToken,
            request.correlationId
        );
        // We currently do not support silent flow for account === null use cases; This will be revisited for confidential flow usecases
        if (!request.account) {
            throw createClientAuthError(
                ClientAuthErrorCodes.noAccountInSilentRequest
            );
        }
        // try checking if FOCI is enabled for the given application
        const isFOCI = this.cacheManager.isAppMetadataFOCI(
            request.account.environment
        );
        // if the app is part of the family, retrive a Family refresh token if present and make a refreshTokenRequest
        if (isFOCI) {
            try {
                return await invokeAsync(
                    this.acquireTokenWithCachedRefreshToken.bind(this),
                    PerformanceEvents.RefreshTokenClientAcquireTokenWithCachedRefreshToken,
                    this.logger,
                    this.performanceClient,
                    request.correlationId
                )(request, true);
            } catch (e) {
                const noFamilyRTInCache =
                    e instanceof InteractionRequiredAuthError &&
                    e.errorCode ===
                        InteractionRequiredAuthErrorCodes.noTokensFound;
                const clientMismatchErrorWithFamilyRT =
                    e instanceof ServerError &&
                    e.errorCode === Errors.INVALID_GRANT_ERROR &&
                    e.subError === Errors.CLIENT_MISMATCH_ERROR;
                // if family Refresh Token (FRT) cache acquisition fails or if client_mismatch error is seen with FRT, reattempt with application Refresh Token (ART)
                if (noFamilyRTInCache || clientMismatchErrorWithFamilyRT) {
                    return invokeAsync(
                        this.acquireTokenWithCachedRefreshToken.bind(this),
                        PerformanceEvents.RefreshTokenClientAcquireTokenWithCachedRefreshToken,
                        this.logger,
                        this.performanceClient,
                        request.correlationId
                    )(request, false);
                    // throw in all other cases
                } else {
                    throw e;
                }
            }
        }
        // fall back to application refresh token acquisition
        return invokeAsync(
            this.acquireTokenWithCachedRefreshToken.bind(this),
            PerformanceEvents.RefreshTokenClientAcquireTokenWithCachedRefreshToken,
            this.logger,
            this.performanceClient,
            request.correlationId
        )(request, false);
    }
    /**
     * makes a network call to acquire tokens by exchanging RefreshToken available in userCache; throws if refresh token is not cached
     * @param request
     */
    private async acquireTokenWithCachedRefreshToken(
        request: CommonSilentFlowRequest,
        foci: boolean
    ) {
        this.performanceClient?.addQueueMeasurement(
            PerformanceEvents.RefreshTokenClientAcquireTokenWithCachedRefreshToken,
            request.correlationId
        );
        // fetches family RT or application RT based on FOCI value
        const refreshToken = invoke(
            this.cacheManager.getRefreshToken.bind(this.cacheManager),
            PerformanceEvents.CacheManagerGetRefreshToken,
            this.logger,
            this.performanceClient,
            request.correlationId
        )(
            request.account,
            foci,
            undefined,
            this.performanceClient,
            request.correlationId
        );
        if (!refreshToken) {
            throw createInteractionRequiredAuthError(
                InteractionRequiredAuthErrorCodes.noTokensFound
            );
        }
        if (
            refreshToken.expiresOn &&
            TimeUtils.isTokenExpired(
                refreshToken.expiresOn,
                request.refreshTokenExpirationOffsetSeconds ||
                    DEFAULT_REFRESH_TOKEN_EXPIRATION_OFFSET_SECONDS
            )
        ) {
            this.performanceClient?.addFields(
                { rtExpiresOnMs: Number(refreshToken.expiresOn) },
                request.correlationId
            );
            throw createInteractionRequiredAuthError(
                InteractionRequiredAuthErrorCodes.refreshTokenExpired
            );
        }
        // attach cached RT size to the current measurement
        const refreshTokenRequest: CommonRefreshTokenRequest = {
            ...request,
            refreshToken: refreshToken.secret,
            authenticationScheme:
                request.authenticationScheme || AuthenticationScheme.BEARER,
            ccsCredential: {
                credential: request.account.homeAccountId,
                type: CcsCredentialType.HOME_ACCOUNT_ID,
            },
        };
        try {
            return await invokeAsync(
                this.acquireToken.bind(this),
                PerformanceEvents.RefreshTokenClientAcquireToken,
                this.logger,
                this.performanceClient,
                request.correlationId
            )(refreshTokenRequest);
        } catch (e) {
            if (e instanceof InteractionRequiredAuthError) {
                this.performanceClient?.addFields(
                    { rtExpiresOnMs: Number(refreshToken.expiresOn) },
                    request.correlationId
                );
                if (e.subError === InteractionRequiredAuthErrorCodes.badToken) {
                    // Remove bad refresh token from cache
                    this.logger.verbose(
                        "acquireTokenWithRefreshToken: bad refresh token, removing from cache"
                    );
                    const badRefreshTokenKey =
                        generateCredentialKey(refreshToken);
                    this.cacheManager.removeRefreshToken(badRefreshTokenKey);
                }
            }
            throw e;
        }
    }
    /**
     * Constructs the network message and makes a NW call to the underlying secure token service
     * @param request
     * @param authority
     */
    private async executeTokenRequest(
        request: CommonRefreshTokenRequest,
        authority: Authority
    ): Promise<NetworkResponse<ServerAuthorizationTokenResponse>> {
        this.performanceClient?.addQueueMeasurement(
            PerformanceEvents.RefreshTokenClientExecuteTokenRequest,
            request.correlationId
        );
        const queryParametersString = this.createTokenQueryParameters(request);
        const endpoint = UrlString.appendQueryString(
            authority.tokenEndpoint,
            queryParametersString
        );
        const requestBody = await invokeAsync(
            this.createTokenRequestBody.bind(this),
            PerformanceEvents.RefreshTokenClientCreateTokenRequestBody,
            this.logger,
            this.performanceClient,
            request.correlationId
        )(request);
        const headers: Record<string, string> = this.createTokenRequestHeaders(
            request.ccsCredential
        );
        const thumbprint = getRequestThumbprint(
            this.config.authOptions.clientId,
            request
        );
        return invokeAsync(
            this.executePostToTokenEndpoint.bind(this),
            PerformanceEvents.RefreshTokenClientExecutePostToTokenEndpoint,
            this.logger,
            this.performanceClient,
            request.correlationId
        )(
            endpoint,
            requestBody,
            headers,
            thumbprint,
            request.correlationId,
            PerformanceEvents.RefreshTokenClientExecutePostToTokenEndpoint
        );
    }
    /**
     * Helper function to create the token request body
     * @param request
     */
    private async createTokenRequestBody(
        request: CommonRefreshTokenRequest
    ): Promise<string> {
        this.performanceClient?.addQueueMeasurement(
            PerformanceEvents.RefreshTokenClientCreateTokenRequestBody,
            request.correlationId
        );
        const parameters = new Map<string, string>();
        RequestParameterBuilder.addClientId(
            parameters,
            request.embeddedClientId ||
                request.tokenBodyParameters?.[AADServerParamKeys.CLIENT_ID] ||
                this.config.authOptions.clientId
        );
        if (request.redirectUri) {
            RequestParameterBuilder.addRedirectUri(
                parameters,
                request.redirectUri
            );
        }
        RequestParameterBuilder.addScopes(
            parameters,
            request.scopes,
            true,
            this.config.authOptions.authority.options.OIDCOptions?.defaultScopes
        );
        RequestParameterBuilder.addGrantType(
            parameters,
            GrantType.REFRESH_TOKEN_GRANT
        );
        RequestParameterBuilder.addClientInfo(parameters);
        RequestParameterBuilder.addLibraryInfo(
            parameters,
            this.config.libraryInfo
        );
        RequestParameterBuilder.addApplicationTelemetry(
            parameters,
            this.config.telemetry.application
        );
        RequestParameterBuilder.addThrottling(parameters);
        if (this.serverTelemetryManager && !isOidcProtocolMode(this.config)) {
            RequestParameterBuilder.addServerTelemetry(
                parameters,
                this.serverTelemetryManager
            );
        }
        RequestParameterBuilder.addRefreshToken(
            parameters,
            request.refreshToken
        );
        if (this.config.clientCredentials.clientSecret) {
            RequestParameterBuilder.addClientSecret(
                parameters,
                this.config.clientCredentials.clientSecret
            );
        }
        if (this.config.clientCredentials.clientAssertion) {
            const clientAssertion: ClientAssertion =
                this.config.clientCredentials.clientAssertion;
            RequestParameterBuilder.addClientAssertion(
                parameters,
                await getClientAssertion(
                    clientAssertion.assertion,
                    this.config.authOptions.clientId,
                    request.resourceRequestUri
                )
            );
            RequestParameterBuilder.addClientAssertionType(
                parameters,
                clientAssertion.assertionType
            );
        }
        if (request.authenticationScheme === AuthenticationScheme.POP) {
            const popTokenGenerator = new PopTokenGenerator(
                this.cryptoUtils,
                this.performanceClient
            );
            let reqCnfData;
            if (!request.popKid) {
                const generatedReqCnfData = await invokeAsync(
                    popTokenGenerator.generateCnf.bind(popTokenGenerator),
                    PerformanceEvents.PopTokenGenerateCnf,
                    this.logger,
                    this.performanceClient,
                    request.correlationId
                )(request, this.logger);
                reqCnfData = generatedReqCnfData.reqCnfString;
            } else {
                reqCnfData = this.cryptoUtils.encodeKid(request.popKid);
            }
            // SPA PoP requires full Base64Url encoded req_cnf string (unhashed)
            RequestParameterBuilder.addPopToken(parameters, reqCnfData);
        } else if (request.authenticationScheme === AuthenticationScheme.SSH) {
            if (request.sshJwk) {
                RequestParameterBuilder.addSshJwk(parameters, request.sshJwk);
            } else {
                throw createClientConfigurationError(
                    ClientConfigurationErrorCodes.missingSshJwk
                );
            }
        }
        if (
            !StringUtils.isEmptyObj(request.claims) ||
            (this.config.authOptions.clientCapabilities &&
                this.config.authOptions.clientCapabilities.length > 0)
        ) {
            RequestParameterBuilder.addClaims(
                parameters,
                request.claims,
                this.config.authOptions.clientCapabilities
            );
        }
        if (
            this.config.systemOptions.preventCorsPreflight &&
            request.ccsCredential
        ) {
            switch (request.ccsCredential.type) {
                case CcsCredentialType.HOME_ACCOUNT_ID:
                    try {
                        const clientInfo = buildClientInfoFromHomeAccountId(
                            request.ccsCredential.credential
                        );
                        RequestParameterBuilder.addCcsOid(
                            parameters,
                            clientInfo
                        );
                    } catch (e) {
                        this.logger.verbose(
                            "Could not parse home account ID for CCS Header: " +
                                e
                        );
                    }
                    break;
                case CcsCredentialType.UPN:
                    RequestParameterBuilder.addCcsUpn(
                        parameters,
                        request.ccsCredential.credential
                    );
                    break;
            }
        }
        if (request.embeddedClientId) {
            RequestParameterBuilder.addBrokerParameters(
                parameters,
                this.config.authOptions.clientId,
                this.config.authOptions.redirectUri
            );
        }
        if (request.tokenBodyParameters) {
            RequestParameterBuilder.addExtraQueryParameters(
                parameters,
                request.tokenBodyParameters
            );
        }
        RequestParameterBuilder.instrumentBrokerParams(
            parameters,
            request.correlationId,
            this.performanceClient
        );
        return UrlUtils.mapToQueryString(parameters);
    }
}