ai-functions
Version:
Core AI primitives for building intelligent applications
196 lines (180 loc) • 6.96 kB
text/typescript
/**
* embeddingCacheMiddleware — content-addressable cache for `wrapEmbeddingModel`
*
* Embedding-side analogue of {@link cacheMiddleware}. Wraps `doEmbed` and
* caches the resulting embeddings keyed on
* `{ values, modelId, providerOptions }` so a re-embed of the same value
* batch with the same model returns the cached vectors without hitting the
* provider.
*
* **Why a separate middleware instead of reusing `cacheMiddleware`?**
* AI SDK 6 splits language-model and embedding-model surfaces:
* `LanguageModelV3Middleware` exposes `wrapGenerate` / `wrapStream` against
* `LanguageModelV3CallOptions`, while `EmbeddingModelV3Middleware` exposes
* `wrapEmbed` against `EmbeddingModelV3CallOptions`. The cache shape
* (per-value vector vs. per-prompt completion payload) is also different —
* embeddings cache batched arrays, generations cache single result objects.
*
* - **Hit derivation:** stable hash of `{ values, modelId, providerOptions }`.
* `values` is the array as-passed (caller can pre-normalise if they want
* case/whitespace insensitivity). Generation knobs don't apply.
*
* - **Batch semantics:** the cache key is the *whole* batch. A subset hit
* doesn't trigger a partial-fill — that's a more invasive shape change
* (the legacy `EmbeddingCache.getMany` did per-text caching, but it was
* only used in the example and added 100+ LOC of bookkeeping). Callers
* that want per-text caching should use stable per-text batches.
*
* - **TTL:** 24h default, configurable. Lazy expiry on access.
*
* - **Pluggable store:** in-memory default (Map-backed); custom store
* honored as-is. Disk persistence is intentionally not provided here —
* embedding payloads (large `number[][]`) make on-disk JSON a bad fit;
* callers who want it should pass a custom store.
*
* - **Env gate:** honors `process.env.V3_EVAL_CACHE` for parity with
* `cacheMiddleware`. Override via the `enabled` option.
*
* @packageDocumentation
*/
import type {
EmbeddingModelV3CallOptions,
EmbeddingModelV3Embedding,
EmbeddingModelV3Middleware,
EmbeddingModelV3Result,
SharedV3Warning,
} from '@ai-sdk/provider'
import { hashKey } from '../cache.js'
// ============================================================================
// Types
// ============================================================================
/** Cached embedding payload. */
interface EmbedCacheEntry {
/** The embedding vectors returned for the cached batch. */
embeddings: Array<EmbeddingModelV3Embedding>
/** Provider warnings carried alongside the cached batch. */
warnings: Array<SharedV3Warning>
/** Insert epoch ms — drives TTL eviction. */
createdAt: number
}
/** Pluggable cache store for embedding results. */
export interface EmbedCacheMiddlewareStore {
get(key: string): EmbedCacheEntry | undefined
set(key: string, value: EmbedCacheEntry): void
delete(key: string): void
}
/** Options for {@link embeddingCacheMiddleware}. */
export interface EmbedCacheMiddlewareOptions {
/**
* Cache backend. `'memory'` uses a process-local Map. A custom
* {@link EmbedCacheMiddlewareStore} can be passed instead.
*
* @default 'memory'
*/
store?: 'memory' | EmbedCacheMiddlewareStore
/**
* TTL in milliseconds. Entries older than `ttlMs` are evicted on access.
*
* @default 86_400_000 (24h)
*/
ttlMs?: number
/**
* Custom hash function for cache keys. Defaults to a stable hash of
* `{ values, modelId, providerOptions }`.
*/
keyHash?: (params: EmbeddingModelV3CallOptions, modelId: string) => string
/**
* Optional override for the env gate. When `false`, the middleware acts
* as a passthrough regardless of `V3_EVAL_CACHE`. When `true`, always
* caches. Defaults to `process.env.V3_EVAL_CACHE` truthy-check.
*/
enabled?: boolean
}
// ============================================================================
// Stores
// ============================================================================
class MemoryStore implements EmbedCacheMiddlewareStore {
private readonly map: Map<string, EmbedCacheEntry> = new Map()
get(key: string): EmbedCacheEntry | undefined {
return this.map.get(key)
}
set(key: string, value: EmbedCacheEntry): void {
this.map.set(key, value)
}
delete(key: string): void {
this.map.delete(key)
}
}
// ============================================================================
// Helpers
// ============================================================================
const DEFAULT_TTL_MS = 24 * 60 * 60 * 1000
function defaultKeyHash(params: EmbeddingModelV3CallOptions, modelId: string): string {
return hashKey({
values: params.values,
modelId,
providerOptions: params.providerOptions,
})
}
function envGateEnabled(): boolean {
const v = process.env['V3_EVAL_CACHE']
return typeof v === 'string' && v.length > 0
}
function isExpired(entry: EmbedCacheEntry, ttlMs: number): boolean {
return Date.now() - entry.createdAt > ttlMs
}
// ============================================================================
// Middleware
// ============================================================================
/**
* Build an embedding-cache middleware for `wrapEmbeddingModel`.
*
* @example
* ```ts
* import { wrapEmbeddingModel } from 'ai'
* import { embeddingCacheMiddleware } from 'ai-functions'
*
* const model = wrapEmbeddingModel({
* model: openai.embedding('text-embedding-3-small'),
* middleware: embeddingCacheMiddleware({ ttlMs: 86_400_000 }),
* })
* ```
*/
export function embeddingCacheMiddleware(
options: EmbedCacheMiddlewareOptions = {}
): EmbeddingModelV3Middleware {
const ttlMs = options.ttlMs ?? DEFAULT_TTL_MS
const keyHash = options.keyHash ?? defaultKeyHash
const store: EmbedCacheMiddlewareStore =
options.store === undefined || options.store === 'memory' ? new MemoryStore() : options.store
const enabled = options.enabled ?? envGateEnabled()
return {
specificationVersion: 'v3',
async wrapEmbed({ doEmbed, params, model }) {
if (!enabled) return doEmbed()
const key = keyHash(params, model.modelId)
const cached = store.get(key)
if (cached !== undefined) {
if (isExpired(cached, ttlMs)) {
store.delete(key)
} else {
// Replay shape matches EmbeddingModelV3Result. Provider-side
// metadata (response headers, body, usage) is intentionally absent
// on a hit — callers reading those should disable the cache.
const replay: EmbeddingModelV3Result = {
embeddings: cached.embeddings,
warnings: cached.warnings,
}
return replay
}
}
const result = await doEmbed()
store.set(key, {
embeddings: result.embeddings,
warnings: result.warnings,
createdAt: Date.now(),
})
return result
},
}
}