@huggingface/jinja
Version:
A minimalistic JavaScript implementation of the Jinja templating engine, specifically designed for parsing and rendering ML chat templates.
671 lines (594 loc) • 19.9 kB
text/typescript
import { Token, TOKEN_TYPES } from "./lexer";
import type { TokenType } from "./lexer";
import type { Statement } from "./ast";
import {
Program,
If,
For,
Break,
Continue,
SetStatement,
MemberExpression,
CallExpression,
Identifier,
StringLiteral,
ArrayLiteral,
ObjectLiteral,
BinaryExpression,
FilterExpression,
TestExpression,
UnaryExpression,
SliceExpression,
KeywordArgumentExpression,
TupleLiteral,
Macro,
SelectExpression,
CallStatement,
FilterStatement,
SpreadExpression,
IntegerLiteral,
FloatLiteral,
Ternary,
Comment,
} from "./ast";
/**
* Generate the Abstract Syntax Tree (AST) from a list of tokens.
* Operator precedence can be found here: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Operator_precedence#table
*/
export function parse(tokens: Token[]): Program {
const program = new Program([]);
let current = 0;
/**
* Consume the next token if it matches the expected type, otherwise throw an error.
* @param type The expected token type
* @param error The error message to throw if the token does not match the expected type
* @returns The consumed token
*/
function expect(type: string, error: string): Token {
const prev = tokens[current++];
if (!prev || prev.type !== type) {
throw new Error(`Parser Error: ${error}. ${prev.type} !== ${type}.`);
}
return prev;
}
function expectIdentifier(name: string): void {
if (!isIdentifier(name)) {
throw new SyntaxError(`Expected ${name}`);
}
++current;
}
function parseAny(): Statement {
switch (tokens[current].type) {
case TOKEN_TYPES.Comment:
return new Comment(tokens[current++].value);
case TOKEN_TYPES.Text:
return parseText();
case TOKEN_TYPES.OpenStatement:
return parseJinjaStatement();
case TOKEN_TYPES.OpenExpression:
return parseJinjaExpression();
default:
throw new SyntaxError(`Unexpected token type: ${tokens[current].type}`);
}
}
function is(...types: TokenType[]): boolean {
return current + types.length <= tokens.length && types.every((type, i) => type === tokens[current + i].type);
}
function isStatement(...names: string[]): boolean {
return (
tokens[current]?.type === TOKEN_TYPES.OpenStatement &&
tokens[current + 1]?.type === TOKEN_TYPES.Identifier &&
names.includes(tokens[current + 1]?.value)
);
}
function isIdentifier(...names: string[]): boolean {
return (
current + names.length <= tokens.length &&
names.every((name, i) => tokens[current + i].type === "Identifier" && name === tokens[current + i].value)
);
}
function parseText(): StringLiteral {
return new StringLiteral(expect(TOKEN_TYPES.Text, "Expected text token").value);
}
function parseJinjaStatement(): Statement {
// Consume {% token
expect(TOKEN_TYPES.OpenStatement, "Expected opening statement token");
// next token must be Identifier whose .value tells us which statement
if (tokens[current].type !== TOKEN_TYPES.Identifier) {
throw new SyntaxError(`Unknown statement, got ${tokens[current].type}`);
}
const name = tokens[current].value;
let result: Statement;
switch (name) {
case "set":
++current;
result = parseSetStatement();
break;
case "if":
++current;
result = parseIfStatement();
// expect {% endif %}
expect(TOKEN_TYPES.OpenStatement, "Expected {% token");
expectIdentifier("endif");
expect(TOKEN_TYPES.CloseStatement, "Expected %} token");
break;
case "macro":
++current;
result = parseMacroStatement();
// expect {% endmacro %}
expect(TOKEN_TYPES.OpenStatement, "Expected {% token");
expectIdentifier("endmacro");
expect(TOKEN_TYPES.CloseStatement, "Expected %} token");
break;
case "for":
++current;
result = parseForStatement();
// expect {% endfor %}
expect(TOKEN_TYPES.OpenStatement, "Expected {% token");
expectIdentifier("endfor");
expect(TOKEN_TYPES.CloseStatement, "Expected %} token");
break;
case "call": {
++current; // consume 'call'
let callerArgs: Statement[] | null = null;
if (is(TOKEN_TYPES.OpenParen)) {
// Optional caller arguments, e.g. {% call(user) dump_users(...) %}
callerArgs = parseArgs();
}
const callee = parsePrimaryExpression();
if (callee.type !== "Identifier") {
throw new SyntaxError(`Expected identifier following call statement`);
}
const callArgs = parseArgs();
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
const body: Statement[] = [];
while (!isStatement("endcall")) {
body.push(parseAny());
}
expect(TOKEN_TYPES.OpenStatement, "Expected '{%'");
expectIdentifier("endcall");
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
const callExpr = new CallExpression(callee, callArgs);
result = new CallStatement(callExpr, callerArgs, body);
break;
}
case "break":
++current;
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
result = new Break();
break;
case "continue":
++current;
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
result = new Continue();
break;
case "filter": {
++current; // consume 'filter'
let filterNode = parsePrimaryExpression();
if (filterNode instanceof Identifier && is(TOKEN_TYPES.OpenParen)) {
filterNode = parseCallExpression(filterNode);
}
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
const filterBody: Statement[] = [];
while (!isStatement("endfilter")) {
filterBody.push(parseAny());
}
expect(TOKEN_TYPES.OpenStatement, "Expected '{%'");
expectIdentifier("endfilter");
expect(TOKEN_TYPES.CloseStatement, "Expected '%}'");
result = new FilterStatement(filterNode as Identifier | CallExpression, filterBody);
break;
}
default:
throw new SyntaxError(`Unknown statement type: ${name}`);
}
return result;
}
function parseJinjaExpression(): Statement {
// Consume {{ }} tokens
expect(TOKEN_TYPES.OpenExpression, "Expected opening expression token");
const result = parseExpression();
expect(TOKEN_TYPES.CloseExpression, "Expected closing expression token");
return result;
}
// NOTE: `set` acts as both declaration statement and assignment expression
function parseSetStatement(): Statement {
const left = parseExpressionSequence();
let value: Statement | null = null;
const body: Statement[] = [];
if (is(TOKEN_TYPES.Equals)) {
++current;
value = parseExpressionSequence();
} else {
// parsing multiline set here
expect(TOKEN_TYPES.CloseStatement, "Expected %} token");
while (!isStatement("endset")) {
body.push(parseAny());
}
expect(TOKEN_TYPES.OpenStatement, "Expected {% token");
expectIdentifier("endset");
}
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
return new SetStatement(left, value, body);
}
function parseIfStatement(): If {
const test = parseExpression();
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
const body: Statement[] = [];
const alternate: Statement[] = [];
// Keep parsing 'if' body until we reach the first {% elif %} or {% else %} or {% endif %}
while (!isStatement("elif", "else", "endif")) {
body.push(parseAny());
}
// handle {% elif %}
if (isStatement("elif")) {
++current; // consume {%
++current; // consume 'elif'
const result = parseIfStatement(); // nested If
alternate.push(result);
}
// handle {% else %}
else if (isStatement("else")) {
++current; // consume {%
++current; // consume 'else'
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
// keep going until we hit {% endif %}
while (!isStatement("endif")) {
alternate.push(parseAny());
}
}
return new If(test, body, alternate);
}
function parseMacroStatement(): Macro {
const name = parsePrimaryExpression();
if (name.type !== "Identifier") {
throw new SyntaxError(`Expected identifier following macro statement`);
}
const args = parseArgs();
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
// Body of macro
const body: Statement[] = [];
// Keep going until we hit {% endmacro
while (!isStatement("endmacro")) {
body.push(parseAny());
}
return new Macro(name as Identifier, args, body);
}
function parseExpressionSequence(primary = false): Statement {
const fn = primary ? parsePrimaryExpression : parseExpression;
const expressions = [fn()];
const isTuple = is(TOKEN_TYPES.Comma);
while (isTuple) {
++current; // consume comma
expressions.push(fn());
if (!is(TOKEN_TYPES.Comma)) {
break;
}
}
return isTuple ? new TupleLiteral(expressions) : expressions[0];
}
function parseForStatement(): For {
// e.g., `message` in `for message in messages`
const loopVariable = parseExpressionSequence(true); // should be an identifier/tuple
if (!(loopVariable instanceof Identifier || loopVariable instanceof TupleLiteral)) {
throw new SyntaxError(`Expected identifier/tuple for the loop variable, got ${loopVariable.type} instead`);
}
if (!isIdentifier("in")) {
throw new SyntaxError("Expected `in` keyword following loop variable");
}
++current;
// `messages` in `for message in messages`
const iterable = parseExpression();
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
// Body of for loop
const body: Statement[] = [];
// Keep going until we hit {% endfor or {% else
while (!isStatement("endfor", "else")) {
body.push(parseAny());
}
// (Optional) else block
const alternative: Statement[] = [];
if (isStatement("else")) {
++current; // consume {%
++current; // consume 'else'
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
while (!isStatement("endfor")) {
alternative.push(parseAny());
}
}
return new For(loopVariable, iterable, body, alternative);
}
function parseExpression(): Statement {
// Choose parse function with lowest precedence
return parseIfExpression();
}
function parseIfExpression(): Statement {
const a = parseLogicalOrExpression();
if (isIdentifier("if")) {
// Ternary expression
++current; // consume 'if'
const test = parseLogicalOrExpression();
if (isIdentifier("else")) {
// Ternary expression with else
++current; // consume 'else'
const falseExpr = parseIfExpression(); // recurse to support chained ternaries
return new Ternary(test, a, falseExpr);
} else {
// Select expression on iterable
return new SelectExpression(a, test);
}
}
return a;
}
function parseLogicalOrExpression(): Statement {
let left = parseLogicalAndExpression();
while (isIdentifier("or")) {
const operator = tokens[current];
++current;
const right = parseLogicalAndExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseLogicalAndExpression(): Statement {
let left = parseLogicalNegationExpression();
while (isIdentifier("and")) {
const operator = tokens[current];
++current;
const right = parseLogicalNegationExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseLogicalNegationExpression(): Statement {
let right: UnaryExpression | undefined;
// Try parse unary operators
while (isIdentifier("not")) {
// not not ...
const operator = tokens[current];
++current;
const arg = parseLogicalNegationExpression(); // not test.x === not (test.x)
right = new UnaryExpression(operator, arg);
}
return right ?? parseComparisonExpression();
}
function parseComparisonExpression(): Statement {
// NOTE: membership has same precedence as comparison
// e.g., ('a' in 'apple' == 'b' in 'banana') evaluates as ('a' in ('apple' == ('b' in 'banana')))
let left = parseAdditiveExpression();
while (true) {
let operator: Token;
if (isIdentifier("not", "in")) {
operator = new Token("not in", TOKEN_TYPES.Identifier);
current += 2;
} else if (isIdentifier("in")) {
operator = tokens[current++];
} else if (is(TOKEN_TYPES.ComparisonBinaryOperator)) {
operator = tokens[current++];
} else {
break;
}
const right = parseAdditiveExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseAdditiveExpression(): Statement {
let left = parseMultiplicativeExpression();
while (is(TOKEN_TYPES.AdditiveBinaryOperator)) {
const operator = tokens[current];
++current;
const right = parseMultiplicativeExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseCallMemberExpression(): Statement {
// Handle member expressions recursively
const member = parseMemberExpression(parsePrimaryExpression()); // foo.x
if (is(TOKEN_TYPES.OpenParen)) {
// foo.x()
return parseCallExpression(member);
}
return member;
}
function parseCallExpression(callee: Statement): Statement {
let expression: Statement = new CallExpression(callee, parseArgs());
expression = parseMemberExpression(expression); // foo.x().y
if (is(TOKEN_TYPES.OpenParen)) {
// foo.x()()
expression = parseCallExpression(expression);
}
return expression;
}
function parseArgs(): Statement[] {
// add (x + 5, foo())
expect(TOKEN_TYPES.OpenParen, "Expected opening parenthesis for arguments list");
const args = parseArgumentsList();
expect(TOKEN_TYPES.CloseParen, "Expected closing parenthesis for arguments list");
return args;
}
function parseArgumentsList(): Statement[] {
// comma-separated arguments list
const args = [];
while (!is(TOKEN_TYPES.CloseParen)) {
let argument: Statement;
// unpacking: *expr
if (tokens[current].type === TOKEN_TYPES.MultiplicativeBinaryOperator && tokens[current].value === "*") {
++current;
const expr = parseExpression();
argument = new SpreadExpression(expr);
} else {
argument = parseExpression();
if (is(TOKEN_TYPES.Equals)) {
// keyword argument
// e.g., func(x = 5, y = a or b)
++current; // consume equals
if (!(argument instanceof Identifier)) {
throw new SyntaxError(`Expected identifier for keyword argument`);
}
const value = parseExpression();
argument = new KeywordArgumentExpression(argument as Identifier, value);
}
}
args.push(argument);
if (is(TOKEN_TYPES.Comma)) {
++current; // consume comma
}
}
return args;
}
function parseMemberExpressionArgumentsList(): Statement {
// NOTE: This also handles slice expressions colon-separated arguments list
// e.g., ['test'], [0], [:2], [1:], [1:2], [1:2:3]
const slices: (Statement | undefined)[] = [];
let isSlice = false;
while (!is(TOKEN_TYPES.CloseSquareBracket)) {
if (is(TOKEN_TYPES.Colon)) {
// A case where a default is used
// e.g., [:2] will be parsed as [undefined, 2]
slices.push(undefined);
++current; // consume colon
isSlice = true;
} else {
slices.push(parseExpression());
if (is(TOKEN_TYPES.Colon)) {
++current; // consume colon after expression, if it exists
isSlice = true;
}
}
}
if (slices.length === 0) {
// []
throw new SyntaxError(`Expected at least one argument for member/slice expression`);
}
if (isSlice) {
if (slices.length > 3) {
throw new SyntaxError(`Expected 0-3 arguments for slice expression`);
}
return new SliceExpression(...slices);
}
return slices[0] as Statement; // normal member expression
}
function parseMemberExpression(object: Statement): Statement {
while (is(TOKEN_TYPES.Dot) || is(TOKEN_TYPES.OpenSquareBracket)) {
const operator = tokens[current]; // . or [
++current;
let property: Statement;
const computed = operator.type === TOKEN_TYPES.OpenSquareBracket;
if (computed) {
// computed (i.e., bracket notation: obj[expr])
property = parseMemberExpressionArgumentsList();
expect(TOKEN_TYPES.CloseSquareBracket, "Expected closing square bracket");
} else {
// non-computed (i.e., dot notation: obj.expr)
property = parsePrimaryExpression(); // should be an identifier
if (property.type !== "Identifier") {
throw new SyntaxError(`Expected identifier following dot operator`);
}
}
object = new MemberExpression(object, property, computed);
}
return object;
}
function parseMultiplicativeExpression(): Statement {
let left = parseTestExpression();
// Multiplicative operators have higher precedence than test expressions
// e.g., (4 * 4 is divisibleby(2)) evaluates as (4 * (4 is divisibleby(2)))
while (is(TOKEN_TYPES.MultiplicativeBinaryOperator)) {
const operator = tokens[current++];
const right = parseTestExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseTestExpression(): Statement {
let operand = parseFilterExpression();
while (isIdentifier("is")) {
// Support chaining tests
++current; // consume is
const negate = isIdentifier("not");
if (negate) {
++current; // consume not
}
const filter = parsePrimaryExpression();
if (!(filter instanceof Identifier)) {
throw new SyntaxError(`Expected identifier for the test`);
}
// TODO: Add support for non-identifier tests
operand = new TestExpression(operand, negate, filter);
}
return operand;
}
function parseFilterExpression(): Statement {
let operand = parseCallMemberExpression();
while (is(TOKEN_TYPES.Pipe)) {
// Support chaining filters
++current; // consume pipe
let filter = parsePrimaryExpression(); // should be an identifier
if (!(filter instanceof Identifier)) {
throw new SyntaxError(`Expected identifier for the filter`);
}
if (is(TOKEN_TYPES.OpenParen)) {
filter = parseCallExpression(filter);
}
operand = new FilterExpression(operand, filter as Identifier | CallExpression);
}
return operand;
}
function parsePrimaryExpression(): Statement {
// Primary expression: number, string, identifier, function call, parenthesized expression
const token = tokens[current++];
switch (token.type) {
case TOKEN_TYPES.NumericLiteral: {
const num = token.value;
return num.includes(".") ? new FloatLiteral(Number(num)) : new IntegerLiteral(Number(num));
}
case TOKEN_TYPES.StringLiteral: {
let value = token.value;
while (is(TOKEN_TYPES.StringLiteral)) {
value += tokens[current++].value;
}
return new StringLiteral(value);
}
case TOKEN_TYPES.Identifier:
return new Identifier(token.value);
case TOKEN_TYPES.OpenParen: {
const expression = parseExpressionSequence();
expect(TOKEN_TYPES.CloseParen, "Expected closing parenthesis, got ${tokens[current].type} instead.");
return expression;
}
case TOKEN_TYPES.OpenSquareBracket: {
const values = [];
while (!is(TOKEN_TYPES.CloseSquareBracket)) {
values.push(parseExpression());
if (is(TOKEN_TYPES.Comma)) {
++current; // consume comma
}
}
++current; // consume closing square bracket
return new ArrayLiteral(values);
}
case TOKEN_TYPES.OpenCurlyBracket: {
const values = new Map();
while (!is(TOKEN_TYPES.CloseCurlyBracket)) {
const key = parseExpression();
expect(TOKEN_TYPES.Colon, "Expected colon between key and value in object literal");
const value = parseExpression();
values.set(key, value);
if (is(TOKEN_TYPES.Comma)) {
++current; // consume comma
}
}
++current; // consume closing curly bracket
return new ObjectLiteral(values);
}
default:
throw new SyntaxError(`Unexpected token: ${token.type}`);
}
}
while (current < tokens.length) {
program.body.push(parseAny());
}
return program;
}