UNPKG

chromadb-default-embed

Version:

Chroma's fork of @xenova/transformers serving as our default embedding function

662 lines (574 loc) 26.2 kB
/** * @file Utility functions to interact with the Hugging Face Hub (https://huggingface.co/models) * * @module utils/hub */ import fs from 'fs'; import path from 'path'; import stream from 'stream/web'; import { env } from '../env.js'; import { dispatchCallback } from './core.js'; if (!globalThis.ReadableStream) { // @ts-ignore globalThis.ReadableStream = stream.ReadableStream; // ReadableStream is not a global with Node 16 } /** * @typedef {Object} PretrainedOptions Options for loading a pretrained model. * @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files). * @property {function} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates. * @property {Object} [config=null] Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when: * - The model is a model provided by the library (loaded with the *model id* string of a pretrained model). * - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory. * @property {string} [cache_dir=null] Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. * @property {boolean} [local_files_only=false] Whether or not to only look at local files (e.g., not try downloading the model). * @property {string} [revision='main'] The specific model version to use. It can be a branch name, a tag name, or a commit id, * since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. * NOTE: This setting is ignored for local requests. * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. */ class FileResponse { /** * Mapping from file extensions to MIME types. */ _CONTENT_TYPE_MAP = { 'txt': 'text/plain', 'html': 'text/html', 'css': 'text/css', 'js': 'text/javascript', 'json': 'application/json', 'png': 'image/png', 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'gif': 'image/gif', } /** * Creates a new `FileResponse` object. * @param {string|URL} filePath */ constructor(filePath) { this.filePath = filePath; this.headers = new Headers(); this.exists = fs.existsSync(filePath); if (this.exists) { this.status = 200; this.statusText = 'OK'; let stats = fs.statSync(filePath); this.headers.set('content-length', stats.size.toString()); this.updateContentType(); let self = this; this.body = new ReadableStream({ start(controller) { self.arrayBuffer().then(buffer => { controller.enqueue(new Uint8Array(buffer)); controller.close(); }) } }); } else { this.status = 404; this.statusText = 'Not Found'; this.body = null; } } /** * Updates the 'content-type' header property of the response based on the extension of * the file specified by the filePath property of the current object. * @returns {void} */ updateContentType() { // Set content-type header based on file extension const extension = this.filePath.toString().split('.').pop().toLowerCase(); this.headers.set('content-type', this._CONTENT_TYPE_MAP[extension] ?? 'application/octet-stream'); } /** * Clone the current FileResponse object. * @returns {FileResponse} A new FileResponse object with the same properties as the current object. */ clone() { let response = new FileResponse(this.filePath); response.exists = this.exists; response.status = this.status; response.statusText = this.statusText; response.headers = new Headers(this.headers); return response; } /** * Reads the contents of the file specified by the filePath property and returns a Promise that * resolves with an ArrayBuffer containing the file's contents. * @returns {Promise<ArrayBuffer>} A Promise that resolves with an ArrayBuffer containing the file's contents. * @throws {Error} If the file cannot be read. */ async arrayBuffer() { const data = await fs.promises.readFile(this.filePath); return data.buffer; } /** * Reads the contents of the file specified by the filePath property and returns a Promise that * resolves with a Blob containing the file's contents. * @returns {Promise<Blob>} A Promise that resolves with a Blob containing the file's contents. * @throws {Error} If the file cannot be read. */ async blob() { const data = await fs.promises.readFile(this.filePath); return new Blob([data], { type: this.headers.get('content-type') }); } /** * Reads the contents of the file specified by the filePath property and returns a Promise that * resolves with a string containing the file's contents. * @returns {Promise<string>} A Promise that resolves with a string containing the file's contents. * @throws {Error} If the file cannot be read. */ async text() { const data = await fs.promises.readFile(this.filePath, 'utf8'); return data; } /** * Reads the contents of the file specified by the filePath property and returns a Promise that * resolves with a parsed JavaScript object containing the file's contents. * * @returns {Promise<Object>} A Promise that resolves with a parsed JavaScript object containing the file's contents. * @throws {Error} If the file cannot be read. */ async json() { return JSON.parse(await this.text()); } } /** * Determines whether the given string is a valid HTTP or HTTPS URL. * @param {string|URL} string The string to test for validity as an HTTP or HTTPS URL. * @param {string[]} [validHosts=null] A list of valid hostnames. If specified, the URL's hostname must be in this list. * @returns {boolean} True if the string is a valid HTTP or HTTPS URL, false otherwise. */ function isValidHttpUrl(string, validHosts = null) { // https://stackoverflow.com/a/43467144 let url; try { url = new URL(string); } catch (_) { return false; } if (validHosts && !validHosts.includes(url.hostname)) { return false; } return url.protocol === "http:" || url.protocol === "https:"; } /** * Helper function to get a file, using either the Fetch API or FileSystem API. * * @param {URL|string} urlOrPath The URL/path of the file to get. * @returns {Promise<FileResponse|Response>} A promise that resolves to a FileResponse object (if the file is retrieved using the FileSystem API), or a Response object (if the file is retrieved using the Fetch API). */ export async function getFile(urlOrPath) { if (env.useFS && !isValidHttpUrl(urlOrPath)) { return new FileResponse(urlOrPath); } else if (typeof process !== 'undefined' && process?.release?.name === 'node') { const IS_CI = !!process.env?.TESTING_REMOTELY; const version = env.version; const headers = new Headers(); headers.set('User-Agent', `transformers.js/${version}; is_ci/${IS_CI};`); // Check whether we are making a request to the Hugging Face Hub. const isHFURL = isValidHttpUrl(urlOrPath, ['huggingface.co', 'hf.co']); if (isHFURL) { // If an access token is present in the environment variables, // we add it to the request headers. // NOTE: We keep `HF_ACCESS_TOKEN` for backwards compatibility (as a fallback). const token = process.env?.HF_TOKEN ?? process.env?.HF_ACCESS_TOKEN; if (token) { headers.set('Authorization', `Bearer ${token}`); } } return fetch(urlOrPath, { headers }); } else { // Running in a browser-environment, so we use default headers // NOTE: We do not allow passing authorization headers in the browser, // since this would require exposing the token to the client. return fetch(urlOrPath); } } const ERROR_MAPPING = { // 4xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses) 400: 'Bad request error occurred while trying to load file', 401: 'Unauthorized access to file', 403: 'Forbidden access to file', 404: 'Could not locate file', 408: 'Request timeout error occurred while trying to load file', // 5xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#server_error_responses) 500: 'Internal server error error occurred while trying to load file', 502: 'Bad gateway error occurred while trying to load file', 503: 'Service unavailable error occurred while trying to load file', 504: 'Gateway timeout error occurred while trying to load file', } /** * Helper method to handle fatal errors that occur while trying to load a file from the Hugging Face Hub. * @param {number} status The HTTP status code of the error. * @param {string} remoteURL The URL of the file that could not be loaded. * @param {boolean} fatal Whether to raise an error if the file could not be loaded. * @returns {null} Returns `null` if `fatal = true`. * @throws {Error} If `fatal = false`. */ function handleError(status, remoteURL, fatal) { if (!fatal) { // File was not loaded correctly, but it is optional. // TODO in future, cache the response? return null; } const message = ERROR_MAPPING[status] ?? `Error (${status}) occurred while trying to load file`; throw Error(`${message}: "${remoteURL}".`); } class FileCache { /** * Instantiate a `FileCache` object. * @param {string} path */ constructor(path) { this.path = path; } /** * Checks whether the given request is in the cache. * @param {string} request * @returns {Promise<FileResponse | undefined>} */ async match(request) { let filePath = path.join(this.path, request); let file = new FileResponse(filePath); if (file.exists) { return file; } else { return undefined; } } /** * Adds the given response to the cache. * @param {string} request * @param {Response|FileResponse} response * @returns {Promise<void>} */ async put(request, response) { const buffer = Buffer.from(await response.arrayBuffer()); let outputPath = path.join(this.path, request); try { await fs.promises.mkdir(path.dirname(outputPath), { recursive: true }); await fs.promises.writeFile(outputPath, buffer); } catch (err) { console.warn('An error occurred while writing the file to cache:', err) } } // TODO add the rest? // addAll(requests: RequestInfo[]): Promise<void>; // delete(request: RequestInfo | URL, options?: CacheQueryOptions): Promise<boolean>; // keys(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise<ReadonlyArray<Request>>; // match(request: RequestInfo | URL, options?: CacheQueryOptions): Promise<Response | undefined>; // matchAll(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise<ReadonlyArray<Response>>; } /** * * @param {FileCache|Cache} cache The cache to search * @param {string[]} names The names of the item to search for * @returns {Promise<FileResponse|Response|undefined>} The item from the cache, or undefined if not found. */ async function tryCache(cache, ...names) { for (let name of names) { try { let result = await cache.match(name); if (result) return result; } catch (e) { continue; } } return undefined; } /** * * Retrieves a file from either a remote URL using the Fetch API or from the local file system using the FileSystem API. * If the filesystem is available and `env.useCache = true`, the file will be downloaded and cached. * * @param {string} path_or_repo_id This can be either: * - a string, the *model id* of a model repo on huggingface.co. * - a path to a *directory* potentially containing the file. * @param {string} filename The name of the file to locate in `path_or_repo`. * @param {boolean} [fatal=true] Whether to throw an error if the file is not found. * @param {PretrainedOptions} [options] An object containing optional parameters. * * @throws Will throw an error if the file is not found and `fatal` is true. * @returns {Promise} A Promise that resolves with the file content as a buffer. */ export async function getModelFile(path_or_repo_id, filename, fatal = true, options = {}) { if (!env.allowLocalModels) { // User has disabled local models, so we just make sure other settings are correct. if (options.local_files_only) { throw Error("Invalid configuration detected: local models are disabled (`env.allowLocalModels=false`) but you have requested to only use local models (`local_files_only=true`).") } else if (!env.allowRemoteModels) { throw Error("Invalid configuration detected: both local and remote models are disabled. Fix by setting `env.allowLocalModels` or `env.allowRemoteModels` to `true`.") } } // Initiate file retrieval dispatchCallback(options.progress_callback, { status: 'initiate', name: path_or_repo_id, file: filename }) // First, check if the a caching backend is available // If no caching mechanism available, will download the file every time let cache; if (!cache && env.useBrowserCache) { if (typeof caches === 'undefined') { throw Error('Browser cache is not available in this environment.') } try { // In some cases, the browser cache may be visible, but not accessible due to security restrictions. // For example, when running an application in an iframe, if a user attempts to load the page in // incognito mode, the following error is thrown: `DOMException: Failed to execute 'open' on 'CacheStorage': // An attempt was made to break through the security policy of the user agent.` // So, instead of crashing, we just ignore the error and continue without using the cache. cache = await caches.open('transformers-cache'); } catch (e) { console.warn('An error occurred while opening the browser cache:', e); } } if (!cache && env.useFSCache) { // TODO throw error if not available // If `cache_dir` is not specified, use the default cache directory cache = new FileCache(options.cache_dir ?? env.cacheDir); } if (!cache && env.useCustomCache) { // Allow the user to specify a custom cache system. if (!env.customCache) { throw Error('`env.useCustomCache=true`, but `env.customCache` is not defined.') } // Check that the required methods are defined: if (!env.customCache.match || !env.customCache.put) { throw new Error( "`env.customCache` must be an object which implements the `match` and `put` functions of the Web Cache API. " + "For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache" ) } cache = env.customCache; } const revision = options.revision ?? 'main'; let requestURL = pathJoin(path_or_repo_id, filename); let localPath = pathJoin(env.localModelPath, requestURL); let remoteURL = pathJoin( env.remoteHost, env.remotePathTemplate .replaceAll('{model}', path_or_repo_id) .replaceAll('{revision}', revision), filename ); // Choose cache key for filesystem cache // When using the main revision (default), we use the request URL as the cache key. // If a specific revision is requested, we account for this in the cache key. let fsCacheKey = revision === 'main' ? requestURL : pathJoin(path_or_repo_id, revision, filename); /** @type {string} */ let cacheKey; let proposedCacheKey = cache instanceof FileCache ? fsCacheKey : remoteURL; // Whether to cache the final response in the end. let toCacheResponse = false; /** @type {Response|FileResponse|undefined} */ let response; if (cache) { // A caching system is available, so we try to get the file from it. // 1. We first try to get from cache using the local path. In some environments (like deno), // non-URL cache keys are not allowed. In these cases, `response` will be undefined. // 2. If no response is found, we try to get from cache using the remote URL or file system cache. response = await tryCache(cache, localPath, proposedCacheKey); } const cacheHit = response !== undefined; if (response === undefined) { // Caching not available, or file is not cached, so we perform the request if (env.allowLocalModels) { // Accessing local models is enabled, so we try to get the file locally. // If request is a valid HTTP URL, we skip the local file check. Otherwise, we try to get the file locally. const isURL = isValidHttpUrl(requestURL); if (!isURL) { try { response = await getFile(localPath); cacheKey = localPath; // Update the cache key to be the local path } catch (e) { // Something went wrong while trying to get the file locally. // NOTE: error handling is done in the next step (since `response` will be undefined) console.warn(`Unable to load from local path "${localPath}": "${e}"`); } } else if (options.local_files_only) { throw new Error(`\`local_files_only=true\`, but attempted to load a remote file from: ${requestURL}.`); } else if (!env.allowRemoteModels) { throw new Error(`\`env.allowRemoteModels=false\`, but attempted to load a remote file from: ${requestURL}.`); } } if (response === undefined || response.status === 404) { // File not found locally. This means either: // - The user has disabled local file access (`env.allowLocalModels=false`) // - the path is a valid HTTP url (`response === undefined`) // - the path is not a valid HTTP url and the file is not present on the file system or local server (`response.status === 404`) if (options.local_files_only || !env.allowRemoteModels) { // User requested local files only, but the file is not found locally. if (fatal) { throw Error(`\`local_files_only=true\` or \`env.allowRemoteModels=false\` and file was not found locally at "${localPath}".`); } else { // File not found, but this file is optional. // TODO in future, cache the response? return null; } } // File not found locally, so we try to download it from the remote server response = await getFile(remoteURL); if (response.status !== 200) { return handleError(response.status, remoteURL, fatal); } // Success! We use the proposed cache key from earlier cacheKey = proposedCacheKey; } // Only cache the response if: toCacheResponse = cache // 1. A caching system is available && typeof Response !== 'undefined' // 2. `Response` is defined (i.e., we are in a browser-like environment) && response instanceof Response // 3. result is a `Response` object (i.e., not a `FileResponse`) && response.status === 200 // 4. request was successful (status code 200) } // Start downloading dispatchCallback(options.progress_callback, { status: 'download', name: path_or_repo_id, file: filename }) const progressInfo = { status: 'progress', name: path_or_repo_id, file: filename } /** @type {Uint8Array} */ let buffer; if (!options.progress_callback) { // If no progress callback is specified, we can use the `.arrayBuffer()` // method to read the response. buffer = new Uint8Array(await response.arrayBuffer()); } else if ( cacheHit // The item is being read from the cache && typeof navigator !== 'undefined' && /firefox/i.test(navigator.userAgent) // We are in Firefox ) { // Due to bug in Firefox, we cannot display progress when loading from cache. // Fortunately, since this should be instantaneous, this should not impact users too much. buffer = new Uint8Array(await response.arrayBuffer()); // For completeness, we still fire the final progress callback dispatchCallback(options.progress_callback, { ...progressInfo, progress: 100, loaded: buffer.length, total: buffer.length, }) } else { buffer = await readResponse(response, data => { dispatchCallback(options.progress_callback, { ...progressInfo, ...data, }) }) } if ( // Only cache web responses // i.e., do not cache FileResponses (prevents duplication) toCacheResponse && cacheKey && // Check again whether request is in cache. If not, we add the response to the cache (await cache.match(cacheKey) === undefined) ) { // NOTE: We use `new Response(buffer, ...)` instead of `response.clone()` to handle LFS files await cache.put(cacheKey, new Response(buffer, { headers: response.headers })) .catch(err => { // Do not crash if unable to add to cache (e.g., QuotaExceededError). // Rather, log a warning and proceed with execution. console.warn(`Unable to add response to browser cache: ${err}.`); }); } dispatchCallback(options.progress_callback, { status: 'done', name: path_or_repo_id, file: filename }); return buffer; } /** * Fetches a JSON file from a given path and file name. * * @param {string} modelPath The path to the directory containing the file. * @param {string} fileName The name of the file to fetch. * @param {boolean} [fatal=true] Whether to throw an error if the file is not found. * @param {PretrainedOptions} [options] An object containing optional parameters. * @returns {Promise<Object>} The JSON data parsed into a JavaScript object. * @throws Will throw an error if the file is not found and `fatal` is true. */ export async function getModelJSON(modelPath, fileName, fatal = true, options = {}) { let buffer = await getModelFile(modelPath, fileName, fatal, options); if (buffer === null) { // Return empty object return {} } let decoder = new TextDecoder('utf-8'); let jsonData = decoder.decode(buffer); return JSON.parse(jsonData); } /** * Read and track progress when reading a Response object * * @param {any} response The Response object to read * @param {function} progress_callback The function to call with progress updates * @returns {Promise<Uint8Array>} A Promise that resolves with the Uint8Array buffer */ async function readResponse(response, progress_callback) { const contentLength = response.headers.get('Content-Length'); if (contentLength === null) { console.warn('Unable to determine content-length from response headers. Will expand buffer when needed.') } let total = parseInt(contentLength ?? '0'); let buffer = new Uint8Array(total); let loaded = 0; const reader = response.body.getReader(); async function read() { const { done, value } = await reader.read(); if (done) return; let newLoaded = loaded + value.length; if (newLoaded > total) { total = newLoaded; // Adding the new data will overflow buffer. // In this case, we extend the buffer let newBuffer = new Uint8Array(total); // copy contents newBuffer.set(buffer); buffer = newBuffer; } buffer.set(value, loaded) loaded = newLoaded; const progress = (loaded / total) * 100; // Call your function here progress_callback({ progress: progress, loaded: loaded, total: total, }) return read(); } // Actually read await read(); return buffer; } /** * Joins multiple parts of a path into a single path, while handling leading and trailing slashes. * * @param {...string} parts Multiple parts of a path. * @returns {string} A string representing the joined path. */ function pathJoin(...parts) { // https://stackoverflow.com/a/55142565 parts = parts.map((part, index) => { if (index) { part = part.replace(new RegExp('^/'), ''); } if (index !== parts.length - 1) { part = part.replace(new RegExp('/$'), ''); } return part; }) return parts.join('/'); }