@huggingface/jinja
Version:
A minimalistic JavaScript implementation of the Jinja templating engine, specifically designed for parsing and rendering ML chat templates.
259 lines (239 loc) • 8.32 kB
text/typescript
import type {
Program,
Statement,
If,
For,
SetStatement,
Macro,
Expression,
MemberExpression,
CallExpression,
Identifier,
NumericLiteral,
StringLiteral,
BooleanLiteral,
ArrayLiteral,
TupleLiteral,
ObjectLiteral,
BinaryExpression,
FilterExpression,
SelectExpression,
TestExpression,
UnaryExpression,
LogicalNegationExpression,
SliceExpression,
KeywordArgumentExpression,
} from "./ast";
const NEWLINE = "\n";
const OPEN_STATEMENT = "{%- ";
const CLOSE_STATEMENT = " -%}";
const OPERATOR_PRECEDENCE: Record<string, number> = {
MultiplicativeBinaryOperator: 2,
AdditiveBinaryOperator: 1,
ComparisonBinaryOperator: 0,
};
export function format(program: Program, indent: string | number = "\t"): string {
const indentStr = typeof indent === "number" ? " ".repeat(indent) : indent;
const body = formatStatements(program.body, 0, indentStr);
return body.replace(/\n$/, "");
}
function createStatement(...text: string[]): string {
return OPEN_STATEMENT + text.join(" ") + CLOSE_STATEMENT;
}
function formatStatements(stmts: Statement[], depth: number, indentStr: string): string {
return stmts.map((stmt) => formatStatement(stmt, depth, indentStr)).join(NEWLINE);
}
function formatStatement(node: Statement, depth: number, indentStr: string): string {
const pad = indentStr.repeat(depth);
switch (node.type) {
case "Program":
return formatStatements((node as Program).body, depth, indentStr);
case "If":
return formatIf(node as If, depth, indentStr);
case "For":
return formatFor(node as For, depth, indentStr);
case "Set":
return formatSet(node as SetStatement, depth, indentStr);
case "Macro":
return formatMacro(node as Macro, depth, indentStr);
default:
return pad + "{{- " + formatExpression(node as Expression) + " -}}";
}
}
function formatIf(node: If, depth: number, indentStr: string): string {
const pad = indentStr.repeat(depth);
const clauses: { test: Expression; body: Statement[] }[] = [];
let current: If | undefined = node;
while (current) {
clauses.push({ test: current.test, body: current.body });
if (current.alternate.length === 1 && current.alternate[0].type === "If") {
current = current.alternate[0] as If;
} else {
break;
}
}
// IF
let out =
pad +
createStatement("if", formatExpression(clauses[0].test)) +
NEWLINE +
formatStatements(clauses[0].body, depth + 1, indentStr);
// ELIF(s)
for (let i = 1; i < clauses.length; i++) {
out +=
NEWLINE +
pad +
createStatement("elif", formatExpression(clauses[i].test)) +
NEWLINE +
formatStatements(clauses[i].body, depth + 1, indentStr);
}
// ELSE
if (current && current.alternate.length > 0) {
out +=
NEWLINE + pad + createStatement("else") + NEWLINE + formatStatements(current.alternate, depth + 1, indentStr);
}
// ENDIF
out += NEWLINE + pad + createStatement("endif");
return out;
}
function formatFor(node: For, depth: number, indentStr: string): string {
const pad = indentStr.repeat(depth);
let formattedIterable = "";
if (node.iterable.type === "SelectExpression") {
// Handle special case: e.g., `for x in [1, 2, 3] if x > 2`
const n = node.iterable as SelectExpression;
formattedIterable = `${formatExpression(n.iterable)} if ${formatExpression(n.test)}`;
} else {
formattedIterable = formatExpression(node.iterable);
}
let out =
pad +
createStatement("for", formatExpression(node.loopvar), "in", formattedIterable) +
NEWLINE +
formatStatements(node.body, depth + 1, indentStr);
if (node.defaultBlock.length > 0) {
out +=
NEWLINE + pad + createStatement("else") + NEWLINE + formatStatements(node.defaultBlock, depth + 1, indentStr);
}
out += NEWLINE + pad + createStatement("endfor");
return out;
}
function formatSet(node: SetStatement, depth: number, indentStr: string): string {
const pad = indentStr.repeat(depth);
const left = formatExpression(node.assignee);
const right = node.value ? formatExpression(node.value) : "";
const value = pad + createStatement("set", `${left}${node.value ? " = " + right : ""}`);
if (node.body.length === 0) {
return value;
}
return (
value + NEWLINE + formatStatements(node.body, depth + 1, indentStr) + NEWLINE + pad + createStatement("endset")
);
}
function formatMacro(node: Macro, depth: number, indentStr: string): string {
const pad = indentStr.repeat(depth);
const args = node.args.map(formatExpression).join(", ");
return (
pad +
createStatement("macro", `${node.name.value}(${args})`) +
NEWLINE +
formatStatements(node.body, depth + 1, indentStr) +
NEWLINE +
pad +
createStatement("endmacro")
);
}
function formatExpression(node: Expression, parentPrec: number = -1): string {
switch (node.type) {
case "Identifier":
return (node as Identifier).value;
case "NullLiteral":
return "none";
case "NumericLiteral":
case "BooleanLiteral":
return `${(node as NumericLiteral | BooleanLiteral).value}`;
case "StringLiteral":
return JSON.stringify((node as StringLiteral).value);
case "BinaryExpression": {
const n = node as BinaryExpression;
const thisPrecedence = OPERATOR_PRECEDENCE[n.operator.type] ?? 0;
const left = formatExpression(n.left, thisPrecedence);
const right = formatExpression(n.right, thisPrecedence + 1);
const expr = `${left} ${n.operator.value} ${right}`;
return thisPrecedence < parentPrec ? `(${expr})` : expr;
}
case "UnaryExpression": {
const n = node as UnaryExpression;
const val = n.operator.value + (n.operator.value === "not" ? " " : "") + formatExpression(n.argument, Infinity);
return val;
}
case "LogicalNegationExpression":
return `not ${formatExpression((node as LogicalNegationExpression).argument, Infinity)}`;
case "CallExpression": {
const n = node as CallExpression;
const args = n.args.map((a) => formatExpression(a, -1)).join(", ");
return `${formatExpression(n.callee, -1)}(${args})`;
}
case "MemberExpression": {
const n = node as MemberExpression;
let obj = formatExpression(n.object, -1);
if (n.object.type !== "Identifier") {
obj = `(${obj})`;
}
let prop = formatExpression(n.property, -1);
if (!n.computed && n.property.type !== "Identifier") {
prop = `(${prop})`;
}
return n.computed ? `${obj}[${prop}]` : `${obj}.${prop}`;
}
case "FilterExpression": {
const n = node as FilterExpression;
const operand = formatExpression(n.operand, Infinity);
if (n.filter.type === "CallExpression") {
return `${operand} | ${formatExpression(n.filter, -1)}`;
}
return `${operand} | ${(n.filter as Identifier).value}`;
}
case "SelectExpression": {
const n = node as SelectExpression;
return `${formatExpression(n.iterable, -1)} | select(${formatExpression(n.test, -1)})`;
}
case "TestExpression": {
const n = node as TestExpression;
return `${formatExpression(n.operand, -1)} is${n.negate ? " not" : ""} ${n.test.value}`;
}
case "ArrayLiteral":
case "TupleLiteral": {
const elems = ((node as ArrayLiteral | TupleLiteral).value as Expression[]).map((e) => formatExpression(e, -1));
const brackets = node.type === "ArrayLiteral" ? "[]" : "()";
return `${brackets[0]}${elems.join(", ")}${brackets[1]}`;
}
case "ObjectLiteral": {
const entries = Array.from((node as ObjectLiteral).value.entries()).map(
([k, v]) => `${formatExpression(k, -1)}: ${formatExpression(v, -1)}`
);
return `{ ${entries.join(", ")} }`;
}
case "SliceExpression": {
const n = node as SliceExpression;
const s = n.start ? formatExpression(n.start, -1) : "";
const t = n.stop ? formatExpression(n.stop, -1) : "";
const st = n.step ? `:${formatExpression(n.step, -1)}` : "";
return `${s}:${t}${st}`;
}
case "KeywordArgumentExpression": {
const n = node as KeywordArgumentExpression;
return `${n.key.value}=${formatExpression(n.value, -1)}`;
}
case "If": {
// Special case for ternary operator (If as an expression, not a statement)
const n = node as If;
const test = formatExpression(n.test, -1);
const body = formatExpression(n.body[0], 0); // Ternary operators have a single body and alternate
const alternate = formatExpression(n.alternate[0], -1);
return `${body} if ${test} else ${alternate}`;
}
default:
throw new Error(`Unknown expression type: ${node.type}`);
}
}