UNPKG

@vercel/sqs-consumer

Version:

Build SQS-based Node applications without the boilerplate

455 lines (402 loc) 14.2 kB
import { AWSError } from 'aws-sdk'; import * as SQS from 'aws-sdk/clients/sqs'; import { PromiseResult } from 'aws-sdk/lib/request'; import * as Debug from 'debug'; import { EventEmitter } from 'events'; import { autoBind } from './bind'; import { SQSError, TimeoutError } from './errors'; const debug = Debug('sqs-consumer'); type ReceieveMessageResponse = PromiseResult<SQS.Types.ReceiveMessageResult, AWSError>; type ReceiveMessageRequest = SQS.Types.ReceiveMessageRequest; export type SQSMessage = SQS.Types.Message; const requiredOptions = [ 'queueUrl', // only one of handleMessage / handleMessagesBatch is required 'handleMessage|handleMessageBatch' ]; interface TimeoutResponse { timeout: NodeJS.Timeout; pending: Promise<void>; } function createTimeout(duration: number): TimeoutResponse[] { let timeout; const pending = new Promise((_, reject) => { timeout = setTimeout((): void => { reject(new TimeoutError()); }, duration); }); return [timeout, pending]; } function assertOptions(options: ConsumerOptions): void { requiredOptions.forEach((option) => { const possibilities = option.split('|'); if (!possibilities.find((p) => options[p])) { throw new Error(`Missing SQS consumer option [ ${possibilities.join(' or ')} ].`); } }); if (options.batchSize > 10 || options.batchSize < 1) { throw new Error('SQS batchSize option must be between 1 and 10.'); } if (options.heartbeatInterval && !(options.heartbeatInterval < options.visibilityTimeout)) { throw new Error('heartbeatInterval must be less than visibilityTimeout.'); } } function isConnectionError(err: Error): boolean { if (err instanceof SQSError) { return (err.statusCode === 403 || err.code === 'CredentialsError' || err.code === 'UnknownEndpoint'); } return false; } function toSQSError(err: AWSError, message: string): SQSError { const sqsError = new SQSError(message); sqsError.code = err.code; sqsError.statusCode = err.statusCode; sqsError.region = err.region; sqsError.retryable = err.retryable; sqsError.hostname = err.hostname; sqsError.time = err.time; return sqsError; } function hasMessages(response: ReceieveMessageResponse): boolean { return response.Messages && response.Messages.length > 0; } export interface ConsumerOptions { queueUrl?: string; attributeNames?: string[]; messageAttributeNames?: string[]; stopped?: boolean; batchSize?: number; visibilityTimeout?: number; waitTimeSeconds?: number; authenticationErrorTimeout?: number; pollingWaitTimeMs?: number; terminateVisibilityTimeout?: boolean | ((message: SQSMessage) => number); heartbeatInterval?: number; sqs?: SQS; region?: string; handleMessageTimeout?: number; handleMessage?(message: SQSMessage): Promise<void>; handleMessageBatch?(messages: SQSMessage[]): Promise<void>; /** * An `async` function (or function that returns a `Promise`) to be called right * before the SQS Client sends a receive message command. * * This function is usefull if SQS Client module exports have been modified, for * example to add middlewares. */ preReceiveMessageCallback?(): Promise<void>; /** * An `async` function (or function that returns a `Promise`) to be called right * after the SQS Client sends a receive message command. * * This function is usefull if SQS Client module exports have been modified, for * example to add middlewares. */ postReceiveMessageCallback?(): Promise<void>; } interface Events { 'response_processed': []; 'empty': []; 'message_received': [SQSMessage]; 'message_processed': [SQSMessage]; 'error': [Error, void | SQSMessage | SQSMessage[]]; 'timeout_error': [Error, SQSMessage]; 'processing_error': [Error, SQSMessage]; 'stopped': []; } export class Consumer extends EventEmitter { private queueUrl: string; private handleMessage: (message: SQSMessage) => Promise<void>; private handleMessageBatch: (message: SQSMessage[]) => Promise<void>; private handleMessageTimeout: number; private attributeNames: string[]; private messageAttributeNames: string[]; private stopped: boolean; private batchSize: number; private visibilityTimeout: number; private waitTimeSeconds: number; private authenticationErrorTimeout: number; private pollingWaitTimeMs: number; private terminateVisibilityTimeout: boolean | ((message: SQSMessage) => number); private heartbeatInterval: number; private sqs: SQS; private preReceiveMessageCallback?: () => Promise<void>; private postReceiveMessageCallback?: () => Promise<void>; constructor(options: ConsumerOptions) { super(); assertOptions(options); this.queueUrl = options.queueUrl; this.handleMessage = options.handleMessage; this.handleMessageBatch = options.handleMessageBatch; this.handleMessageTimeout = options.handleMessageTimeout; this.attributeNames = options.attributeNames || []; this.messageAttributeNames = options.messageAttributeNames || []; this.stopped = true; this.batchSize = options.batchSize || 1; this.visibilityTimeout = options.visibilityTimeout; this.terminateVisibilityTimeout = options.terminateVisibilityTimeout || false; this.heartbeatInterval = options.heartbeatInterval; this.waitTimeSeconds = options.waitTimeSeconds || 20; this.authenticationErrorTimeout = options.authenticationErrorTimeout || 10000; this.pollingWaitTimeMs = options.pollingWaitTimeMs || 0; this.sqs = options.sqs || new SQS({ region: options.region || process.env.AWS_REGION || 'eu-west-1' }); this.preReceiveMessageCallback = options.preReceiveMessageCallback; this.postReceiveMessageCallback = options.postReceiveMessageCallback; autoBind(this); } emit<T extends keyof Events>(event: T, ...args: Events[T]) { return super.emit(event, ...args); } on<T extends keyof Events>(event: T, listener: (...args: Events[T]) => void): this { return super.on(event, listener); } once<T extends keyof Events>(event: T, listener: (...args: Events[T]) => void): this { return super.once(event, listener); } public get isRunning(): boolean { return !this.stopped; } public static create(options: ConsumerOptions): Consumer { return new Consumer(options); } public start(): void { if (this.stopped) { debug('Starting consumer'); this.stopped = false; this.poll(); } } public stop(): void { debug('Stopping consumer'); this.stopped = true; } private async handleSqsResponse(response: ReceieveMessageResponse): Promise<void> { debug('Received SQS response'); debug(response); if (response) { if (hasMessages(response)) { if (this.handleMessageBatch) { // prefer handling messages in batch when available await this.processMessageBatch(response.Messages); } else { await Promise.all(response.Messages.map(this.processMessage)); } this.emit('response_processed'); } else { this.emit('empty'); } } } private async processMessage(message: SQSMessage): Promise<void> { this.emit('message_received', message); let heartbeat; try { if (this.heartbeatInterval) { heartbeat = this.startHeartbeat(async (elapsedSeconds) => { return this.changeVisibilityTimeout(message, elapsedSeconds + this.visibilityTimeout); }); } await this.executeHandler(message); await this.deleteMessage(message); this.emit('message_processed', message); } catch (err) { this.emitError(err, message); if (this.terminateVisibilityTimeout) { if (typeof this.terminateVisibilityTimeout === 'function') { const visibilityTimeout = this.terminateVisibilityTimeout(message); await this.changeVisibilityTimeout(message, visibilityTimeout); } else { await this.changeVisibilityTimeout(message, 0); } } } finally { clearInterval(heartbeat); } } private async receiveMessage(params: ReceiveMessageRequest): Promise<ReceieveMessageResponse> { try { if (this.preReceiveMessageCallback) { await this.preReceiveMessageCallback(); } const result = await this.sqs .receiveMessage(params) .promise(); if (this.postReceiveMessageCallback) { await this.postReceiveMessageCallback(); } return result; } catch (err) { throw toSQSError(err, `SQS receive message failed: ${err.message}`); } } private async deleteMessage(message: SQSMessage): Promise<void> { debug('Deleting message %s', message.MessageId); const deleteParams = { QueueUrl: this.queueUrl, ReceiptHandle: message.ReceiptHandle }; try { await this.sqs .deleteMessage(deleteParams) .promise(); } catch (err) { throw toSQSError(err, `SQS delete message failed: ${err.message}`); } } private async executeHandler(message: SQSMessage): Promise<void> { let timeout; let pending; try { if (this.handleMessageTimeout) { [timeout, pending] = createTimeout(this.handleMessageTimeout); await Promise.race([ this.handleMessage(message), pending ]); } else { await this.handleMessage(message); } } catch (err) { if (err instanceof TimeoutError) { err.message = `Message handler timed out after ${this.handleMessageTimeout}ms: Operation timed out.`; } else { err.message = `Unexpected message handler failure: ${err.message}`; } throw err; } finally { clearTimeout(timeout); } } private async changeVisibilityTimeout(message: SQSMessage, timeout: number): Promise<PromiseResult<any, AWSError>> { try { return this.sqs .changeMessageVisibility({ QueueUrl: this.queueUrl, ReceiptHandle: message.ReceiptHandle, VisibilityTimeout: timeout }) .promise(); } catch (err) { this.emit('error', err, message); } } private emitError(err: Error, message: SQSMessage): void { if (err.name === SQSError.name) { this.emit('error', err, message); } else if (err instanceof TimeoutError) { this.emit('timeout_error', err, message); } else { this.emit('processing_error', err, message); } } private poll(): void { if (this.stopped) { this.emit('stopped'); return; } debug('Polling for messages'); const receiveParams = { QueueUrl: this.queueUrl, AttributeNames: this.attributeNames, MessageAttributeNames: this.messageAttributeNames, MaxNumberOfMessages: this.batchSize, WaitTimeSeconds: this.waitTimeSeconds, VisibilityTimeout: this.visibilityTimeout }; let currentPollingTimeout = this.pollingWaitTimeMs; this.receiveMessage(receiveParams) .then(this.handleSqsResponse) .catch((err) => { this.emit('error', err); if (isConnectionError(err)) { debug('There was an authentication error. Pausing before retrying.'); currentPollingTimeout = this.authenticationErrorTimeout; } return; }).then(() => { setTimeout(this.poll, currentPollingTimeout); }).catch((err) => { this.emit('error', err); }); } private async processMessageBatch(messages: SQSMessage[]): Promise<void> { messages.forEach((message) => { this.emit('message_received', message); }); let heartbeat; try { if (this.heartbeatInterval) { heartbeat = this.startHeartbeat(async (elapsedSeconds) => { return this.changeVisibilityTimeoutBatch(messages, () => elapsedSeconds + this.visibilityTimeout); }); } await this.executeBatchHandler(messages); await this.deleteMessageBatch(messages); messages.forEach((message) => { this.emit('message_processed', message); }); } catch (err) { this.emit('error', err, messages); if (this.terminateVisibilityTimeout) { if (typeof this.terminateVisibilityTimeout === 'function') { await this.changeVisibilityTimeoutBatch(messages, this.terminateVisibilityTimeout); } else { await this.changeVisibilityTimeoutBatch(messages, () => 0); } } } finally { clearInterval(heartbeat); } } private async deleteMessageBatch(messages: SQSMessage[]): Promise<void> { debug('Deleting messages %s', messages.map((msg) => msg.MessageId).join(' ,')); const deleteParams = { QueueUrl: this.queueUrl, Entries: messages.map((message) => ({ Id: message.MessageId, ReceiptHandle: message.ReceiptHandle })) }; try { await this.sqs .deleteMessageBatch(deleteParams) .promise(); } catch (err) { throw toSQSError(err, `SQS delete message failed: ${err.message}`); } } private async executeBatchHandler(messages: SQSMessage[]): Promise<void> { try { await this.handleMessageBatch(messages); } catch (err) { err.message = `Unexpected message handler failure: ${err.message}`; throw err; } } private async changeVisibilityTimeoutBatch(messages: SQSMessage[], getTimeout: (message: SQSMessage) => number): Promise<PromiseResult<any, AWSError>> { const params = { QueueUrl: this.queueUrl, Entries: messages.map((message) => ({ Id: message.MessageId, ReceiptHandle: message.ReceiptHandle, VisibilityTimeout: getTimeout(message) })) }; try { return this.sqs .changeMessageVisibilityBatch(params) .promise(); } catch (err) { this.emit('error', err, messages); } } private startHeartbeat(heartbeatFn: (elapsedSeconds: number) => void): NodeJS.Timeout { const startTime = Date.now(); return setInterval(() => { const elapsedSeconds = Math.ceil((Date.now() - startTime) / 1000); heartbeatFn(elapsedSeconds); }, this.heartbeatInterval * 1000); } }