auspice
Version:
Web app for visualizing pathogen evolution
436 lines (380 loc) • 18.9 kB
text/typescript
import { ReduxNode, StreamDimensions, Visibility, Streams, sigma, weightToDisplayOrderScaleFactor,
colorBySymbol, streamLabelSymbol } from "../reducers/tree/types";
import { getTraitFromNode, getDivFromNode } from "./treeMiscHelpers"
import { NODE_VISIBLE } from "./globals";
import pdf from '@stdlib/stats-base-dists-normal-pdf';
/**
* Side effects:
* - sets node.streamMembership -> false | stream name
* - sets node.streamStart -> boolean
* Returns
* "streams": each key -> object
* "connectedStreamOrdering" (or a tree structure? This seems hard...)
*/
export function labelStreamMembership(tree: ReduxNode, branchLabelKey): Streams {
const streams: Streams = {};
streams[streamLabelSymbol] = branchLabelKey;
const stack: [ReduxNode, false|string][] = tree.children.map((subtreeRootNode) => [subtreeRootNode, false]);
while (stack.length) {
const [node, parentStreamMembership] = stack.pop();
let newStreamMembership = node?.branch_attrs?.labels?.[branchLabelKey];
if (newStreamMembership && newStreamMembership in streams) {
console.error(`Stream label ${newStreamMembership} seen more than once. Ignoring all but the first.`)
newStreamMembership = undefined;
}
/* clear any previous stream-related information using `streamName` as a sentinel value */
// note that node.inStream (which is on every node) is re-set later in this loop
if (node.streamName) {
delete node.streamName;
delete node.streamPivots;
delete node.streamCategories;
delete node.streamDimensions;
delete node.streamMaxHeight;
}
if (newStreamMembership) {
streams[newStreamMembership] = {
name: newStreamMembership,
startNode: node.arrayIdx,
members: [], // terminals only
streamChildren: [], // direct children only
parentStreamName: parentStreamMembership,
domains: {
num_date: [Infinity, -Infinity],
div: [Infinity, -Infinity],
}
};
node.streamName = newStreamMembership;
if (parentStreamMembership) {
if (!(parentStreamMembership in streams)) throw new Error("labelStreamMembership fatal error II");
streams[parentStreamMembership].streamChildren.push(newStreamMembership);
}
}
const currentStreamMembership = newStreamMembership || parentStreamMembership;
node.inStream = !!currentStreamMembership;
if (currentStreamMembership && !node.hasChildren) {
streams[currentStreamMembership].members.push(node.arrayIdx);
// update domains
const domains = streams[currentStreamMembership].domains;
const div = getDivFromNode(node);
const num_date = getTraitFromNode(node, 'num_date');
if (div<domains.div[0]) {domains.div[0]=div}
if (div>domains.div[1]) {domains.div[1]=div}
if (num_date<domains.num_date[0]) {domains.num_date[0]=num_date}
if (num_date>domains.num_date[1]) {domains.num_date[1]=num_date}
}
for (const child of node.children || []) {
stack.push([child, currentStreamMembership])
}
}
return streams;
}
export function processStreams(
streams: Streams,
nodes: ReduxNode[],
visibility: Visibility[],
metric: "num_date"|"div",
colorScale,
{ skipPivots=false, skipCategories=false }: {skipPivots?: boolean, skipCategories?: boolean} = {},
):void {
/**
* Pivots often don't need to be recalculated. Sigma is also recalculated.
*/
if (!skipPivots || !Object.values(streams).every((s) => Object.hasOwn(s, 'streamPivots'))) {
// entire domain spanning all streams
const domain = (Object.values(streams)).reduce((dd, stream) => {
if (dd[0] > stream.domains[metric][0]) dd[0] = stream.domains[metric][0];
if (dd[1] < stream.domains[metric][1]) dd[1] = stream.domains[metric][1];
return dd;
}, [Infinity, -Infinity]);
const nPivots = 500 ; // which will represent the entire domain
const nExtend = 2;
// use (nPivots-2*nExtend) to span the domain into pieces (of a constant size), then extend the domain either side by nExtend pieces
const size = Math.abs(domain[1] - domain[0])/(nPivots-1-2*nExtend);
const pivots = Array.from(Array(nPivots), (_, i) => domain[0]-nExtend*size + i*size);
/** Sigma calculation This seems really hard to get right over every dataset
* the following seems to work nicely over the dozen or so testing datasets
* I used, but there may be a better way to compute this, e.g. relative to
* the density of tips across the pivots etc.
*/
streams[sigma] = (pivots.at(-1)-pivots.at(0))/nPivots*5;
// Each stream sees a filtered version of these pivots
for (const stream of Object.values(streams)) {
const startNode = nodes[stream.startNode];
const parentPosition: number = metric==='div' ? getDivFromNode(startNode.parent) : getTraitFromNode(startNode.parent, 'num_date');
startNode.streamPivots = restrictPivots(pivots, stream.domains[metric], parentPosition, streams[sigma], 5);
}
/**
* We define a scale factor here which is applied later on when we convert kde-weight space into display-order space.
* This is needed because the display order (for non-stream tips) is 1 unit = 1 tip. Since kernel PDF values can be huge
* (or tiny, depending on STDEV) the kde-weight space can be very large and thus streams take up all the display order space
* and normal tips are all squashed together.
*
* The scale factor is the PDF evaluated at x=0, i.e. the max height of an individual kernel in display order space will be
* equivalent to what a single tip would have occupied. Because kernels aren't all stacked on top of each other we add a
* fudge factor here (can be improved).
*/
streams[weightToDisplayOrderScaleFactor] = 1 / pdf.factory(0, streams[sigma])(0) * 5;
streams[colorBySymbol] = colorScale.colorBy;
/** we want to ladderize each time we change metric, which is also when we need to recalculate pivots */
Object.values(streams)
.filter((s) => s.parentStreamName===false) // filter to the streams which represent the start of connected series of streams
.forEach((stream) => {
stream.renderingOrder = calcRenderingOrder(stream.name, streams, nodes, metric);
})
}
for (const stream of Object.values(streams)) {
const startNode = nodes[stream.startNode];
const nodesThisStream = _pick(nodes, stream.members)
/**
* Categories only need to be recalculated when the colouring changes (or upon stream creation)
*/
if (!skipCategories) {
startNode.streamCategories = observedCategories(nodesThisStream, colorScale);
}
/**
* When a stream is being instantiated for the first time we need to compute the max height.
* Importantly this considers all candidate nodes to be visible. We also need to recalculate this
* when the pivots have changed (because values in weight space are evaluated at pivots)
*/
let dimensions: StreamDimensions;
let everythingIsVisibleAnyway = false;
let streamNodeCountsTotal: number;
let streamNodeCountsVisible: number;
if (!Object.hasOwn(startNode, "streamMaxHeight") || !skipPivots) {
({dimensions, streamNodeCountsTotal, streamNodeCountsVisible} = computeStreamDimensions(nodesThisStream, startNode.streamPivots, metric, startNode.streamCategories, true, streams[sigma]));
startNode.streamMaxHeight = dimensions.length ? computeStreamMaxHeight(dimensions) : 0;
/**
* NOTE: the heights of these (i.e. in KDE weight space) can be huge if we have finely spaced pivots such that the PDFs are evaluated a large number
* of times. We must perform some normalization of this when we go to display order space, otherwise our typical approach to spacing tips (tips
* separated by 1 unit of display order space) means tips are right on top of each other if the display order space occupied by streams is very large.
* (by large, I've seen examples of 10e6...)
*/
const visibilityValues = new Set(visibility);
if (visibilityValues.size===1 && visibilityValues.has(NODE_VISIBLE)) {
everythingIsVisibleAnyway = true;
}
}
/**
* Compute the dimensions of the stream, taking into account visibility
*/
if (!everythingIsVisibleAnyway) {
({dimensions, streamNodeCountsTotal, streamNodeCountsVisible} = computeStreamDimensions(nodesThisStream, startNode.streamPivots, metric, startNode.streamCategories, visibility, streams[sigma]));
}
startNode.streamDimensions = dimensions;
startNode.streamNodeCounts = {total: streamNodeCountsTotal, visible: streamNodeCountsVisible};
}
}
/**
* Collect all possible categories - "ribbons within a stream(tree)" - by looping
* over all nodes in the stream. The order here reflects the order of ripples in the streamtree.
*/
function observedCategories(nodes: ReduxNode[], colorScale: any): ReduxNode['streamCategories'] {
const colorBy: string = colorScale.colorBy;
if (colorScale.continuous) {
type ColorCategory = [ // intermediate type
/** name */
string,
/** lower bound of category */
number,
/** upper bound of category */
number,
/** RGB string */
string,
/** indexes of nodes which are in this category */
number[],
];
/* NOTE: plenty of speed-ups here if it's a bottleneck */
const categories: ColorCategory[] = Object.entries(colorScale.legendBounds)
.map(([name, bounds]) => [name, bounds[0], bounds[1], colorScale.scale(parseFloat(name)), []])
const undefinedNodes: number[] = [];
for (const n of nodes) {
const v = getTraitFromNode(n, colorBy);
if (v===undefined) {
undefinedNodes.push(n.arrayIdx);
continue;
}
for (const c of categories) {
if (v <= c[2] && v >= c[1]) {
c[4].push(n.arrayIdx)
break;
}
}
}
const streamCategories = categories
.filter((c) => c[4].length>0)
.map((c) => ({name: c[0], color: c[3], nodes: c[4]}));
if (undefinedNodes.length) {
streamCategories.push({name: undefined, color: colorScale.scale(undefined), nodes: undefinedNodes});
}
return streamCategories;
}
const getter: (n: ReduxNode) => [number, string|undefined] = colorScale.genotype ? (n): [number, string] => [n.arrayIdx, n.currentGt] : (n): [number, string|undefined] => [n.arrayIdx, getTraitFromNode(n, colorBy)];
const nodesAndCategories: [number, string|undefined][] = nodes.map(getter);
const orderedCategories = Array.from(new Set(nodesAndCategories.map((el) => el[1])))
.sort((a,b) => colorScale.legendValues.indexOf(a) - colorScale.legendValues.indexOf(b));
return orderedCategories.map((name) => ({
name,
color: colorScale.scale(name),
nodes: nodesAndCategories.filter(([_, catName]) => catName===name).map(([nodeIdx,]) => nodeIdx)
}));
}
/**
* Returns a matrix of data intended for visualisation by d3 as a stream graph
* The outer dimensions correspond to categories, i.e. the ribbons of a stream
* The inner dimensions correspond to pivots. These dimensions are a KDE
* with each tip represented by a gaussian centered on the tip with some
* constant std-dev.
*
* See stream-trees.md for more explanation
*/
function computeStreamDimensions(nodes: ReduxNode[], pivots: number[], metric, categories: ReduxNode['streamCategories'], visibility: true|Visibility[], sigma:number):
{dimensions: StreamDimensions, streamNodeCountsTotal: number, streamNodeCountsVisible: number} {
let [streamNodeCountsTotal, streamNodeCountsVisible] = [0,0];
// per-stream weight (to increase weights of small streams)
const w = Math.exp(-(nodes.length-4)/4)+1;
const dimensions = categories.map((categoryInfo) => {
const mass = pivots.map(() => 0);
const categoryNodes = nodes.filter((node) => categoryInfo.nodes.includes(node.arrayIdx))
streamNodeCountsTotal += categoryNodes.length;
const visibleCategoryNodes = categoryNodes.filter((node) => visibility===true || visibility[node.arrayIdx]===NODE_VISIBLE);
streamNodeCountsVisible += visibleCategoryNodes.length;
for (const node of visibleCategoryNodes) {
const mu = metric==='div' ? getDivFromNode(node) : getTraitFromNode(node, 'num_date');
const kde = pdf.factory(mu, sigma);
// We know that once \mu is 3*\sigma away from the pivot we don't really add any weight so could leverage this to
// speed things up (we do this already for the pivots in this stream, but could also do it for the individual nodes)
pivots.forEach((pivot, idx) => {
mass[idx]+= w * kde(pivot);
})
}
return mass;
})
return {dimensions, streamNodeCountsTotal, streamNodeCountsVisible};
}
function computeStreamMaxHeight(dimensions: StreamDimensions): number {
const nPivots = dimensions[0].length;
return Array.from(Array(nPivots), undefined)
.map((_, pivotIdx) => _sum(dimensions.map((weightsPerPivot) => weightsPerPivot[pivotIdx])))
.reduce((maxValue, cv) => cv > maxValue ? cv : maxValue, 0)
}
/**
* Given a set of connected streams (with `rootName` representing the initial stream) we return a list
* of stream names which can be rendered in order such that there are no crossings (i.e. stream connector lines
* don't go "through" other streams). Using a toy example of stream R which has 2 child streams {A,B}
* we want to render this as
*
* ┌ AAAAAAAAAAAAAAAAAA
* │ ┌ BBBBBBBBBBBBBB
* │ │
* ─────────── RRRRRRRRRRRRRRRRRRRRRRRRRR
*
* Where we want A to be drawn above B (i.e. smaller display order) based on the numeric date of the branch leaving R.
* This approach continues to further child streams (e.g. child streams of A). We construct this via a tree
* structure (where nodes represent streams) and return a list of nodes ordered by a post-order traversal,
* i.e. [A,B,R]. These streams can then be assigned display orders in a simple incremental fashion.
*
* For divergence trees we do the same but using divergence values. Note that this often results in a different
* return value! E.g. B might branch off before A in divergence space.
*/
function calcRenderingOrder(rootName: string, streams: Streams, nodes: ReduxNode[], metric: 'num_date'|'div'): string[] {
interface Node {
name: string,
children: Node[],
parent: false|Node,
seen: boolean
}
const treeOfStreams: Node = {name: rootName, parent: false, children: [], seen: false}
const stack = [treeOfStreams];
let _counter = 100000
while (stack.length && _counter>0) {
_counter--
const element = stack.pop();
element.children = streams[element.name].streamChildren
.map((name) => [
name,
metric==='div' ? getDivFromNode(nodes[streams[name].startNode].parent) : getTraitFromNode(nodes[streams[name].startNode].parent, 'num_date')
])
.sort((a, b) => a[1]<b[1] ? -1 : a[1]>b[1] ? 1 : 0)
.map(([name]) => ({name, parent: element, children: [], seen: false}))
for (const el of element.children) {
stack.push(el);
}
}
function _topmostTerminal(n: Node): Node {
while (n.children.length) {
n = n.children[0];
}
return n;
}
const postOrderStartNode = _topmostTerminal(treeOfStreams)
const postOrder = [postOrderStartNode]
_counter=100000
while (true && _counter>0) {
_counter--
const currentNode = postOrder.at(-1);
currentNode.seen=true;
if (currentNode.parent===false) break; // We've reached the root!
const nextSibling = currentNode.parent.children.filter((c) => !c.seen)[0];
if (nextSibling) {
const topmost = _topmostTerminal(nextSibling)
postOrder.push(topmost);
continue;
}
// else no siblings, take parent!
postOrder.push(currentNode.parent);
}
return postOrder.map((el) => el.name);
}
function _pick<T>(arr:T[], idxs: number[]):T[] {
return idxs.map((idx) => arr[idx])
}
function _sum(arr: number[]): number {
return arr.reduce((acc, cv) => acc+cv, 0)
}
export function isNodeWithinAnotherStream(node: ReduxNode, branchLabelKey: string): boolean {
// if the current node is a stream start then it's not _within_ another stream, for this definition of _within_
if (node?.branch_attrs?.labels?.[branchLabelKey]) return false;
let n = node.parent;
while (true) { // eslint-disable-line no-constant-condition
if (n?.branch_attrs?.labels?.[branchLabelKey]) return true;
if (n.parent===n) return false;
n = n.parent;
}
}
/**
* Given the dataset's pivot array, restrict this to a subset of pivots which are applicable for this stream.
*
* While it's obvious that the pivots should span the stream's tips (i.e. the domain) it's a little more
* ambiguous about how far to extend it either side. The more we extend it the more we get smooth ends to the
* KDE however there are downsides.
* Extending too far to the left is problematic if the pivot list ends up going over the connector branch position
* (e.g. the pivots go back further in time than the parent node date).
* Extending too far to the right can make it look like every stream has gradually died out.
*/
function restrictPivots(pivots: number[], domain:[number,number], parentPosition: number, sigma:number, cutoff:number): number[] {
let min = domain[0] - sigma*cutoff;
if (min<parentPosition) min=parentPosition; // stops pivots (and therefore streams) going to the left of the connecting branch
const max = domain[1] + sigma*cutoff;
return pivots.filter((value) => value>=min && value<=max);
}
export function availableStreamLabelKeys(availableBranchLabels: string[], jsonDefinedStreamLabels: undefined|string[]): string[] {
if (jsonDefinedStreamLabels) {
const labels = jsonDefinedStreamLabels.filter((l) => availableBranchLabels.includes(l));
if (labels.length!==jsonDefinedStreamLabels.length) {
console.warn("Some of the metadata-specified 'stream_labels' were not found on the tree and have been excluded: " +
jsonDefinedStreamLabels.filter((l) => labels.includes(l)).join(", "));
}
return labels;
}
// Use a hardcoded list to sort labels which are present so certain ones come first
const preset = ['stream', 'streams', 'stream_label'];
// we may want to do something here to exclude certain branch labels, e.g. ones which are repeated many times on the tree
return ([...availableBranchLabels])
.sort((a, b) => {
const [ai, bi] = [preset.indexOf(a), preset.indexOf(b)];
if (ai===-1 && bi!==-1) return 1;
if (ai!==-1 && bi===-1) return -1;
return ai -bi;
})
.filter((l) => l!=='aa' && l!=='none');
}