@sudowealth/schwab-api
Version:
TypeScript client for Charles Schwab API with OAuth support, market data, trading functionality, and complete type safety
254 lines (253 loc) • 7.71 kB
JavaScript
import { z } from 'zod';
import { createLogger } from '../utils/secure-logger';
import { safeBase64Encode, safeBase64Decode } from './auth-utils';
const logger = createLogger('OAuthState');
/**
* Base OAuth state schema for PKCE flow
*/
export const BasicOAuthStateSchema = z.object({
pkce_code_verifier: z.string().optional(),
csrf_token: z.string().optional(),
timestamp: z.number().optional(),
});
/**
* Extended OAuth state schema with PKCE support
*/
export const PKCEOAuthStateSchema = BasicOAuthStateSchema.extend({
pkce_code_verifier: z.string(),
pkce_code_challenge: z.string().optional(),
pkce_method: z.literal('S256').optional(),
});
/**
* Encode OAuth state to a base64 string
*
* @param state State object to encode
* @param options Encoding options
* @returns Base64-encoded state string
*/
export function encodeOAuthState(state, options = {}) {
try {
const stateObject = {
...state,
...(options.customData || {}),
};
// Add CSRF token if requested
if (options.includeCSRF && !stateObject.csrf_token) {
stateObject.csrf_token = generateCSRFToken();
}
// Add timestamp if requested
if (options.includeTimestamp && !stateObject.timestamp) {
stateObject.timestamp = Date.now();
}
const jsonString = JSON.stringify(stateObject);
return safeBase64Encode(jsonString, true); // URL-safe encoding
}
catch (error) {
logger.error('Failed to encode OAuth state:', error);
throw new Error('Failed to encode OAuth state');
}
}
/**
* Decode OAuth state from a base64 string
*
* @param encodedState Base64-encoded state string
* @returns Decoded state object or null if decoding fails
*/
export function decodeOAuthState(encodedState) {
try {
// Handle URL-encoded state parameters
let processedState = encodedState;
// Check if the state might be URL-encoded
if (encodedState.includes('%')) {
try {
processedState = decodeURIComponent(encodedState);
logger.debug('URL-decoded state parameter');
}
catch {
logger.warn('Failed to URL-decode state parameter, using as-is');
}
}
const jsonString = safeBase64Decode(processedState);
const parsed = JSON.parse(jsonString);
return parsed;
}
catch (error) {
logger.error('Failed to decode OAuth state:', error);
return null;
}
}
/**
* Validate OAuth state against a schema
*
* @param state State object to validate
* @param schema Zod schema to validate against
* @returns True if valid, false otherwise
*/
export function validateOAuthState(state, schema) {
try {
schema.parse(state);
return true;
}
catch (error) {
logger.debug('OAuth state validation failed:', error);
return false;
}
}
/**
* Merge application state with PKCE parameters
*
* @param appState Application-specific state
* @param pkceVerifier PKCE code verifier
* @param pkceChallenge PKCE code challenge
* @returns Merged state object
*/
export function mergeStateWithPKCE(appState, pkceVerifier, pkceChallenge) {
return {
...appState,
pkce_code_verifier: pkceVerifier,
...(pkceChallenge && { pkce_code_challenge: pkceChallenge }),
};
}
/**
* Extract PKCE parameters from state
*
* @param state State object containing PKCE parameters
* @returns PKCE parameters or null if not found
*/
export function extractPKCEFromState(state) {
if (!state || typeof state !== 'object') {
return null;
}
const stateObj = state;
if (!stateObj.pkce_code_verifier) {
return null;
}
return {
codeVerifier: stateObj.pkce_code_verifier,
codeChallenge: stateObj.pkce_code_challenge,
};
}
/**
* Create OAuth state with CSRF token
*
* @param data State data
* @returns State with CSRF token
*/
export function createStateWithCSRF(data) {
return {
...data,
csrf_token: generateCSRFToken(),
timestamp: Date.now(),
};
}
/**
* Verify OAuth state CSRF token and timestamp
*
* @param state State object to verify
* @param expectedCSRF Expected CSRF token
* @param maxAgeMs Maximum age in milliseconds (default: 10 minutes)
* @returns True if valid, false otherwise
*/
export function verifyStateWithCSRF(state, expectedCSRF, maxAgeMs = 600000) {
if (!state || typeof state !== 'object') {
return false;
}
const stateObj = state;
// Verify CSRF token if provided
if (expectedCSRF && stateObj.csrf_token !== expectedCSRF) {
logger.warn('CSRF token mismatch in OAuth state');
return false;
}
// Verify timestamp if present
if (stateObj.timestamp) {
const age = Date.now() - stateObj.timestamp;
if (age > maxAgeMs) {
logger.warn('OAuth state has expired');
return false;
}
}
return true;
}
/**
* Generate a secure CSRF token
*
* @returns Random CSRF token
*/
function generateCSRFToken() {
const array = new Uint8Array(32);
crypto.getRandomValues(array);
return Array.from(array, (byte) => byte.toString(16).padStart(2, '0')).join('');
}
/**
* Extract client ID from various state formats
* Supports both direct clientId and nested oauthReqInfo.clientId
*
* @param state State object
* @returns Client ID or null if not found
*/
export function extractClientIdFromState(state) {
if (!state || typeof state !== 'object') {
return null;
}
const stateObj = state;
// Direct clientId
if (stateObj.clientId && typeof stateObj.clientId === 'string') {
return stateObj.clientId;
}
// Nested in oauthReqInfo (MCP pattern)
if (stateObj.oauthReqInfo &&
typeof stateObj.oauthReqInfo === 'object' &&
stateObj.oauthReqInfo.clientId &&
typeof stateObj.oauthReqInfo.clientId === 'string') {
return stateObj.oauthReqInfo.clientId;
}
// Nested in other common patterns
if (stateObj.auth &&
typeof stateObj.auth === 'object' &&
stateObj.auth.clientId) {
return stateObj.auth.clientId;
}
return null;
}
/**
* Advanced state verification with custom validation
*
* @param encodedState Encoded state string
* @param options Verification options
* @returns Decoded and validated state or null
*/
export function decodeAndVerifyState(encodedState, options = {}) {
try {
// Decode the state
const decoded = decodeOAuthState(encodedState);
if (!decoded) {
return null;
}
// Validate against schema if provided
if (options.schema && !validateOAuthState(decoded, options.schema)) {
logger.error('State validation failed against provided schema');
return null;
}
// Verify CSRF if needed
if (options.expectedCSRF || options.maxAgeMs) {
if (!verifyStateWithCSRF(decoded, options.expectedCSRF, options.maxAgeMs)) {
return null;
}
}
// Check required fields
if (options.requiredFields) {
const decodedObj = decoded;
for (const field of options.requiredFields) {
if (!(field in decodedObj)) {
logger.error(`Required field '${field}' missing in state`);
return null;
}
}
}
return decoded;
}
catch (error) {
logger.error('Error decoding and verifying state:', error);
return null;
}
}