UNPKG

@tanstack/db

Version:

A reactive client store for building super fast apps on sync

587 lines (530 loc) 19.5 kB
import { filter, groupBy, groupByOperators, map, serializeValue, } from '@tanstack/db-ivm' import { Func, PropRef, getHavingExpression, isExpressionLike } from '../ir.js' import { AggregateFunctionNotInSelectError, NonAggregateExpressionNotInGroupByError, UnknownHavingExpressionTypeError, UnsupportedAggregateFunctionError, } from '../../errors.js' import { compileExpression, toBooleanPredicate } from './evaluators.js' import type { Aggregate, BasicExpression, GroupBy, Having, Select, } from '../ir.js' import type { NamespacedAndKeyedStream, NamespacedRow } from '../../types.js' const { sum, count, avg, min, max } = groupByOperators /** * Interface for caching the mapping between GROUP BY expressions and SELECT expressions */ interface GroupBySelectMapping { selectToGroupByIndex: Map<string, number> // Maps SELECT alias to GROUP BY expression index groupByExpressions: Array<any> // The GROUP BY expressions for reference } /** * Validates that all non-aggregate expressions in SELECT are present in GROUP BY * and creates a cached mapping for efficient lookup during processing */ function validateAndCreateMapping( groupByClause: GroupBy, selectClause?: Select, ): GroupBySelectMapping { const selectToGroupByIndex = new Map<string, number>() const groupByExpressions = [...groupByClause] if (!selectClause) { return { selectToGroupByIndex, groupByExpressions } } // Validate each SELECT expression for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg` || containsAggregate(expr)) { // Aggregate expressions (plain or wrapped) are allowed and don't need to be in GROUP BY continue } // Non-aggregate expression must be in GROUP BY const groupIndex = groupByExpressions.findIndex((groupExpr) => expressionsEqual(expr, groupExpr), ) if (groupIndex === -1) { throw new NonAggregateExpressionNotInGroupByError(alias) } // Cache the mapping selectToGroupByIndex.set(alias, groupIndex) } return { selectToGroupByIndex, groupByExpressions } } /** * Processes the GROUP BY clause with optional HAVING and SELECT * Works with the new $selected structure from early SELECT processing */ export function processGroupBy( pipeline: NamespacedAndKeyedStream, groupByClause: GroupBy, havingClauses?: Array<Having>, selectClause?: Select, fnHavingClauses?: Array<(row: any) => any>, ): NamespacedAndKeyedStream { // Handle empty GROUP BY (single-group aggregation) if (groupByClause.length === 0) { // For single-group aggregation, create a single group with all data const aggregates: Record<string, any> = {} // Expressions that wrap aggregates (e.g. coalesce(count(...), 0)). // Keys are the original SELECT aliases; values are pre-compiled evaluators // over the transformed (aggregate-free) expression. const wrappedAggExprs: Record<string, (data: any) => any> = {} const aggCounter = { value: 0 } if (selectClause) { // Scan the SELECT clause for aggregate functions for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg`) { aggregates[alias] = getAggregateFunction(expr) } else if (containsAggregate(expr)) { const { transformed, extracted } = extractAndReplaceAggregates( expr as BasicExpression | Aggregate, aggCounter, ) for (const [syntheticAlias, aggExpr] of Object.entries(extracted)) { aggregates[syntheticAlias] = getAggregateFunction(aggExpr) } wrappedAggExprs[alias] = compileExpression(transformed) } } } // Use a constant key for single group const keyExtractor = () => ({ __singleGroup: true }) // Apply the groupBy operator with single group pipeline = pipeline.pipe( groupBy(keyExtractor, aggregates), ) as NamespacedAndKeyedStream // Update $selected to include aggregate values pipeline = pipeline.pipe( map(([, aggregatedRow]) => { // Start with the existing $selected from early SELECT processing const selectResults = (aggregatedRow as any).$selected || {} const finalResults: Record<string, any> = { ...selectResults } if (selectClause) { // First pass: populate plain aggregate results and synthetic aliases for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg`) { finalResults[alias] = aggregatedRow[alias] } } evaluateWrappedAggregates( finalResults, aggregatedRow as Record<string, any>, wrappedAggExprs, ) } // Use a single key for the result and update $selected return [ `single_group`, { ...aggregatedRow, $selected: finalResults, }, ] as [unknown, Record<string, any>] }), ) // Apply HAVING clauses if present if (havingClauses && havingClauses.length > 0) { for (const havingClause of havingClauses) { const havingExpression = getHavingExpression(havingClause) const transformedHavingClause = replaceAggregatesByRefs( havingExpression, selectClause || {}, `$selected`, ) const compiledHaving = compileExpression(transformedHavingClause) pipeline = pipeline.pipe( filter(([, row]) => { // Create a namespaced row structure for HAVING evaluation const namespacedRow = { $selected: (row as any).$selected } return toBooleanPredicate(compiledHaving(namespacedRow)) }), ) } } // Apply functional HAVING clauses if present if (fnHavingClauses && fnHavingClauses.length > 0) { for (const fnHaving of fnHavingClauses) { pipeline = pipeline.pipe( filter(([, row]) => { // Create a namespaced row structure for functional HAVING evaluation const namespacedRow = { $selected: (row as any).$selected } return toBooleanPredicate(fnHaving(namespacedRow)) }), ) } } return pipeline } // Multi-group aggregation logic... // Validate and create mapping for non-aggregate expressions in SELECT const mapping = validateAndCreateMapping(groupByClause, selectClause) // Pre-compile groupBy expressions const compiledGroupByExpressions = groupByClause.map((e) => compileExpression(e), ) // Create a key extractor function using simple __key_X format const keyExtractor = ([, row]: [ string, NamespacedRow & { $selected?: any }, ]) => { // Use the original namespaced row for GROUP BY expressions, not $selected const namespacedRow = { ...row } delete (namespacedRow as any).$selected const key: Record<string, unknown> = {} // Use simple __key_X format for each groupBy expression for (let i = 0; i < groupByClause.length; i++) { const compiledExpr = compiledGroupByExpressions[i]! const value = compiledExpr(namespacedRow) key[`__key_${i}`] = value } return key } // Create aggregate functions for any aggregated columns in the SELECT clause const aggregates: Record<string, any> = {} const wrappedAggExprs: Record<string, (data: any) => any> = {} const aggCounter = { value: 0 } if (selectClause) { // Scan the SELECT clause for aggregate functions for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg`) { aggregates[alias] = getAggregateFunction(expr) } else if (containsAggregate(expr)) { const { transformed, extracted } = extractAndReplaceAggregates( expr as BasicExpression | Aggregate, aggCounter, ) for (const [syntheticAlias, aggExpr] of Object.entries(extracted)) { aggregates[syntheticAlias] = getAggregateFunction(aggExpr) } wrappedAggExprs[alias] = compileExpression(transformed) } } } // Apply the groupBy operator pipeline = pipeline.pipe(groupBy(keyExtractor, aggregates)) // Update $selected to handle GROUP BY results pipeline = pipeline.pipe( map(([, aggregatedRow]) => { // Start with the existing $selected from early SELECT processing const selectResults = (aggregatedRow as any).$selected || {} const finalResults: Record<string, any> = {} if (selectClause) { // First pass: populate group keys, plain aggregates, and synthetic aliases for (const [alias, expr] of Object.entries(selectClause)) { if (expr.type === `agg`) { finalResults[alias] = aggregatedRow[alias] } else if (!wrappedAggExprs[alias]) { // Use cached mapping to get the corresponding __key_X for non-aggregates const groupIndex = mapping.selectToGroupByIndex.get(alias) if (groupIndex !== undefined) { finalResults[alias] = aggregatedRow[`__key_${groupIndex}`] } else { // Fallback to original SELECT results finalResults[alias] = selectResults[alias] } } } evaluateWrappedAggregates( finalResults, aggregatedRow as Record<string, any>, wrappedAggExprs, ) } else { // No SELECT clause - just use the group keys for (let i = 0; i < groupByClause.length; i++) { finalResults[`__key_${i}`] = aggregatedRow[`__key_${i}`] } } // Generate a simple key for the live collection using group values let finalKey: unknown if (groupByClause.length === 1) { finalKey = aggregatedRow[`__key_0`] } else { const keyParts: Array<unknown> = [] for (let i = 0; i < groupByClause.length; i++) { keyParts.push(aggregatedRow[`__key_${i}`]) } finalKey = serializeValue(keyParts) } return [ finalKey, { ...aggregatedRow, $selected: finalResults, }, ] as [unknown, Record<string, any>] }), ) // Apply HAVING clauses if present if (havingClauses && havingClauses.length > 0) { for (const havingClause of havingClauses) { const havingExpression = getHavingExpression(havingClause) const transformedHavingClause = replaceAggregatesByRefs( havingExpression, selectClause || {}, ) const compiledHaving = compileExpression(transformedHavingClause) pipeline = pipeline.pipe( filter(([, row]) => { // Create a namespaced row structure for HAVING evaluation const namespacedRow = { $selected: (row as any).$selected } return compiledHaving(namespacedRow) }), ) } } // Apply functional HAVING clauses if present if (fnHavingClauses && fnHavingClauses.length > 0) { for (const fnHaving of fnHavingClauses) { pipeline = pipeline.pipe( filter(([, row]) => { // Create a namespaced row structure for functional HAVING evaluation const namespacedRow = { $selected: (row as any).$selected } return toBooleanPredicate(fnHaving(namespacedRow)) }), ) } } return pipeline } /** * Helper function to check if two expressions are equal */ function expressionsEqual(expr1: any, expr2: any): boolean { if (!expr1 || !expr2) return false if (expr1.type !== expr2.type) return false switch (expr1.type) { case `ref`: // Compare paths as arrays if (!expr1.path || !expr2.path) return false if (expr1.path.length !== expr2.path.length) return false return expr1.path.every( (segment: string, i: number) => segment === expr2.path[i], ) case `val`: return expr1.value === expr2.value case `func`: return ( expr1.name === expr2.name && expr1.args?.length === expr2.args?.length && (expr1.args || []).every((arg: any, i: number) => expressionsEqual(arg, expr2.args[i]), ) ) case `agg`: return ( expr1.name === expr2.name && expr1.args?.length === expr2.args?.length && (expr1.args || []).every((arg: any, i: number) => expressionsEqual(arg, expr2.args[i]), ) ) default: return false } } /** * Helper function to get an aggregate function based on the Agg expression */ function getAggregateFunction(aggExpr: Aggregate) { // Pre-compile the value extractor expression const compiledExpr = compileExpression(aggExpr.args[0]!) // Create a value extractor function for the expression to aggregate const valueExtractor = ([, namespacedRow]: [string, NamespacedRow]) => { const value = compiledExpr(namespacedRow) // Ensure we return a number for numeric aggregate functions if (typeof value === `number`) { return value } return value != null ? Number(value) : 0 } // Create a value extractor function for min/max that preserves comparable types const valueExtractorForMinMax = ([, namespacedRow]: [ string, NamespacedRow, ]) => { const value = compiledExpr(namespacedRow) // Preserve strings, numbers, Dates, and bigints for comparison if ( typeof value === `number` || typeof value === `string` || typeof value === `bigint` || value instanceof Date ) { return value } return value != null ? Number(value) : 0 } // Create a raw value extractor function for the expression to aggregate const rawValueExtractor = ([, namespacedRow]: [string, NamespacedRow]) => { return compiledExpr(namespacedRow) } // Return the appropriate aggregate function switch (aggExpr.name.toLowerCase()) { case `sum`: return sum(valueExtractor) case `count`: return count(rawValueExtractor) case `avg`: return avg(valueExtractor) case `min`: return min(valueExtractorForMinMax) case `max`: return max(valueExtractorForMinMax) default: throw new UnsupportedAggregateFunctionError(aggExpr.name) } } /** * Transforms expressions to replace aggregate functions with references to computed values. * * For aggregate expressions, finds matching aggregates in the SELECT clause and replaces them * with PropRef([resultAlias, alias]) to reference the computed aggregate value. * * Ref expressions (table columns and $selected fields) and value expressions are passed through unchanged. * Function expressions are recursively transformed. * * @param havingExpr - The expression to transform (can be aggregate, ref, func, or val) * @param selectClause - The SELECT clause containing aliases and aggregate definitions * @param resultAlias - The namespace alias for SELECT results (default: '$selected') * @returns A transformed BasicExpression that references computed values instead of raw expressions */ export function replaceAggregatesByRefs( havingExpr: BasicExpression | Aggregate, selectClause: Select, resultAlias: string = `$selected`, ): BasicExpression { switch (havingExpr.type) { case `agg`: { const aggExpr = havingExpr // Find matching aggregate in SELECT clause for (const [alias, selectExpr] of Object.entries(selectClause)) { if (selectExpr.type === `agg` && aggregatesEqual(aggExpr, selectExpr)) { // Replace with a reference to the computed aggregate return new PropRef([resultAlias, alias]) } } // If no matching aggregate found in SELECT, throw error throw new AggregateFunctionNotInSelectError(aggExpr.name) } case `func`: { const funcExpr = havingExpr // Transform function arguments recursively const transformedArgs = funcExpr.args.map( (arg: BasicExpression | Aggregate) => replaceAggregatesByRefs(arg, selectClause), ) return new Func(funcExpr.name, transformedArgs) } case `ref`: // Ref expressions are passed through unchanged - they reference either: // - $selected fields (which are already in the correct namespace) // - Table column references (which remain valid) return havingExpr as BasicExpression case `val`: // Return as-is return havingExpr as BasicExpression default: throw new UnknownHavingExpressionTypeError((havingExpr as any).type) } } /** * Evaluates wrapped-aggregate expressions against the aggregated row. * Copies synthetic __agg_N values into finalResults so the compiled wrapper * expressions can reference them, evaluates each wrapper, then removes the * synthetic keys so they don't leak onto user-visible result rows. */ function evaluateWrappedAggregates( finalResults: Record<string, any>, aggregatedRow: Record<string, any>, wrappedAggExprs: Record<string, (data: any) => any>, ): void { for (const key of Object.keys(aggregatedRow)) { if (key.startsWith(`__agg_`)) { finalResults[key] = aggregatedRow[key] } } for (const [alias, evaluator] of Object.entries(wrappedAggExprs)) { finalResults[alias] = evaluator({ $selected: finalResults }) } for (const key of Object.keys(finalResults)) { if (key.startsWith(`__agg_`)) delete finalResults[key] } } /** * Checks whether an expression contains an aggregate anywhere in its tree. * Returns true for a top-level Aggregate, or a Func whose args (recursively) * contain an Aggregate. Safely returns false for nested Select objects. */ export function containsAggregate( expr: BasicExpression | Aggregate | Select, ): boolean { if (!isExpressionLike(expr)) { return false } if (expr.type === `agg`) { return true } if (expr.type === `func`) { return expr.args.some((arg: BasicExpression | Aggregate) => containsAggregate(arg), ) } return false } /** * Walks an expression tree containing nested aggregates. * Each Aggregate node is extracted, assigned a synthetic alias (__agg_N), * and replaced with PropRef(["$selected", "__agg_N"]) so the wrapper * expression can be compiled as a pure BasicExpression after groupBy * populates the synthetic values. */ function extractAndReplaceAggregates( expr: BasicExpression | Aggregate, counter: { value: number }, ): { transformed: BasicExpression extracted: Record<string, Aggregate> } { if (expr.type === `agg`) { const alias = `__agg_${counter.value++}` return { transformed: new PropRef([`$selected`, alias]), extracted: { [alias]: expr }, } } if (expr.type === `func`) { const allExtracted: Record<string, Aggregate> = {} const newArgs = expr.args.map((arg: BasicExpression | Aggregate) => { const result = extractAndReplaceAggregates(arg, counter) Object.assign(allExtracted, result.extracted) return result.transformed }) return { transformed: new Func(expr.name, newArgs), extracted: allExtracted, } } // ref / val – pass through unchanged return { transformed: expr as BasicExpression, extracted: {} } } /** * Checks if two aggregate expressions are equal */ function aggregatesEqual(agg1: Aggregate, agg2: Aggregate): boolean { return ( agg1.name === agg2.name && agg1.args.length === agg2.args.length && agg1.args.every((arg, i) => expressionsEqual(arg, agg2.args[i])) ) }