UNPKG

@axflow/models

Version:

Zero-dependency, modular SDK for building robust natural language applications

122 lines (121 loc) 3.8 kB
// src/huggingface/text-generation.ts import { HttpError, isHttpError, POST } from "@axflow/models/shared"; var HUGGING_FACE_MODEL_API_URL = "https://api-inference.huggingface.co/models/"; var HUGGING_FACE_STOP_TOKEN = "</s>"; function headers(apiKey, customHeaders) { const headers2 = { accept: "application/json", ...customHeaders, "content-type": "application/json" }; if (typeof apiKey === "string") { headers2.authorization = `Bearer ${apiKey}`; } return headers2; } async function run(request, options) { const url = options.apiUrl || HUGGING_FACE_MODEL_API_URL + request.model; const headers_ = headers(options.apiKey, options.headers); const body = JSON.stringify({ ...request, stream: false }); const response = await POST(url, { headers: headers_, body, fetch: options.fetch, signal: options.signal }); return response.json(); } async function streamBytes(request, options) { const url = options.apiUrl || HUGGING_FACE_MODEL_API_URL + request.model; const headers_ = headers(options.apiKey, options.headers); const body = JSON.stringify({ ...request, stream: true }); try { const response = await POST(url, { headers: headers_, body, fetch: options.fetch, signal: options.signal }); if (!response.body) { throw new HttpError("Expected response body to be a ReadableStream", response); } return response.body; } catch (e) { if (isHttpError(e)) { try { const body2 = await e.response.json(); if (body2?.error[0]?.includes("`stream` is not supported for this model")) { throw new HttpError(`Model '${request.model}' does not support streaming`, e.response); } } catch { throw e; } } throw e; } } function noop(chunk) { return chunk; } function chunkToToken(chunk) { if (chunk.token.special && chunk.token.text.includes(HUGGING_FACE_STOP_TOKEN)) { return ""; } return chunk.token.text; } async function stream(request, options) { const byteStream = await streamBytes(request, options); return byteStream.pipeThrough(new HuggingFaceDecoderStream(noop)); } async function streamTokens(request, options) { const byteStream = await streamBytes(request, options); return byteStream.pipeThrough(new HuggingFaceDecoderStream(chunkToToken)); } var HuggingFaceTextGeneration = class { static run = run; static streamBytes = streamBytes; static stream = stream; static streamTokens = streamTokens; }; var HuggingFaceDecoderStream = class _HuggingFaceDecoderStream extends TransformStream { static LINES_RE = /data:\s*(.+)/; static parseChunk(lines) { lines = lines.trim(); if (lines.length === 0) { return null; } const match = lines.match(_HuggingFaceDecoderStream.LINES_RE); try { const data = match[1]; return JSON.parse(data); } catch (e) { throw new Error(`Malformed streaming data from HuggingFace: ${JSON.stringify(lines)}`); } } static transformer(map) { let buffer = []; const decoder = new TextDecoder(); return (bytes, controller) => { const chunk = decoder.decode(bytes); for (let i = 0, len = chunk.length; i < len; ++i) { const bufferLength = buffer.length; const isSeparator = chunk[i] === "\n" && buffer[bufferLength - 1] === "\n"; if (!isSeparator) { buffer.push(chunk[i]); continue; } const parsedChunk = _HuggingFaceDecoderStream.parseChunk(buffer.join("")); if (parsedChunk) { controller.enqueue(map(parsedChunk)); } buffer = []; } }; } constructor(map) { super({ transform: _HuggingFaceDecoderStream.transformer(map) }); } }; export { HuggingFaceTextGeneration };