@towns-protocol/sdk
Version:
For more details, visit the following resources:
678 lines • 28.1 kB
JavaScript
import { SessionKeysSchema, } from '@towns-protocol/proto';
import { shortenHexString, dlog, dlogError, check, bin_toHexString, } from '@towns-protocol/dlog';
import { GroupEncryptionAlgorithmId, parseGroupEncryptionAlgorithmId, } from '@towns-protocol/encryption';
import { create, fromJsonString } from '@bufbuild/protobuf';
import { sortedArraysEqual } from './observable/utils';
export var DecryptionStatus;
(function (DecryptionStatus) {
DecryptionStatus["initializing"] = "initializing";
DecryptionStatus["updating"] = "updating";
DecryptionStatus["working"] = "working";
DecryptionStatus["idle"] = "idle";
DecryptionStatus["done"] = "done";
})(DecryptionStatus || (DecryptionStatus = {}));
class StreamTasks {
encryptedContent = new Array();
keySolicitations = new Array();
isMissingKeys = false;
keySolicitationsNeedsSort = false;
sortKeySolicitations() {
this.keySolicitations.sort((a, b) => a.respondAfter - b.respondAfter);
this.keySolicitationsNeedsSort = false;
}
isEmpty() {
return (this.encryptedContent.length === 0 &&
this.keySolicitations.length === 0 &&
!this.isMissingKeys);
}
}
class StreamQueues {
streams = new Map();
getStreamIds() {
return Array.from(this.streams.keys());
}
getQueue(streamId) {
let tasks = this.streams.get(streamId);
if (!tasks) {
tasks = new StreamTasks();
this.streams.set(streamId, tasks);
}
return tasks;
}
isEmpty() {
for (const tasks of this.streams.values()) {
if (!tasks.isEmpty()) {
return false;
}
}
return true;
}
toString() {
const counts = Array.from(this.streams.entries()).reduce((acc, [_, stream]) => {
acc['encryptedContent'] =
(acc['encryptedContent'] ?? 0) + stream.encryptedContent.length;
acc['streamsMissingKeys'] =
(acc['streamsMissingKeys'] ?? 0) + (stream.isMissingKeys ? 1 : 0);
acc['keySolicitations'] =
(acc['keySolicitations'] ?? 0) + stream.keySolicitations.length;
return acc;
}, {});
return Object.entries(counts)
.map(([key, count]) => `${key}: ${count}`)
.join(', ');
}
}
/**
*
* Responsibilities:
* 1. Download new to-device messages that happened while we were offline
* 2. Decrypt new to-device messages
* 3. Decrypt encrypted content
* 4. Retry decryption failures, request keys for failed decryption
* 5. Respond to key solicitations
*
*
* Notes:
* If in the future we started snapshotting the eventNum of the last message sent by every user,
* we could use that to determine the order we send out keys, and the order that we reply to key solicitations.
*
* It should be easy to introduce a priority stream, where we decrypt messages from that stream first, before
* anything else, so the messages show up quicky in the ui that the user is looking at.
*
* We need code to purge bad sessions (if someones sends us the wrong key, or a key that doesn't decrypt the message)
*/
export class BaseDecryptionExtensions {
_status = DecryptionStatus.initializing;
mainQueues = {
priorityTasks: new Array(),
newGroupSession: new Array(),
ownKeySolicitations: new Array(),
};
streamQueues = new StreamQueues();
upToDateStreams = new Set();
highPriorityIds = new Set();
recentStreamIds = [];
decryptionFailures = {}; // streamId: sessionId: EncryptedContentItem[]
inProgressTick;
timeoutId;
delayMs = 1;
started = false;
numRecentStreamIds = 5;
emitter;
_onStopFn;
log;
crypto;
entitlementDelegate;
userDevice;
userId;
constructor(emitter, crypto, entitlementDelegate, userDevice, userId, upToDateStreams, inLogId) {
this.emitter = emitter;
this.crypto = crypto;
this.entitlementDelegate = entitlementDelegate;
this.userDevice = userDevice;
this.userId = userId;
// initialize with a set of up-to-date streams
// ready for processing
this.upToDateStreams = upToDateStreams;
const shortKey = shortenHexString(userDevice.deviceKey);
const logId = `${inLogId}:${shortKey}`;
this.log = {
debug: dlog('csb:decryption:debug', { defaultEnabled: false }).extend(logId),
info: dlog('csb:decryption', { defaultEnabled: true }).extend(logId),
error: dlogError('csb:decryption:error').extend(logId),
};
this.log.debug('new DecryptionExtensions', { userDevice });
}
enqueueNewGroupSessions(sessions, _senderId) {
this.log.debug('enqueueNewGroupSessions', sessions);
const streamId = bin_toHexString(sessions.streamId);
this.mainQueues.newGroupSession.push({ streamId, sessions });
this.checkStartTicking();
}
enqueueNewEncryptedContent(streamId, eventId, kind, // kind of encrypted data
encryptedData) {
// dms, channels, gdms ("we're in the wrong package")
if (streamId.startsWith('20') || streamId.startsWith('88') || streamId.startsWith('77')) {
this.recentStreamIds.push(streamId);
if (this.recentStreamIds.length > this.numRecentStreamIds) {
this.recentStreamIds.shift();
}
}
this.streamQueues.getQueue(streamId).encryptedContent.push({
streamId,
eventId,
kind,
encryptedData,
});
this.checkStartTicking();
}
enqueueInitKeySolicitations(streamId, eventHashStr, members, sigBundle) {
const streamQueue = this.streamQueues.getQueue(streamId);
streamQueue.keySolicitations = [];
this.mainQueues.ownKeySolicitations = this.mainQueues.ownKeySolicitations.filter((x) => x.streamId !== streamId);
for (const member of members) {
const { userId: fromUserId, userAddress: fromUserAddress } = member;
for (const keySolicitation of member.solicitations) {
if (keySolicitation.deviceKey === this.userDevice.deviceKey) {
continue;
}
if (keySolicitation.sessionIds.length === 0) {
continue;
}
const selectedQueue = fromUserId === this.userId
? this.mainQueues.ownKeySolicitations
: streamQueue.keySolicitations;
selectedQueue.push({
streamId,
fromUserId,
fromUserAddress,
solicitation: keySolicitation,
respondAfter: Date.now() + this.getRespondDelayMSForKeySolicitation(streamId, fromUserId),
sigBundle,
hashStr: eventHashStr,
});
}
}
streamQueue.keySolicitationsNeedsSort = true;
this.checkStartTicking();
}
enqueueKeySolicitation(streamId, eventHashStr, fromUserId, fromUserAddress, keySolicitation, sigBundle) {
if (keySolicitation.deviceKey === this.userDevice.deviceKey) {
//this.log.debug('ignoring key solicitation for our own device')
return;
}
const streamQueue = this.streamQueues.getQueue(streamId);
const selectedQueue = fromUserId === this.userId
? this.mainQueues.ownKeySolicitations
: streamQueue.keySolicitations;
const index = selectedQueue.findIndex((x) => x.streamId === streamId && x.solicitation.deviceKey === keySolicitation.deviceKey);
if (index > -1) {
selectedQueue.splice(index, 1);
}
if (keySolicitation.sessionIds.length > 0 || keySolicitation.isNewDevice) {
//this.log.debug('new key solicitation', { fromUserId, streamId, keySolicitation })
streamQueue.keySolicitationsNeedsSort = true;
selectedQueue.push({
streamId,
fromUserId,
fromUserAddress,
solicitation: keySolicitation,
respondAfter: Date.now() + this.getRespondDelayMSForKeySolicitation(streamId, fromUserId),
sigBundle,
hashStr: eventHashStr,
});
this.checkStartTicking();
}
else if (index > -1) {
//this.log.debug('cleared key solicitation', keySolicitation)
}
}
setStreamUpToDate(streamId) {
//this.log.debug('streamUpToDate', streamId)
this.upToDateStreams.add(streamId);
this.checkStartTicking();
}
resetUpToDateStreams() {
this.upToDateStreams.clear();
this.checkStartTicking();
}
retryDecryptionFailures(streamId) {
const streamQueue = this.streamQueues.getQueue(streamId);
if (this.decryptionFailures[streamId] &&
Object.keys(this.decryptionFailures[streamId]).length > 0) {
this.log.debug('membership change, re-enqueuing decryption failures for stream', streamId);
streamQueue.isMissingKeys = true;
this.checkStartTicking();
}
}
start() {
check(!this.started, 'start() called twice, please re-instantiate instead');
this.log.debug('starting');
this.started = true;
// let the subclass override and do any custom startup tasks
this.onStart();
// enqueue a task to upload device keys
this.mainQueues.priorityTasks.push(() => this.uploadDeviceKeys());
// enqueue a task to download new to-device messages
this.enqueueNewMessageDownload();
// start the tick loop
this.checkStartTicking();
}
// enqueue a task to download new to-device messages, should be safe to call multiple times
enqueueNewMessageDownload() {
this.mainQueues.priorityTasks.push(() => this.downloadNewMessages());
}
onStart() {
// let the subclass override and do any custom startup tasks
}
async stop() {
this._onStopFn?.();
this._onStopFn = undefined;
// let the subclass override and do any custom shutdown tasks
await this.onStop();
await this.stopTicking();
}
onStop() {
// let the subclass override and do any custom shutdown tasks
return Promise.resolve();
}
get status() {
return this._status;
}
setStatus(status) {
if (this._status !== status) {
this.log.debug(`status changed ${status}`);
this._status = status;
this.emitter.emit('decryptionExtStatusChanged', status);
}
}
compareStreamIds(a, b) {
const recentStreamIds = new Set(this.recentStreamIds);
return (this.getPriorityForStream(a, this.highPriorityIds, recentStreamIds) -
this.getPriorityForStream(b, this.highPriorityIds, recentStreamIds));
}
lastPrintedAt = 0;
checkStartTicking() {
if (!this.started ||
this.timeoutId ||
!this._onStopFn ||
!this.isUserInboxStreamUpToDate(this.upToDateStreams) ||
this.shouldPauseTicking()) {
return;
}
if (!Object.values(this.mainQueues).find((q) => q.length > 0) &&
this.streamQueues.isEmpty()) {
this.setStatus(DecryptionStatus.done);
return;
}
if (Date.now() - this.lastPrintedAt > 30000) {
this.log.info(`status: ${this.status} queues: ${Object.entries(this.mainQueues)
.map(([key, q]) => `${key}: ${q.length}`)
.join(', ')} ${this.streamQueues.toString()}`);
const streamIds = Array.from(this.streamQueues.streams.entries())
.filter(([_, value]) => !value.isEmpty())
.map(([key, _]) => key)
.sort((a, b) => this.compareStreamIds(a, b));
const first4Priority = streamIds
.filter((x) => this.upToDateStreams.has(x))
.slice(0, 4)
.join(', ');
const first4Blocked = streamIds
.filter((x) => !this.upToDateStreams.has(x))
.slice(0, 4)
.join(', ');
if (first4Priority.length > 0 || first4Blocked.length > 0) {
this.log.info(`priorityTasks: ${first4Priority} waitingFor: ${first4Blocked}`);
}
this.lastPrintedAt = Date.now();
}
this.timeoutId = setTimeout(() => {
this.inProgressTick = this.tick();
this.inProgressTick
.catch((e) => this.log.error('ProcessTick Error', e))
.finally(() => {
this.timeoutId = undefined;
setTimeout(() => this.checkStartTicking());
});
}, this.getDelayMs());
}
async stopTicking() {
if (this.timeoutId) {
clearTimeout(this.timeoutId);
this.timeoutId = undefined;
}
if (this.inProgressTick) {
try {
await this.inProgressTick;
}
catch (e) {
this.log.error('ProcessTick Error while stopping', e);
}
finally {
this.inProgressTick = undefined;
}
}
}
getDelayMs() {
if (this.mainQueues.newGroupSession.length > 0) {
return 0;
}
else {
return this.delayMs;
}
}
// just do one thing then return
tick() {
const now = Date.now();
const priorityTask = this.mainQueues.priorityTasks.shift();
if (priorityTask) {
this.setStatus(DecryptionStatus.updating);
return priorityTask();
}
// update any new group sessions
const session = this.mainQueues.newGroupSession.shift();
if (session) {
this.setStatus(DecryptionStatus.working);
return this.processNewGroupSession(session);
}
const ownSolicitation = this.mainQueues.ownKeySolicitations.shift();
if (ownSolicitation) {
this.log.debug(' processing own key solicitation');
this.setStatus(DecryptionStatus.working);
return this.processKeySolicitation(ownSolicitation);
}
const streamIds = this.streamQueues.getStreamIds();
streamIds.sort((a, b) => this.compareStreamIds(a, b));
for (const streamId of streamIds) {
if (!this.upToDateStreams.has(streamId)) {
continue;
}
const streamQueue = this.streamQueues.getQueue(streamId);
const encryptedContent = streamQueue.encryptedContent.shift();
if (encryptedContent) {
this.setStatus(DecryptionStatus.working);
return this.processEncryptedContentItem(encryptedContent);
}
if (streamQueue.isMissingKeys) {
this.setStatus(DecryptionStatus.working);
streamQueue.isMissingKeys = false;
return this.processMissingKeys(streamId);
}
if (streamQueue.keySolicitationsNeedsSort) {
streamQueue.sortKeySolicitations();
}
const keySolicitation = dequeueUpToDate(streamQueue.keySolicitations, now, (x) => x.respondAfter, this.upToDateStreams);
if (keySolicitation) {
this.setStatus(DecryptionStatus.working);
return this.processKeySolicitation(keySolicitation);
}
}
this.setStatus(DecryptionStatus.idle);
return Promise.resolve();
}
/**
* processNewGroupSession
* process new group sessions that were sent to our to device stream inbox
* re-enqueue any decryption failures with matching session id
*/
async processNewGroupSession(sessionItem) {
const { streamId, sessions: session } = sessionItem;
// check if this message is to our device
const ciphertext = session.ciphertexts[this.userDevice.deviceKey];
if (!ciphertext) {
this.log.debug('skipping, no session for our device');
return;
}
this.log.debug('processNewGroupSession', session);
// check if it contains any keys we need, default to GroupEncryption if the algorithm is not set
const parsed = parseGroupEncryptionAlgorithmId(session.algorithm, GroupEncryptionAlgorithmId.GroupEncryption);
if (parsed.kind === 'unrecognized') {
// todo dispatch event to update the error message
this.log.error('skipping, invalid algorithm', session.algorithm);
return;
}
const algorithm = parsed.value;
const neededKeyIndexs = [];
for (let i = 0; i < session.sessionIds.length; i++) {
const sessionId = session.sessionIds[i];
const hasKeys = await this.crypto.hasSessionKey(streamId, sessionId, algorithm);
if (!hasKeys) {
neededKeyIndexs.push(i);
}
}
if (!neededKeyIndexs.length) {
this.log.debug('skipping, we have all the keys');
return;
}
// decrypt the message
const cleartext = await this.crypto.decryptWithDeviceKey(ciphertext, session.senderKey);
const sessionKeys = fromJsonString(SessionKeysSchema, cleartext);
check(sessionKeys.keys.length === session.sessionIds.length, 'bad sessionKeys');
// make group sessions
const sessions = neededKeyIndexs.map((i) => ({
streamId: streamId,
sessionId: session.sessionIds[i],
sessionKey: sessionKeys.keys[i],
algorithm: algorithm,
}));
// import the sessions
this.log.debug('importing group sessions streamId:', streamId, 'count: ', sessions.length, session.sessionIds);
try {
await this.crypto.importSessionKeys(streamId, sessions);
// re-enqueue any decryption failures with these ids
const streamQueue = this.streamQueues.getQueue(streamId);
for (const session of sessions) {
if (this.decryptionFailures[streamId]?.[session.sessionId]) {
streamQueue.encryptedContent.push(...this.decryptionFailures[streamId][session.sessionId]);
delete this.decryptionFailures[streamId][session.sessionId];
}
}
}
catch (e) {
// don't re-enqueue to prevent infinite loops if this session is truely corrupted
// we will keep requesting it on each boot until it goes out of the scroll window
this.log.error('failed to import sessions', { sessionItem, error: e });
}
// if we processed them all, ack the stream
if (this.mainQueues.newGroupSession.length === 0) {
await this.ackNewGroupSession(session);
}
}
/**
* processEncryptedContentItem
* try to decrypt encrytped content
*/
async processEncryptedContentItem(item) {
this.log.debug('processEncryptedContentItem', item);
try {
await this.decryptGroupEvent(item.streamId, item.eventId, item.kind, item.encryptedData);
}
catch (err) {
const sessionNotFound = isSessionNotFoundError(err);
this.onDecryptionError(item, {
missingSession: sessionNotFound,
kind: item.kind,
encryptedData: item.encryptedData,
error: err,
});
if (sessionNotFound) {
const streamId = item.streamId;
const sessionId = item.encryptedData.sessionId && item.encryptedData.sessionId.length > 0
? item.encryptedData.sessionId
: bin_toHexString(item.encryptedData.sessionIdBytes);
if (sessionId.length === 0) {
this.log.error('session id length is 0 for failed decryption', {
err,
streamId: item.streamId,
eventId: item.eventId,
});
return;
}
if (!this.decryptionFailures[streamId]) {
this.decryptionFailures[streamId] = { [sessionId]: [item] };
}
else if (!this.decryptionFailures[streamId][sessionId]) {
this.decryptionFailures[streamId][sessionId] = [item];
}
else if (!this.decryptionFailures[streamId][sessionId].includes(item)) {
this.decryptionFailures[streamId][sessionId].push(item);
}
const streamQueue = this.streamQueues.getQueue(streamId);
streamQueue.isMissingKeys = true;
}
else {
this.log.info('failed to decrypt', err, 'streamId', item.streamId);
}
}
}
/**
* processMissingKeys
* process missing keys and send key solicitations to streams
*/
async processMissingKeys(streamId) {
this.log.debug('processing missing keys', streamId);
const missingSessionIds = takeFirst(100, Object.keys(this.decryptionFailures[streamId] ?? {}).sort());
// limit to 100 keys for now todo revisit https://linear.app/hnt-labs/issue/HNT-3936/revisit-how-we-limit-the-number-of-session-ids-that-we-request
if (!missingSessionIds.length) {
this.log.debug('processing missing keys', streamId, 'no missing keys');
return;
}
if (!this.hasStream(streamId)) {
this.log.debug('processing missing keys', streamId, 'stream not found');
return;
}
const isEntitled = await this.isUserEntitledToKeyExchange(streamId, this.userId, {
skipOnChainValidation: true,
});
if (!isEntitled) {
this.log.debug('processing missing keys', streamId, 'user is not member of stream');
return;
}
const solicitedEvents = this.getKeySolicitations(streamId);
const existingKeyRequest = solicitedEvents.find((x) => x.deviceKey === this.userDevice.deviceKey);
if (existingKeyRequest?.isNewDevice ||
sortedArraysEqual(existingKeyRequest?.sessionIds ?? [], missingSessionIds)) {
this.log.debug('processing missing keys already requested keys for this session', existingKeyRequest);
return;
}
const knownSessionIds = await this.crypto.getGroupSessionIds(streamId);
const isNewDevice = knownSessionIds.length === 0;
this.log.debug('requesting keys', streamId, 'isNewDevice', isNewDevice, 'sessionIds:', missingSessionIds.length);
await this.sendKeySolicitation({
streamId,
isNewDevice,
missingSessionIds,
});
}
/**
* processKeySolicitation
* process incoming key solicitations and send keys and key fulfillments
*/
async processKeySolicitation(item) {
this.log.debug('processing key solicitation', item.streamId, item);
const streamId = item.streamId;
check(this.hasStream(streamId), 'stream not found');
const { isValid, reason } = this.isValidEvent(item);
if (!isValid) {
this.log.error('processing key solicitation: invalid event id', {
streamId,
eventId: item.hashStr,
reason,
});
return;
}
const knownSessionIds = await this.crypto.getGroupSessionIds(streamId);
// todo split this up by algorithm so that we can send all the new hybrid keys
knownSessionIds.sort();
const requestedSessionIds = new Set(item.solicitation.sessionIds.sort());
const replySessionIds = item.solicitation.isNewDevice
? knownSessionIds
: knownSessionIds.filter((x) => requestedSessionIds.has(x));
if (replySessionIds.length === 0) {
this.log.debug('processing key solicitation: no keys to reply with');
return;
}
const isUserEntitledToKeyExchange = await this.isUserEntitledToKeyExchange(streamId, item.fromUserId);
if (!isUserEntitledToKeyExchange) {
return;
}
const allSessions = [];
for (const sessionId of replySessionIds) {
const groupSession = await this.crypto.exportGroupSession(streamId, sessionId);
if (groupSession) {
allSessions.push(groupSession);
}
}
this.log.debug('processing key solicitation with', item.streamId, {
to: item.fromUserId,
toDevice: item.solicitation.deviceKey,
requestedCount: item.solicitation.sessionIds.length,
replyIds: replySessionIds.length,
sessions: allSessions.length,
});
if (allSessions.length === 0) {
return;
}
// send a single key fulfillment for all algorithms
const { error } = await this.sendKeyFulfillment({
streamId,
userAddress: item.fromUserAddress,
deviceKey: item.solicitation.deviceKey,
sessionIds: allSessions
.map((x) => x.sessionId)
.filter((x) => requestedSessionIds.has(x))
.sort(),
});
// if the key fulfillment failed, someone else already sent a key fulfillment
if (error) {
if (!error.msg.includes('DUPLICATE_EVENT') && !error.msg.includes('NOT_FOUND')) {
// duplicate events are expected, we can ignore them, others are not
this.log.error('failed to send key fulfillment', error);
}
return;
}
// if the key fulfillment succeeded, send one group session payload for each algorithm
const sessions = allSessions.reduce((acc, session) => {
if (!acc[session.algorithm]) {
acc[session.algorithm] = [];
}
acc[session.algorithm].push(session);
return acc;
}, {});
// send one key fulfillment for each algorithm
for (const kv of Object.entries(sessions)) {
const algorithm = kv[0];
const sessions = kv[1];
await this.encryptAndShareGroupSessions({
streamId,
item,
sessions,
algorithm,
});
}
}
/**
* can be overridden to add a delay to the key solicitation response
*/
getRespondDelayMSForKeySolicitation(_streamId, _userId) {
return 0;
}
setHighPriorityStreams(streamIds) {
this.highPriorityIds = new Set(streamIds);
}
}
export function makeSessionKeys(sessions) {
const sessionKeys = sessions.map((s) => s.sessionKey);
return create(SessionKeysSchema, {
keys: sessionKeys,
});
}
/// Returns the first item from the array,
/// if dateFn is provided, returns the first item where dateFn(item) <= now
function dequeueUpToDate(items, now, dateFn, upToDateStreams) {
if (items.length === 0) {
return undefined;
}
if (dateFn(items[0]) > now) {
return undefined;
}
const index = items.findIndex((x) => dateFn(x) <= now && upToDateStreams.has(x.streamId));
if (index === -1) {
return undefined;
}
return items.splice(index, 1)[0];
}
function takeFirst(count, array) {
const result = [];
for (let i = 0; i < count && i < array.length; i++) {
result.push(array[i]);
}
return result;
}
function isSessionNotFoundError(err) {
if (err !== null && typeof err === 'object' && 'message' in err) {
return err.message.toLowerCase().includes('session not found');
}
return false;
}
//# sourceMappingURL=decryptionExtensions.js.map