@forge-ml/rag
Version:
A RAG (Retrieval-Augmented Generation) package for Forge ML
105 lines (104 loc) • 3.4 kB
JavaScript
import { Pool } from 'pg';
import { VECTOR_MODEL_DIM } from "../types";
const defaultCreateIndexOpts = {
dim: VECTOR_MODEL_DIM.NOMIC_V1_5,
};
class PostgresVectorStore {
pool;
constructor(connectionString) {
this.pool = new Pool({
connectionString,
});
}
async createIndex(opts = defaultCreateIndexOpts) {
const client = await this.pool.connect();
try {
await client.query('CREATE EXTENSION IF NOT EXISTS vector');
await client.query(`
CREATE TABLE IF NOT EXISTS embeddings (
id SERIAL PRIMARY KEY,
chunk_id TEXT NOT NULL,
document_id TEXT NOT NULL,
embedding vector(${opts.dim})
)
`);
await client.query('CREATE INDEX IF NOT EXISTS embedding_idx ON embeddings USING ivfflat (embedding vector_cosine_ops)');
}
catch (error) {
console.error('Error creating index:', error);
throw error;
}
finally {
client.release();
}
}
async addEmbedding(embedding) {
const client = await this.pool.connect();
try {
const embeddingArray = `[${embedding.embedding.join(',')}]`;
await client.query('INSERT INTO embeddings (chunk_id, document_id, embedding) VALUES ($1, $2, $3::vector)', [embedding.chunkId, embedding.documentId, embeddingArray]);
}
catch (error) {
console.error('Error adding embedding:', error);
throw error;
}
finally {
client.release();
}
}
async storeEmbeddings(embeddings) {
const client = await this.pool.connect();
try {
await client.query('BEGIN');
for (const embedding of embeddings) {
await this.addEmbedding(embedding);
}
await client.query('COMMIT');
}
catch (error) {
await client.query('ROLLBACK');
console.error('Error storing embeddings:', error);
throw error;
}
finally {
client.release();
}
}
async queryEmbeddings(params) {
const client = await this.pool.connect();
try {
const queryVector = `[${params.query.join(',')}]`;
let queryString = `
SELECT chunk_id, document_id, 1 - (embedding <=> $1::vector) AS score
FROM embeddings
`;
const queryParams = [queryVector];
if (params.documentIds && params.documentIds.length > 0) {
queryString += ' WHERE document_id = ANY($2)';
queryParams.push(params.documentIds);
}
queryString += `
ORDER BY score DESC
LIMIT $${queryParams.length + 1}
`;
queryParams.push(params.k);
const result = await client.query(queryString, queryParams);
return result.rows.map(row => ({
chunkId: row.chunk_id,
documentId: row.document_id,
score: row.score,
}));
}
catch (error) {
console.error('Error querying embeddings:', error);
throw error;
}
finally {
client.release();
}
}
async close() {
await this.pool.end();
}
}
export default PostgresVectorStore;