chromadb
Version:
A JavaScript interface for chroma
514 lines (435 loc) • 14.3 kB
text/typescript
import type { SparseVector } from "../../api";
import { deepClone, isPlainObject, IterableInput } from "./common";
import { Key } from "./key";
export type RankLiteral = Record<string, unknown>;
export type RankInput = RankExpression | RankLiteral | number | null | undefined;
const requireNumber = (value: unknown, message: string): number => {
if (typeof value !== "number" || Number.isNaN(value) || !Number.isFinite(value)) {
throw new TypeError(message);
}
return value;
};
abstract class RankExpressionBase {
public abstract toJSON(): Record<string, unknown>;
public add(...others: RankInput[]): RankExpression {
if (others.length === 0) {
return this as unknown as RankExpression;
}
const expressions = [
this as unknown as RankExpression,
...others.map((item, index) => requireRank(item, `add operand ${index}`)),
];
return SumRankExpression.create(expressions);
}
public subtract(other: RankInput): RankExpression {
return new SubRankExpression(
this as unknown as RankExpression,
requireRank(other, "subtract operand"),
);
}
public multiply(...others: RankInput[]): RankExpression {
if (others.length === 0) {
return this as unknown as RankExpression;
}
const expressions = [
this as unknown as RankExpression,
...others.map((item, index) => requireRank(item, `multiply operand ${index}`)),
];
return MulRankExpression.create(expressions);
}
public divide(other: RankInput): RankExpression {
return new DivRankExpression(
this as unknown as RankExpression,
requireRank(other, "divide operand"),
);
}
public negate(): RankExpression {
return this.multiply(-1);
}
public abs(): RankExpression {
return new AbsRankExpression(this as unknown as RankExpression);
}
public exp(): RankExpression {
return new ExpRankExpression(this as unknown as RankExpression);
}
public log(): RankExpression {
return new LogRankExpression(this as unknown as RankExpression);
}
public max(...others: RankInput[]): RankExpression {
if (others.length === 0) {
return this as unknown as RankExpression;
}
const expressions = [
this as unknown as RankExpression,
...others.map((item, index) => requireRank(item, `max operand ${index}`)),
];
return MaxRankExpression.create(expressions);
}
public min(...others: RankInput[]): RankExpression {
if (others.length === 0) {
return this as unknown as RankExpression;
}
const expressions = [
this as unknown as RankExpression,
...others.map((item, index) => requireRank(item, `min operand ${index}`)),
];
return MinRankExpression.create(expressions);
}
}
export abstract class RankExpression extends RankExpressionBase {
public static from(input: RankInput): RankExpression | undefined {
if (input instanceof RankExpression) {
return input;
}
if (input === null || input === undefined) {
return undefined;
}
if (typeof input === "number") {
return new ValueRankExpression(input);
}
if (isPlainObject(input)) {
return new RawRankExpression(input);
}
throw new TypeError("Rank input must be a RankExpression, number, or plain object");
}
}
class RawRankExpression extends RankExpression {
constructor(private readonly raw: RankLiteral) {
super();
}
public toJSON(): RankLiteral {
return deepClone(this.raw);
}
}
class ValueRankExpression extends RankExpression {
constructor(private readonly value: number) {
super();
}
public toJSON(): RankLiteral {
return { $val: this.value };
}
}
class SumRankExpression extends RankExpression {
constructor(private readonly ranks: RankExpression[]) {
super();
}
public static create(ranks: RankExpression[]): RankExpression {
const flattened: RankExpression[] = [];
for (const rank of ranks) {
if (rank instanceof SumRankExpression) {
flattened.push(...rank.operands);
} else {
flattened.push(rank);
}
}
if (flattened.length === 1) {
return flattened[0];
}
return new SumRankExpression(flattened);
}
public get operands(): RankExpression[] {
return this.ranks.slice();
}
public toJSON(): RankLiteral {
return { $sum: this.ranks.map((rank) => rank.toJSON()) };
}
}
class SubRankExpression extends RankExpression {
constructor(
private readonly left: RankExpression,
private readonly right: RankExpression,
) {
super();
}
public toJSON(): RankLiteral {
return {
$sub: {
left: this.left.toJSON(),
right: this.right.toJSON(),
},
};
}
}
class MulRankExpression extends RankExpression {
constructor(private readonly ranks: RankExpression[]) {
super();
}
public static create(ranks: RankExpression[]): RankExpression {
const flattened: RankExpression[] = [];
for (const rank of ranks) {
if (rank instanceof MulRankExpression) {
flattened.push(...rank.operands);
} else {
flattened.push(rank);
}
}
if (flattened.length === 1) {
return flattened[0];
}
return new MulRankExpression(flattened);
}
public get operands(): RankExpression[] {
return this.ranks.slice();
}
public toJSON(): RankLiteral {
return { $mul: this.ranks.map((rank) => rank.toJSON()) };
}
}
class DivRankExpression extends RankExpression {
constructor(
private readonly left: RankExpression,
private readonly right: RankExpression,
) {
super();
}
public toJSON(): RankLiteral {
return {
$div: {
left: this.left.toJSON(),
right: this.right.toJSON(),
},
};
}
}
class AbsRankExpression extends RankExpression {
constructor(private readonly operand: RankExpression) {
super();
}
public toJSON(): RankLiteral {
return { $abs: this.operand.toJSON() };
}
}
class ExpRankExpression extends RankExpression {
constructor(private readonly operand: RankExpression) {
super();
}
public toJSON(): RankLiteral {
return { $exp: this.operand.toJSON() };
}
}
class LogRankExpression extends RankExpression {
constructor(private readonly operand: RankExpression) {
super();
}
public toJSON(): RankLiteral {
return { $log: this.operand.toJSON() };
}
}
class MaxRankExpression extends RankExpression {
constructor(private readonly ranks: RankExpression[]) {
super();
}
public static create(ranks: RankExpression[]): RankExpression {
const flattened: RankExpression[] = [];
for (const rank of ranks) {
if (rank instanceof MaxRankExpression) {
flattened.push(...rank.operands);
} else {
flattened.push(rank);
}
}
if (flattened.length === 1) {
return flattened[0];
}
return new MaxRankExpression(flattened);
}
public get operands(): RankExpression[] {
return this.ranks.slice();
}
public toJSON(): RankLiteral {
return { $max: this.ranks.map((rank) => rank.toJSON()) };
}
}
class MinRankExpression extends RankExpression {
constructor(private readonly ranks: RankExpression[]) {
super();
}
public static create(ranks: RankExpression[]): RankExpression {
const flattened: RankExpression[] = [];
for (const rank of ranks) {
if (rank instanceof MinRankExpression) {
flattened.push(...rank.operands);
} else {
flattened.push(rank);
}
}
if (flattened.length === 1) {
return flattened[0];
}
return new MinRankExpression(flattened);
}
public get operands(): RankExpression[] {
return this.ranks.slice();
}
public toJSON(): RankLiteral {
return { $min: this.ranks.map((rank) => rank.toJSON()) };
}
}
class KnnRankExpression extends RankExpression {
constructor(private readonly config: KnnOptionsNormalized) {
super();
}
public toJSON(): RankLiteral {
const base: Record<string, unknown> = {
query: this.config.query,
key: this.config.key,
limit: this.config.limit,
};
if (this.config.defaultValue !== undefined) {
base.default = this.config.defaultValue;
}
if (this.config.returnRank) {
base.return_rank = true;
}
return { $knn: base };
}
}
interface KnnOptionsNormalized {
query: number[] | SparseVector | string;
key: string;
limit: number;
defaultValue?: number;
returnRank: boolean;
}
export interface KnnOptions {
query: IterableInput<number> | SparseVector | string;
key?: string | Key;
limit?: number;
default?: number | null;
returnRank?: boolean;
}
const normalizeDenseVector = (vector: IterableInput<number>): number[] => {
if (Array.isArray(vector)) {
return vector.slice();
}
return Array.from(vector as Iterable<number>, (value) => {
if (typeof value !== "number" || Number.isNaN(value) || !Number.isFinite(value)) {
throw new TypeError("Dense query vector values must be finite numbers");
}
return value;
});
};
const normalizeKnnOptions = (options: KnnOptions): KnnOptionsNormalized => {
const limit = options.limit ?? 128;
if (!Number.isInteger(limit) || limit <= 0) {
throw new TypeError("Knn limit must be a positive integer");
}
const queryInput = options.query;
let query: number[] | SparseVector | string;
if (typeof queryInput === "string") {
query = queryInput;
} else if (
isPlainObject(queryInput) &&
Array.isArray((queryInput as SparseVector).indices) &&
Array.isArray((queryInput as SparseVector).values)
) {
const sparse = queryInput as SparseVector;
query = {
indices: sparse.indices.slice(),
values: sparse.values.slice(),
};
} else {
query = normalizeDenseVector(queryInput as IterableInput<number>);
}
const key = options.key instanceof Key ? options.key.name : options.key ?? "#embedding";
if (typeof key !== "string") {
throw new TypeError("Knn key must be a string or Key instance");
}
const defaultValue =
options.default === null || options.default === undefined
? undefined
: requireNumber(options.default, "Knn default must be a number");
if (defaultValue !== undefined && !Number.isFinite(defaultValue)) {
throw new TypeError("Knn default must be a finite number");
}
return {
query: Array.isArray(query) || typeof query === "string" ? query : deepClone(query),
key,
limit,
defaultValue,
returnRank: options.returnRank ?? false,
};
};
const requireRank = (input: RankInput, context: string): RankExpression => {
const result = RankExpression.from(input);
if (!result) {
throw new TypeError(`${context} must be a rank expression`);
}
return result;
};
export const Val = (value: number): RankExpression =>
new ValueRankExpression(requireNumber(value, "Val requires a numeric value"));
export const Knn = (options: KnnOptions): RankExpression =>
new KnnRankExpression(normalizeKnnOptions(options));
export interface RrfOptions {
ranks: RankInput[];
k?: number;
weights?: number[];
normalize?: boolean;
}
export const Rrf = ({ ranks, k = 60, weights, normalize = false }: RrfOptions): RankExpression => {
if (!Number.isInteger(k) || k <= 0) {
throw new TypeError("Rrf k must be a positive integer");
}
if (!Array.isArray(ranks) || ranks.length === 0) {
throw new TypeError("Rrf requires at least one rank expression");
}
const expressions = ranks.map((rank, index) => requireRank(rank, `ranks[${index}]`));
let weightValues = weights ? weights.slice() : new Array(expressions.length).fill(1);
if (weightValues.length !== expressions.length) {
throw new Error("Number of weights must match number of ranks");
}
if (weightValues.some((value) => typeof value !== "number" || value < 0)) {
throw new TypeError("Weights must be non-negative numbers");
}
if (normalize) {
const total = weightValues.reduce((sum, value) => sum + value, 0);
if (total <= 0) {
throw new Error("Weights must sum to a positive value when normalize=true");
}
weightValues = weightValues.map((value) => value / total);
}
const terms = expressions.map((rank, index) => {
const weight = weightValues[index];
const numerator = Val(weight);
const denominator = rank.add(k);
return numerator.divide(denominator);
});
const fused = terms.reduce((acc, term) => acc.add(term));
return fused.negate();
};
export const Sum = (...inputs: RankInput[]): RankExpression => {
if (inputs.length === 0) {
throw new Error("Sum requires at least one rank expression");
}
const expressions = inputs.map((rank, index) => requireRank(rank, `Sum operand ${index}`));
return SumRankExpression.create(expressions);
};
export const Sub = (left: RankInput, right: RankInput): RankExpression =>
new SubRankExpression(requireRank(left, "Sub left"), requireRank(right, "Sub right"));
export const Mul = (...inputs: RankInput[]): RankExpression => {
if (inputs.length === 0) {
throw new Error("Mul requires at least one rank expression");
}
const expressions = inputs.map((rank, index) => requireRank(rank, `Mul operand ${index}`));
return MulRankExpression.create(expressions);
};
export const Div = (left: RankInput, right: RankInput): RankExpression =>
new DivRankExpression(requireRank(left, "Div left"), requireRank(right, "Div right"));
export const Abs = (input: RankInput): RankExpression => requireRank(input, "Abs").abs();
export const Exp = (input: RankInput): RankExpression => requireRank(input, "Exp").exp();
export const Log = (input: RankInput): RankExpression => requireRank(input, "Log").log();
export const Max = (...inputs: RankInput[]): RankExpression => {
if (inputs.length === 0) {
throw new Error("Max requires at least one rank expression");
}
const expressions = inputs.map((rank, index) => requireRank(rank, `Max operand ${index}`));
return MaxRankExpression.create(expressions);
};
export const Min = (...inputs: RankInput[]): RankExpression => {
if (inputs.length === 0) {
throw new Error("Min requires at least one rank expression");
}
const expressions = inputs.map((rank, index) => requireRank(rank, `Min operand ${index}`));
return MinRankExpression.create(expressions);
};