UNPKG

auspice

Version:

Web app for visualizing pathogen evolution

436 lines (380 loc) 18.9 kB
import { ReduxNode, StreamDimensions, Visibility, Streams, sigma, weightToDisplayOrderScaleFactor, colorBySymbol, streamLabelSymbol } from "../reducers/tree/types"; import { getTraitFromNode, getDivFromNode } from "./treeMiscHelpers" import { NODE_VISIBLE } from "./globals"; import pdf from '@stdlib/stats-base-dists-normal-pdf'; /** * Side effects: * - sets node.streamMembership -> false | stream name * - sets node.streamStart -> boolean * Returns * "streams": each key -> object * "connectedStreamOrdering" (or a tree structure? This seems hard...) */ export function labelStreamMembership(tree: ReduxNode, branchLabelKey): Streams { const streams: Streams = {}; streams[streamLabelSymbol] = branchLabelKey; const stack: [ReduxNode, false|string][] = tree.children.map((subtreeRootNode) => [subtreeRootNode, false]); while (stack.length) { const [node, parentStreamMembership] = stack.pop(); let newStreamMembership = node?.branch_attrs?.labels?.[branchLabelKey]; if (newStreamMembership && newStreamMembership in streams) { console.error(`Stream label ${newStreamMembership} seen more than once. Ignoring all but the first.`) newStreamMembership = undefined; } /* clear any previous stream-related information using `streamName` as a sentinel value */ // note that node.inStream (which is on every node) is re-set later in this loop if (node.streamName) { delete node.streamName; delete node.streamPivots; delete node.streamCategories; delete node.streamDimensions; delete node.streamMaxHeight; } if (newStreamMembership) { streams[newStreamMembership] = { name: newStreamMembership, startNode: node.arrayIdx, members: [], // terminals only streamChildren: [], // direct children only parentStreamName: parentStreamMembership, domains: { num_date: [Infinity, -Infinity], div: [Infinity, -Infinity], } }; node.streamName = newStreamMembership; if (parentStreamMembership) { if (!(parentStreamMembership in streams)) throw new Error("labelStreamMembership fatal error II"); streams[parentStreamMembership].streamChildren.push(newStreamMembership); } } const currentStreamMembership = newStreamMembership || parentStreamMembership; node.inStream = !!currentStreamMembership; if (currentStreamMembership && !node.hasChildren) { streams[currentStreamMembership].members.push(node.arrayIdx); // update domains const domains = streams[currentStreamMembership].domains; const div = getDivFromNode(node); const num_date = getTraitFromNode(node, 'num_date'); if (div<domains.div[0]) {domains.div[0]=div} if (div>domains.div[1]) {domains.div[1]=div} if (num_date<domains.num_date[0]) {domains.num_date[0]=num_date} if (num_date>domains.num_date[1]) {domains.num_date[1]=num_date} } for (const child of node.children || []) { stack.push([child, currentStreamMembership]) } } return streams; } export function processStreams( streams: Streams, nodes: ReduxNode[], visibility: Visibility[], metric: "num_date"|"div", colorScale, { skipPivots=false, skipCategories=false }: {skipPivots?: boolean, skipCategories?: boolean} = {}, ):void { /** * Pivots often don't need to be recalculated. Sigma is also recalculated. */ if (!skipPivots || !Object.values(streams).every((s) => Object.hasOwn(s, 'streamPivots'))) { // entire domain spanning all streams const domain = (Object.values(streams)).reduce((dd, stream) => { if (dd[0] > stream.domains[metric][0]) dd[0] = stream.domains[metric][0]; if (dd[1] < stream.domains[metric][1]) dd[1] = stream.domains[metric][1]; return dd; }, [Infinity, -Infinity]); const nPivots = 500 ; // which will represent the entire domain const nExtend = 2; // use (nPivots-2*nExtend) to span the domain into pieces (of a constant size), then extend the domain either side by nExtend pieces const size = Math.abs(domain[1] - domain[0])/(nPivots-1-2*nExtend); const pivots = Array.from(Array(nPivots), (_, i) => domain[0]-nExtend*size + i*size); /** Sigma calculation This seems really hard to get right over every dataset * the following seems to work nicely over the dozen or so testing datasets * I used, but there may be a better way to compute this, e.g. relative to * the density of tips across the pivots etc. */ streams[sigma] = (pivots.at(-1)-pivots.at(0))/nPivots*5; // Each stream sees a filtered version of these pivots for (const stream of Object.values(streams)) { const startNode = nodes[stream.startNode]; const parentPosition: number = metric==='div' ? getDivFromNode(startNode.parent) : getTraitFromNode(startNode.parent, 'num_date'); startNode.streamPivots = restrictPivots(pivots, stream.domains[metric], parentPosition, streams[sigma], 5); } /** * We define a scale factor here which is applied later on when we convert kde-weight space into display-order space. * This is needed because the display order (for non-stream tips) is 1 unit = 1 tip. Since kernel PDF values can be huge * (or tiny, depending on STDEV) the kde-weight space can be very large and thus streams take up all the display order space * and normal tips are all squashed together. * * The scale factor is the PDF evaluated at x=0, i.e. the max height of an individual kernel in display order space will be * equivalent to what a single tip would have occupied. Because kernels aren't all stacked on top of each other we add a * fudge factor here (can be improved). */ streams[weightToDisplayOrderScaleFactor] = 1 / pdf.factory(0, streams[sigma])(0) * 5; streams[colorBySymbol] = colorScale.colorBy; /** we want to ladderize each time we change metric, which is also when we need to recalculate pivots */ Object.values(streams) .filter((s) => s.parentStreamName===false) // filter to the streams which represent the start of connected series of streams .forEach((stream) => { stream.renderingOrder = calcRenderingOrder(stream.name, streams, nodes, metric); }) } for (const stream of Object.values(streams)) { const startNode = nodes[stream.startNode]; const nodesThisStream = _pick(nodes, stream.members) /** * Categories only need to be recalculated when the colouring changes (or upon stream creation) */ if (!skipCategories) { startNode.streamCategories = observedCategories(nodesThisStream, colorScale); } /** * When a stream is being instantiated for the first time we need to compute the max height. * Importantly this considers all candidate nodes to be visible. We also need to recalculate this * when the pivots have changed (because values in weight space are evaluated at pivots) */ let dimensions: StreamDimensions; let everythingIsVisibleAnyway = false; let streamNodeCountsTotal: number; let streamNodeCountsVisible: number; if (!Object.hasOwn(startNode, "streamMaxHeight") || !skipPivots) { ({dimensions, streamNodeCountsTotal, streamNodeCountsVisible} = computeStreamDimensions(nodesThisStream, startNode.streamPivots, metric, startNode.streamCategories, true, streams[sigma])); startNode.streamMaxHeight = dimensions.length ? computeStreamMaxHeight(dimensions) : 0; /** * NOTE: the heights of these (i.e. in KDE weight space) can be huge if we have finely spaced pivots such that the PDFs are evaluated a large number * of times. We must perform some normalization of this when we go to display order space, otherwise our typical approach to spacing tips (tips * separated by 1 unit of display order space) means tips are right on top of each other if the display order space occupied by streams is very large. * (by large, I've seen examples of 10e6...) */ const visibilityValues = new Set(visibility); if (visibilityValues.size===1 && visibilityValues.has(NODE_VISIBLE)) { everythingIsVisibleAnyway = true; } } /** * Compute the dimensions of the stream, taking into account visibility */ if (!everythingIsVisibleAnyway) { ({dimensions, streamNodeCountsTotal, streamNodeCountsVisible} = computeStreamDimensions(nodesThisStream, startNode.streamPivots, metric, startNode.streamCategories, visibility, streams[sigma])); } startNode.streamDimensions = dimensions; startNode.streamNodeCounts = {total: streamNodeCountsTotal, visible: streamNodeCountsVisible}; } } /** * Collect all possible categories - "ribbons within a stream(tree)" - by looping * over all nodes in the stream. The order here reflects the order of ripples in the streamtree. */ function observedCategories(nodes: ReduxNode[], colorScale: any): ReduxNode['streamCategories'] { const colorBy: string = colorScale.colorBy; if (colorScale.continuous) { type ColorCategory = [ // intermediate type /** name */ string, /** lower bound of category */ number, /** upper bound of category */ number, /** RGB string */ string, /** indexes of nodes which are in this category */ number[], ]; /* NOTE: plenty of speed-ups here if it's a bottleneck */ const categories: ColorCategory[] = Object.entries(colorScale.legendBounds) .map(([name, bounds]) => [name, bounds[0], bounds[1], colorScale.scale(parseFloat(name)), []]) const undefinedNodes: number[] = []; for (const n of nodes) { const v = getTraitFromNode(n, colorBy); if (v===undefined) { undefinedNodes.push(n.arrayIdx); continue; } for (const c of categories) { if (v <= c[2] && v >= c[1]) { c[4].push(n.arrayIdx) break; } } } const streamCategories = categories .filter((c) => c[4].length>0) .map((c) => ({name: c[0], color: c[3], nodes: c[4]})); if (undefinedNodes.length) { streamCategories.push({name: undefined, color: colorScale.scale(undefined), nodes: undefinedNodes}); } return streamCategories; } const getter: (n: ReduxNode) => [number, string|undefined] = colorScale.genotype ? (n): [number, string] => [n.arrayIdx, n.currentGt] : (n): [number, string|undefined] => [n.arrayIdx, getTraitFromNode(n, colorBy)]; const nodesAndCategories: [number, string|undefined][] = nodes.map(getter); const orderedCategories = Array.from(new Set(nodesAndCategories.map((el) => el[1]))) .sort((a,b) => colorScale.legendValues.indexOf(a) - colorScale.legendValues.indexOf(b)); return orderedCategories.map((name) => ({ name, color: colorScale.scale(name), nodes: nodesAndCategories.filter(([_, catName]) => catName===name).map(([nodeIdx,]) => nodeIdx) })); } /** * Returns a matrix of data intended for visualisation by d3 as a stream graph * The outer dimensions correspond to categories, i.e. the ribbons of a stream * The inner dimensions correspond to pivots. These dimensions are a KDE * with each tip represented by a gaussian centered on the tip with some * constant std-dev. * * See stream-trees.md for more explanation */ function computeStreamDimensions(nodes: ReduxNode[], pivots: number[], metric, categories: ReduxNode['streamCategories'], visibility: true|Visibility[], sigma:number): {dimensions: StreamDimensions, streamNodeCountsTotal: number, streamNodeCountsVisible: number} { let [streamNodeCountsTotal, streamNodeCountsVisible] = [0,0]; // per-stream weight (to increase weights of small streams) const w = Math.exp(-(nodes.length-4)/4)+1; const dimensions = categories.map((categoryInfo) => { const mass = pivots.map(() => 0); const categoryNodes = nodes.filter((node) => categoryInfo.nodes.includes(node.arrayIdx)) streamNodeCountsTotal += categoryNodes.length; const visibleCategoryNodes = categoryNodes.filter((node) => visibility===true || visibility[node.arrayIdx]===NODE_VISIBLE); streamNodeCountsVisible += visibleCategoryNodes.length; for (const node of visibleCategoryNodes) { const mu = metric==='div' ? getDivFromNode(node) : getTraitFromNode(node, 'num_date'); const kde = pdf.factory(mu, sigma); // We know that once \mu is 3*\sigma away from the pivot we don't really add any weight so could leverage this to // speed things up (we do this already for the pivots in this stream, but could also do it for the individual nodes) pivots.forEach((pivot, idx) => { mass[idx]+= w * kde(pivot); }) } return mass; }) return {dimensions, streamNodeCountsTotal, streamNodeCountsVisible}; } function computeStreamMaxHeight(dimensions: StreamDimensions): number { const nPivots = dimensions[0].length; return Array.from(Array(nPivots), undefined) .map((_, pivotIdx) => _sum(dimensions.map((weightsPerPivot) => weightsPerPivot[pivotIdx]))) .reduce((maxValue, cv) => cv > maxValue ? cv : maxValue, 0) } /** * Given a set of connected streams (with `rootName` representing the initial stream) we return a list * of stream names which can be rendered in order such that there are no crossings (i.e. stream connector lines * don't go "through" other streams). Using a toy example of stream R which has 2 child streams {A,B} * we want to render this as * * ┌ AAAAAAAAAAAAAAAAAA * │ ┌ BBBBBBBBBBBBBB * │ │ * ─────────── RRRRRRRRRRRRRRRRRRRRRRRRRR * * Where we want A to be drawn above B (i.e. smaller display order) based on the numeric date of the branch leaving R. * This approach continues to further child streams (e.g. child streams of A). We construct this via a tree * structure (where nodes represent streams) and return a list of nodes ordered by a post-order traversal, * i.e. [A,B,R]. These streams can then be assigned display orders in a simple incremental fashion. * * For divergence trees we do the same but using divergence values. Note that this often results in a different * return value! E.g. B might branch off before A in divergence space. */ function calcRenderingOrder(rootName: string, streams: Streams, nodes: ReduxNode[], metric: 'num_date'|'div'): string[] { interface Node { name: string, children: Node[], parent: false|Node, seen: boolean } const treeOfStreams: Node = {name: rootName, parent: false, children: [], seen: false} const stack = [treeOfStreams]; let _counter = 100000 while (stack.length && _counter>0) { _counter-- const element = stack.pop(); element.children = streams[element.name].streamChildren .map((name) => [ name, metric==='div' ? getDivFromNode(nodes[streams[name].startNode].parent) : getTraitFromNode(nodes[streams[name].startNode].parent, 'num_date') ]) .sort((a, b) => a[1]<b[1] ? -1 : a[1]>b[1] ? 1 : 0) .map(([name]) => ({name, parent: element, children: [], seen: false})) for (const el of element.children) { stack.push(el); } } function _topmostTerminal(n: Node): Node { while (n.children.length) { n = n.children[0]; } return n; } const postOrderStartNode = _topmostTerminal(treeOfStreams) const postOrder = [postOrderStartNode] _counter=100000 while (true && _counter>0) { _counter-- const currentNode = postOrder.at(-1); currentNode.seen=true; if (currentNode.parent===false) break; // We've reached the root! const nextSibling = currentNode.parent.children.filter((c) => !c.seen)[0]; if (nextSibling) { const topmost = _topmostTerminal(nextSibling) postOrder.push(topmost); continue; } // else no siblings, take parent! postOrder.push(currentNode.parent); } return postOrder.map((el) => el.name); } function _pick<T>(arr:T[], idxs: number[]):T[] { return idxs.map((idx) => arr[idx]) } function _sum(arr: number[]): number { return arr.reduce((acc, cv) => acc+cv, 0) } export function isNodeWithinAnotherStream(node: ReduxNode, branchLabelKey: string): boolean { // if the current node is a stream start then it's not _within_ another stream, for this definition of _within_ if (node?.branch_attrs?.labels?.[branchLabelKey]) return false; let n = node.parent; while (true) { // eslint-disable-line no-constant-condition if (n?.branch_attrs?.labels?.[branchLabelKey]) return true; if (n.parent===n) return false; n = n.parent; } } /** * Given the dataset's pivot array, restrict this to a subset of pivots which are applicable for this stream. * * While it's obvious that the pivots should span the stream's tips (i.e. the domain) it's a little more * ambiguous about how far to extend it either side. The more we extend it the more we get smooth ends to the * KDE however there are downsides. * Extending too far to the left is problematic if the pivot list ends up going over the connector branch position * (e.g. the pivots go back further in time than the parent node date). * Extending too far to the right can make it look like every stream has gradually died out. */ function restrictPivots(pivots: number[], domain:[number,number], parentPosition: number, sigma:number, cutoff:number): number[] { let min = domain[0] - sigma*cutoff; if (min<parentPosition) min=parentPosition; // stops pivots (and therefore streams) going to the left of the connecting branch const max = domain[1] + sigma*cutoff; return pivots.filter((value) => value>=min && value<=max); } export function availableStreamLabelKeys(availableBranchLabels: string[], jsonDefinedStreamLabels: undefined|string[]): string[] { if (jsonDefinedStreamLabels) { const labels = jsonDefinedStreamLabels.filter((l) => availableBranchLabels.includes(l)); if (labels.length!==jsonDefinedStreamLabels.length) { console.warn("Some of the metadata-specified 'stream_labels' were not found on the tree and have been excluded: " + jsonDefinedStreamLabels.filter((l) => labels.includes(l)).join(", ")); } return labels; } // Use a hardcoded list to sort labels which are present so certain ones come first const preset = ['stream', 'streams', 'stream_label']; // we may want to do something here to exclude certain branch labels, e.g. ones which are repeated many times on the tree return ([...availableBranchLabels]) .sort((a, b) => { const [ai, bi] = [preset.indexOf(a), preset.indexOf(b)]; if (ai===-1 && bi!==-1) return 1; if (ai!==-1 && bi===-1) return -1; return ai -bi; }) .filter((l) => l!=='aa' && l!=='none'); }