@tanstack/db
Version:
A reactive client store for building super fast apps on sync
356 lines (355 loc) • 12 kB
JavaScript
import { groupBy, map, filter, serializeValue, groupByOperators } from "@tanstack/db-ivm";
import { getHavingExpression, isExpressionLike, PropRef, Func } from "../ir.js";
import { UnsupportedAggregateFunctionError, UnknownHavingExpressionTypeError, AggregateFunctionNotInSelectError, NonAggregateExpressionNotInGroupByError } from "../../errors.js";
import { compileExpression, toBooleanPredicate } from "./evaluators.js";
const { sum, count, avg, min, max } = groupByOperators;
function validateAndCreateMapping(groupByClause, selectClause) {
const selectToGroupByIndex = /* @__PURE__ */ new Map();
const groupByExpressions = [...groupByClause];
if (!selectClause) {
return { selectToGroupByIndex, groupByExpressions };
}
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type === `agg` || containsAggregate(expr)) {
continue;
}
const groupIndex = groupByExpressions.findIndex(
(groupExpr) => expressionsEqual(expr, groupExpr)
);
if (groupIndex === -1) {
throw new NonAggregateExpressionNotInGroupByError(alias);
}
selectToGroupByIndex.set(alias, groupIndex);
}
return { selectToGroupByIndex, groupByExpressions };
}
function processGroupBy(pipeline, groupByClause, havingClauses, selectClause, fnHavingClauses) {
if (groupByClause.length === 0) {
const aggregates2 = {};
const wrappedAggExprs2 = {};
const aggCounter2 = { value: 0 };
if (selectClause) {
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type === `agg`) {
aggregates2[alias] = getAggregateFunction(expr);
} else if (containsAggregate(expr)) {
const { transformed, extracted } = extractAndReplaceAggregates(
expr,
aggCounter2
);
for (const [syntheticAlias, aggExpr] of Object.entries(extracted)) {
aggregates2[syntheticAlias] = getAggregateFunction(aggExpr);
}
wrappedAggExprs2[alias] = compileExpression(transformed);
}
}
}
const keyExtractor2 = () => ({ __singleGroup: true });
pipeline = pipeline.pipe(
groupBy(keyExtractor2, aggregates2)
);
pipeline = pipeline.pipe(
map(([, aggregatedRow]) => {
const selectResults = aggregatedRow.$selected || {};
const finalResults = { ...selectResults };
if (selectClause) {
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type === `agg`) {
finalResults[alias] = aggregatedRow[alias];
}
}
evaluateWrappedAggregates(
finalResults,
aggregatedRow,
wrappedAggExprs2
);
}
return [
`single_group`,
{
...aggregatedRow,
$selected: finalResults
}
];
})
);
if (havingClauses && havingClauses.length > 0) {
for (const havingClause of havingClauses) {
const havingExpression = getHavingExpression(havingClause);
const transformedHavingClause = replaceAggregatesByRefs(
havingExpression,
selectClause || {},
`$selected`
);
const compiledHaving = compileExpression(transformedHavingClause);
pipeline = pipeline.pipe(
filter(([, row]) => {
const namespacedRow = { $selected: row.$selected };
return toBooleanPredicate(compiledHaving(namespacedRow));
})
);
}
}
if (fnHavingClauses && fnHavingClauses.length > 0) {
for (const fnHaving of fnHavingClauses) {
pipeline = pipeline.pipe(
filter(([, row]) => {
const namespacedRow = { $selected: row.$selected };
return toBooleanPredicate(fnHaving(namespacedRow));
})
);
}
}
return pipeline;
}
const mapping = validateAndCreateMapping(groupByClause, selectClause);
const compiledGroupByExpressions = groupByClause.map(
(e) => compileExpression(e)
);
const keyExtractor = ([, row]) => {
const namespacedRow = { ...row };
delete namespacedRow.$selected;
const key = {};
for (let i = 0; i < groupByClause.length; i++) {
const compiledExpr = compiledGroupByExpressions[i];
const value = compiledExpr(namespacedRow);
key[`__key_${i}`] = value;
}
return key;
};
const aggregates = {};
const wrappedAggExprs = {};
const aggCounter = { value: 0 };
if (selectClause) {
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type === `agg`) {
aggregates[alias] = getAggregateFunction(expr);
} else if (containsAggregate(expr)) {
const { transformed, extracted } = extractAndReplaceAggregates(
expr,
aggCounter
);
for (const [syntheticAlias, aggExpr] of Object.entries(extracted)) {
aggregates[syntheticAlias] = getAggregateFunction(aggExpr);
}
wrappedAggExprs[alias] = compileExpression(transformed);
}
}
}
pipeline = pipeline.pipe(groupBy(keyExtractor, aggregates));
pipeline = pipeline.pipe(
map(([, aggregatedRow]) => {
const selectResults = aggregatedRow.$selected || {};
const finalResults = {};
if (selectClause) {
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type === `agg`) {
finalResults[alias] = aggregatedRow[alias];
} else if (!wrappedAggExprs[alias]) {
const groupIndex = mapping.selectToGroupByIndex.get(alias);
if (groupIndex !== void 0) {
finalResults[alias] = aggregatedRow[`__key_${groupIndex}`];
} else {
finalResults[alias] = selectResults[alias];
}
}
}
evaluateWrappedAggregates(
finalResults,
aggregatedRow,
wrappedAggExprs
);
} else {
for (let i = 0; i < groupByClause.length; i++) {
finalResults[`__key_${i}`] = aggregatedRow[`__key_${i}`];
}
}
let finalKey;
if (groupByClause.length === 1) {
finalKey = aggregatedRow[`__key_0`];
} else {
const keyParts = [];
for (let i = 0; i < groupByClause.length; i++) {
keyParts.push(aggregatedRow[`__key_${i}`]);
}
finalKey = serializeValue(keyParts);
}
return [
finalKey,
{
...aggregatedRow,
$selected: finalResults
}
];
})
);
if (havingClauses && havingClauses.length > 0) {
for (const havingClause of havingClauses) {
const havingExpression = getHavingExpression(havingClause);
const transformedHavingClause = replaceAggregatesByRefs(
havingExpression,
selectClause || {}
);
const compiledHaving = compileExpression(transformedHavingClause);
pipeline = pipeline.pipe(
filter(([, row]) => {
const namespacedRow = { $selected: row.$selected };
return compiledHaving(namespacedRow);
})
);
}
}
if (fnHavingClauses && fnHavingClauses.length > 0) {
for (const fnHaving of fnHavingClauses) {
pipeline = pipeline.pipe(
filter(([, row]) => {
const namespacedRow = { $selected: row.$selected };
return toBooleanPredicate(fnHaving(namespacedRow));
})
);
}
}
return pipeline;
}
function expressionsEqual(expr1, expr2) {
if (!expr1 || !expr2) return false;
if (expr1.type !== expr2.type) return false;
switch (expr1.type) {
case `ref`:
if (!expr1.path || !expr2.path) return false;
if (expr1.path.length !== expr2.path.length) return false;
return expr1.path.every(
(segment, i) => segment === expr2.path[i]
);
case `val`:
return expr1.value === expr2.value;
case `func`:
return expr1.name === expr2.name && expr1.args?.length === expr2.args?.length && (expr1.args || []).every(
(arg, i) => expressionsEqual(arg, expr2.args[i])
);
case `agg`:
return expr1.name === expr2.name && expr1.args?.length === expr2.args?.length && (expr1.args || []).every(
(arg, i) => expressionsEqual(arg, expr2.args[i])
);
default:
return false;
}
}
function getAggregateFunction(aggExpr) {
const compiledExpr = compileExpression(aggExpr.args[0]);
const valueExtractor = ([, namespacedRow]) => {
const value = compiledExpr(namespacedRow);
if (typeof value === `number`) {
return value;
}
return value != null ? Number(value) : 0;
};
const valueExtractorForMinMax = ([, namespacedRow]) => {
const value = compiledExpr(namespacedRow);
if (typeof value === `number` || typeof value === `string` || typeof value === `bigint` || value instanceof Date) {
return value;
}
return value != null ? Number(value) : 0;
};
const rawValueExtractor = ([, namespacedRow]) => {
return compiledExpr(namespacedRow);
};
switch (aggExpr.name.toLowerCase()) {
case `sum`:
return sum(valueExtractor);
case `count`:
return count(rawValueExtractor);
case `avg`:
return avg(valueExtractor);
case `min`:
return min(valueExtractorForMinMax);
case `max`:
return max(valueExtractorForMinMax);
default:
throw new UnsupportedAggregateFunctionError(aggExpr.name);
}
}
function replaceAggregatesByRefs(havingExpr, selectClause, resultAlias = `$selected`) {
switch (havingExpr.type) {
case `agg`: {
const aggExpr = havingExpr;
for (const [alias, selectExpr] of Object.entries(selectClause)) {
if (selectExpr.type === `agg` && aggregatesEqual(aggExpr, selectExpr)) {
return new PropRef([resultAlias, alias]);
}
}
throw new AggregateFunctionNotInSelectError(aggExpr.name);
}
case `func`: {
const funcExpr = havingExpr;
const transformedArgs = funcExpr.args.map(
(arg) => replaceAggregatesByRefs(arg, selectClause)
);
return new Func(funcExpr.name, transformedArgs);
}
case `ref`:
return havingExpr;
case `val`:
return havingExpr;
default:
throw new UnknownHavingExpressionTypeError(havingExpr.type);
}
}
function evaluateWrappedAggregates(finalResults, aggregatedRow, wrappedAggExprs) {
for (const key of Object.keys(aggregatedRow)) {
if (key.startsWith(`__agg_`)) {
finalResults[key] = aggregatedRow[key];
}
}
for (const [alias, evaluator] of Object.entries(wrappedAggExprs)) {
finalResults[alias] = evaluator({ $selected: finalResults });
}
for (const key of Object.keys(finalResults)) {
if (key.startsWith(`__agg_`)) delete finalResults[key];
}
}
function containsAggregate(expr) {
if (!isExpressionLike(expr)) {
return false;
}
if (expr.type === `agg`) {
return true;
}
if (expr.type === `func`) {
return expr.args.some(
(arg) => containsAggregate(arg)
);
}
return false;
}
function extractAndReplaceAggregates(expr, counter) {
if (expr.type === `agg`) {
const alias = `__agg_${counter.value++}`;
return {
transformed: new PropRef([`$selected`, alias]),
extracted: { [alias]: expr }
};
}
if (expr.type === `func`) {
const allExtracted = {};
const newArgs = expr.args.map((arg) => {
const result = extractAndReplaceAggregates(arg, counter);
Object.assign(allExtracted, result.extracted);
return result.transformed;
});
return {
transformed: new Func(expr.name, newArgs),
extracted: allExtracted
};
}
return { transformed: expr, extracted: {} };
}
function aggregatesEqual(agg1, agg2) {
return agg1.name === agg2.name && agg1.args.length === agg2.args.length && agg1.args.every((arg, i) => expressionsEqual(arg, agg2.args[i]));
}
export {
containsAggregate,
processGroupBy,
replaceAggregatesByRefs
};
//# sourceMappingURL=group-by.js.map