proteus-hd
Version:
Signal Protocol (with header encryption) implementation for JavaScript. Based on Proteus.js.
399 lines (337 loc) • 13.2 kB
JavaScript
/*
* Wire
* Copyright (C) 2016 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*
*/
'use strict';
const CBOR = require('wire-webapp-cbor');
const ArrayUtil = require('../util/ArrayUtil');
const ClassUtil = require('../util/ClassUtil');
const DontCallConstructor = require('../errors/DontCallConstructor');
const MemoryUtil = require('../util/MemoryUtil');
const TypeUtil = require('../util/TypeUtil');
const DecryptError = require('../errors/DecryptError');
const DerivedSecrets = require('../derived/DerivedSecrets');
const HeadKey = require('../derived/HeadKey');
const IdentityKey = require('../keys/IdentityKey');
const IdentityKeyPair = require('../keys/IdentityKeyPair');
const KeyPair = require('../keys/KeyPair');
const PreKeyBundle = require('../keys/PreKeyBundle');
const PublicKey = require('../keys/PublicKey');
const Header = require('../message/Header');
const HeaderMessage = require('../message/HeaderMessage');
const Envelope = require('../message/Envelope');
const PreKeyMessage = require('../message/PreKeyMessage');
const ChainKey = require('./ChainKey');
const RecvChain = require('./RecvChain');
const RootKey = require('./RootKey');
const SendChain = require('./SendChain');
const Session = require('./Session');
/** @module session */
/** @class SessionState */
class SessionState {
constructor() {
this.recv_chains = [];
this.send_chain = null;
this.root_key = null;
this.prev_counter = null;
this.next_send_head_key = null;
this.next_recv_head_key = null;
throw new DontCallConstructor(this);
}
/**
* @param {!keys.IdentityKeyPair} alice_identity_pair
* @param {!keys.PublicKey} alice_base
* @param {!keys.PreKeyBundle} bob_pkbundle
* @returns {SessionState}
*/
static init_as_alice(alice_identity_pair, alice_base, bob_pkbundle) {
TypeUtil.assert_is_instance(IdentityKeyPair, alice_identity_pair);
TypeUtil.assert_is_instance(KeyPair, alice_base);
TypeUtil.assert_is_instance(PreKeyBundle, bob_pkbundle);
const master_key = ArrayUtil.concatenate_array_buffers([
alice_identity_pair.secret_key.shared_secret(bob_pkbundle.public_key),
alice_base.secret_key.shared_secret(bob_pkbundle.identity_key.public_key),
alice_base.secret_key.shared_secret(bob_pkbundle.public_key),
]);
const derived_secrets = DerivedSecrets.kdf_without_salt(master_key, 'handshake');
MemoryUtil.zeroize(master_key);
const rootkey = RootKey.from_cipher_key(derived_secrets.cipher_key);
const chainkey = ChainKey.from_mac_key(derived_secrets.mac_key, 0);
const head_key_alice = derived_secrets.head_key_alice;
const next_head_key_Bob = derived_secrets.next_head_key_bob;
const send_ratchet = KeyPair.new();
const [rok, chk, nextHeadKey] = rootkey.dh_ratchet(send_ratchet, bob_pkbundle.public_key);
const recv_chains = [RecvChain.new(chainkey, bob_pkbundle.public_key, next_head_key_Bob)];
const send_chain = SendChain.new(chk, send_ratchet, head_key_alice);
const state = ClassUtil.new_instance(SessionState);
state.next_send_head_key = nextHeadKey;
state.recv_chains = recv_chains;
state.next_recv_head_key = next_head_key_Bob;
state.send_chain = send_chain;
state.root_key = rok;
state.prev_counter = 0;
return state;
}
/**
* @param {!keys.IdentityKeyPair} bob_ident
* @param {!keys.KeyPair} bob_prekey
* @param {!keys.IdentityKey} alice_ident
* @param {!keys.PublicKey} alice_base
* @returns {SessionState}
*/
static init_as_bob(bob_ident, bob_prekey, alice_ident, alice_base) {
TypeUtil.assert_is_instance(IdentityKeyPair, bob_ident);
TypeUtil.assert_is_instance(KeyPair, bob_prekey);
TypeUtil.assert_is_instance(IdentityKey, alice_ident);
TypeUtil.assert_is_instance(PublicKey, alice_base);
const master_key = ArrayUtil.concatenate_array_buffers([
bob_prekey.secret_key.shared_secret(alice_ident.public_key),
bob_ident.secret_key.shared_secret(alice_base),
bob_prekey.secret_key.shared_secret(alice_base),
]);
const derived_secrets = DerivedSecrets.kdf_without_salt(master_key, 'handshake');
MemoryUtil.zeroize(master_key);
const rootkey = RootKey.from_cipher_key(derived_secrets.cipher_key);
const chainkey = ChainKey.from_mac_key(derived_secrets.mac_key, 0);
const head_key_alice = derived_secrets.head_key_alice;
const next_head_key_bob = derived_secrets.next_head_key_bob;
const send_chain = SendChain.new(chainkey, bob_prekey, next_head_key_bob);
const state = ClassUtil.new_instance(SessionState);
state.next_send_head_key = next_head_key_bob;
state.next_recv_head_key = head_key_alice;
state.send_chain = send_chain;
state.root_key = rootkey;
state.prev_counter = 0;
return state;
}
/**
* @param {!keys.KeyPair} ratchet_key
* @param {!number} prev_counter
* @returns {void}
*/
ratchet(ratchet_key, prev_counter) {
const new_ratchet = KeyPair.new();
const [recv_root_key, recv_chain_key, next_recv_head_key] =
this.root_key.dh_ratchet(this.send_chain.ratchet_key, ratchet_key);
const [send_root_key, send_chain_key, next_send_head_key] =
recv_root_key.dh_ratchet(new_ratchet, ratchet_key);
const recv_chain = RecvChain.new(recv_chain_key, ratchet_key, this.next_recv_head_key);
const send_chain = SendChain.new(send_chain_key, new_ratchet, this.next_send_head_key);
this.root_key = send_root_key;
this.prev_counter = this.send_chain.chain_key.idx;
this.send_chain = send_chain;
this.next_send_head_key = next_send_head_key;
this.next_recv_head_key = next_recv_head_key;
// save last chains counter
const last_chain = this.recv_chains[0];
if (last_chain) {
last_chain.final_count = prev_counter;
}
this.recv_chains.unshift(recv_chain);
if (this.recv_chains.length > Session.MAX_RECV_CHAINS) {
for (let index = Session.MAX_RECV_CHAINS; index < this.recv_chains.length; index++) {
MemoryUtil.zeroize(this.recv_chains[index]);
}
this.recv_chains = this.recv_chains.slice(0, Session.MAX_RECV_CHAINS);
}
}
/**
* @param {!keys.IdentityKey} identity_key - Public identity key of the local identity key pair
* @param {!Array<number|keys.PublicKey>} pending - Pending pre-key
* @param {!(string|Uint8Array)} plaintext - The plaintext to encrypt
* @param {number} confuse_pre_key_id - Use to create confused pre-key message
* @returns {message.Envelope}
*/
encrypt(identity_key, pending, plaintext, confuse_pre_key_id) {
if (pending) {
TypeUtil.assert_is_integer(pending[0]);
TypeUtil.assert_is_instance(PublicKey, pending[1]);
}
TypeUtil.assert_is_instance(IdentityKey, identity_key);
const message_index = this.send_chain.chain_key.idx;
const msgkeys = this.send_chain.chain_key.message_keys();
const head_key = this.send_chain.head_key;
const header = Header.new(
message_index,
this.prev_counter,
this.send_chain.ratchet_key.public_key
).serialise();
let message = HeaderMessage.new(
head_key.encrypt(header, HeadKey.index_as_nonce(message_index)),
msgkeys.encrypt(plaintext)
);
if (pending) {
message = PreKeyMessage.new(pending[0], pending[1], identity_key, message);
} else if (confuse_pre_key_id !== undefined) {
// create a confused pre-key message
message = PreKeyMessage.new(confuse_pre_key_id, KeyPair.new().public_key, identity_key, message);
}
const env = Envelope.new(msgkeys.mac_key, message);
this.send_chain.chain_key = this.send_chain.chain_key.next();
return env;
}
/**
* @param {!message.Envelope} envelope
* @param {!message.HeaderMessage} msg
* @returns {Uint8Array}
*/
decrypt(envelope, msg) {
TypeUtil.assert_is_instance(Envelope, envelope);
TypeUtil.assert_is_instance(HeaderMessage, msg);
const encrypted_header = msg.header;
const [header, recv_chain] = (() => {
// Try next_head_key first, run a DH-ratchet step and create a new receiving chain if it succeeded.
try {
return [RecvChain.try_next_head_key(encrypted_header, this.next_recv_head_key), null];
} catch (err) {
handleHeaderDecryptionError(err);
}
// Otherwise, try old receving chains.
const recv_chains_length = this.recv_chains.length;
let idx = 0;
for (; idx < recv_chains_length; idx++) {
const _recv_chain = this.recv_chains[idx];
try {
return [_recv_chain.try_head_key(encrypted_header), _recv_chain];
} catch (err) {
handleHeaderDecryptionError(err);
}
}
return [null, null];
function handleHeaderDecryptionError(err) {
if (!(err instanceof DecryptError.HeaderDecryptionFailed)) {
throw err;
}
}
})();
if (!header) {
throw new DecryptError.HeaderDecryptionFailed('All chains failed', DecryptError.CODE.CASE_215);
}
const rc = (() => {
if (!recv_chain) {
this.ratchet(header.ratchet_key, header.prev_counter);
return this.recv_chains[0];
}
return recv_chain;
})();
const cipher_text = msg.cipher_text;
const counter = header.counter;
const recv_chain_index = rc.chain_key.idx;
if (counter < recv_chain_index) {
return rc.try_message_keys(envelope, header, cipher_text);
} else if (counter == recv_chain_index) {
const mks = rc.chain_key.message_keys();
if (!envelope.verify(mks.mac_key)) {
throw new DecryptError.InvalidSignature(`Envelope verification failed for message with counters in sync at '${counter}'`, DecryptError.CODE.CASE_206);
}
const plain = mks.decrypt(cipher_text);
rc.chain_key = rc.chain_key.next();
return plain;
} else if (counter > recv_chain_index) {
const [chk, mk, mks] = rc.stage_message_keys(header);
if (!envelope.verify(mk.mac_key)) {
throw new DecryptError.InvalidSignature(`Envelope verification failed for message with counter ahead. Message index is '${counter}' while receive chain index is '${recv_chain_index}'.`, DecryptError.CODE.CASE_207);
}
const plain = mk.decrypt(cipher_text);
rc.chain_key = chk.next();
rc.commit_message_keys(mks);
return plain;
}
}
/** @returns {ArrayBuffer} */
serialise() {
const e = new CBOR.Encoder();
this.encode(e);
return e.get_buffer();
}
static deserialise(buf) {
TypeUtil.assert_is_instance(ArrayBuffer, buf);
return SessionState.decode(new CBOR.Decoder(buf));
}
/**
* @param {!CBOR.Encoder} e
* @returns {CBOR.Encoder}
*/
encode(e) {
e.object(6);
e.u8(0);
this.next_send_head_key.encode(e);
e.u8(1);
this.next_recv_head_key.encode(e);
e.u8(2);
e.array(this.recv_chains.length);
this.recv_chains.map((rch) => rch.encode(e));
e.u8(3);
this.send_chain.encode(e);
e.u8(4);
this.root_key.encode(e);
e.u8(5);
return e.u32(this.prev_counter);
}
/**
* @param {!CBOR.Decoder} d
* @returns {SessionState}
*/
static decode(d) {
TypeUtil.assert_is_instance(CBOR.Decoder, d);
const self = ClassUtil.new_instance(SessionState);
const nprops = d.object();
for (let i = 0; i <= nprops - 1; i++) {
switch (d.u8()) {
case 0: {
self.next_send_head_key = HeadKey.decode(d);
break;
}
case 1: {
self.next_recv_head_key = HeadKey.decode(d);
break;
}
case 2: {
self.recv_chains = [];
let len = d.array();
while (len--) {
self.recv_chains.push(RecvChain.decode(d));
}
break;
}
case 3: {
self.send_chain = SendChain.decode(d);
break;
}
case 4: {
self.root_key = RootKey.decode(d);
break;
}
case 5: {
self.prev_counter = d.u32();
break;
}
default: {
d.skip();
}
}
}
TypeUtil.assert_is_instance(HeadKey, self.next_send_head_key);
TypeUtil.assert_is_instance(HeadKey, self.next_recv_head_key);
TypeUtil.assert_is_instance(SendChain, self.send_chain);
TypeUtil.assert_is_instance(Array, self.recv_chains);
TypeUtil.assert_is_instance(RootKey, self.root_key);
TypeUtil.assert_is_integer(self.prev_counter);
return self;
}
}
module.exports = SessionState;