UNPKG

@appium/base-driver

Version:

Base driver class for Appium drivers

508 lines (479 loc) 16.7 kB
import type { AppiumLogger, HTTPBody, HTTPHeaders, HTTPMethod, ProxyOptions, ProxyResponse, } from '@appium/types'; import _ from 'lodash'; import {logger, util} from '@appium/support'; import {getSummaryByCode} from '../jsonwp-status/status'; import { errors, isErrorType, errorFromMJSONWPStatusCode, errorFromW3CJsonCode, getResponseForW3CError, } from '../protocol/errors'; import {isSessionCommand, routeToCommandName} from '../protocol'; import {MAX_LOG_BODY_LENGTH, DEFAULT_BASE_PATH, PROTOCOLS} from '../constants'; import {ProtocolConverter} from './protocol-converter'; import {formatResponseValue, ensureW3cResponse} from '../protocol/helpers'; import http from 'node:http'; import https from 'node:https'; import {match as pathToRegexMatch} from 'path-to-regexp'; import nodeUrl from 'node:url'; import {ProxyRequest} from './proxy-request'; import type {Request, Response} from 'express'; import type {AxiosError, AxiosResponse, RawAxiosRequestConfig} from 'axios'; const DEFAULT_LOG = logger.getLogger('WD Proxy'); const DEFAULT_REQUEST_TIMEOUT = 240000; const COMMAND_WITH_SESSION_ID_MATCHER = pathToRegexMatch( '{/*prefix}/session/:sessionId{/*command}' ); const {MJSONWP, W3C} = PROTOCOLS; type Protocol = (typeof PROTOCOLS)[keyof typeof PROTOCOLS]; const ALLOWED_OPTS = [ 'scheme', 'server', 'port', 'base', 'reqBasePath', 'sessionId', 'timeout', 'log', 'keepAlive', 'headers', ] as const; export class JWProxy { readonly scheme: string; readonly server: string; readonly port: number; readonly base: string; readonly reqBasePath: string; sessionId: string | null; timeout: number; readonly headers: HTTPHeaders | undefined; readonly httpAgent: http.Agent; readonly httpsAgent: https.Agent; readonly protocolConverter: ProtocolConverter; private _downstreamProtocol: Protocol | null | undefined; private _activeRequests: ProxyRequest[]; private readonly _log: AppiumLogger | undefined; constructor(opts: ProxyOptions = {}) { const filteredOpts = _.pick(opts, ALLOWED_OPTS); const options = _.defaults(_.omit(filteredOpts, 'log'), { scheme: 'http', server: 'localhost', port: 4444, base: DEFAULT_BASE_PATH, reqBasePath: DEFAULT_BASE_PATH, sessionId: null, timeout: DEFAULT_REQUEST_TIMEOUT, }) as ProxyOptions & { scheme: string; server: string; port: number; base: string; reqBasePath: string; sessionId: string | null; timeout: number; }; options.scheme = options.scheme.toLowerCase(); Object.assign(this, options); this._activeRequests = []; this._downstreamProtocol = null; const agentOpts = { keepAlive: opts.keepAlive ?? true, maxSockets: 10, maxFreeSockets: 5, }; this.httpAgent = new http.Agent(agentOpts); this.httpsAgent = new https.Agent(agentOpts); this.protocolConverter = new ProtocolConverter(this.proxy.bind(this), opts.log); this._log = opts.log; this.log.debug(`${this.constructor.name} options: ${JSON.stringify(options)}`); } get log(): AppiumLogger { return this._log ?? DEFAULT_LOG; } /** * Gets the protocol used by the downstream server (W3C or MJSONWP). */ get downstreamProtocol(): Protocol | null | undefined { return this._downstreamProtocol; } /** * Sets the protocol used by the downstream server (W3C or MJSONWP). */ set downstreamProtocol(value: Protocol | null | undefined) { this._downstreamProtocol = value; this.protocolConverter.downstreamProtocol = value; } /** * Returns the number of active downstream HTTP requests. */ getActiveRequestsCount(): number { return this._activeRequests.length; } /** * Cancels all currently active downstream HTTP requests. */ cancelActiveRequests(): void { for (const ar of this._activeRequests) { ar.cancel(); } this._activeRequests = []; } /** * Builds a full downstream URL (including base path and session) for a given upstream URL. */ getUrlForProxy(url: string, method?: HTTPMethod): string { const parsedUrl = this._parseUrl(url); const normalizedPathname = this._toNormalizedPathname(parsedUrl); const commandName = normalizedPathname ? routeToCommandName(normalizedPathname, method) : ''; const requiresSessionId = !commandName || (commandName && isSessionCommand(commandName)); const proxyPrefix = `${this.scheme}://${this.server}:${this.port}${this.base}`; let proxySuffix = normalizedPathname ? `/${_.trimStart(normalizedPathname, '/')}` : ''; if (parsedUrl.search) { proxySuffix += parsedUrl.search; } if (!requiresSessionId) { return `${proxyPrefix}${proxySuffix}`; } if (!this.sessionId) { throw new ReferenceError( `Session ID is not set, but saw a URL that requires it (${url})` ); } return `${proxyPrefix}/session/${this.sessionId}${proxySuffix}`; } /** * Proxies a raw WebDriver command to the downstream server. */ async proxy( url: string, method: string, body: HTTPBody = null ): Promise<[ProxyResponse, HTTPBody]> { method = method.toUpperCase(); const newUrl = this.getUrlForProxy(url, method as HTTPMethod); const truncateBody = (content: unknown): string => _.truncate(_.isString(content) ? content : JSON.stringify(content), { length: MAX_LOG_BODY_LENGTH, }); const reqOpts: RawAxiosRequestConfig = { url: newUrl, method, headers: { 'content-type': 'application/json; charset=utf-8', 'user-agent': 'appium', accept: 'application/json, */*', ...(this.headers ?? {}), }, proxy: false, timeout: this.timeout, httpAgent: this.httpAgent, httpsAgent: this.httpsAgent, }; // GET methods shouldn't have any body. Most servers are OK with this, // but WebDriverAgent throws 400 errors if (util.hasValue(body) && method !== 'GET') { if (typeof body !== 'object') { try { reqOpts.data = JSON.parse(body as string); } catch (error) { this.log.warn( 'Invalid body payload (%s): %s', (error as Error).message, logger.markSensitive(truncateBody(body)) ); throw new Error( 'Cannot interpret the request body as valid JSON. Check the server log for more details.', {cause: error} ); } } else { reqOpts.data = body; } } this.log.debug( `Proxying [%s %s] to [%s %s] with ${reqOpts.data ? 'body: %s' : '%s body'}`, method, url || '/', method, newUrl, reqOpts.data ? logger.markSensitive(truncateBody(reqOpts.data)) : 'no' ); const throwProxyError = (error: unknown): never => { const err = new Error(`The request to ${url} has failed`) as Error & { response: {data: unknown; status: number}; }; err.response = { data: error, status: 500, }; throw err; }; let isResponseLogged = false; try { const {data, status, headers} = await this.request(reqOpts); // `data` might be really big // Be careful while handling it to avoid memory leaks if (!_.isPlainObject(data)) { // The response should be a valid JSON object // If it cannot be coerced to an object then the response is wrong throwProxyError(data); } this.log.debug(`Got response with status ${status}: ${truncateBody(data)}`); isResponseLogged = true; const isSessionCreationRequest = url.endsWith('/session') && method === 'POST'; if (isSessionCreationRequest) { if (status === 200) { const value = (data as Record<string, unknown>).value as | Record<string, unknown> | undefined; const raw = (data as Record<string, unknown>).sessionId ?? value?.sessionId; this.sessionId = typeof raw === 'string' ? raw : raw != null ? String(raw) : null; } this.downstreamProtocol = this.getProtocolFromResBody( data as Record<string, unknown> ) ?? this.downstreamProtocol; this.log.info(`Determined the downstream protocol as '${this.downstreamProtocol}'`); } if ( _.has(data, 'status') && parseInt((data as Record<string, unknown>).status as string, 10) !== 0 ) { throwProxyError(data); } return [ { statusCode: status, headers: headers as HTTPHeaders, body: data, }, data, ]; } catch (e: unknown) { const err = e as AxiosError<unknown> & {message: string}; let proxyErrorMsg = err.message; if (util.hasValue(err.response)) { if (!isResponseLogged) { const error = truncateBody(err.response.data); this.log.info( util.hasValue(err.response.status) ? `Got response with status ${err.response.status}: ${error}` : `Got response with unknown status: ${error}` ); } } else { proxyErrorMsg = `Could not proxy command to the remote server. Original error: ${err.message}`; this.log.info(err.message); } throw new errors.ProxyRequestError( proxyErrorMsg, err.response?.data, err.response?.status ); } } /** * Detects the downstream protocol from a response body. */ getProtocolFromResBody(resObj: Record<string, unknown>): Protocol | undefined { if (_.isInteger(resObj.status)) { return MJSONWP; } if (!_.isUndefined(resObj.value)) { return W3C; } } /** * Proxies a command identified by its HTTP method and URL to the downstream server. */ async proxyCommand( url: string, method: HTTPMethod, body: HTTPBody = null ): Promise<[ProxyResponse, HTTPBody]> { const parsedUrl = this._parseUrl(url); const normalizedPathname = this._toNormalizedPathname(parsedUrl); const commandName = normalizedPathname ? routeToCommandName(normalizedPathname, method) : ''; if (!commandName) { return await this.proxy(url, method, body); } this.log.debug(`Matched '${url}' to command name '${commandName}'`); return await this.protocolConverter.convertAndProxy(commandName, url, method, body); } /** * Executes a WebDriver command and returns the unwrapped `value` field (or throws). */ async command( url: string, method: HTTPMethod, body: HTTPBody = null ): Promise<HTTPBody> { let response: ProxyResponse; let resBodyObj: HTTPBody; try { [response, resBodyObj] = await this.proxyCommand(url, method, body); } catch (err: unknown) { if (isErrorType(err, errors.ProxyRequestError)) { throw err.getActualError(); } throw new errors.UnknownError((err as Error).message); } const resBody = resBodyObj as Record<string, unknown>; const protocol = this.getProtocolFromResBody(resBody); if (protocol === MJSONWP) { if (response.statusCode === 200 && resBody.status === 0) { return resBody.value; } const status = parseInt(resBody.status as string, 10); if (!isNaN(status) && status !== 0) { let message: unknown = resBody.value; if (_.isPlainObject(message) && _.has(message, 'message')) { message = (message as Record<string, unknown>).message; } throw errorFromMJSONWPStatusCode(status, _.isEmpty(message) ? getSummaryByCode(status) : (message as string | {message: string})); } } else if (protocol === W3C) { if (response.statusCode < 300) { return resBody.value; } if (_.isPlainObject(resBody.value) && (resBody.value as Record<string, unknown>).error) { const value = resBody.value as Record<string, unknown>; throw errorFromW3CJsonCode( value.error as string, (value.message as string) ?? '', value.stacktrace as string | undefined ); } } else if (response.statusCode === 200) { return resBodyObj; } throw new errors.UnknownError( `Did not know what to do with response code '${response.statusCode}' ` + `and response body '${_.truncate(JSON.stringify(resBodyObj), { length: 300, })}'` ); } /** * Extracts a session id from a WebDriver-style URL. */ getSessionIdFromUrl(url: string): string | null { const match = url.match(/\/session\/([^/]+)/); return match ? match[1] : null; } /** * Proxies an Express `Request`/`Response` pair to the downstream server, * converting any downstream errors into a proper W3C HTTP response. * * This method must not throw; it always writes a response. */ async proxyReqRes(req: Request, res: Response): Promise<void> { let statusCode: number; let resBodyObj: HTTPBody; try { const [response, body] = await this.proxyCommand( req.originalUrl, req.method as HTTPMethod, req.body ); statusCode = response.statusCode; resBodyObj = body; } catch (err: unknown) { [statusCode, resBodyObj] = getResponseForW3CError( isErrorType(err, errors.ProxyRequestError) ? (err as InstanceType<typeof errors.ProxyRequestError>).getActualError() : err ); } res.setHeader('content-type', 'application/json; charset=utf-8'); if (!_.isPlainObject(resBodyObj)) { const error = new errors.UnknownError( `The downstream server response with the status code ${statusCode} is not a valid JSON object: ` + _.truncate(`${resBodyObj}`, {length: 300}) ); [statusCode, resBodyObj] = getResponseForW3CError(error); } const resBody = resBodyObj as Record<string, unknown>; if (_.has(resBody, 'sessionId')) { const reqSessionId = this.getSessionIdFromUrl(req.originalUrl); if (reqSessionId) { this.log.info(`Replacing sessionId ${resBody.sessionId} with ${reqSessionId}`); resBody.sessionId = reqSessionId; } else if (this.sessionId) { this.log.info(`Replacing sessionId ${resBody.sessionId} with ${this.sessionId}`); resBody.sessionId = this.sessionId; } } resBody.value = formatResponseValue(resBody.value as object | undefined); res.status(statusCode).json(ensureW3cResponse(resBody)); } /** * Performs requests to the downstream server * * @private - Do not call this method directly, * it uses client-specific arguments and responses! */ private async request(requestConfig: RawAxiosRequestConfig): Promise<AxiosResponse> { const req = new ProxyRequest(requestConfig); this._activeRequests.push(req); try { return await req.execute(); } finally { _.pull(this._activeRequests, req); } } private _parseUrl(url: string): nodeUrl.UrlWithStringQuery { // eslint-disable-next-line n/no-deprecated-api -- we need relative URL support const parsedUrl = nodeUrl.parse(url || '/'); if ( _.isNil(parsedUrl.href) || _.isNil(parsedUrl.pathname) || (parsedUrl.protocol && !['http:', 'https:'].includes(parsedUrl.protocol)) ) { throw new Error(`Did not know how to proxy the url '${url}'`); } return parsedUrl; } private _toNormalizedPathname(parsedUrl: nodeUrl.UrlWithStringQuery): string { if (!_.isString(parsedUrl.pathname)) { return ''; } let pathname = this.reqBasePath && parsedUrl.pathname.startsWith(this.reqBasePath) ? parsedUrl.pathname.replace(this.reqBasePath, '') : parsedUrl.pathname; const match = COMMAND_WITH_SESSION_ID_MATCHER(pathname); // This is needed for the backward compatibility // if drivers don't set reqBasePath properly if (!this.reqBasePath) { if (match && match.params && _.isArray((match.params as Record<string, unknown>).prefix)) { pathname = pathname.replace( `/${((match.params as Record<string, unknown>).prefix as string[]).join('/')}`, '' ); } else if (_.startsWith(pathname, '/wd/hub')) { pathname = pathname.replace('/wd/hub', ''); } } let result = pathname; if (match && match.params) { const command = (match.params as Record<string, unknown>).command; result = _.isArray(command) ? `/${(command as string[]).join('/')}` : ''; } return _.trimEnd(result, '/'); } }