@wearesage/schema
Version:
A flexible schema definition and validation system for TypeScript with multi-database support
346 lines (288 loc) • 10.5 kB
text/typescript
import 'reflect-metadata';
import { Type } from '../core/types';
/**
* Configuration for automatic embeddings generation
*/
export interface EmbeddingsConfig {
fields: string[]; // Which fields to embed
model?: string; // Embedding model to use
dimensions?: number; // Vector dimensions
chunkSize?: number; // Max tokens per chunk
overlap?: number; // Overlap between chunks
vectorStore?: 'pgvector' | 'neo4j-vector' | 'memory';
embeddingField?: string; // Where to store the embedding
metadataFields?: string[]; // Additional metadata to store with chunks
autoUpdate?: boolean; // Auto-update embeddings on field changes
provider?: 'openai' | 'azure' | 'local' | 'ollama'; // Embedding provider
}
/**
* Decorator for automatic embeddings generation
*/
export function Embeddings(config: Partial<EmbeddingsConfig> = {}) {
return function(target: any) {
const defaultConfig: EmbeddingsConfig = {
fields: ['content'],
model: 'text-embedding-3-large',
dimensions: 1536,
chunkSize: 512,
overlap: 50,
vectorStore: 'pgvector',
embeddingField: 'embedding',
metadataFields: ['id', 'createdAt'],
autoUpdate: true,
provider: 'ollama'
};
const finalConfig = { ...defaultConfig, ...config };
Reflect.defineMetadata('embeddings:config', finalConfig, target);
};
}
/**
* Interface for embedding providers
*/
export interface EmbeddingProvider {
generateEmbedding(text: string, model?: string): Promise<number[]>;
generateBatchEmbeddings(texts: string[], model?: string): Promise<number[][]>;
}
/**
* Ollama embedding provider (LOCAL!)
*/
export class OllamaEmbeddingProvider implements EmbeddingProvider {
constructor(private baseUrl = 'http://localhost:11434') {}
async generateEmbedding(text: string, model = 'nomic-embed-text:latest'): Promise<number[]> {
const response = await fetch(`${this.baseUrl}/api/embeddings`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
model: model,
prompt: text
})
});
if (!response.ok) {
throw new Error(`Ollama API error: ${response.statusText}`);
}
const data = await response.json();
return data.embedding;
}
async generateBatchEmbeddings(texts: string[], model = 'nomic-embed-text:latest'): Promise<number[][]> {
// Ollama doesn't support batch, so we'll do sequential (could be parallelized)
const embeddings: number[][] = [];
for (const text of texts) {
const embedding = await this.generateEmbedding(text, model);
embeddings.push(embedding);
}
return embeddings;
}
}
/**
* OpenAI embedding provider
*/
export class OpenAIEmbeddingProvider implements EmbeddingProvider {
constructor(private apiKey: string) {}
async generateEmbedding(text: string, model = 'text-embedding-3-large'): Promise<number[]> {
const response = await fetch('https://api.openai.com/v1/embeddings', {
method: 'POST',
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json'
},
body: JSON.stringify({
input: text,
model: model
})
});
if (!response.ok) {
throw new Error(`OpenAI API error: ${response.statusText}`);
}
const data = await response.json();
return data.data[0].embedding;
}
async generateBatchEmbeddings(texts: string[], model = 'text-embedding-3-large'): Promise<number[][]> {
const response = await fetch('https://api.openai.com/v1/embeddings', {
method: 'POST',
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json'
},
body: JSON.stringify({
input: texts,
model: model
})
});
if (!response.ok) {
throw new Error(`OpenAI API error: ${response.statusText}`);
}
const data = await response.json();
return data.data.map((item: any) => item.embedding);
}
}
/**
* Text chunking utilities for large content
*/
export class TextChunker {
static chunkText(text: string, chunkSize = 512, overlap = 50): string[] {
if (!text || text.length === 0) return [];
// Simple word-based chunking (could be enhanced with token counting)
const words = text.split(/\s+/);
const chunks: string[] = [];
let currentChunk: string[] = [];
let currentLength = 0;
for (const word of words) {
if (currentLength + word.length > chunkSize && currentChunk.length > 0) {
chunks.push(currentChunk.join(' '));
// Start new chunk with overlap
const overlapWords = currentChunk.slice(-overlap);
currentChunk = overlapWords;
currentLength = overlapWords.join(' ').length;
}
currentChunk.push(word);
currentLength += word.length + 1; // +1 for space
}
if (currentChunk.length > 0) {
chunks.push(currentChunk.join(' '));
}
return chunks.filter(chunk => chunk.trim().length > 0);
}
static extractRelevantText(entity: any, fields: string[]): string {
const textParts: string[] = [];
for (const field of fields) {
const value = entity[field];
if (value && typeof value === 'string') {
textParts.push(value);
} else if (value && typeof value === 'object') {
// Handle nested objects by JSON stringifying
textParts.push(JSON.stringify(value));
}
}
return textParts.join('\n\n');
}
}
/**
* Embedding service that handles the actual embedding generation and storage
*/
export class EmbeddingService {
private providers: Map<string, EmbeddingProvider> = new Map();
constructor() {
// Initialize default providers
// Always add Ollama provider since it's local
this.providers.set('ollama', new OllamaEmbeddingProvider());
this.providers.set('local', new OllamaEmbeddingProvider()); // alias
if (process.env.OPENAI_API_KEY) {
this.providers.set('openai', new OpenAIEmbeddingProvider(process.env.OPENAI_API_KEY));
}
}
addProvider(name: string, provider: EmbeddingProvider) {
this.providers.set(name, provider);
}
getProvider(name: string): EmbeddingProvider {
const provider = this.providers.get(name);
if (!provider) {
throw new Error(`Embedding provider '${name}' not found`);
}
return provider;
}
/**
* Generate embeddings for an entity based on its configuration
*/
async generateEmbeddings<T extends object>(entity: T): Promise<void> {
const entityType = entity.constructor as Type<T>;
const config = this.getEmbeddingsConfig(entityType);
if (!config) {
return; // No embeddings configuration
}
const provider = this.getProvider(config.provider!);
const text = TextChunker.extractRelevantText(entity, config.fields);
if (!text || text.trim().length === 0) {
return; // No text to embed
}
try {
// For simple case, generate single embedding
if (text.length <= (config.chunkSize || 512)) {
const embedding = await provider.generateEmbedding(text, config.model);
(entity as any)[config.embeddingField!] = embedding;
} else {
// For large text, chunk and create multiple embeddings
const chunks = TextChunker.chunkText(text, config.chunkSize, config.overlap);
const embeddings = await provider.generateBatchEmbeddings(chunks, config.model);
// Store as array of embeddings or combine them (average, max, etc.)
(entity as any)[config.embeddingField!] = this.combineEmbeddings(embeddings);
(entity as any)[`${config.embeddingField!}_chunks`] = chunks;
}
console.log(`🧠 Generated embeddings for ${entityType.name} with ${config.fields.join(', ')}`);
} catch (error) {
console.error(`❌ Failed to generate embeddings for ${entityType.name}:`, error);
throw error;
}
}
/**
* Combine multiple embeddings into a single representative embedding
*/
private combineEmbeddings(embeddings: number[][]): number[] {
if (embeddings.length === 0) return [];
if (embeddings.length === 1) return embeddings[0];
// Average pooling
const dimensions = embeddings[0].length;
const combined = new Array(dimensions).fill(0);
for (const embedding of embeddings) {
for (let i = 0; i < dimensions; i++) {
combined[i] += embedding[i];
}
}
for (let i = 0; i < dimensions; i++) {
combined[i] /= embeddings.length;
}
return combined;
}
/**
* Get embeddings configuration for an entity type
*/
private getEmbeddingsConfig(entityType: Type<any>): EmbeddingsConfig | undefined {
return Reflect.getMetadata('embeddings:config', entityType);
}
/**
* Semantic similarity search using cosine similarity
*/
static cosineSimilarity(a: number[], b: number[]): number {
if (a.length !== b.length) {
throw new Error('Vectors must have the same length');
}
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
if (normA === 0 || normB === 0) {
return 0;
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
/**
* Find similar entities based on embedding similarity
*/
async findSimilar<T>(
queryEmbedding: number[],
entities: T[],
embeddingField: string,
threshold = 0.7
): Promise<Array<{ entity: T, similarity: number }>> {
const results: Array<{ entity: T, similarity: number }> = [];
for (const entity of entities) {
const entityEmbedding = (entity as any)[embeddingField];
if (!entityEmbedding || !Array.isArray(entityEmbedding)) {
continue;
}
const similarity = EmbeddingService.cosineSimilarity(queryEmbedding, entityEmbedding);
if (similarity >= threshold) {
results.push({ entity, similarity });
}
}
// Sort by similarity (highest first)
return results.sort((a, b) => b.similarity - a.similarity);
}
}
// Export the singleton service
export const embeddingService = new EmbeddingService();