@axflow/models
Version:
Zero-dependency, modular SDK for building robust natural language applications
122 lines (121 loc) • 3.8 kB
JavaScript
// src/huggingface/text-generation.ts
import { HttpError, isHttpError, POST } from "@axflow/models/shared";
var HUGGING_FACE_MODEL_API_URL = "https://api-inference.huggingface.co/models/";
var HUGGING_FACE_STOP_TOKEN = "</s>";
function headers(apiKey, customHeaders) {
const headers2 = {
accept: "application/json",
...customHeaders,
"content-type": "application/json"
};
if (typeof apiKey === "string") {
headers2.authorization = `Bearer ${apiKey}`;
}
return headers2;
}
async function run(request, options) {
const url = options.apiUrl || HUGGING_FACE_MODEL_API_URL + request.model;
const headers_ = headers(options.apiKey, options.headers);
const body = JSON.stringify({ ...request, stream: false });
const response = await POST(url, {
headers: headers_,
body,
fetch: options.fetch,
signal: options.signal
});
return response.json();
}
async function streamBytes(request, options) {
const url = options.apiUrl || HUGGING_FACE_MODEL_API_URL + request.model;
const headers_ = headers(options.apiKey, options.headers);
const body = JSON.stringify({ ...request, stream: true });
try {
const response = await POST(url, {
headers: headers_,
body,
fetch: options.fetch,
signal: options.signal
});
if (!response.body) {
throw new HttpError("Expected response body to be a ReadableStream", response);
}
return response.body;
} catch (e) {
if (isHttpError(e)) {
try {
const body2 = await e.response.json();
if (body2?.error[0]?.includes("`stream` is not supported for this model")) {
throw new HttpError(`Model '${request.model}' does not support streaming`, e.response);
}
} catch {
throw e;
}
}
throw e;
}
}
function noop(chunk) {
return chunk;
}
function chunkToToken(chunk) {
if (chunk.token.special && chunk.token.text.includes(HUGGING_FACE_STOP_TOKEN)) {
return "";
}
return chunk.token.text;
}
async function stream(request, options) {
const byteStream = await streamBytes(request, options);
return byteStream.pipeThrough(new HuggingFaceDecoderStream(noop));
}
async function streamTokens(request, options) {
const byteStream = await streamBytes(request, options);
return byteStream.pipeThrough(new HuggingFaceDecoderStream(chunkToToken));
}
var HuggingFaceTextGeneration = class {
static run = run;
static streamBytes = streamBytes;
static stream = stream;
static streamTokens = streamTokens;
};
var HuggingFaceDecoderStream = class _HuggingFaceDecoderStream extends TransformStream {
static LINES_RE = /data:\s*(.+)/;
static parseChunk(lines) {
lines = lines.trim();
if (lines.length === 0) {
return null;
}
const match = lines.match(_HuggingFaceDecoderStream.LINES_RE);
try {
const data = match[1];
return JSON.parse(data);
} catch (e) {
throw new Error(`Malformed streaming data from HuggingFace: ${JSON.stringify(lines)}`);
}
}
static transformer(map) {
let buffer = [];
const decoder = new TextDecoder();
return (bytes, controller) => {
const chunk = decoder.decode(bytes);
for (let i = 0, len = chunk.length; i < len; ++i) {
const bufferLength = buffer.length;
const isSeparator = chunk[i] === "\n" && buffer[bufferLength - 1] === "\n";
if (!isSeparator) {
buffer.push(chunk[i]);
continue;
}
const parsedChunk = _HuggingFaceDecoderStream.parseChunk(buffer.join(""));
if (parsedChunk) {
controller.enqueue(map(parsedChunk));
}
buffer = [];
}
};
}
constructor(map) {
super({ transform: _HuggingFaceDecoderStream.transformer(map) });
}
};
export {
HuggingFaceTextGeneration
};