@arklabs/wallet-sdk
Version:
Bitcoin wallet SDK with Taproot and Ark integration
201 lines (200 loc) • 7.52 kB
JavaScript
import * as musig2 from '../musig2/index.js';
import { getCosignerKeys } from './vtxoTree.js';
import { Script, SigHash, Transaction } from "@scure/btc-signer";
import { base64, hex } from "@scure/base";
import { schnorr, secp256k1 } from "@noble/curves/secp256k1";
import { randomPrivateKeyBytes } from "@scure/btc-signer/utils";
export const ErrMissingVtxoTree = new Error("missing vtxo tree");
export const ErrMissingAggregateKey = new Error("missing aggregate key");
export class TreeSignerSession {
constructor(secretKey) {
this.secretKey = secretKey;
this.myNonces = null;
this.aggregateNonces = null;
this.tree = null;
this.scriptRoot = null;
this.rootSharedOutputAmount = null;
}
static random() {
const secretKey = randomPrivateKeyBytes();
return new TreeSignerSession(secretKey);
}
init(tree, scriptRoot, rootInputAmount) {
this.tree = tree;
this.scriptRoot = scriptRoot;
this.rootSharedOutputAmount = rootInputAmount;
}
getPublicKey() {
return secp256k1.getPublicKey(this.secretKey);
}
getNonces() {
if (!this.tree)
throw ErrMissingVtxoTree;
if (!this.myNonces) {
this.myNonces = this.generateNonces();
}
const nonces = [];
for (const levelNonces of this.myNonces) {
const levelPubNonces = [];
for (const nonce of levelNonces) {
if (!nonce) {
levelPubNonces.push(null);
continue;
}
levelPubNonces.push({ pubNonce: nonce.pubNonce });
}
nonces.push(levelPubNonces);
}
return nonces;
}
setAggregatedNonces(nonces) {
if (this.aggregateNonces)
throw new Error("nonces already set");
this.aggregateNonces = nonces;
}
sign() {
if (!this.tree)
throw ErrMissingVtxoTree;
if (!this.aggregateNonces)
throw new Error("nonces not set");
if (!this.myNonces)
throw new Error("nonces not generated");
const sigs = [];
for (let levelIndex = 0; levelIndex < this.tree.levels.length; levelIndex++) {
const levelSigs = [];
const level = this.tree.levels[levelIndex];
for (let nodeIndex = 0; nodeIndex < level.length; nodeIndex++) {
const node = level[nodeIndex];
const tx = Transaction.fromPSBT(base64.decode(node.tx));
const sig = this.signPartial(tx, levelIndex, nodeIndex);
if (sig) {
levelSigs.push(sig);
}
else {
levelSigs.push(null);
}
}
sigs.push(levelSigs);
}
return sigs;
}
generateNonces() {
if (!this.tree)
throw ErrMissingVtxoTree;
const myNonces = [];
const publicKey = secp256k1.getPublicKey(this.secretKey);
for (const level of this.tree.levels) {
const levelNonces = [];
for (let i = 0; i < level.length; i++) {
const nonces = musig2.generateNonces(publicKey);
levelNonces.push(nonces);
}
myNonces.push(levelNonces);
}
return myNonces;
}
signPartial(tx, levelIndex, nodeIndex) {
if (!this.tree || !this.scriptRoot || !this.rootSharedOutputAmount) {
throw TreeSignerSession.NOT_INITIALIZED;
}
if (!this.myNonces || !this.aggregateNonces) {
throw new Error("session not properly initialized");
}
const myNonce = this.myNonces[levelIndex][nodeIndex];
if (!myNonce)
return null;
const aggNonce = this.aggregateNonces[levelIndex][nodeIndex];
if (!aggNonce)
throw new Error("missing aggregate nonce");
const prevoutAmounts = [];
const prevoutScripts = [];
const cosigners = getCosignerKeys(tx);
const { finalKey } = musig2.aggregateKeys(cosigners, true, {
taprootTweak: this.scriptRoot,
});
for (let inputIndex = 0; inputIndex < tx.inputsLength; inputIndex++) {
const prevout = getPrevOutput(finalKey, this.tree, this.rootSharedOutputAmount, tx);
prevoutAmounts.push(prevout.amount);
prevoutScripts.push(prevout.script);
}
const message = tx.preimageWitnessV1(0, // always first input
prevoutScripts, SigHash.DEFAULT, prevoutAmounts);
return musig2.sign(myNonce.secNonce, this.secretKey, aggNonce.pubNonce, cosigners, message, {
taprootTweak: this.scriptRoot,
sortKeys: true,
});
}
}
TreeSignerSession.NOT_INITIALIZED = new Error("session not initialized, call init method");
// Helper function to validate tree signatures
export async function validateTreeSigs(finalAggregatedKey, sharedOutputAmount, vtxoTree) {
// Iterate through each level of the tree
for (const level of vtxoTree.levels) {
for (const node of level) {
// Parse the transaction
const tx = Transaction.fromPSBT(base64.decode(node.tx));
const input = tx.getInput(0);
// Check if input has signature
if (!input.tapKeySig) {
throw new Error("unsigned tree input");
}
// Get the previous output information
const prevout = getPrevOutput(finalAggregatedKey, vtxoTree, sharedOutputAmount, tx);
// Calculate the message that was signed
const message = tx.preimageWitnessV1(0, // always first input
[prevout.script], SigHash.DEFAULT, [prevout.amount]);
// Verify the signature
const isValid = schnorr.verify(input.tapKeySig, message, finalAggregatedKey);
if (!isValid) {
throw new Error("invalid signature");
}
}
}
}
function getPrevOutput(finalKey, vtxoTree, sharedOutputAmount, partial) {
// Generate P2TR script
const pkScript = Script.encode(["OP_1", finalKey.slice(1)]);
// Get root node
const rootNode = vtxoTree.levels[0][0];
if (!rootNode)
throw new Error("empty vtxo tree");
const input = partial.getInput(0);
if (!input.txid)
throw new Error("missing input txid");
const parentTxID = hex.encode(input.txid);
// Check if parent is root
if (rootNode.parentTxid === parentTxID) {
return {
amount: sharedOutputAmount,
script: pkScript,
};
}
// Search for parent in tree
let parent = null;
for (const level of vtxoTree.levels) {
for (const node of level) {
if (node.txid === parentTxID) {
parent = node;
break;
}
}
if (parent)
break;
}
if (!parent) {
throw new Error("parent tx not found");
}
// Parse parent tx
const parentTx = Transaction.fromPSBT(base64.decode(parent.tx));
if (!input.index)
throw new Error("missing input index");
const parentOutput = parentTx.getOutput(input.index);
if (!parentOutput)
throw new Error("parent output not found");
if (!parentOutput.amount)
throw new Error("parent output amount not found");
return {
amount: parentOutput.amount,
script: pkScript,
};
}