UNPKG

@lewist9x/distil

Version:

An opinionated library for managing LLM pipelines. Define, track, rate, and curate prompt–completion pairs for fine-tuning.

465 lines (464 loc) • 17.9 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.markGenerationsForFinetuning = exports.rateGeneration = exports.getGenerationById = exports.getGenerationsForVersion = exports.markPipelineVersionAsFinetuned = exports.ratePipelineVersion = exports.addTagToPipelineVersion = exports.getAllPipelineVersions = exports.DistilPipeline = void 0; // src/pipeline.ts const elasticsearch_1 = require("@elastic/elasticsearch"); const utils_1 = require("./utils"); const inference_1 = require("./inference"); const logger_1 = require("./logger"); const config_1 = require("./config"); let esClient; if (config_1.config.elastic.host) { const clientConfig = { node: config_1.config.elastic.host, }; // Only add authentication if credentials are provided if (config_1.config.elastic.user && config_1.config.elastic.password) { clientConfig.auth = { username: config_1.config.elastic.user, password: config_1.config.elastic.password, }; } esClient = new elasticsearch_1.Client(clientConfig); } class DistilPipeline { constructor(config, logLevel) { var _a, _b; this.logger = new logger_1.Logger(logLevel || "DEBUG"); this.inferenceEngine = new inference_1.InferenceEngine(); this.pipelineName = config.pipelineName; this.modelName = config.modelName; this.systemPrompt = config.systemPrompt; this.userPrompt = config.userPrompt; this.defaultParameters = config.defaultParameters; // Default: identity preprocessing. this.preprocessFn = (_a = config.preprocess) !== null && _a !== void 0 ? _a : ((input) => input); // Default: generic postprocessing. this.postprocessFn = (_b = config.postprocess) !== null && _b !== void 0 ? _b : utils_1.postprocess; } /** * Runs the pipeline: * 1. Validate input and merge default parameters. * 2. Apply custom preprocessing. * 3. Compute the template hash. * 4. Run inference and apply postprocessing. * 5. Return output and metadata. */ async generate(inputData) { const startTime = Date.now(); let totalCost = 0; try { // Set required fields from pipeline config, ensuring they can't be overridden const input = { ...inputData, modelName: this.modelName, }; await this.logger.info("Input:" + JSON.stringify(input)); await this.logger.info("Input:" + JSON.stringify(input)); await this.logger.info("Input:" + JSON.stringify(input)); await this.logger.info("Input:" + JSON.stringify(input)); // Merge default parameters first const parameters = { ...(this.defaultParameters || {}), ...(inputData || {}), }; await this.logger.info("Parameters:" + JSON.stringify(parameters)); // Helper function to substitute parameters in template strings const substituteParameters = (template, params) => { let result = template; for (const [key, value] of Object.entries(params)) { const placeholder = `{${key}}`; // Handle different value types const replacement = typeof value === "object" ? JSON.stringify(value) : String(value); // Use regex to replace all occurrences result = result.replace(new RegExp(placeholder.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"), "g"), replacement); } return result; }; const templateHash = (0, utils_1.computeTemplateHash)({ systemPrompt: this.systemPrompt, userPrompt: this.userPrompt, parameters, pipelineName: this.pipelineName, modelName: this.modelName, }); await this.logger.info("Template hash:" + templateHash); // Apply parameter substitution to prompt templates const systemPrompt = substituteParameters(this.systemPrompt, parameters); const userPrompt = substituteParameters(this.userPrompt, parameters); await this.logger.info("System prompt:" + systemPrompt); // Add processed prompts to input input.systemPrompt = systemPrompt; input.userPrompt = userPrompt; input.parameters = parameters; await this.logger.info("Validating input..." + JSON.stringify(input)); let validInput = (0, utils_1.validateInput)(input); await this.logger.info("Input validated."); // Store raw input before preprocessing const rawInput = { ...validInput }; // Run preprocessing validInput = await this.preprocessFn(validInput); validInput.preprocessFn = this.preprocessFn; // Pass preprocessing function validInput.postprocessFn = this.postprocessFn; // Pass postprocessing function await this.logger.info("Valid input:" + JSON.stringify(validInput)); // Compute template version hash. // Run inference. const { detail, rawOutput, processedOutput, cost } = await this.inferenceEngine.callInference({ ...validInput, templateHash, originalInput: inputData, pipelineName: this.pipelineName, }); totalCost += cost; await this.logger.info("Processed output:" + processedOutput); const timeTaken = (Date.now() - startTime) / 1000; return { processedOutput, metadata: { generationCost: totalCost, timeTaken, rawInput, preprocessedInput: validInput, rawOutput, templateHash, pipelineName: this.pipelineName, generationId: detail, // detail is now the document ID }, }; } catch (error) { await this.logger.error("Generation error: " + JSON.stringify(error)); return null; } } } exports.DistilPipeline = DistilPipeline; function isValidHit(hit) { return !!hit && !!hit._source; } /** * Get all pipeline versions. */ async function getAllPipelineVersions() { var _a, _b, _c; try { // Get all indices const indices = await esClient.indices.get({ index: "*" }); const pipelineIndices = Object.keys(indices).filter((index) => !index.startsWith(".")); if (pipelineIndices.length === 0) { return []; } const pipelines = []; for (const index of pipelineIndices) { try { const result = await esClient.search({ index, body: { size: 0, aggs: { unique_pipelines: { terms: { field: "pipelineHash.keyword", size: 1000, }, aggs: { latest_doc: { top_hits: { size: 1, _source: [ "pipelineName", "pipelineHash", "parameterKeys", "input", "template", "tags", "rating", "isFinetuned", "timestamp", ], }, }, first_generation: { top_hits: { size: 1, sort: [{ timestamp: "asc" }], }, }, }, }, }, }, }); const buckets = ((_a = result.aggregations.unique_pipelines) === null || _a === void 0 ? void 0 : _a.buckets) || []; for (const bucket of buckets) { const hits = bucket.latest_doc.hits.hits; if (hits.length === 0) continue; const hit = hits[0]; const source = hit._source; if (source && typeof source === "object" && "pipelineName" in source) { const typedSource = source; const firstGeneration = (_b = bucket.first_generation) === null || _b === void 0 ? void 0 : _b.hits.hits[0]; const firstGenerationDate = (_c = firstGeneration === null || firstGeneration === void 0 ? void 0 : firstGeneration._source) === null || _c === void 0 ? void 0 : _c.timestamp; pipelines.push({ id: bucket.key, pipelineName: typedSource.pipelineName, template: { systemPrompt: typedSource.input.preprocessed.systemPrompt, userPrompt: typedSource.input.preprocessed.userPrompt, parameterKeys: typedSource.input.preprocessed.parameters ? Object.keys(typedSource.input.preprocessed.parameters) : [], }, tags: typedSource.tags || [], rating: typedSource.rating, isFinetuned: typedSource.isFinetuned, createdAt: firstGenerationDate || typedSource.timestamp || new Date().toISOString(), generations: [], }); } } } catch (error) { console.error(`Error searching index ${index}:`, error); // Continue with other indices even if one fails continue; } } // Sort by timestamp return pipelines.sort((a, b) => new Date(b.createdAt).getTime() - new Date(a.createdAt).getTime()); } catch (error) { console.error("Error fetching pipeline versions:", error); return []; } } exports.getAllPipelineVersions = getAllPipelineVersions; /** * Add a tag to a pipeline version. */ async function addTagToPipelineVersion(id, tag) { var _a; try { // Get all indices const indices = await esClient.indices.get({ index: "*" }); const pipelineIndices = Object.keys(indices).filter((index) => !index.startsWith(".")); if (pipelineIndices.length === 0) { return false; } const response = await esClient.search({ index: pipelineIndices.join(","), body: { query: { ids: { values: [id], }, }, }, }); if (!((_a = response.hits) === null || _a === void 0 ? void 0 : _a.hits) || response.hits.hits.length === 0 || response.hits.hits[0]._source === undefined) { return false; } const hit = response.hits.hits[0]; const version = hit._source; const tags = new Set([...(version.tags || []), tag]); await esClient.update({ index: hit._index, id: hit._id, body: { doc: { tags: Array.from(tags), }, }, }); return true; } catch (error) { console.error("Error adding tag:", error); return false; } } exports.addTagToPipelineVersion = addTagToPipelineVersion; /** * Rate a pipeline version (1-5 stars). */ async function ratePipelineVersion(id, rating) { var _a, _b; try { if (rating < 1 || rating > 5) { throw new Error("Rating must be between 1 and 5."); } const response = await esClient.search({ index: "pipeline_versions", body: { query: { term: { id }, }, }, }); if (!((_b = (_a = response.hits) === null || _a === void 0 ? void 0 : _a.hits) === null || _b === void 0 ? void 0 : _b[0])) { return false; } const hit = response.hits.hits[0]; await esClient.update({ index: "pipeline_versions", id: hit._id, body: { doc: { rating }, }, }); return true; } catch (error) { if (error instanceof Error && error.message === "Rating must be between 1 and 5.") { throw error; } console.error("Error rating pipeline version:", error); return false; } } exports.ratePipelineVersion = ratePipelineVersion; /** * Mark a pipeline version as finetuned. */ async function markPipelineVersionAsFinetuned(id) { var _a, _b; try { const response = await esClient.search({ index: "pipeline_versions", body: { query: { term: { id }, }, }, }); if (!((_b = (_a = response.hits) === null || _a === void 0 ? void 0 : _a.hits) === null || _b === void 0 ? void 0 : _b[0])) { return false; } const hit = response.hits.hits[0]; await esClient.update({ index: "pipeline_versions", id: hit._id, body: { doc: { isFinetuned: true }, }, }); return true; } catch (error) { console.error("Error marking pipeline version as finetuned:", error); return false; } } exports.markPipelineVersionAsFinetuned = markPipelineVersionAsFinetuned; /** * Get generations for a pipeline version */ async function getGenerationsForVersion(pipelineName, versionId) { try { const response = await esClient.search({ index: pipelineName.toLowerCase(), query: { bool: { must: [{ term: { pipelineHash: versionId } }], }, }, size: 50, }); return response.hits.hits.map((hit) => { const source = hit._source; return { id: hit._id, processedOutput: source.output, metadata: { input: source.input, timeTaken: source.timeTaken || 0, generationCost: source.generationCost || 0, pipelineHash: source.pipelineHash, rawOutput: source.output, pipelineName: source.pipelineName, }, rating: source.rating, isFinetuned: source.isFinetuned, }; }); } catch (error) { console.error("Error fetching generations for pipeline version:", error); return []; } } exports.getGenerationsForVersion = getGenerationsForVersion; /** * Get a specific generation by ID */ async function getGenerationById(pipelineName, id) { const response = await esClient.get({ index: pipelineName.toLowerCase(), id, }); if (!response._source) { throw new Error(`Generation ${id} not found`); } const source = response._source; return { id: response._id, processedOutput: source.output, metadata: { input: source.input, timeTaken: source.timeTaken || 0, generationCost: source.generationCost || 0, pipelineHash: source.pipelineHash, rawOutput: source.output, pipelineName: source.pipelineName, }, rating: source.rating, isFinetuned: source.isFinetuned, }; } exports.getGenerationById = getGenerationById; /** * Rate a generation */ async function rateGeneration(pipelineName, id, rating) { try { await esClient.update({ index: pipelineName.toLowerCase(), id, doc: { rating: parseInt(rating.toString()), }, }); return true; } catch (error) { console.error("Failed to rate generation:", error); return false; } } exports.rateGeneration = rateGeneration; /** * Mark generations for finetuning */ async function markGenerationsForFinetuning(ids) { try { await Promise.all(ids.map((id) => esClient.update({ index: config_1.config.elastic.logIndex, id, doc: { isFinetuned: true, }, }))); return true; } catch (error) { console.error("Failed to mark generations for finetuning:", error); return false; } } exports.markGenerationsForFinetuning = markGenerationsForFinetuning;