@sisu-ai/mw-context-compressor
Version:
Middleware that compresses long conversation context using the model itself.
178 lines (177 loc) • 7.41 kB
JavaScript
const HEAD_TO_SUMMARY_RATIO = 5;
const MAX_ARRAY_CLAMP = 50; // Maximum number of array elements to keep when clamping
export const contextCompressor = (opts = {}) => {
const maxChars = opts.maxChars ?? 140_000;
const keepRecent = opts.keepRecent ?? 8;
const summaryMaxChars = opts.summaryMaxChars ?? 8_000;
const recentClampChars = opts.recentClampChars ?? 8_000;
return async (ctx, next) => {
const original = ctx.model;
ctx.model = wrapModelWithCompression(original, { maxChars, keepRecent, summaryMaxChars, recentClampChars }, ctx);
await next();
};
};
function wrapModelWithCompression(model, cfg, ctx) {
const origGenerate = model.generate.bind(model);
// Provide a generate function compatible with LLM overloads.
const wrappedGenerate = ((messages, genOpts) => {
// For streaming requests, avoid async prework that would change return type.
if (genOpts?.stream) {
try {
messages = clampRecent(messages, cfg, ctx);
}
catch (e) {
ctx.log.warn?.('[context-compressor] clampRecent failed (stream path); proceeding', e);
}
return origGenerate(messages, genOpts);
}
return (async () => {
try {
// Only compress when not already summarizing and context seems large
if (!ctx.state.__compressing && approxChars(messages) > cfg.maxChars) {
ctx.log.info?.('[context-compressor] compressing conversation context');
(ctx.state).__compressing = true;
try {
const compressed = await compressMessages(messages, cfg, ctx, origGenerate);
messages = compressed;
}
finally {
delete (ctx.state).__compressing;
}
}
// Always clamp oversized recent tool outputs to avoid huge bodies
messages = clampRecent(messages, cfg, ctx);
}
catch (e) {
ctx.log.warn?.('[context-compressor] failed to compress; proceeding uncompressed', e);
}
return await origGenerate(messages, genOpts);
})();
});
return { ...model, generate: wrappedGenerate };
}
function approxChars(messages) {
let n = 0;
for (const m of messages) {
const c = (m).content;
if (typeof c === 'string')
n += c.length;
else if (Array.isArray(c))
n += JSON.stringify(c).length;
}
return n;
}
async function compressMessages(messages, cfg, ctx, gen) {
if (messages.length <= cfg.keepRecent + 1)
return messages;
let cut = Math.max(1, messages.length - cfg.keepRecent);
// Don’t split a tool-call group: if tail starts with tool messages, include the preceding
// assistant that requested tool_calls in the tail as well.
if (messages[cut] && (messages[cut]).role === 'tool') {
const anchor = findPrevAssistantWithToolCalls(messages, cut - 1);
if (anchor >= 0) {
cut = anchor; // include the assistant-with-tool_calls in the tail
}
else {
// As a last resort, advance cut forward past any leading tool messages in tail
while (cut < messages.length && (messages[cut]).role === 'tool')
cut++;
}
}
const head = messages.slice(0, cut);
const tail = messages.slice(cut);
// Build a compression prompt
const headText = sliceAndFlatten(head, cfg.summaryMaxChars * HEAD_TO_SUMMARY_RATIO);
const prompt = [
{ role: 'system', content: 'You are a compression assistant. Summarize the following conversation and tool outputs into a compact bullet list of established facts and extracted citations (URLs). Keep it under the specified character budget. Do not invent facts.' },
{ role: 'user', content: `Character budget: ${cfg.summaryMaxChars}. Include a section "Citations:" listing unique URLs.\n\nConversation to compress:\n${headText}` },
];
const res = await gen(prompt, { toolChoice: 'none', signal: ctx.signal });
const summary = String(res?.message?.content ?? '').slice(0, cfg.summaryMaxChars);
const summaryMsg = { role: 'assistant', content: `[Summary of earlier turns]\n${summary}` };
return [messages[0], summaryMsg, ...tail];
}
function sliceAndFlatten(msgs, max) {
const parts = [];
for (const m of msgs) {
const role = m.role;
const c = (m).content;
let text = '';
if (typeof c === 'string')
text = c;
else if (Array.isArray(c))
text = JSON.stringify(c);
else
text = String(c ?? '');
parts.push(`--- ${role} ---\n${text}`);
const joined = parts.join('\n');
if (joined.length > max)
return joined.slice(0, max);
}
return parts.join('\n');
}
function clampRecent(messages, cfg, ctx) {
// Create shallow copies to avoid mutating upstream state
const out = messages.map(m => ({ ...m }));
const limit = cfg.recentClampChars;
for (let i = Math.max(0, out.length - (cfg.keepRecent + 4)); i < out.length; i++) {
const m = out[i];
const c = m.content;
if (typeof c !== 'string')
continue;
if (m.role === 'tool') {
const clamped = clampToolContentString(c, limit);
if (clamped !== c) {
m.content = clamped;
ctx.log.debug?.('[context-compressor] clamped tool message', { idx: i, before: c.length, after: clamped.length });
}
}
else if (c.length > limit * 2) {
m.content = c.slice(0, limit * 2);
ctx.log.debug?.('[context-compressor] truncated long message', { idx: i, before: c.length, after: m.content.length });
}
}
return out;
}
function clampToolContentString(s, limit) {
try {
const obj = JSON.parse(s);
// Remove heavy fields commonly present in webFetch
if (obj && typeof obj === 'object') {
if (typeof obj.html === 'string')
delete obj.html;
if (typeof obj.text === 'string') {
if (obj.text.length > limit)
obj.text = String(obj.text).slice(0, limit);
}
// Recursively clamp nested arrays/objects
return JSON.stringify(clampDeep(obj, limit));
}
}
catch { /* ignore JSON parse error */ }
return s.length > limit ? s.slice(0, limit) : s;
}
function clampDeep(v, limit) {
if (!v || typeof v !== 'object')
return v;
if (Array.isArray(v))
return v.slice(0, MAX_ARRAY_CLAMP).map(x => clampDeep(x, limit));
const out = {};
for (const [k, val] of Object.entries(v)) {
if (k === 'html')
continue; // drop
if (typeof val === 'string')
out[k] = val.length > limit ? val.slice(0, limit) : val;
else
out[k] = clampDeep(val, limit);
}
return out;
}
function findPrevAssistantWithToolCalls(messages, start) {
for (let i = start; i >= 0; i--) {
const m = messages[i];
if (m?.role === 'assistant' && Array.isArray(m.tool_calls) && m.tool_calls.length > 0)
return i;
}
return -1;
}