UNPKG

@wearesage/schema

Version:

A flexible schema definition and validation system for TypeScript with multi-database support

346 lines (288 loc) 10.5 kB
import 'reflect-metadata'; import { Type } from '../core/types'; /** * Configuration for automatic embeddings generation */ export interface EmbeddingsConfig { fields: string[]; // Which fields to embed model?: string; // Embedding model to use dimensions?: number; // Vector dimensions chunkSize?: number; // Max tokens per chunk overlap?: number; // Overlap between chunks vectorStore?: 'pgvector' | 'neo4j-vector' | 'memory'; embeddingField?: string; // Where to store the embedding metadataFields?: string[]; // Additional metadata to store with chunks autoUpdate?: boolean; // Auto-update embeddings on field changes provider?: 'openai' | 'azure' | 'local' | 'ollama'; // Embedding provider } /** * Decorator for automatic embeddings generation */ export function Embeddings(config: Partial<EmbeddingsConfig> = {}) { return function(target: any) { const defaultConfig: EmbeddingsConfig = { fields: ['content'], model: 'text-embedding-3-large', dimensions: 1536, chunkSize: 512, overlap: 50, vectorStore: 'pgvector', embeddingField: 'embedding', metadataFields: ['id', 'createdAt'], autoUpdate: true, provider: 'ollama' }; const finalConfig = { ...defaultConfig, ...config }; Reflect.defineMetadata('embeddings:config', finalConfig, target); }; } /** * Interface for embedding providers */ export interface EmbeddingProvider { generateEmbedding(text: string, model?: string): Promise<number[]>; generateBatchEmbeddings(texts: string[], model?: string): Promise<number[][]>; } /** * Ollama embedding provider (LOCAL!) */ export class OllamaEmbeddingProvider implements EmbeddingProvider { constructor(private baseUrl = 'http://localhost:11434') {} async generateEmbedding(text: string, model = 'nomic-embed-text:latest'): Promise<number[]> { const response = await fetch(`${this.baseUrl}/api/embeddings`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model: model, prompt: text }) }); if (!response.ok) { throw new Error(`Ollama API error: ${response.statusText}`); } const data = await response.json(); return data.embedding; } async generateBatchEmbeddings(texts: string[], model = 'nomic-embed-text:latest'): Promise<number[][]> { // Ollama doesn't support batch, so we'll do sequential (could be parallelized) const embeddings: number[][] = []; for (const text of texts) { const embedding = await this.generateEmbedding(text, model); embeddings.push(embedding); } return embeddings; } } /** * OpenAI embedding provider */ export class OpenAIEmbeddingProvider implements EmbeddingProvider { constructor(private apiKey: string) {} async generateEmbedding(text: string, model = 'text-embedding-3-large'): Promise<number[]> { const response = await fetch('https://api.openai.com/v1/embeddings', { method: 'POST', headers: { 'Authorization': `Bearer ${this.apiKey}`, 'Content-Type': 'application/json' }, body: JSON.stringify({ input: text, model: model }) }); if (!response.ok) { throw new Error(`OpenAI API error: ${response.statusText}`); } const data = await response.json(); return data.data[0].embedding; } async generateBatchEmbeddings(texts: string[], model = 'text-embedding-3-large'): Promise<number[][]> { const response = await fetch('https://api.openai.com/v1/embeddings', { method: 'POST', headers: { 'Authorization': `Bearer ${this.apiKey}`, 'Content-Type': 'application/json' }, body: JSON.stringify({ input: texts, model: model }) }); if (!response.ok) { throw new Error(`OpenAI API error: ${response.statusText}`); } const data = await response.json(); return data.data.map((item: any) => item.embedding); } } /** * Text chunking utilities for large content */ export class TextChunker { static chunkText(text: string, chunkSize = 512, overlap = 50): string[] { if (!text || text.length === 0) return []; // Simple word-based chunking (could be enhanced with token counting) const words = text.split(/\s+/); const chunks: string[] = []; let currentChunk: string[] = []; let currentLength = 0; for (const word of words) { if (currentLength + word.length > chunkSize && currentChunk.length > 0) { chunks.push(currentChunk.join(' ')); // Start new chunk with overlap const overlapWords = currentChunk.slice(-overlap); currentChunk = overlapWords; currentLength = overlapWords.join(' ').length; } currentChunk.push(word); currentLength += word.length + 1; // +1 for space } if (currentChunk.length > 0) { chunks.push(currentChunk.join(' ')); } return chunks.filter(chunk => chunk.trim().length > 0); } static extractRelevantText(entity: any, fields: string[]): string { const textParts: string[] = []; for (const field of fields) { const value = entity[field]; if (value && typeof value === 'string') { textParts.push(value); } else if (value && typeof value === 'object') { // Handle nested objects by JSON stringifying textParts.push(JSON.stringify(value)); } } return textParts.join('\n\n'); } } /** * Embedding service that handles the actual embedding generation and storage */ export class EmbeddingService { private providers: Map<string, EmbeddingProvider> = new Map(); constructor() { // Initialize default providers // Always add Ollama provider since it's local this.providers.set('ollama', new OllamaEmbeddingProvider()); this.providers.set('local', new OllamaEmbeddingProvider()); // alias if (process.env.OPENAI_API_KEY) { this.providers.set('openai', new OpenAIEmbeddingProvider(process.env.OPENAI_API_KEY)); } } addProvider(name: string, provider: EmbeddingProvider) { this.providers.set(name, provider); } getProvider(name: string): EmbeddingProvider { const provider = this.providers.get(name); if (!provider) { throw new Error(`Embedding provider '${name}' not found`); } return provider; } /** * Generate embeddings for an entity based on its configuration */ async generateEmbeddings<T extends object>(entity: T): Promise<void> { const entityType = entity.constructor as Type<T>; const config = this.getEmbeddingsConfig(entityType); if (!config) { return; // No embeddings configuration } const provider = this.getProvider(config.provider!); const text = TextChunker.extractRelevantText(entity, config.fields); if (!text || text.trim().length === 0) { return; // No text to embed } try { // For simple case, generate single embedding if (text.length <= (config.chunkSize || 512)) { const embedding = await provider.generateEmbedding(text, config.model); (entity as any)[config.embeddingField!] = embedding; } else { // For large text, chunk and create multiple embeddings const chunks = TextChunker.chunkText(text, config.chunkSize, config.overlap); const embeddings = await provider.generateBatchEmbeddings(chunks, config.model); // Store as array of embeddings or combine them (average, max, etc.) (entity as any)[config.embeddingField!] = this.combineEmbeddings(embeddings); (entity as any)[`${config.embeddingField!}_chunks`] = chunks; } console.log(`🧠 Generated embeddings for ${entityType.name} with ${config.fields.join(', ')}`); } catch (error) { console.error(`❌ Failed to generate embeddings for ${entityType.name}:`, error); throw error; } } /** * Combine multiple embeddings into a single representative embedding */ private combineEmbeddings(embeddings: number[][]): number[] { if (embeddings.length === 0) return []; if (embeddings.length === 1) return embeddings[0]; // Average pooling const dimensions = embeddings[0].length; const combined = new Array(dimensions).fill(0); for (const embedding of embeddings) { for (let i = 0; i < dimensions; i++) { combined[i] += embedding[i]; } } for (let i = 0; i < dimensions; i++) { combined[i] /= embeddings.length; } return combined; } /** * Get embeddings configuration for an entity type */ private getEmbeddingsConfig(entityType: Type<any>): EmbeddingsConfig | undefined { return Reflect.getMetadata('embeddings:config', entityType); } /** * Semantic similarity search using cosine similarity */ static cosineSimilarity(a: number[], b: number[]): number { if (a.length !== b.length) { throw new Error('Vectors must have the same length'); } let dotProduct = 0; let normA = 0; let normB = 0; for (let i = 0; i < a.length; i++) { dotProduct += a[i] * b[i]; normA += a[i] * a[i]; normB += b[i] * b[i]; } if (normA === 0 || normB === 0) { return 0; } return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); } /** * Find similar entities based on embedding similarity */ async findSimilar<T>( queryEmbedding: number[], entities: T[], embeddingField: string, threshold = 0.7 ): Promise<Array<{ entity: T, similarity: number }>> { const results: Array<{ entity: T, similarity: number }> = []; for (const entity of entities) { const entityEmbedding = (entity as any)[embeddingField]; if (!entityEmbedding || !Array.isArray(entityEmbedding)) { continue; } const similarity = EmbeddingService.cosineSimilarity(queryEmbedding, entityEmbedding); if (similarity >= threshold) { results.push({ entity, similarity }); } } // Sort by similarity (highest first) return results.sort((a, b) => b.similarity - a.similarity); } } // Export the singleton service export const embeddingService = new EmbeddingService();