wasmrs-js
Version:
A JavaScript implementation of the RSocket protocol over WebAssembly.
331 lines (330 loc) • 11.7 kB
JavaScript
import { debug } from './debug.js';
import { HostCallNotImplementedError } from './errors.js';
import { GuestProtocolMethods, HostProtocolMethods, } from './protocol.js';
import { fromU16Bytes, fromU24Bytes, fromU32Bytes, toU24Bytes, toU32Bytes, } from './utils.js';
class ModuleState {
guestRequest;
guestResponse;
hostResponse;
guestError;
hostError;
hostCallback;
writer;
constructor(hostCall, writer) {
this.hostCallback =
hostCall ||
((binding, namespace, operation) => {
throw new HostCallNotImplementedError(binding, namespace, operation);
});
this.writer = writer || (() => undefined);
}
}
let WASI = undefined;
export class WasmRsModule {
module;
constructor(module) {
this.module = module;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
static from(any) {
if (any instanceof WasmRsModule) {
return any;
}
if (any instanceof WebAssembly.Module) {
return new WasmRsModule(any);
}
if ('module' in any && any.module instanceof WebAssembly.Module) {
return new WasmRsModule(any.module);
}
throw new Error(`cannot convert ${any} to WasmRsModule`);
}
static async compile(source) {
const mod = WebAssembly.compile(source);
return new WasmRsModule(await mod);
}
static async compileStreaming(source) {
if (!WebAssembly.compileStreaming) {
console.warn('WebAssembly.compileStreaming is not supported on this browser, wasm execution will be impacted.');
const bytes = new Uint8Array(await (await source).arrayBuffer());
return WasmRsModule.compile(bytes);
}
const mod = WebAssembly.compileStreaming(source);
return new WasmRsModule(await mod);
}
async instantiate(options = {}) {
const host = new WasmRsInstance(options);
let wasi = undefined;
if (options.wasi) {
if (!WASI) {
throw new Error('Wasi options provided but no WASI implementation found');
}
wasi = await WASI.create(options.wasi);
}
const imports = linkImports(host, wasi);
debug('instantiating wasm module');
const instance = await WebAssembly.instantiate(this.module, imports);
if (wasi) {
wasi.initialize(instance);
}
await host.initialize(instance);
return host;
}
}
export class WasmRsInstance extends EventTarget {
guestBufferStart = 0;
hostBufferStart = 0;
state;
guestSend;
guestOpListRequest;
textEncoder;
textDecoder;
instance;
operations = new OperationList([], []);
constructor(options = {}) {
super();
this.state = new ModuleState(options.hostCall, options.writer);
this.textEncoder = new TextEncoder();
this.textDecoder = new TextDecoder('utf-8');
this.guestSend = () => undefined;
this.guestOpListRequest = () => undefined;
}
static setWasi(wasi) {
WASI = wasi;
}
initialize(instance) {
this.instance = instance;
const start = this.instance.exports[GuestProtocolMethods.START];
if (start != null) {
debug(`>>>`, `${GuestProtocolMethods.START}()`);
start([]);
}
const init = this.getProtocolExport(GuestProtocolMethods.INIT);
const size = 512 * 1024;
debug(`>>>`, `${GuestProtocolMethods.INIT}(${size},${size},${size})`);
init(size, size, size);
const opList = this.getProtocolExport(GuestProtocolMethods.OP_LIST_REQUEST);
if (opList != null) {
debug(`>>>`, `${GuestProtocolMethods.OP_LIST_REQUEST}()`);
opList();
}
this.guestSend = this.getProtocolExport(GuestProtocolMethods.SEND);
this.guestOpListRequest = this.getProtocolExport(GuestProtocolMethods.OP_LIST_REQUEST);
debug('initialized wasm module');
}
getProtocolExport(name) {
const fn = this.instance.exports[name];
if (fn == null) {
throw new Error(`WebAssembly module does not export ${name}`);
}
return fn;
}
send(payload) {
const memory = this.getCallerMemory();
const buffer = new Uint8Array(memory.buffer);
debug(`writing ${payload.length} bytes to guest memory buffer`, payload, this.guestBufferStart);
buffer.set(toU24Bytes(payload.length), this.guestBufferStart);
buffer.set(payload, this.guestBufferStart + 3);
debug(`>>>`, ` ${GuestProtocolMethods.SEND}(${payload.length})`);
this.guestSend(payload.length);
}
getCallerMemory() {
return this.instance.exports.memory;
}
close() {
//
}
}
function linkImports(instance, wasi) {
if (wasi) {
debug('enabling wasi');
// This looks like a broken types issue in the wasi module.
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
return {
wasi_snapshot_preview1: wasi.getImports(),
wasmrs: linkHostExports(instance),
};
}
else {
debug('disabling wasi');
return {
wasmrs: linkHostExports(instance),
};
}
}
export class FrameEvent extends Event {
payload;
constructor(type, payload) {
super(type);
this.payload = payload;
}
}
function linkHostExports(instance) {
return {
[HostProtocolMethods.CONSOLE_LOG](ptr, len) {
debug('<<< __console_log %o bytes @ %o', len, ptr);
const buffer = new Uint8Array(instance.getCallerMemory().buffer);
const bytes = buffer.slice(ptr, ptr + len);
console.log(instance.textDecoder.decode(bytes));
},
[HostProtocolMethods.INIT_BUFFERS](guestBufferPtr, hostBufferPtr) {
debug('<<< __init_buffers(%o, %o)', guestBufferPtr, hostBufferPtr);
instance.guestBufferStart = guestBufferPtr;
instance.hostBufferStart = hostBufferPtr;
},
[HostProtocolMethods.SEND](length) {
debug('<<< __send(%o)', length);
const buffer = new Uint8Array(instance.getCallerMemory().buffer);
const bytes = buffer.slice(instance.hostBufferStart, instance.hostBufferStart + length);
debug(`'frame' event: ${bytes.length} bytes`, Array.from(bytes)
.map((n) => {
if (n > 16 && n < 127) {
return String.fromCharCode(n);
}
else {
return `\\x${n.toString(16)}`;
}
})
.join(''));
let done = false;
let index = 0;
while (!done) {
const len = fromU24Bytes(bytes.slice(index, 3));
const frame = bytes.slice(index + 3, index + 3 + len);
instance.dispatchEvent(new FrameEvent('frame', frame));
index += 3 + len;
done = index >= bytes.length;
}
},
[HostProtocolMethods.OP_LIST](ptr, length) {
debug('<<< __op_list(%o,%o)', ptr, length);
const buffer = new Uint8Array(instance.getCallerMemory().buffer);
const bytes = buffer.slice(ptr, ptr + length);
if (length === 0) {
return;
}
if (bytes.slice(0, 4).toString() !== OP_MAGIC_BYTES.toString()) {
throw new Error('invalid op_list magic bytes');
}
const version = fromU16Bytes(bytes.slice(4, 6));
debug(`op_list bytes: %o`, bytes);
if (version == 1) {
const ops = decodeV1Operations(bytes.slice(6), instance.textDecoder);
debug('module operations: %o', ops);
instance.operations = ops;
}
},
};
}
function decodeV1Operations(buffer, decoder) {
const imports = [];
const exports = [];
let numOps = fromU32Bytes(buffer.slice(0, 4));
debug(`decoding %o operations`, numOps);
let index = 4;
while (numOps > 0) {
const kind = buffer[index++];
const dir = buffer[index++];
const opIndex = fromU32Bytes(buffer.slice(index, index + 4));
index += 4;
const nsLen = fromU16Bytes(buffer.slice(index, index + 2));
index += 2;
const namespace = decoder.decode(buffer.slice(index, index + nsLen));
index += nsLen;
const opLen = fromU16Bytes(buffer.slice(index, index + 2));
index += 2;
const operation = decoder.decode(buffer.slice(index, index + opLen));
index += opLen;
const reservedLen = fromU16Bytes(buffer.slice(index, index + 2));
index += 2 + reservedLen;
const op = new Operation(opIndex, kind, namespace, operation);
if (dir === 1) {
exports.push(op);
}
else {
imports.push(op);
}
numOps--;
}
return new OperationList(imports, exports);
}
export class OperationList {
imports;
exports;
constructor(imports, exports) {
this.imports = imports;
this.exports = exports;
}
getExport(namespace, operation) {
const op = this.exports.find((op) => op.namespace === namespace && op.operation === operation);
if (!op) {
throw new Error(`operation ${namespace}::${operation} not found in exports`);
}
return op;
}
getImport(namespace, operation) {
const op = this.imports.find((op) => op.namespace === namespace && op.operation === operation);
if (!op) {
throw new Error(`operation ${namespace}::${operation} not found in imports`);
}
return op;
}
}
export class Operation {
index;
kind;
namespace;
operation;
constructor(index, kind, namespace, operation) {
this.index = index;
this.kind = kind;
this.namespace = namespace;
this.operation = operation;
}
asEncoded() {
const index = toU32Bytes(this.index);
const encoded = new Uint8Array(index.length + 4);
encoded.set(index);
encoded.set(toU32Bytes(0), index.length);
return encoded;
}
}
var OperationType;
(function (OperationType) {
OperationType[OperationType["RR"] = 0] = "RR";
OperationType[OperationType["FNF"] = 1] = "FNF";
OperationType[OperationType["RS"] = 2] = "RS";
OperationType[OperationType["RC"] = 3] = "RC";
})(OperationType || (OperationType = {}));
/*
fn decode_v1(mut buf: Bytes) -> Result<Self, Error> {
let num_ops = from_u32_bytes(&buf.split_to(4));
let mut imports = Vec::new();
let mut exports = Vec::new();
for _ in 0..num_ops {
let kind = buf.split_to(1)[0];
let kind: OperationType = kind.into();
let dir = buf.split_to(1)[0];
let index = from_u32_bytes(&buf.split_to(4));
let ns_len = from_u16_bytes(&buf.split_to(2));
let namespace = String::from_utf8(buf.split_to(ns_len as _).to_vec())?;
let op_len = from_u16_bytes(&buf.split_to(2));
let operation = String::from_utf8(buf.split_to(op_len as _).to_vec())?;
let _reserved_len = from_u16_bytes(&buf.split_to(2));
let op = Operation {
index,
kind,
namespace,
operation,
};
if dir == 1 {
exports.push(op);
} else {
imports.push(op);
}
}
Ok(Self { imports, exports })
}
*/
const OP_MAGIC_BYTES = Uint8Array.from([0x00, 0x77, 0x72, 0x73]);
//# sourceMappingURL=wasmrs.js.map