UNPKG

@chainsafe/libp2p-yamux

Version:
610 lines (512 loc) 17.8 kB
import { InvalidParametersError, MuxerClosedError, TooManyOutboundProtocolStreamsError, serviceCapabilities, setMaxListeners } from '@libp2p/interface' import { getIterator } from 'get-iterator' import { pushable } from 'it-pushable' import { Uint8ArrayList } from 'uint8arraylist' import { defaultConfig, verifyConfig } from './config.js' import { PROTOCOL_ERRORS } from './constants.js' import { Decoder } from './decode.js' import { encodeHeader } from './encode.js' import { InvalidFrameError, NotMatchingPingError, UnrequestedPingError } from './errors.js' import { Flag, FrameType, GoAwayCode } from './frame.js' import { StreamState, YamuxStream } from './stream.js' import type { Config } from './config.js' import type { FrameHeader } from './frame.js' import type { YamuxMuxerComponents } from './index.js' import type { AbortOptions, ComponentLogger, Logger, Stream, StreamMuxer, StreamMuxerFactory, StreamMuxerInit } from '@libp2p/interface' import type { Pushable } from 'it-pushable' import type { Sink, Source } from 'it-stream-types' const YAMUX_PROTOCOL_ID = '/yamux/1.0.0' const CLOSE_TIMEOUT = 500 export interface YamuxMuxerInit extends StreamMuxerInit, Partial<Config> { } export class Yamux implements StreamMuxerFactory { protocol = YAMUX_PROTOCOL_ID private readonly _components: YamuxMuxerComponents private readonly _init: YamuxMuxerInit constructor (components: YamuxMuxerComponents, init: YamuxMuxerInit = {}) { this._components = components this._init = init } readonly [Symbol.toStringTag] = '@chainsafe/libp2p-yamux' readonly [serviceCapabilities]: string[] = [ '@libp2p/stream-multiplexing' ] createStreamMuxer (init?: YamuxMuxerInit): YamuxMuxer { return new YamuxMuxer(this._components, { ...this._init, ...init }) } } export interface CloseOptions extends AbortOptions { reason?: GoAwayCode } export class YamuxMuxer implements StreamMuxer { protocol = YAMUX_PROTOCOL_ID source: Pushable<Uint8ArrayList | Uint8Array> sink: Sink<Source<Uint8ArrayList | Uint8Array>, Promise<void>> private readonly config: Config private readonly log?: Logger private readonly logger: ComponentLogger /** Used to close the muxer from either the sink or source */ private readonly closeController: AbortController /** The next stream id to be used when initiating a new stream */ private nextStreamID: number /** Primary stream mapping, streamID => stream */ private readonly _streams: Map<number, YamuxStream> /** The next ping id to be used when pinging */ private nextPingID: number /** Tracking info for the currently active ping */ private activePing?: { id: number, promise: Promise<void>, resolve(): void } /** Round trip time */ private rtt: number /** True if client, false if server */ private readonly client: boolean private localGoAway?: GoAwayCode private remoteGoAway?: GoAwayCode /** Number of tracked inbound streams */ private numInboundStreams: number /** Number of tracked outbound streams */ private numOutboundStreams: number private readonly onIncomingStream?: (stream: Stream) => void private readonly onStreamEnd?: (stream: Stream) => void constructor (components: YamuxMuxerComponents, init: YamuxMuxerInit) { this.client = init.direction === 'outbound' this.config = { ...defaultConfig, ...init } this.logger = components.logger this.log = this.logger.forComponent('libp2p:yamux') verifyConfig(this.config) this.closeController = new AbortController() setMaxListeners(Infinity, this.closeController.signal) this.onIncomingStream = init.onIncomingStream this.onStreamEnd = init.onStreamEnd this._streams = new Map() this.source = pushable({ onEnd: (): void => { this.log?.trace('muxer source ended') this._streams.forEach(stream => { stream.destroy() }) } }) this.sink = async (source: Source<Uint8ArrayList | Uint8Array>): Promise<void> => { const shutDownListener = (): void => { const iterator = getIterator(source) if (iterator.return != null) { const res = iterator.return() if (isPromise(res)) { res.catch(err => { this.log?.('could not cause sink source to return', err) }) } } } let reason, error try { const decoder = new Decoder(source) try { this.closeController.signal.addEventListener('abort', shutDownListener) for await (const frame of decoder.emitFrames()) { await this.handleFrame(frame.header, frame.readData) } } finally { this.closeController.signal.removeEventListener('abort', shutDownListener) } reason = GoAwayCode.NormalTermination } catch (err: any) { // either a protocol or internal error if (PROTOCOL_ERRORS.has(err.name)) { this.log?.error('protocol error in sink', err) reason = GoAwayCode.ProtocolError } else { this.log?.error('internal error in sink', err) reason = GoAwayCode.InternalError } error = err as Error } this.log?.trace('muxer sink ended') if (error != null) { this.abort(error, reason) } else { await this.close({ reason }) } } this.numInboundStreams = 0 this.numOutboundStreams = 0 // client uses odd streamIDs, server uses even streamIDs this.nextStreamID = this.client ? 1 : 2 this.nextPingID = 0 this.rtt = -1 this.log?.trace('muxer created') if (this.config.enableKeepAlive) { this.keepAliveLoop().catch(e => this.log?.error('keepalive error: %s', e)) } // send an initial ping to establish RTT this.ping().catch(e => this.log?.error('ping error: %s', e)) } get streams (): YamuxStream[] { return Array.from(this._streams.values()) } newStream (name?: string | undefined): YamuxStream { if (this.remoteGoAway !== undefined) { throw new MuxerClosedError('Muxer closed remotely') } if (this.localGoAway !== undefined) { throw new MuxerClosedError('Muxer closed locally') } const id = this.nextStreamID this.nextStreamID += 2 // check against our configured maximum number of outbound streams if (this.numOutboundStreams >= this.config.maxOutboundStreams) { throw new TooManyOutboundProtocolStreamsError('max outbound streams exceeded') } this.log?.trace('new outgoing stream id=%s', id) const stream = this._newStream(id, name, StreamState.Init, 'outbound') this._streams.set(id, stream) this.numOutboundStreams++ // send a window update to open the stream on the receiver end stream.sendWindowUpdate() return stream } /** * Initiate a ping and wait for a response * * Note: only a single ping will be initiated at a time. * If a ping is already in progress, a new ping will not be initiated. * * @returns the round-trip-time in milliseconds */ async ping (): Promise<number> { if (this.remoteGoAway !== undefined) { throw new MuxerClosedError('Muxer closed remotely') } if (this.localGoAway !== undefined) { throw new MuxerClosedError('Muxer closed locally') } // An active ping does not yet exist, handle the process here if (this.activePing === undefined) { // create active ping let _resolve = (): void => {} this.activePing = { id: this.nextPingID++, // this promise awaits resolution or the close controller aborting promise: new Promise<void>((resolve, reject) => { const closed = (): void => { reject(new MuxerClosedError('Muxer closed locally')) } this.closeController.signal.addEventListener('abort', closed, { once: true }) _resolve = (): void => { this.closeController.signal.removeEventListener('abort', closed) resolve() } }), resolve: _resolve } // send ping const start = Date.now() this.sendPing(this.activePing.id) // await pong try { await this.activePing.promise } finally { // clean-up active ping delete this.activePing } // update rtt const end = Date.now() this.rtt = end - start } else { // an active ping is already in progress, piggyback off that await this.activePing.promise } return this.rtt } /** * Get the ping round trip time * * Note: Will return 0 if no successful ping has yet been completed * * @returns the round-trip-time in milliseconds */ getRTT (): number { return this.rtt } /** * Close the muxer */ async close (options: CloseOptions = {}): Promise<void> { if (this.closeController.signal.aborted) { // already closed return } const reason = options?.reason ?? GoAwayCode.NormalTermination this.log?.trace('muxer close reason=%s', reason) if (options.signal == null) { const signal = AbortSignal.timeout(CLOSE_TIMEOUT) setMaxListeners(Infinity, signal) options = { ...options, signal } } try { await Promise.all( [...this._streams.values()].map(async s => s.close(options)) ) // send reason to the other side, allow the other side to close gracefully this.sendGoAway(reason) this._closeMuxer() } catch (err: any) { this.abort(err) } } abort (err: Error, reason?: GoAwayCode): void { if (this.closeController.signal.aborted) { // already closed return } reason = reason ?? GoAwayCode.InternalError // If reason was provided, use that, otherwise use the presence of `err` to determine the reason this.log?.error('muxer abort reason=%s error=%s', reason, err) // Abort all underlying streams for (const stream of this._streams.values()) { stream.abort(err) } // send reason to the other side, allow the other side to close gracefully this.sendGoAway(reason) this._closeMuxer() } isClosed (): boolean { return this.closeController.signal.aborted } /** * Called when either the local or remote shuts down the muxer */ private _closeMuxer (): void { // stop the sink and any other processes this.closeController.abort() // stop the source this.source.end() } /** Create a new stream */ private _newStream (id: number, name: string | undefined, state: StreamState, direction: 'inbound' | 'outbound'): YamuxStream { if (this._streams.get(id) != null) { throw new InvalidParametersError('Stream already exists with that id') } const stream = new YamuxStream({ id: id.toString(), name, state, direction, sendFrame: this.sendFrame.bind(this), onEnd: () => { this.closeStream(id) this.onStreamEnd?.(stream) }, log: this.logger.forComponent(`libp2p:yamux:${direction}:${id}`), config: this.config, getRTT: this.getRTT.bind(this) }) return stream } /** * closeStream is used to close a stream once both sides have * issued a close. */ private closeStream (id: number): void { if (this.client === (id % 2 === 0)) { this.numInboundStreams-- } else { this.numOutboundStreams-- } this._streams.delete(id) } private async keepAliveLoop (): Promise<void> { const abortPromise = new Promise((_resolve, reject) => { this.closeController.signal.addEventListener('abort', reject, { once: true }) }) this.log?.trace('muxer keepalive enabled interval=%s', this.config.keepAliveInterval) while (true) { let timeoutId try { await Promise.race([ abortPromise, new Promise((resolve) => { timeoutId = setTimeout(resolve, this.config.keepAliveInterval) }) ]) this.ping().catch(e => this.log?.error('ping error: %s', e)) } catch (e) { // closed clearInterval(timeoutId) return } } } private async handleFrame (header: FrameHeader, readData?: () => Promise<Uint8ArrayList>): Promise<void> { const { streamID, type, length } = header this.log?.trace('received frame %o', header) if (streamID === 0) { switch (type) { case FrameType.Ping: { this.handlePing(header); return } case FrameType.GoAway: { this.handleGoAway(length); return } default: // Invalid state throw new InvalidFrameError('Invalid frame type') } } else { switch (header.type) { case FrameType.Data: case FrameType.WindowUpdate: { await this.handleStreamMessage(header, readData); return } default: // Invalid state throw new InvalidFrameError('Invalid frame type') } } } private handlePing (header: FrameHeader): void { // If the ping is initiated by the sender, send a response if (header.flag === Flag.SYN) { this.log?.trace('received ping request pingId=%s', header.length) this.sendPing(header.length, Flag.ACK) } else if (header.flag === Flag.ACK) { this.log?.trace('received ping response pingId=%s', header.length) this.handlePingResponse(header.length) } else { // Invalid state throw new InvalidFrameError('Invalid frame flag') } } private handlePingResponse (pingId: number): void { if (this.activePing === undefined) { // this ping was not requested throw new UnrequestedPingError('ping not requested') } if (this.activePing.id !== pingId) { // this ping doesn't match our active ping request throw new NotMatchingPingError('ping doesn\'t match our id') } // valid ping response this.activePing.resolve() } private handleGoAway (reason: GoAwayCode): void { this.log?.trace('received GoAway reason=%s', GoAwayCode[reason] ?? 'unknown') this.remoteGoAway = reason // If the other side is friendly, they would have already closed all streams before sending a GoAway // In case they weren't, reset all streams for (const stream of this._streams.values()) { stream.reset() } this._closeMuxer() } private async handleStreamMessage (header: FrameHeader, readData?: () => Promise<Uint8ArrayList>): Promise<void> { const { streamID, flag, type } = header if ((flag & Flag.SYN) === Flag.SYN) { this.incomingStream(streamID) } const stream = this._streams.get(streamID) if (stream === undefined) { if (type === FrameType.Data) { this.log?.('discarding data for stream id=%s', streamID) if (readData === undefined) { throw new Error('unreachable') } await readData() } else { this.log?.trace('frame for missing stream id=%s', streamID) } return } switch (type) { case FrameType.WindowUpdate: { stream.handleWindowUpdate(header); return } case FrameType.Data: { if (readData === undefined) { throw new Error('unreachable') } await stream.handleData(header, readData); return } default: throw new Error('unreachable') } } private incomingStream (id: number): void { if (this.client !== (id % 2 === 0)) { throw new InvalidParametersError('Both endpoints are clients') } if (this._streams.has(id)) { return } this.log?.trace('new incoming stream id=%s', id) if (this.localGoAway !== undefined) { // reject (reset) immediately if we are doing a go away this.sendFrame({ type: FrameType.WindowUpdate, flag: Flag.RST, streamID: id, length: 0 }); return } // check against our configured maximum number of inbound streams if (this.numInboundStreams >= this.config.maxInboundStreams) { this.log?.('maxIncomingStreams exceeded, forcing stream reset') this.sendFrame({ type: FrameType.WindowUpdate, flag: Flag.RST, streamID: id, length: 0 }); return } // allocate a new stream const stream = this._newStream(id, undefined, StreamState.SYNReceived, 'inbound') this.numInboundStreams++ // the stream should now be tracked this._streams.set(id, stream) this.onIncomingStream?.(stream) } private sendFrame (header: FrameHeader, data?: Uint8ArrayList): void { this.log?.trace('sending frame %o', header) if (header.type === FrameType.Data) { if (data === undefined) { throw new InvalidFrameError('Invalid frame') } this.source.push( new Uint8ArrayList(encodeHeader(header), data) ) } else { this.source.push(encodeHeader(header)) } } private sendPing (pingId: number, flag: Flag = Flag.SYN): void { if (flag === Flag.SYN) { this.log?.trace('sending ping request pingId=%s', pingId) } else { this.log?.trace('sending ping response pingId=%s', pingId) } this.sendFrame({ type: FrameType.Ping, flag, streamID: 0, length: pingId }) } private sendGoAway (reason: GoAwayCode = GoAwayCode.NormalTermination): void { this.log?.('sending GoAway reason=%s', GoAwayCode[reason]) this.localGoAway = reason this.sendFrame({ type: FrameType.GoAway, flag: 0, streamID: 0, length: reason }) } } function isPromise <T = unknown> (thing: any): thing is Promise<T> { return thing != null && typeof thing.then === 'function' }