UNPKG

chromadb

Version:

A JavaScript interface for chroma

608 lines (556 loc) 18.8 kB
import { createClient, createConfig } from "@hey-api/client-fetch"; import { defaultChromaClientArgs as defaultArgs, HttpMethod, normalizeMethod, parseConnectionPath, deserializeMetadata, serializeMetadata, } from "./utils"; import { DefaultService as Api, ChecklistResponse } from "./api"; import { CollectionMetadata, UserIdentity } from "./types"; import { Collection, CollectionImpl } from "./collection"; import { EmbeddingFunction, getEmbeddingFunction } from "./embedding-function"; import { chromaFetch } from "./chroma-fetch"; import * as process from "node:process"; import { ChromaConnectionError, ChromaUnauthorizedError, ChromaValueError, } from "./errors"; import { CreateCollectionConfiguration, processCreateCollectionConfig, } from "./collection-configuration"; import { EMBEDDING_KEY, Schema } from "./schema"; import { client } from "./api/client.gen"; const resolveSchemaEmbeddingFunction = ( schema: Schema | undefined, ): EmbeddingFunction | undefined => { if (!schema) { return undefined; } const embeddingOverride = schema.keys[EMBEDDING_KEY]?.floatList?.vectorIndex?.config .embeddingFunction ?? undefined; if (embeddingOverride) { return embeddingOverride; } return ( schema.defaults.floatList?.vectorIndex?.config.embeddingFunction ?? undefined ); }; /** * Configuration options for the ChromaClient. */ export interface ChromaClientArgs { /** The host address of the Chroma server. Defaults to 'localhost' */ host?: string; /** The port number of the Chroma server. Defaults to 8000 */ port?: number; /** Whether to use SSL/HTTPS for connections. Defaults to false */ ssl?: boolean; /** The tenant name in the Chroma server to connect to */ tenant?: string; /** The database name to connect to */ database?: string; /** Additional HTTP headers to send with requests */ headers?: Record<string, string>; /** Additional fetch options for HTTP requests */ fetchOptions?: RequestInit; /** @deprecated Use host, port, and ssl instead */ path?: string; /** @deprecated */ auth?: Record<string, string>; } /** * Main client class for interacting with ChromaDB. * Provides methods for managing collections and performing operations on them. */ export class ChromaClient { private _tenant: string | undefined; private _database: string | undefined; private _preflightChecks: ChecklistResponse | undefined; private _headers: Record<string, string> | undefined; private readonly apiClient: ReturnType<typeof createClient>; /** * Creates a new ChromaClient instance. * @param args - Configuration options for the client */ constructor(args: Partial<ChromaClientArgs> = {}) { let { host = defaultArgs.host, port = defaultArgs.port, ssl = defaultArgs.ssl, tenant = defaultArgs.tenant, database = defaultArgs.database, headers = defaultArgs.headers, fetchOptions = defaultArgs.fetchOptions, } = args; if (args.path) { console.warn( "The 'path' argument is deprecated. Please use 'ssl', 'host', and 'port' instead", ); const parsedPath = parseConnectionPath(args.path); ssl = parsedPath.ssl; host = parsedPath.host; port = parsedPath.port; } if (args.auth) { console.warn( "The 'auth' argument is deprecated. Please use 'headers' instead", ); if (!headers) { headers = {}; } if ( !headers["x-chroma-token"] && args.auth.tokenHeaderType === "X_CHROMA_TOKEN" && args.auth.credentials ) { headers["x-chroma-token"] = args.auth.credentials; } } const baseUrl = `${ssl ? "https" : "http"}://${host}:${port}`; this._tenant = tenant || process.env.CHROMA_TENANT; this._database = database || process.env.CHROMA_DATABASE; this._headers = headers; const configOptions = { ...fetchOptions, method: normalizeMethod(fetchOptions?.method) as HttpMethod, baseUrl, headers, }; this.apiClient = createClient(createConfig(configOptions)); this.apiClient.setConfig({ fetch: chromaFetch }); } /** * Gets the current tenant name. * @returns The tenant name or undefined if not set */ public get tenant(): string | undefined { return this._tenant; } protected set tenant(tenant: string | undefined) { this._tenant = tenant; } /** * Gets the current database name. * @returns The database name or undefined if not set */ public get database(): string | undefined { return this._database; } protected set database(database: string | undefined) { this._database = database; } /** * Gets the preflight checks * @returns The preflight checks or undefined if not set */ public get preflightChecks(): ChecklistResponse | undefined { return this._preflightChecks; } protected set preflightChecks( preflightChecks: ChecklistResponse | undefined, ) { this._preflightChecks = preflightChecks; } public get headers(): Record<string, string> | undefined { return this._headers; } /** @ignore */ public async _path(): Promise<{ tenant: string; database: string }> { if (!this._tenant || !this._database) { const { tenant, databases } = await this.getUserIdentity(); const uniqueDBs = [...new Set(databases)]; this._tenant = tenant; if (uniqueDBs.length === 0) { throw new ChromaUnauthorizedError( `Your API key does not have access to any DBs for tenant ${this.tenant}`, ); } if (uniqueDBs.length > 1 || uniqueDBs[0] === "*") { throw new ChromaValueError( "Your API key is scoped to more than 1 DB. Please provide a DB name to the CloudClient constructor", ); } this._database = uniqueDBs[0]; } return { tenant: this._tenant, database: this._database }; } /** * Gets the user identity information including tenant and accessible databases. * @returns Promise resolving to user identity data */ public async getUserIdentity(): Promise<UserIdentity> { const { data } = await Api.getUserIdentity({ client: this.apiClient, }); return data; } /** * Sends a heartbeat request to check server connectivity. * @returns Promise resolving to the server's nanosecond heartbeat timestamp */ public async heartbeat(): Promise<number> { const { data } = await Api.heartbeat({ client: this.apiClient, }); return data["nanosecond heartbeat"]; } /** * Lists all collections in the current database. * @param args - Optional pagination parameters * @param args.limit - Maximum number of collections to return (default: 100) * @param args.offset - Number of collections to skip (default: 0) * @returns Promise resolving to an array of Collection instances */ public async listCollections( args?: Partial<{ limit: number; offset: number; }>, ): Promise<Collection[]> { const { limit = 100, offset = 0 } = args || {}; const { data } = await Api.listCollections({ client: this.apiClient, path: await this._path(), query: { limit, offset }, }); return Promise.all( data.map(async (collection) => { const schema = await Schema.deserializeFromJSON( collection.schema ?? null, this, ); const schemaEmbeddingFunction = resolveSchemaEmbeddingFunction(schema); const resolvedEmbeddingFunction = (await getEmbeddingFunction({ collectionName: collection.name, client: this, efConfig: collection.configuration_json.embedding_function ?? undefined, })) ?? schemaEmbeddingFunction; return new CollectionImpl({ chromaClient: this, apiClient: this.apiClient, tenant: collection.tenant, database: collection.database, name: collection.name, id: collection.id, embeddingFunction: resolvedEmbeddingFunction, configuration: collection.configuration_json, metadata: deserializeMetadata(collection.metadata ?? undefined) ?? undefined, schema, }); }), ); } /** * Gets the total number of collections in the current database. * @returns Promise resolving to the collection count */ public async countCollections(): Promise<number> { const { data } = await Api.countCollections({ client: this.apiClient, path: await this._path(), }); return data; } /** * Creates a new collection with the specified configuration. * @param options - Collection creation options * @param options.name - The name of the collection * @param options.configuration - Optional collection configuration * @param options.metadata - Optional metadata for the collection * @param options.embeddingFunction - Optional embedding function to use. Defaults to `DefaultEmbeddingFunction` from @chroma-core/default-embed * @returns Promise resolving to the created Collection instance * @throws Error if a collection with the same name already exists */ public async createCollection({ name, configuration, metadata, embeddingFunction, schema, }: { name: string; configuration?: CreateCollectionConfiguration; metadata?: CollectionMetadata; embeddingFunction?: EmbeddingFunction | null; schema?: Schema; }): Promise<Collection> { const collectionConfig = await processCreateCollectionConfig({ configuration, embeddingFunction, metadata, schema, }); const { data } = await Api.createCollection({ client: this.apiClient, path: await this._path(), body: { name, configuration: collectionConfig, metadata: serializeMetadata(metadata), get_or_create: false, schema: schema ? schema.serializeToJSON() : undefined, }, }); const serverSchema = await Schema.deserializeFromJSON( data.schema ?? null, this, ); const schemaEmbeddingFunction = resolveSchemaEmbeddingFunction(serverSchema); const resolvedEmbeddingFunction = embeddingFunction ?? (await getEmbeddingFunction({ collectionName: data.name, client: this, efConfig: data.configuration_json.embedding_function ?? undefined, })) ?? schemaEmbeddingFunction; return new CollectionImpl({ chromaClient: this, apiClient: this.apiClient, name, tenant: data.tenant, database: data.database, configuration: data.configuration_json, metadata: deserializeMetadata(data.metadata ?? undefined) ?? undefined, embeddingFunction: resolvedEmbeddingFunction, id: data.id, schema: serverSchema, }); } /** * Retrieves an existing collection by name. * @param options - Collection retrieval options * @param options.name - The name of the collection to retrieve * @param options.embeddingFunction - Optional embedding function. Should match the one used to create the collection. * @returns Promise resolving to the Collection instance * @throws Error if the collection does not exist */ public async getCollection({ name, embeddingFunction, }: { name: string; embeddingFunction?: EmbeddingFunction; }): Promise<Collection> { const { data } = await Api.getCollection({ client: this.apiClient, path: { ...(await this._path()), collection_id: name }, }); const schema = await Schema.deserializeFromJSON(data.schema ?? null, this); const schemaEmbeddingFunction = resolveSchemaEmbeddingFunction(schema); const resolvedEmbeddingFunction = embeddingFunction ?? (await getEmbeddingFunction({ collectionName: data.name, client: this, efConfig: data.configuration_json.embedding_function ?? undefined, })) ?? schemaEmbeddingFunction; return new CollectionImpl({ chromaClient: this, apiClient: this.apiClient, name, tenant: data.tenant, database: data.database, configuration: data.configuration_json, metadata: deserializeMetadata(data.metadata ?? undefined) ?? undefined, embeddingFunction: resolvedEmbeddingFunction, id: data.id, schema, }); } /** * Retrieves an existing collection by its Chroma Resource Name (CRN). * @param crn - The Chroma Resource Name of the collection to retrieve * @returns Promise resolving to the Collection instance * @throws Error if the collection does not exist */ public async getCollectionByCrn(crn: string): Promise<Collection> { const { data } = await Api.getCollectionByCrn({ client: this.apiClient, path: { crn }, }); const schema = await Schema.deserializeFromJSON(data.schema ?? null, this); const schemaEmbeddingFunction = resolveSchemaEmbeddingFunction(schema); const resolvedEmbeddingFunction = (await getEmbeddingFunction({ collectionName: data.name, efConfig: data.configuration_json.embedding_function ?? undefined, client: this, })) ?? schemaEmbeddingFunction; return new CollectionImpl({ chromaClient: this, apiClient: this.apiClient, name: data.name, tenant: data.tenant, database: data.database, configuration: data.configuration_json, metadata: deserializeMetadata(data.metadata ?? undefined) ?? undefined, embeddingFunction: resolvedEmbeddingFunction, id: data.id, schema, }); } /** * Retrieves multiple collections by name. * @param items - Array of collection names or objects with name and optional embedding function (should match the ones used to create the collections) * @returns Promise resolving to an array of Collection instances */ public async getCollections( items: string[] | { name: string; embeddingFunction?: EmbeddingFunction }[], ): Promise<Collection[]> { if (items.length === 0) return []; let requestedCollections = items; if (typeof items[0] === "string") { requestedCollections = (items as string[]).map((item) => { return { name: item, embeddingFunction: undefined }; }); } let collections = requestedCollections as { name: string; embeddingFunction?: EmbeddingFunction; }[]; return Promise.all( collections.map(async (collection) => { return this.getCollection({ ...collection }); }), ); } /** * Gets an existing collection or creates it if it doesn't exist. * @param options - Collection options * @param options.name - The name of the collection * @param options.configuration - Optional collection configuration (used only if creating) * @param options.metadata - Optional metadata for the collection (used only if creating) * @param options.embeddingFunction - Optional embedding function to use * @returns Promise resolving to the Collection instance */ public async getOrCreateCollection({ name, configuration, metadata, embeddingFunction, schema, }: { name: string; configuration?: CreateCollectionConfiguration; metadata?: CollectionMetadata; embeddingFunction?: EmbeddingFunction | null; schema?: Schema; }): Promise<Collection> { const collectionConfig = await processCreateCollectionConfig({ configuration, embeddingFunction, metadata, schema, }); const { data } = await Api.createCollection({ client: this.apiClient, path: await this._path(), body: { name, configuration: collectionConfig, metadata: serializeMetadata(metadata), get_or_create: true, schema: schema ? schema.serializeToJSON() : undefined, }, }); const serverSchema = await Schema.deserializeFromJSON( data.schema ?? null, this, ); const schemaEmbeddingFunction = resolveSchemaEmbeddingFunction(serverSchema); const resolvedEmbeddingFunction = embeddingFunction ?? (await getEmbeddingFunction({ collectionName: name, efConfig: data.configuration_json.embedding_function ?? undefined, client: this, })) ?? schemaEmbeddingFunction; return new CollectionImpl({ chromaClient: this, apiClient: this.apiClient, name, tenant: data.tenant, database: data.database, configuration: data.configuration_json, metadata: deserializeMetadata(data.metadata ?? undefined) ?? undefined, embeddingFunction: resolvedEmbeddingFunction, id: data.id, schema: serverSchema, }); } /** * Deletes a collection and all its data. * @param options - Deletion options * @param options.name - The name of the collection to delete */ public async deleteCollection({ name }: { name: string }): Promise<void> { await Api.deleteCollection({ client: this.apiClient, path: { ...(await this._path()), collection_id: name }, }); } /** * Resets the entire database, deleting all collections and data. * @returns Promise that resolves when the reset is complete * @warning This operation is irreversible and will delete all data */ public async reset(): Promise<void> { await Api.reset({ client: this.apiClient, }); } /** * Gets the version of the Chroma server. * @returns Promise resolving to the server version string */ public async version(): Promise<string> { const { data } = await Api.version({ client: this.apiClient, }); return data; } /** * Gets the preflight checks * @returns Promise resolving to the preflight checks */ public async getPreflightChecks(): Promise<ChecklistResponse> { if (!this.preflightChecks) { const { data } = await Api.preFlightChecks({ client: this.apiClient, }); this.preflightChecks = data; return this.preflightChecks; } return this.preflightChecks; } /** * Gets the max batch size * @returns Promise resolving to the max batch size */ public async getMaxBatchSize(): Promise<number> { const preflightChecks = await this.getPreflightChecks(); return preflightChecks.max_batch_size ?? -1; } /** * Gets whether base64_encoding is supported by the connected server * @returns Promise resolving to whether base64_encoding is supported */ public async supportsBase64Encoding(): Promise<boolean> { const preflightChecks = await this.getPreflightChecks(); return preflightChecks.supports_base64_encoding ?? false; } }