@statezero/core
Version:
The type-safe frontend client for StateZero - connect directly to your backend models with zero boilerplate
285 lines (284 loc) • 12.1 kB
JavaScript
var __classPrivateFieldGet = (this && this.__classPrivateFieldGet) || function (receiver, state, kind, f) {
if (kind === "a" && !f) throw new TypeError("Private accessor was defined without a getter");
if (typeof state === "function" ? receiver !== state || !f : !state.has(receiver)) throw new TypeError("Cannot read private member from an object whose class did not declare it");
return kind === "m" ? f : kind === "a" ? f.call(receiver) : f ? f.value : state.get(receiver);
};
var _a, _MetricStrategyFactory_customStrategies, _MetricStrategyFactory_defaultStrategies, _MetricStrategyFactory_generateStrategyKey;
import { Status, Type } from '../stores/operation';
import { isNil } from 'lodash-es';
/**
* Base class for metric calculation strategies with operations
*/
export class MetricCalculationStrategy {
/**
* Calculate metric based on ground truth and operations
*/
calculateWithOperations(baseValue, operations, field, ModelClass) {
// Initialize the current value based on the metric type
let currentValue = baseValue === null ? this.getInitialValue() : baseValue;
// Process operations sequentially
for (const op of operations) {
// Skip rejected operations
if (op.status === Status.REJECTED)
continue;
// Process each instance in the operation
op.frozenInstances.forEach((originalData, index) => {
let pk = originalData[ModelClass.primaryKeyField];
// Get the updated data for this instance (for UPDATE operations)
const updatedData = op.instances[index] || null;
// Create operation data object for the reducer
const singleOp = {
originalData, // Pre-operation state
updatedData, // Post-operation state (for updates)
type: op.type,
status: op.status
};
// Apply this single operation using the strategy-specific reducer
currentValue = this.reduceOperation(currentValue, singleOp, field);
});
}
return currentValue;
}
/**
* Get initial value for the metric type
* Override in subclasses if needed
*/
getInitialValue() {
return 0; // Default for count and sum
}
/**
* Process a single operation - implement in subclasses
*/
reduceOperation(currentValue, operation, field) {
throw new Error('reduceOperation must be implemented by subclass');
}
/**
* Safely get a numeric value from an object
*/
safeGetValue(obj, field) {
if (!obj || !field)
return 0;
const value = obj[field];
if (isNil(value))
return 0;
const numValue = parseFloat(value);
return isNaN(numValue) ? 0 : numValue;
}
}
/**
* Count strategy implementation
*/
export class CountStrategy extends MetricCalculationStrategy {
reduceOperation(currentCount, operation, field) {
// Skip rejected operations
if (operation.status === Status.REJECTED) {
return currentCount;
}
const { type } = operation;
// Handle operation types
if (type === Type.CREATE) {
return currentCount + 1;
}
else if ([Type.DELETE, Type.DELETE_INSTANCE].includes(type)) {
return Math.max(0, currentCount - 1);
}
// Other operation types don't affect the count
return currentCount;
}
}
/**
* Sum strategy implementation
*/
export class SumStrategy extends MetricCalculationStrategy {
reduceOperation(currentSum, operation, field) {
// Skip rejected operations
if (operation.status === Status.REJECTED) {
return currentSum;
}
if (!field) {
throw new Error('SumStrategy requires a field parameter');
}
const { type, originalData, updatedData } = operation;
switch (type) {
case Type.CREATE:
// For CREATE, add the value to the sum
return currentSum + this.safeGetValue(originalData, field);
case Type.CHECKPOINT:
case Type.UPDATE:
// For UPDATE, subtract old value and add new value
if (updatedData) {
const oldValue = this.safeGetValue(originalData, field);
const newValue = this.safeGetValue(updatedData, field);
return currentSum - oldValue + newValue;
}
return currentSum;
case Type.DELETE:
case Type.DELETE_INSTANCE:
// For DELETE, subtract the value from the sum
return currentSum - this.safeGetValue(originalData, field);
default:
return currentSum;
}
}
}
/**
* Min strategy implementation
*/
export class MinStrategy extends MetricCalculationStrategy {
getInitialValue() {
return Infinity;
}
reduceOperation(currentMin, operation, field) {
// Skip rejected operations
if (operation.status === Status.REJECTED) {
return currentMin;
}
if (!field) {
throw new Error('MinStrategy requires a field parameter');
}
const { type, originalData, updatedData } = operation;
if (type === Type.CREATE) {
// For CREATE, check if the new value is smaller than current min
const value = this.safeGetValue(originalData, field);
return Math.min(currentMin, value);
}
else if ((type === Type.UPDATE || type.CHECKPOINT) && updatedData) {
// For UPDATE, first check if we're updating the minimum value
const oldValue = this.safeGetValue(originalData, field);
const newValue = this.safeGetValue(updatedData, field);
if (oldValue === currentMin) {
// We're updating the current minimum, need to find the new minimum
if (newValue <= oldValue) {
// Simple case: new value is still the minimum
return newValue;
}
else {
// Harder case: need to recalculate minimum
// For now, conservatively use the new value
return newValue;
}
}
else if (newValue < currentMin) {
// The updated value is a new minimum
return newValue;
}
}
// For other cases, maintain current min
return currentMin;
}
}
/**
* Max strategy implementation
*/
export class MaxStrategy extends MetricCalculationStrategy {
getInitialValue() {
return -Infinity;
}
reduceOperation(currentMax, operation, field) {
// Skip rejected operations
if (operation.status === Status.REJECTED) {
return currentMax;
}
if (!field) {
throw new Error('MaxStrategy requires a field parameter');
}
const { type, originalData, updatedData } = operation;
if (type === Type.CREATE) {
// For CREATE, check if the new value is larger than current max
const value = this.safeGetValue(originalData, field);
return Math.max(currentMax, value);
}
else if ((type === Type.UPDATE || type === Type.CHECKPOINT) && updatedData) {
// For UPDATE, first check if we're updating the maximum value
const oldValue = this.safeGetValue(originalData, field);
const newValue = this.safeGetValue(updatedData, field);
if (oldValue === currentMax) {
// We're updating the current maximum, need to find the new maximum
if (newValue >= oldValue) {
// Simple case: new value is still the maximum
return newValue;
}
else {
// Harder case: need to recalculate maximum
// For now, conservatively use the new value
return newValue;
}
}
else if (newValue > currentMax) {
// The updated value is a new maximum
return newValue;
}
}
// For other cases, maintain current max
return currentMax;
}
}
/**
* Factory class for creating the appropriate strategy
*/
export class MetricStrategyFactory {
/**
* Clear all custom strategy overrides
*/
static clearCustomStrategies() {
__classPrivateFieldGet(this, _a, "f", _MetricStrategyFactory_customStrategies).clear();
}
/**
* Override a strategy for a specific metric type and model class
* @param {string} metricType - The type of metric (count, sum, min, max)
* @param {Function|null} ModelClass - The model class or null for a generic override
* @param {MetricCalculationStrategy} strategy - The strategy to use
*/
static overrideStrategy(metricType, ModelClass, strategy) {
if (!metricType || !strategy) {
throw new Error('overrideStrategy requires metricType and strategy');
}
if (!(strategy instanceof MetricCalculationStrategy)) {
throw new Error('strategy must be an instance of MetricCalculationStrategy');
}
let key;
if (ModelClass) {
// Model-specific override
key = __classPrivateFieldGet(this, _a, "m", _MetricStrategyFactory_generateStrategyKey).call(this, metricType, ModelClass);
}
else {
// Generic override for all models
key = `${metricType}::*::*`;
}
__classPrivateFieldGet(this, _a, "f", _MetricStrategyFactory_customStrategies).set(key, strategy);
}
/**
* Get the appropriate strategy for a model class and metric type
* @param {string} metricType - The type of metric (count, sum, min, max)
* @param {Function} ModelClass - The model class
* @returns {MetricCalculationStrategy} The appropriate strategy
*/
static getStrategy(metricType, ModelClass) {
const normalizedMetricType = metricType.toLowerCase();
// Check for model-specific override first
const specificKey = __classPrivateFieldGet(this, _a, "m", _MetricStrategyFactory_generateStrategyKey).call(this, normalizedMetricType, ModelClass);
if (__classPrivateFieldGet(this, _a, "f", _MetricStrategyFactory_customStrategies).has(specificKey)) {
return __classPrivateFieldGet(this, _a, "f", _MetricStrategyFactory_customStrategies).get(specificKey);
}
// Check for metric-only override (works across all models)
const genericKey = `${normalizedMetricType}::*::*`;
if (__classPrivateFieldGet(this, _a, "f", _MetricStrategyFactory_customStrategies).has(genericKey)) {
return __classPrivateFieldGet(this, _a, "f", _MetricStrategyFactory_customStrategies).get(genericKey);
}
// Otherwise, return the default strategy based on the metric type
const strategyCreator = __classPrivateFieldGet(this, _a, "f", _MetricStrategyFactory_defaultStrategies).get(normalizedMetricType) || __classPrivateFieldGet(this, _a, "f", _MetricStrategyFactory_defaultStrategies).get('count');
return strategyCreator();
}
}
_a = MetricStrategyFactory, _MetricStrategyFactory_generateStrategyKey = function _MetricStrategyFactory_generateStrategyKey(metricType, ModelClass) {
return `${metricType}::${ModelClass.configKey}::${ModelClass.modelName}`;
};
// Collection of custom strategy overrides
_MetricStrategyFactory_customStrategies = { value: new Map() };
// Default strategy map
_MetricStrategyFactory_defaultStrategies = { value: new Map([
['count', () => new CountStrategy()],
['sum', () => new SumStrategy()],
['min', () => new MinStrategy()],
['max', () => new MaxStrategy()]
]) };