UNPKG

@tanstack/db

Version:

A reactive client store for building super fast apps on sync

428 lines (427 loc) 14.9 kB
import { groupBy, map, serializeValue, filter, groupByOperators } from "@tanstack/db-ivm"; import { getHavingExpression, isExpressionLike, PropRef, Func } from "../ir.js"; import { UnsupportedAggregateFunctionError, UnknownHavingExpressionTypeError, AggregateFunctionNotInSelectError, NonAggregateExpressionNotInGroupByError } from "../../errors.js"; import { compileExpression, toBooleanPredicate } from "./evaluators.js"; const VIRTUAL_SYNCED_KEY = `__virtual_synced__`; const VIRTUAL_HAS_LOCAL_KEY = `__virtual_has_local__`; function getRowVirtualMetadata(row) { let found = false; let allSynced = true; let hasLocal = false; for (const [alias, value] of Object.entries(row)) { if (alias === `$selected`) continue; const asRecord = value; const hasSyncedProp = `$synced` in asRecord; const hasOriginProp = `$origin` in asRecord; if (!hasSyncedProp && !hasOriginProp) { continue; } found = true; if (asRecord.$synced === false) { allSynced = false; } if (asRecord.$origin === `local`) { hasLocal = true; } } return { synced: found ? allSynced : true, hasLocal }; } const { sum, count, avg, min, max } = groupByOperators; function validateAndCreateMapping(groupByClause, selectClause) { const selectToGroupByIndex = /* @__PURE__ */ new Map(); const groupByExpressions = [...groupByClause]; if (!selectClause) { return { selectToGroupByIndex, groupByExpressions }; } for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg` || containsAggregate(expr)) { continue; } const groupIndex = groupByExpressions.findIndex( (groupExpr) => expressionsEqual(expr, groupExpr) ); if (groupIndex === -1) { throw new NonAggregateExpressionNotInGroupByError(alias); } selectToGroupByIndex.set(alias, groupIndex); } return { selectToGroupByIndex, groupByExpressions }; } function processGroupBy(pipeline, groupByClause, havingClauses, selectClause, fnHavingClauses, aggregateCollectionId, mainSource) { const virtualAggregates = { [VIRTUAL_SYNCED_KEY]: { preMap: ([, row]) => getRowVirtualMetadata(row).synced, reduce: (values) => { for (const [isSynced, multiplicity] of values) { if (!isSynced && multiplicity > 0) { return false; } } return true; } }, [VIRTUAL_HAS_LOCAL_KEY]: { preMap: ([, row]) => getRowVirtualMetadata(row).hasLocal, reduce: (values) => { for (const [isLocal, multiplicity] of values) { if (isLocal && multiplicity > 0) { return true; } } return false; } } }; if (groupByClause.length === 0) { const aggregates2 = virtualAggregates; const wrappedAggExprs2 = {}; const aggCounter2 = { value: 0 }; if (selectClause) { for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg`) { aggregates2[alias] = getAggregateFunction(expr); } else if (containsAggregate(expr)) { const { transformed, extracted } = extractAndReplaceAggregates( expr, aggCounter2 ); for (const [syntheticAlias, aggExpr] of Object.entries(extracted)) { aggregates2[syntheticAlias] = getAggregateFunction(aggExpr); } wrappedAggExprs2[alias] = compileExpression(transformed); } } } const keyExtractor2 = mainSource ? ([, row]) => ({ __singleGroup: true, __correlationKey: row?.[mainSource]?.__correlationKey }) : () => ({ __singleGroup: true }); pipeline = pipeline.pipe( groupBy(keyExtractor2, aggregates2) ); pipeline = pipeline.pipe( map(([, aggregatedRow]) => { const selectResults = aggregatedRow.$selected || {}; const finalResults = { ...selectResults }; if (selectClause) { for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg`) { finalResults[alias] = aggregatedRow[alias]; } } evaluateWrappedAggregates( finalResults, aggregatedRow, wrappedAggExprs2 ); } const correlationKey = mainSource ? aggregatedRow.__correlationKey : void 0; const resultKey = correlationKey !== void 0 ? `single_group_${serializeValue(correlationKey)}` : `single_group`; const resultRow = { ...aggregatedRow, $selected: finalResults }; const groupSynced = aggregatedRow[VIRTUAL_SYNCED_KEY]; const groupHasLocal = aggregatedRow[VIRTUAL_HAS_LOCAL_KEY]; resultRow.$synced = groupSynced ?? true; resultRow.$origin = groupHasLocal ? `local` : `remote`; resultRow.$key = resultKey; resultRow.$collectionId = aggregateCollectionId ?? resultRow.$collectionId; if (mainSource && correlationKey !== void 0) { resultRow[mainSource] = { __correlationKey: correlationKey }; } return [resultKey, resultRow]; }) ); if (havingClauses && havingClauses.length > 0) { for (const havingClause of havingClauses) { const havingExpression = getHavingExpression(havingClause); const transformedHavingClause = replaceAggregatesByRefs( havingExpression, selectClause || {}, `$selected` ); const compiledHaving = compileExpression(transformedHavingClause); pipeline = pipeline.pipe( filter(([, row]) => { const namespacedRow = { $selected: row.$selected }; return toBooleanPredicate(compiledHaving(namespacedRow)); }) ); } } if (fnHavingClauses && fnHavingClauses.length > 0) { for (const fnHaving of fnHavingClauses) { pipeline = pipeline.pipe( filter(([, row]) => { const namespacedRow = { $selected: row.$selected }; return toBooleanPredicate(fnHaving(namespacedRow)); }) ); } } return pipeline; } const mapping = validateAndCreateMapping(groupByClause, selectClause); const compiledGroupByExpressions = groupByClause.map( (e) => compileExpression(e) ); const keyExtractor = ([, row]) => { const namespacedRow = { ...row }; delete namespacedRow.$selected; const key = {}; for (let i = 0; i < groupByClause.length; i++) { const compiledExpr = compiledGroupByExpressions[i]; const value = compiledExpr(namespacedRow); key[`__key_${i}`] = value; } if (mainSource) { key.__correlationKey = row?.[mainSource]?.__correlationKey; } return key; }; const aggregates = virtualAggregates; const wrappedAggExprs = {}; const aggCounter = { value: 0 }; if (selectClause) { for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg`) { aggregates[alias] = getAggregateFunction(expr); } else if (containsAggregate(expr)) { const { transformed, extracted } = extractAndReplaceAggregates( expr, aggCounter ); for (const [syntheticAlias, aggExpr] of Object.entries(extracted)) { aggregates[syntheticAlias] = getAggregateFunction(aggExpr); } wrappedAggExprs[alias] = compileExpression(transformed); } } } pipeline = pipeline.pipe(groupBy(keyExtractor, aggregates)); pipeline = pipeline.pipe( map(([, aggregatedRow]) => { const selectResults = aggregatedRow.$selected || {}; const finalResults = {}; if (selectClause) { for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg`) { finalResults[alias] = aggregatedRow[alias]; } else if (!wrappedAggExprs[alias]) { const groupIndex = mapping.selectToGroupByIndex.get(alias); if (groupIndex !== void 0) { finalResults[alias] = aggregatedRow[`__key_${groupIndex}`]; } else { finalResults[alias] = selectResults[alias]; } } } evaluateWrappedAggregates( finalResults, aggregatedRow, wrappedAggExprs ); } else { for (let i = 0; i < groupByClause.length; i++) { finalResults[`__key_${i}`] = aggregatedRow[`__key_${i}`]; } } const correlationKey = mainSource ? aggregatedRow.__correlationKey : void 0; const keyParts = []; for (let i = 0; i < groupByClause.length; i++) { keyParts.push(aggregatedRow[`__key_${i}`]); } if (correlationKey !== void 0) { keyParts.push(correlationKey); } const finalKey = keyParts.length === 1 ? keyParts[0] : serializeValue(keyParts); const resultRow = { ...aggregatedRow, $selected: finalResults }; const groupSynced = aggregatedRow[VIRTUAL_SYNCED_KEY]; const groupHasLocal = aggregatedRow[VIRTUAL_HAS_LOCAL_KEY]; resultRow.$synced = groupSynced ?? true; resultRow.$origin = groupHasLocal ? `local` : `remote`; resultRow.$key = finalKey; resultRow.$collectionId = aggregateCollectionId ?? resultRow.$collectionId; if (mainSource && correlationKey !== void 0) { resultRow[mainSource] = { __correlationKey: correlationKey }; } return [finalKey, resultRow]; }) ); if (havingClauses && havingClauses.length > 0) { for (const havingClause of havingClauses) { const havingExpression = getHavingExpression(havingClause); const transformedHavingClause = replaceAggregatesByRefs( havingExpression, selectClause || {} ); const compiledHaving = compileExpression(transformedHavingClause); pipeline = pipeline.pipe( filter(([, row]) => { const namespacedRow = { $selected: row.$selected }; return compiledHaving(namespacedRow); }) ); } } if (fnHavingClauses && fnHavingClauses.length > 0) { for (const fnHaving of fnHavingClauses) { pipeline = pipeline.pipe( filter(([, row]) => { const namespacedRow = { $selected: row.$selected }; return toBooleanPredicate(fnHaving(namespacedRow)); }) ); } } return pipeline; } function expressionsEqual(expr1, expr2) { if (!expr1 || !expr2) return false; if (expr1.type !== expr2.type) return false; switch (expr1.type) { case `ref`: if (!expr1.path || !expr2.path) return false; if (expr1.path.length !== expr2.path.length) return false; return expr1.path.every( (segment, i) => segment === expr2.path[i] ); case `val`: return expr1.value === expr2.value; case `func`: return expr1.name === expr2.name && expr1.args?.length === expr2.args?.length && (expr1.args || []).every( (arg, i) => expressionsEqual(arg, expr2.args[i]) ); case `agg`: return expr1.name === expr2.name && expr1.args?.length === expr2.args?.length && (expr1.args || []).every( (arg, i) => expressionsEqual(arg, expr2.args[i]) ); default: return false; } } function getAggregateFunction(aggExpr) { const compiledExpr = compileExpression(aggExpr.args[0]); const valueExtractor = ([, namespacedRow]) => { const value = compiledExpr(namespacedRow); if (typeof value === `number`) { return value; } return value != null ? Number(value) : 0; }; const valueExtractorForMinMax = ([, namespacedRow]) => { const value = compiledExpr(namespacedRow); if (typeof value === `number` || typeof value === `string` || typeof value === `bigint` || value instanceof Date) { return value; } return value != null ? Number(value) : 0; }; const rawValueExtractor = ([, namespacedRow]) => { return compiledExpr(namespacedRow); }; switch (aggExpr.name.toLowerCase()) { case `sum`: return sum(valueExtractor); case `count`: return count(rawValueExtractor); case `avg`: return avg(valueExtractor); case `min`: return min(valueExtractorForMinMax); case `max`: return max(valueExtractorForMinMax); default: throw new UnsupportedAggregateFunctionError(aggExpr.name); } } function replaceAggregatesByRefs(havingExpr, selectClause, resultAlias = `$selected`) { switch (havingExpr.type) { case `agg`: { const aggExpr = havingExpr; for (const [alias, selectExpr] of Object.entries(selectClause)) { if (selectExpr.type === `agg` && aggregatesEqual(aggExpr, selectExpr)) { return new PropRef([resultAlias, alias]); } } throw new AggregateFunctionNotInSelectError(aggExpr.name); } case `func`: { const funcExpr = havingExpr; const transformedArgs = funcExpr.args.map( (arg) => replaceAggregatesByRefs(arg, selectClause) ); return new Func(funcExpr.name, transformedArgs); } case `ref`: return havingExpr; case `val`: return havingExpr; default: throw new UnknownHavingExpressionTypeError(havingExpr.type); } } function evaluateWrappedAggregates(finalResults, aggregatedRow, wrappedAggExprs) { for (const key of Object.keys(aggregatedRow)) { if (key.startsWith(`__agg_`)) { finalResults[key] = aggregatedRow[key]; } } for (const [alias, evaluator] of Object.entries(wrappedAggExprs)) { finalResults[alias] = evaluator({ $selected: finalResults }); } for (const key of Object.keys(finalResults)) { if (key.startsWith(`__agg_`)) delete finalResults[key]; } } function containsAggregate(expr) { if (!isExpressionLike(expr)) { return false; } if (expr.type === `agg`) { return true; } if (expr.type === `func` && `args` in expr) { return expr.args.some( (arg) => containsAggregate(arg) ); } return false; } function extractAndReplaceAggregates(expr, counter) { if (expr.type === `agg`) { const alias = `__agg_${counter.value++}`; return { transformed: new PropRef([`$selected`, alias]), extracted: { [alias]: expr } }; } if (expr.type === `func`) { const allExtracted = {}; const newArgs = expr.args.map((arg) => { const result = extractAndReplaceAggregates(arg, counter); Object.assign(allExtracted, result.extracted); return result.transformed; }); return { transformed: new Func(expr.name, newArgs), extracted: allExtracted }; } return { transformed: expr, extracted: {} }; } function aggregatesEqual(agg1, agg2) { return agg1.name === agg2.name && agg1.args.length === agg2.args.length && agg1.args.every((arg, i) => expressionsEqual(arg, agg2.args[i])); } export { containsAggregate, processGroupBy, replaceAggregatesByRefs }; //# sourceMappingURL=group-by.js.map