@lewist9x/distil
Version:
An opinionated library for managing LLM pipelines. Define, track, rate, and curate prompt–completion pairs for fine-tuning.
465 lines (464 loc) • 17.9 kB
JavaScript
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.markGenerationsForFinetuning = exports.rateGeneration = exports.getGenerationById = exports.getGenerationsForVersion = exports.markPipelineVersionAsFinetuned = exports.ratePipelineVersion = exports.addTagToPipelineVersion = exports.getAllPipelineVersions = exports.DistilPipeline = void 0;
// src/pipeline.ts
const elasticsearch_1 = require("@elastic/elasticsearch");
const utils_1 = require("./utils");
const inference_1 = require("./inference");
const logger_1 = require("./logger");
const config_1 = require("./config");
let esClient;
if (config_1.config.elastic.host) {
const clientConfig = {
node: config_1.config.elastic.host,
};
// Only add authentication if credentials are provided
if (config_1.config.elastic.user && config_1.config.elastic.password) {
clientConfig.auth = {
username: config_1.config.elastic.user,
password: config_1.config.elastic.password,
};
}
esClient = new elasticsearch_1.Client(clientConfig);
}
class DistilPipeline {
constructor(config, logLevel) {
var _a, _b;
this.logger = new logger_1.Logger(logLevel || "DEBUG");
this.inferenceEngine = new inference_1.InferenceEngine();
this.pipelineName = config.pipelineName;
this.modelName = config.modelName;
this.systemPrompt = config.systemPrompt;
this.userPrompt = config.userPrompt;
this.defaultParameters = config.defaultParameters;
// Default: identity preprocessing.
this.preprocessFn = (_a = config.preprocess) !== null && _a !== void 0 ? _a : ((input) => input);
// Default: generic postprocessing.
this.postprocessFn = (_b = config.postprocess) !== null && _b !== void 0 ? _b : utils_1.postprocess;
}
/**
* Runs the pipeline:
* 1. Validate input and merge default parameters.
* 2. Apply custom preprocessing.
* 3. Compute the template hash.
* 4. Run inference and apply postprocessing.
* 5. Return output and metadata.
*/
async generate(inputData) {
const startTime = Date.now();
let totalCost = 0;
try {
// Set required fields from pipeline config, ensuring they can't be overridden
const input = {
...inputData,
modelName: this.modelName,
};
await this.logger.info("Input:" + JSON.stringify(input));
await this.logger.info("Input:" + JSON.stringify(input));
await this.logger.info("Input:" + JSON.stringify(input));
await this.logger.info("Input:" + JSON.stringify(input));
// Merge default parameters first
const parameters = {
...(this.defaultParameters || {}),
...(inputData || {}),
};
await this.logger.info("Parameters:" + JSON.stringify(parameters));
// Helper function to substitute parameters in template strings
const substituteParameters = (template, params) => {
let result = template;
for (const [key, value] of Object.entries(params)) {
const placeholder = `{${key}}`;
// Handle different value types
const replacement = typeof value === "object" ? JSON.stringify(value) : String(value);
// Use regex to replace all occurrences
result = result.replace(new RegExp(placeholder.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"), "g"), replacement);
}
return result;
};
const templateHash = (0, utils_1.computeTemplateHash)({
systemPrompt: this.systemPrompt,
userPrompt: this.userPrompt,
parameters,
pipelineName: this.pipelineName,
modelName: this.modelName,
});
await this.logger.info("Template hash:" + templateHash);
// Apply parameter substitution to prompt templates
const systemPrompt = substituteParameters(this.systemPrompt, parameters);
const userPrompt = substituteParameters(this.userPrompt, parameters);
await this.logger.info("System prompt:" + systemPrompt);
// Add processed prompts to input
input.systemPrompt = systemPrompt;
input.userPrompt = userPrompt;
input.parameters = parameters;
await this.logger.info("Validating input..." + JSON.stringify(input));
let validInput = (0, utils_1.validateInput)(input);
await this.logger.info("Input validated.");
// Store raw input before preprocessing
const rawInput = { ...validInput };
// Run preprocessing
validInput = await this.preprocessFn(validInput);
validInput.preprocessFn = this.preprocessFn; // Pass preprocessing function
validInput.postprocessFn = this.postprocessFn; // Pass postprocessing function
await this.logger.info("Valid input:" + JSON.stringify(validInput));
// Compute template version hash.
// Run inference.
const { detail, rawOutput, processedOutput, cost } = await this.inferenceEngine.callInference({
...validInput,
templateHash,
originalInput: inputData,
pipelineName: this.pipelineName,
});
totalCost += cost;
await this.logger.info("Processed output:" + processedOutput);
const timeTaken = (Date.now() - startTime) / 1000;
return {
processedOutput,
metadata: {
generationCost: totalCost,
timeTaken,
rawInput,
preprocessedInput: validInput,
rawOutput,
templateHash,
pipelineName: this.pipelineName,
generationId: detail, // detail is now the document ID
},
};
}
catch (error) {
await this.logger.error("Generation error: " + JSON.stringify(error));
return null;
}
}
}
exports.DistilPipeline = DistilPipeline;
function isValidHit(hit) {
return !!hit && !!hit._source;
}
/**
* Get all pipeline versions.
*/
async function getAllPipelineVersions() {
var _a, _b, _c;
try {
// Get all indices
const indices = await esClient.indices.get({ index: "*" });
const pipelineIndices = Object.keys(indices).filter((index) => !index.startsWith("."));
if (pipelineIndices.length === 0) {
return [];
}
const pipelines = [];
for (const index of pipelineIndices) {
try {
const result = await esClient.search({
index,
body: {
size: 0,
aggs: {
unique_pipelines: {
terms: {
field: "pipelineHash.keyword",
size: 1000,
},
aggs: {
latest_doc: {
top_hits: {
size: 1,
_source: [
"pipelineName",
"pipelineHash",
"parameterKeys",
"input",
"template",
"tags",
"rating",
"isFinetuned",
"timestamp",
],
},
},
first_generation: {
top_hits: {
size: 1,
sort: [{ timestamp: "asc" }],
},
},
},
},
},
},
});
const buckets = ((_a = result.aggregations.unique_pipelines) === null || _a === void 0 ? void 0 : _a.buckets) || [];
for (const bucket of buckets) {
const hits = bucket.latest_doc.hits.hits;
if (hits.length === 0)
continue;
const hit = hits[0];
const source = hit._source;
if (source &&
typeof source === "object" &&
"pipelineName" in source) {
const typedSource = source;
const firstGeneration = (_b = bucket.first_generation) === null || _b === void 0 ? void 0 : _b.hits.hits[0];
const firstGenerationDate = (_c = firstGeneration === null || firstGeneration === void 0 ? void 0 : firstGeneration._source) === null || _c === void 0 ? void 0 : _c.timestamp;
pipelines.push({
id: bucket.key,
pipelineName: typedSource.pipelineName,
template: {
systemPrompt: typedSource.input.preprocessed.systemPrompt,
userPrompt: typedSource.input.preprocessed.userPrompt,
parameterKeys: typedSource.input.preprocessed.parameters
? Object.keys(typedSource.input.preprocessed.parameters)
: [],
},
tags: typedSource.tags || [],
rating: typedSource.rating,
isFinetuned: typedSource.isFinetuned,
createdAt: firstGenerationDate || typedSource.timestamp || new Date().toISOString(),
generations: [],
});
}
}
}
catch (error) {
console.error(`Error searching index ${index}:`, error);
// Continue with other indices even if one fails
continue;
}
}
// Sort by timestamp
return pipelines.sort((a, b) => new Date(b.createdAt).getTime() - new Date(a.createdAt).getTime());
}
catch (error) {
console.error("Error fetching pipeline versions:", error);
return [];
}
}
exports.getAllPipelineVersions = getAllPipelineVersions;
/**
* Add a tag to a pipeline version.
*/
async function addTagToPipelineVersion(id, tag) {
var _a;
try {
// Get all indices
const indices = await esClient.indices.get({ index: "*" });
const pipelineIndices = Object.keys(indices).filter((index) => !index.startsWith("."));
if (pipelineIndices.length === 0) {
return false;
}
const response = await esClient.search({
index: pipelineIndices.join(","),
body: {
query: {
ids: {
values: [id],
},
},
},
});
if (!((_a = response.hits) === null || _a === void 0 ? void 0 : _a.hits) ||
response.hits.hits.length === 0 ||
response.hits.hits[0]._source === undefined) {
return false;
}
const hit = response.hits.hits[0];
const version = hit._source;
const tags = new Set([...(version.tags || []), tag]);
await esClient.update({
index: hit._index,
id: hit._id,
body: {
doc: {
tags: Array.from(tags),
},
},
});
return true;
}
catch (error) {
console.error("Error adding tag:", error);
return false;
}
}
exports.addTagToPipelineVersion = addTagToPipelineVersion;
/**
* Rate a pipeline version (1-5 stars).
*/
async function ratePipelineVersion(id, rating) {
var _a, _b;
try {
if (rating < 1 || rating > 5) {
throw new Error("Rating must be between 1 and 5.");
}
const response = await esClient.search({
index: "pipeline_versions",
body: {
query: {
term: { id },
},
},
});
if (!((_b = (_a = response.hits) === null || _a === void 0 ? void 0 : _a.hits) === null || _b === void 0 ? void 0 : _b[0])) {
return false;
}
const hit = response.hits.hits[0];
await esClient.update({
index: "pipeline_versions",
id: hit._id,
body: {
doc: { rating },
},
});
return true;
}
catch (error) {
if (error instanceof Error &&
error.message === "Rating must be between 1 and 5.") {
throw error;
}
console.error("Error rating pipeline version:", error);
return false;
}
}
exports.ratePipelineVersion = ratePipelineVersion;
/**
* Mark a pipeline version as finetuned.
*/
async function markPipelineVersionAsFinetuned(id) {
var _a, _b;
try {
const response = await esClient.search({
index: "pipeline_versions",
body: {
query: {
term: { id },
},
},
});
if (!((_b = (_a = response.hits) === null || _a === void 0 ? void 0 : _a.hits) === null || _b === void 0 ? void 0 : _b[0])) {
return false;
}
const hit = response.hits.hits[0];
await esClient.update({
index: "pipeline_versions",
id: hit._id,
body: {
doc: { isFinetuned: true },
},
});
return true;
}
catch (error) {
console.error("Error marking pipeline version as finetuned:", error);
return false;
}
}
exports.markPipelineVersionAsFinetuned = markPipelineVersionAsFinetuned;
/**
* Get generations for a pipeline version
*/
async function getGenerationsForVersion(pipelineName, versionId) {
try {
const response = await esClient.search({
index: pipelineName.toLowerCase(),
query: {
bool: {
must: [{ term: { pipelineHash: versionId } }],
},
},
size: 50,
});
return response.hits.hits.map((hit) => {
const source = hit._source;
return {
id: hit._id,
processedOutput: source.output,
metadata: {
input: source.input,
timeTaken: source.timeTaken || 0,
generationCost: source.generationCost || 0,
pipelineHash: source.pipelineHash,
rawOutput: source.output,
pipelineName: source.pipelineName,
},
rating: source.rating,
isFinetuned: source.isFinetuned,
};
});
}
catch (error) {
console.error("Error fetching generations for pipeline version:", error);
return [];
}
}
exports.getGenerationsForVersion = getGenerationsForVersion;
/**
* Get a specific generation by ID
*/
async function getGenerationById(pipelineName, id) {
const response = await esClient.get({
index: pipelineName.toLowerCase(),
id,
});
if (!response._source) {
throw new Error(`Generation ${id} not found`);
}
const source = response._source;
return {
id: response._id,
processedOutput: source.output,
metadata: {
input: source.input,
timeTaken: source.timeTaken || 0,
generationCost: source.generationCost || 0,
pipelineHash: source.pipelineHash,
rawOutput: source.output,
pipelineName: source.pipelineName,
},
rating: source.rating,
isFinetuned: source.isFinetuned,
};
}
exports.getGenerationById = getGenerationById;
/**
* Rate a generation
*/
async function rateGeneration(pipelineName, id, rating) {
try {
await esClient.update({
index: pipelineName.toLowerCase(),
id,
doc: {
rating: parseInt(rating.toString()),
},
});
return true;
}
catch (error) {
console.error("Failed to rate generation:", error);
return false;
}
}
exports.rateGeneration = rateGeneration;
/**
* Mark generations for finetuning
*/
async function markGenerationsForFinetuning(ids) {
try {
await Promise.all(ids.map((id) => esClient.update({
index: config_1.config.elastic.logIndex,
id,
doc: {
isFinetuned: true,
},
})));
return true;
}
catch (error) {
console.error("Failed to mark generations for finetuning:", error);
return false;
}
}
exports.markGenerationsForFinetuning = markGenerationsForFinetuning;