@aws-amplify/auth
Version:
Auth category of aws-amplify
334 lines (278 loc) • 9.24 kB
text/typescript
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
import { parse } from 'url'; // Used for OAuth parsing of Cognito Hosted UI
import { launchUri } from './urlOpener';
import * as oAuthStorage from './oauthStorage';
import { Buffer } from 'buffer';
import {
OAuthOpts,
isCognitoHostedOpts,
CognitoHostedUIIdentityProvider,
} from '../types/Auth';
import { ConsoleLogger as Logger, Hub, urlSafeEncode } from '@aws-amplify/core';
import { Sha256 } from '@aws-crypto/sha256-js';
const AMPLIFY_SYMBOL = (
typeof Symbol !== 'undefined' && typeof Symbol.for === 'function'
? Symbol.for('amplify_default')
: '@@amplify_default'
) as Symbol;
const dispatchAuthEvent = (event: string, data: any, message: string) => {
Hub.dispatch('auth', { event, data, message }, 'Auth', AMPLIFY_SYMBOL);
};
const logger = new Logger('OAuth');
export default class OAuth {
private _urlOpener;
private _config;
private _cognitoClientId;
private _scopes;
constructor({
config,
cognitoClientId,
scopes = [],
}: {
scopes: string[];
config: OAuthOpts;
cognitoClientId: string;
}) {
this._urlOpener = config.urlOpener || launchUri;
this._config = config;
this._cognitoClientId = cognitoClientId;
if (!this.isValidScopes(scopes))
throw Error('scopes must be a String Array');
this._scopes = scopes;
}
private isValidScopes(scopes: string[]) {
return (
Array.isArray(scopes) && scopes.every(scope => typeof scope === 'string')
);
}
public oauthSignIn(
responseType = 'code',
domain: string,
redirectSignIn: string,
clientId: string,
provider:
| CognitoHostedUIIdentityProvider
| string = CognitoHostedUIIdentityProvider.Cognito,
customState?: string
) {
const generatedState = this._generateState(32);
/* encodeURIComponent is not URL safe, use urlSafeEncode instead. Cognito
single-encodes/decodes url on first sign in and double-encodes/decodes url
when user already signed in. Using encodeURIComponent, Base32, Base64 add
characters % or = which on further encoding becomes unsafe. '=' create issue
for parsing query params.
Refer: https://github.com/aws-amplify/amplify-js/issues/5218 */
const state = customState
? `${generatedState}-${urlSafeEncode(customState)}`
: generatedState;
oAuthStorage.setState(state);
const pkce_key = this._generateRandom(128);
oAuthStorage.setPKCE(pkce_key);
const code_challenge = this._generateChallenge(pkce_key);
const code_challenge_method = 'S256';
const scopesString = this._scopes.join(' ');
const queryString = Object.entries({
redirect_uri: redirectSignIn,
response_type: responseType,
client_id: clientId,
identity_provider: provider,
scope: scopesString,
state,
...(responseType === 'code' ? { code_challenge } : {}),
...(responseType === 'code' ? { code_challenge_method } : {}),
})
.map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(v)}`)
.join('&');
const URL = `https://${domain}/oauth2/authorize?${queryString}`;
logger.debug(`Redirecting to ${URL}`);
this._urlOpener(URL, redirectSignIn);
}
private async _handleCodeFlow(currentUrl: string) {
/* Convert URL into an object with parameters as keys
{ redirect_uri: 'http://localhost:3000/', response_type: 'code', ...} */
const { code } = (parse(currentUrl).query || '')
.split('&')
.map(pairings => pairings.split('='))
.reduce((accum, [k, v]) => ({ ...accum, [k]: v }), { code: undefined });
const currentUrlPathname = parse(currentUrl).pathname || '/';
const redirectSignInPathname =
parse(this._config.redirectSignIn).pathname || '/';
if (!code || currentUrlPathname !== redirectSignInPathname) {
return;
}
const oAuthTokenEndpoint =
'https://' + this._config.domain + '/oauth2/token';
dispatchAuthEvent(
'codeFlow',
{},
`Retrieving tokens from ${oAuthTokenEndpoint}`
);
const client_id = isCognitoHostedOpts(this._config)
? this._cognitoClientId
: this._config.clientID;
const redirect_uri = isCognitoHostedOpts(this._config)
? this._config.redirectSignIn
: this._config.redirectUri;
const code_verifier = oAuthStorage.getPKCE();
const oAuthTokenBody = {
grant_type: 'authorization_code',
code,
client_id,
redirect_uri,
...(code_verifier ? { code_verifier } : {}),
};
logger.debug(
`Calling token endpoint: ${oAuthTokenEndpoint} with`,
oAuthTokenBody
);
const body = Object.entries(oAuthTokenBody)
.map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(v)}`)
.join('&');
const { access_token, refresh_token, id_token, error } = await (
(await fetch(oAuthTokenEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body,
})) as any
).json();
if (error) {
throw new Error(error);
}
return {
accessToken: access_token,
refreshToken: refresh_token,
idToken: id_token,
};
}
private async _handleImplicitFlow(currentUrl: string) {
// hash is `null` if `#` doesn't exist on URL
const { id_token, access_token } = (parse(currentUrl).hash || '#')
.substr(1) // Remove # from returned code
.split('&')
.map(pairings => pairings.split('='))
.reduce((accum, [k, v]) => ({ ...accum, [k]: v }), {
id_token: undefined,
access_token: undefined,
});
dispatchAuthEvent('implicitFlow', {}, `Got tokens from ${currentUrl}`);
logger.debug(`Retrieving implicit tokens from ${currentUrl} with`);
return {
accessToken: access_token,
idToken: id_token,
refreshToken: null,
};
}
public async handleAuthResponse(currentUrl?: string) {
try {
const urlParams = currentUrl
? ({
...(parse(currentUrl).hash || '#')
.substr(1)
.split('&')
.map(entry => entry.split('='))
.reduce((acc, [k, v]) => ((acc[k] = v), acc), {}),
...(parse(currentUrl).query || '')
.split('&')
.map(entry => entry.split('='))
.reduce((acc, [k, v]) => ((acc[k] = v), acc), {}),
} as any)
: {};
const { error, error_description } = urlParams;
if (error) {
throw new Error(error_description);
}
const state: string = this._validateState(urlParams);
logger.debug(
`Starting ${this._config.responseType} flow with ${currentUrl}`
);
if (this._config.responseType === 'code') {
return { ...(await this._handleCodeFlow(currentUrl)), state };
} else {
return { ...(await this._handleImplicitFlow(currentUrl)), state };
}
} catch (e) {
logger.error(`Error handling auth response.`, e);
throw e;
}
}
private _validateState(urlParams: any): string {
if (!urlParams) {
return;
}
const savedState = oAuthStorage.getState();
const { state: returnedState } = urlParams;
// This is because savedState only exists if the flow was initiated by Amplify
if (savedState && savedState !== returnedState) {
throw new Error('Invalid state in OAuth flow');
}
return returnedState;
}
public async signOut() {
let oAuthLogoutEndpoint = 'https://' + this._config.domain + '/logout?';
const client_id = isCognitoHostedOpts(this._config)
? this._cognitoClientId
: this._config.oauth.clientID;
const signout_uri = isCognitoHostedOpts(this._config)
? this._config.redirectSignOut
: this._config.returnTo;
oAuthLogoutEndpoint += Object.entries({
client_id,
logout_uri: encodeURIComponent(signout_uri),
})
.map(([k, v]) => `${k}=${v}`)
.join('&');
dispatchAuthEvent(
'oAuthSignOut',
{ oAuth: 'signOut' },
`Signing out from ${oAuthLogoutEndpoint}`
);
logger.debug(`Signing out from ${oAuthLogoutEndpoint}`);
return this._urlOpener(oAuthLogoutEndpoint, signout_uri);
}
private _generateState(length: number) {
let result = '';
let i = length;
const chars =
'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ';
for (; i > 0; --i)
result += chars[Math.round(Math.random() * (chars.length - 1))];
return result;
}
private _generateChallenge(code: string) {
const awsCryptoHash = new Sha256();
awsCryptoHash.update(code);
const resultFromAWSCrypto = awsCryptoHash.digestSync();
const b64 = Buffer.from(resultFromAWSCrypto).toString('base64');
const base64URLFromAWSCrypto = this._base64URL(b64);
return base64URLFromAWSCrypto;
}
private _base64URL(string) {
return string.replace(/=/g, '').replace(/\+/g, '-').replace(/\//g, '_');
}
private _generateRandom(size: number) {
const CHARSET =
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';
const buffer = new Uint8Array(size);
if (typeof window !== 'undefined' && !!window.crypto) {
window.crypto.getRandomValues(buffer);
} else {
for (let i = 0; i < size; i += 1) {
buffer[i] = (Math.random() * CHARSET.length) | 0;
}
}
return this._bufferToString(buffer);
}
private _bufferToString(buffer: Uint8Array) {
const CHARSET =
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';
const state = [];
for (let i = 0; i < buffer.byteLength; i += 1) {
const index = buffer[i] % CHARSET.length;
state.push(CHARSET[index]);
}
return state.join('');
}
}