@towns-protocol/sdk
Version:
For more details, visit the following resources:
353 lines • 14.4 kB
JavaScript
import { BaseDecryptionExtensions, DecryptionStatus, makeSessionKeys, } from '../../decryptionExtensions';
import { SessionKeysSchema, UserInboxPayload_GroupEncryptionSessionsSchema, } from '@towns-protocol/proto';
import { GroupEncryptionAlgorithmId, CryptoStore, GroupEncryptionCrypto, } from '@towns-protocol/encryption';
import { bin_fromHexString, bin_toHexString, dlog, shortenHexString } from '@towns-protocol/dlog';
import EventEmitter from 'events';
import { customAlphabet } from 'nanoid';
import { create, toJsonString } from '@bufbuild/protobuf';
const log = dlog('test:decryptionExtensions:');
describe.concurrent('TestDecryptionExtensions', () => {
// test should iterate over all the algorithms
it.each(Object.values(GroupEncryptionAlgorithmId))('should be able to make key solicitation request', async (algorithm) => {
// arrange
const clientDiscoveryService = {};
const streamId = genStreamId();
const alice = genUserId('Alice');
const aliceUserAddress = stringToArray(alice);
const bob = genUserId('Bob');
const bobsPlaintext = "bob's plaintext";
const { client: aliceClient, decryptionExtension: aliceDex } = await createCryptoMocks(alice, clientDiscoveryService);
const { crypto: bobCrypto, decryptionExtension: bobDex } = await createCryptoMocks(bob, clientDiscoveryService);
// act
aliceDex.start();
// bob starts the decryption extension
bobDex.start();
// bob encrypts a message
const encryptedData = await bobCrypto.encryptGroupEvent(streamId, new TextEncoder().encode(bobsPlaintext), algorithm);
const encryptedData_V0 = await bobCrypto.encryptGroupEvent_deprecated_v0(streamId, bobsPlaintext, algorithm);
const sessionId = encryptedData.sessionId;
// alice doesn't have the session key
// alice sends a key solicitation request
const keySolicitationData = {
deviceKey: aliceDex.userDevice.deviceKey,
fallbackKey: aliceDex.userDevice.fallbackKey,
isNewDevice: true,
sessionIds: [sessionId],
};
const keySolicitation = aliceClient.sendKeySolicitation(keySolicitationData);
// pretend bob receives a key solicitation request from alice, and starts processing it.
await bobDex.handleKeySolicitationRequest(streamId, '', alice, aliceUserAddress, keySolicitationData, {
hash: new Uint8Array(),
signature: new Uint8Array(),
event: {
creatorAddress: new Uint8Array(),
delegateSig: new Uint8Array(),
delegateExpiryEpochMs: 0n,
},
});
// alice waits for the response
await keySolicitation;
// after alice gets the session key,
// try to decrypt the message
const decrypted = await aliceDex.crypto.decryptGroupEvent(streamId, encryptedData);
const decrypted_V0 = await aliceDex.crypto.decryptGroupEvent(streamId, encryptedData_V0);
if (typeof decrypted === 'string') {
throw new Error('decrypted is a string'); // v1 should be bytes
}
if (typeof decrypted_V0 !== 'string') {
throw new Error('decrypted_V0 is a string'); // v0 should be bytes
}
// stop the decryption extensions
await bobDex.stop();
await aliceDex.stop();
// assert
expect(new TextDecoder().decode(decrypted)).toBe(bobsPlaintext);
expect(decrypted_V0).toBe(bobsPlaintext);
expect(bobDex.seenStates).toContain(DecryptionStatus.working);
expect(aliceDex.seenStates).toContain(DecryptionStatus.working);
});
// test should iterate over all the algorithms
it.each(Object.values(GroupEncryptionAlgorithmId))('should be able to export/import stream room key', async (algorithm) => {
// arrange
const clientDiscoveryService = {};
const streamId = genStreamId();
const alice = genUserId('Alice');
const bob = genUserId('Bob');
const bobsPlaintext = "bob's plaintext";
const { decryptionExtension: aliceDex } = await createCryptoMocks(alice, clientDiscoveryService);
const { crypto: bobCrypto, decryptionExtension: bobDex } = await createCryptoMocks(bob, clientDiscoveryService);
// act
aliceDex.start();
// bob starts the decryption extension
bobDex.start();
// bob encrypts a message
const encryptedData = await bobCrypto.encryptGroupEvent(streamId, new TextEncoder().encode(bobsPlaintext), algorithm);
const encryptedData_V0 = await bobCrypto.encryptGroupEvent_deprecated_v0(streamId, bobsPlaintext, algorithm);
// alice doesn't have the session key
// alice imports the keys exported by bob
const roomKeys = await bobDex.crypto.exportRoomKeys();
if (roomKeys) {
await aliceDex.crypto.importRoomKeys(roomKeys);
}
// after alice gets the session key,
// try to decrypt the message
const decrypted = await aliceDex.crypto.decryptGroupEvent(streamId, encryptedData);
if (typeof decrypted === 'string') {
throw new Error('decrypted is a string'); // v1 should be bytes
}
const decrypted_V0 = await aliceDex.crypto.decryptGroupEvent(streamId, encryptedData_V0);
if (typeof decrypted_V0 !== 'string') {
throw new Error('decrypted_V0 is a string'); // v0 should be bytes
}
// stop the decryption extensions
await bobDex.stop();
await aliceDex.stop();
// assert
expect(new TextDecoder().decode(decrypted)).toBe(bobsPlaintext);
expect(decrypted_V0).toBe(bobsPlaintext);
});
});
async function createCryptoMocks(userId, clientDiscoveryService) {
const cryptoStore = new CryptoStore(`db_${userId}`, userId);
const entitlementDelegate = new MockEntitlementsDelegate();
const client = new MockGroupEncryptionClient(clientDiscoveryService);
const crypto = new GroupEncryptionCrypto(client, cryptoStore);
await crypto.init();
const userDevice = crypto.getUserDevice();
const decryptionExtension = new MockDecryptionExtensions(userId, crypto, entitlementDelegate, userDevice, client);
client.crypto = crypto;
client.decryptionExtensions = decryptionExtension;
clientDiscoveryService[userDevice.deviceKey] = client;
return {
client,
crypto,
cryptoStore,
decryptionExtension,
userDevice,
};
}
class MicroTask {
resolve;
startState;
endState;
isStarted = false;
_isCompleted = false;
constructor(resolve, startState, endState) {
this.resolve = resolve;
this.startState = startState;
this.endState = endState;
}
get isCompleted() {
return this._isCompleted;
}
tick(state) {
if (state === this.startState) {
this.isStarted = true;
}
if (this.isStarted && state === this.endState) {
this.resolve();
this._isCompleted = true;
}
}
}
class MockDecryptionExtensions extends BaseDecryptionExtensions {
inProgress = {};
client;
_upToDateStreams;
constructor(userId, crypto, entitlementDelegate, userDevice, client) {
const upToDateStreams = new Set();
const logId = shortenHexString(userId);
super(client, crypto, entitlementDelegate, userDevice, userId, upToDateStreams, logId);
this._upToDateStreams = upToDateStreams;
this.client = client;
this._onStopFn = () => {
log('onStopFn');
};
client.on('decryptionExtStatusChanged', () => {
this.statusChangedTick();
});
}
seenStates = [];
shouldPauseTicking() {
return false;
}
newGroupSessions(sessions, senderId) {
log('newGroupSessions', sessions, senderId);
const streamId = bin_toHexString(sessions.streamId);
this.markStreamUpToDate(streamId);
const p = new Promise((resolve) => {
this.inProgress[streamId] = new MicroTask(resolve, DecryptionStatus.working, DecryptionStatus.done);
// start processing the new sessions
this.enqueueNewGroupSessions(sessions, senderId);
});
return p;
}
ackNewGroupSession(session) {
log('newGroupSessionsDone', session.streamId);
return Promise.resolve();
}
async handleKeySolicitationRequest(streamId, eventHashStr, fromUserId, fromUserAddress, keySolicitation, sigBundle) {
log('keySolicitationRequest', streamId, keySolicitation);
this.markStreamUpToDate(streamId);
const p = new Promise((resolve) => {
this.inProgress[streamId] = new MicroTask(resolve, DecryptionStatus.working, DecryptionStatus.done);
// start processing the request
this.enqueueKeySolicitation(streamId, eventHashStr, fromUserId, fromUserAddress, keySolicitation, sigBundle);
});
return p;
}
hasStream(streamId) {
log('canProcessStream', streamId, true);
return this._upToDateStreams.has(streamId);
}
isValidEvent(item) {
log('isValidEvent', item);
return { isValid: true };
}
decryptGroupEvent(_streamId, _eventId, _kind, _encryptedData) {
log('decryptGroupEvent');
return Promise.resolve();
}
downloadNewMessages() {
log('downloadNewMessages');
return Promise.resolve();
}
getKeySolicitations(_streamId) {
log('getKeySolicitations');
return [];
}
isUserEntitledToKeyExchange(_streamId, _userId) {
log('isUserEntitledToKeyExchange');
return Promise.resolve(true);
}
onDecryptionError(_item, _err) {
log('onDecryptionError', 'item:', _item, 'err:', _err);
}
sendKeySolicitation(args) {
log('sendKeySolicitation', args);
return Promise.resolve();
}
sendKeyFulfillment(args) {
log('sendKeyFulfillment', args);
return Promise.resolve({});
}
encryptAndShareGroupSessions(args) {
log('encryptAndSendToGroup');
return this.client.encryptAndSendMock(args);
}
uploadDeviceKeys() {
log('uploadDeviceKeys');
return Promise.resolve();
}
isUserInboxStreamUpToDate(_upToDateStreams) {
return true;
}
getPriorityForStream(_streamId, _highPriorityIds, _recentStreamIds) {
return 0;
}
markStreamUpToDate(streamId) {
this._upToDateStreams.add(streamId);
this.setStreamUpToDate(streamId);
}
statusChangedTick() {
this.seenStates.push(this.status);
Object.values(this.inProgress).forEach((t) => {
t.tick(this.status);
});
}
}
class MockGroupEncryptionClient extends EventEmitter {
clientDiscoveryService;
shareKeysResponses = {};
constructor(clientDiscoveryService) {
super();
this.clientDiscoveryService = clientDiscoveryService;
}
crypto;
decryptionExtensions;
get userDevice() {
return this.crypto ? this.crypto.getUserDevice() : undefined;
}
downloadUserDeviceInfo(_userIds, _forceDownload) {
return Promise.resolve({});
}
encryptAndShareGroupSessions(_streamId, _sessions, _devicesInRoom) {
return Promise.resolve();
}
getDevicesInStream(_streamId) {
return Promise.resolve({});
}
getMiniblockInfo(_streamId) {
return Promise.resolve({ miniblockNum: 0n, miniblockHash: new Uint8Array() });
}
sendKeySolicitation(args) {
// assume the request is sent
return new Promise((resolve) => {
// resolve when the response is received
this.shareKeysResponses[args.deviceKey] = resolve;
});
}
async encryptAndSendMock(args) {
const { sessions, streamId } = args;
if (!this.userDevice) {
throw new Error('no user device');
}
// prepare the common parts of the payload
const streamIdBytes = streamIdToBytes(streamId);
const sessionIds = sessions.map((s) => s.sessionId);
const payload = toJsonString(SessionKeysSchema, makeSessionKeys(sessions));
// encrypt and send the payload to each client
const otherClients = Object.values(this.clientDiscoveryService).filter((c) => c.userDevice?.deviceKey != this.userDevice?.deviceKey);
const promises = otherClients.map(async (c) => {
const cipertext = await this.crypto?.encryptWithDeviceKeys(payload, [c.userDevice]);
const groupSession = create(UserInboxPayload_GroupEncryptionSessionsSchema, {
streamId: streamIdBytes,
senderKey: this.userDevice?.deviceKey,
sessionIds: sessionIds,
ciphertexts: cipertext,
algorithm: args.algorithm,
});
// pretend sending the payload to the client
// ....
// pretend receiving the response
// trigger a new group session processing
await c.decryptionExtensions?.newGroupSessions(groupSession, this.userDevice.deviceKey);
await c.resolveGroupSessionResponse(args);
});
await Promise.all(promises);
}
resolveGroupSessionResponse(args) {
// fake receiving the response
const resolve = this.shareKeysResponses[args.item.solicitation.deviceKey];
if (resolve) {
resolve(args);
}
return Promise.resolve();
}
sendKeyFulfillment(_args) {
return Promise.resolve({});
}
uploadDeviceKeys() {
return Promise.resolve();
}
}
class MockEntitlementsDelegate {
isEntitled(_spaceId, _channelId, _user, _permission) {
return Promise.resolve(true);
}
}
function genUserId(name) {
return `0x${name}${Date.now()}`;
}
function genStreamId() {
const hexNanoId = customAlphabet('0123456789abcdef', 64);
return hexNanoId();
}
function stringToArray(fromString) {
const uint8Array = new TextEncoder().encode(fromString);
return uint8Array;
}
function streamIdToBytes(streamId) {
return bin_fromHexString(streamId);
}
//# sourceMappingURL=decryptionExtensions.test.js.map