UNPKG

zerolabel

Version:

Zero-shot multimodal classification SDK - classify text and images with custom labels, no training required

260 lines (250 loc) 8.65 kB
// src/classifier.ts import OpenAI from "openai"; // src/config.ts var DEFAULT_CONFIG = { /** Fixed base URL for inference.net API - not customizable */ BASE_URL: "https://api.inference.net/v1", /** Fixed model for classification - not customizable */ MODEL: "google/gemma-3-27b-instruct/bf-16", /** Default maximum number of retries for failed requests */ MAX_RETRIES: 3, /** Default temperature for model inference */ TEMPERATURE: 0.5, /** Maximum tokens for model response */ MAX_TOKENS: 1, /** Number of top logprobs to request */ TOP_LOGPROBS: 20 }; // src/types.ts var ZeroLabelError = class extends Error { constructor(message, code, statusCode) { super(message); this.code = code; this.statusCode = statusCode; this.name = "ZeroLabelError"; } }; // src/classifier.ts var ZeroLabelClient = class { constructor(config) { if (!config.apiKey) { throw new ZeroLabelError("Inference API key is required", "MISSING_API_KEY"); } this.config = { apiKey: config.apiKey, maxRetries: config.maxRetries ?? DEFAULT_CONFIG.MAX_RETRIES }; this.client = new OpenAI({ apiKey: this.config.apiKey, baseURL: DEFAULT_CONFIG.BASE_URL }); } /** * Classify text and/or images against provided labels */ async classify(input, options = {}) { this.validateInput(input); const { texts = [], images = [], labels } = input; const maxInputs = Math.max(texts.length, images.length); if (maxInputs === 0) { return []; } const tasks = Array.from({ length: maxInputs }, async (_, idx) => { const text = texts[idx] || ""; const image = images[idx]; return this.classifySingle(text, image, labels, input, options); }); return Promise.all(tasks); } /** * Classify a single text/image pair */ async classifySingle(text, image, labels, input, options) { const isImageOnly = !text && Boolean(image); const displayText = isImageOnly ? "(image input)" : text; const prompt = this.constructPrompt( displayText, labels, options.criteria ?? input.criteria ?? "", options.additionalInstructions ?? input.additionalInstructions ?? "" ); const response = await this.getCompletion( prompt, image, options.maxRetries ?? this.config.maxRetries ); return this.processResponse(response, text, labels, prompt); } /** * Validate input parameters */ validateInput(input) { if (!input.labels || input.labels.length === 0) { throw new ZeroLabelError("Labels array cannot be empty", "INVALID_LABELS"); } const hasTexts = input.texts && input.texts.length > 0; const hasImages = input.images && input.images.length > 0; if (!hasTexts && !hasImages) { throw new ZeroLabelError("Either texts or images must be provided", "INVALID_INPUT"); } } /** * Construct the classification prompt */ constructPrompt(text, labels, criteria, additionalInstructions) { const numberedLabels = labels.map((label, idx) => `${idx + 1} - ${label}`).join("\n"); const header = `You are a top-tier domain expert and world-class text-classification assistant. Your goal is to read the given CONTENT and decide which single CATEGORY best matches the overall meaning. Even if multiple categories seem to fit, there is always a best match\u2014select the one that most accurately captures the primary intent or theme. Think through the semantics, context, nuance and any implicit information before deciding.`; const reasoningCue = `Before you answer you may reason internally, but DO NOT output that reasoning. After finishing your reasoning, output **all category numbers in descending order of likelihood** (most likely first), separated by single spaces. Example: "3 1 2 4 5". Output nothing else\u2014no words, punctuation, or code fences.`; const parts = [ header, criteria ? ` ADDITIONAL EVALUATION CRITERIA: ${criteria}` : "", ` CATEGORIES (number \u2014 label): ${numberedLabels}`, additionalInstructions ? ` EXTRA INSTRUCTIONS: ${additionalInstructions}` : "", reasoningCue, ` CONTENT (delimited by triple backticks): \`\`\` ${text} \`\`\` YOUR ANSWER:` ].filter(Boolean); return parts.join("\n\n"); } /** * Make API call to get completion * Uses fixed inference.net model - no customization allowed */ async getCompletion(prompt, imageDataUri, maxRetries) { let attempt = 0; while (attempt <= maxRetries) { try { const userMessage = imageDataUri ? [ { type: "image_url", image_url: { url: imageDataUri } }, { type: "text", text: prompt } ] : prompt; const stream = await this.client.chat.completions.create({ model: DEFAULT_CONFIG.MODEL, // Fixed model - no customization messages: [{ role: "user", content: userMessage }], temperature: DEFAULT_CONFIG.TEMPERATURE, max_tokens: DEFAULT_CONFIG.MAX_TOKENS, logprobs: true, top_logprobs: DEFAULT_CONFIG.TOP_LOGPROBS, stream: true }); for await (const chunk of stream) { const hasLogprobs = Boolean( chunk.choices?.[0]?.logprobs?.content?.[0]?.top_logprobs?.length ); if (hasLogprobs) { return chunk; } } throw new ZeroLabelError("Stream ended without data", "NO_LOGPROBS"); } catch (error) { attempt += 1; if (attempt > maxRetries) { if (error instanceof ZeroLabelError) { throw error; } throw new ZeroLabelError( `Failed to get completion after ${maxRetries} retries: ${error instanceof Error ? error.message : "Unknown error"}`, "COMPLETION_FAILED" ); } const delay = 2 ** (attempt - 1) * 1e3; await new Promise((resolve) => setTimeout(resolve, delay)); } } throw new ZeroLabelError("Failed to obtain completion", "COMPLETION_FAILED"); } /** * Process the API response into a classification result */ processResponse(response, text, labels, prompt) { const labelMap = {}; labels.forEach((label, idx) => { labelMap[String(idx + 1)] = label; }); const choice0 = response.choices?.[0]; const topLogprobs = choice0?.logprobs?.content?.[0]?.top_logprobs ?? []; const labelLogprobs = {}; const probabilities = {}; const allLogprobs = {}; let bestLogProb = -Infinity; let secondLogProb = -Infinity; let bestLabelTok = null; labels.forEach((label) => { probabilities[label] = 0; labelLogprobs[label] = null; }); const digitOnly = (str) => str.replace(/[^0-9]/g, ""); for (const lp of topLogprobs) { const raw = lp.token.trim(); const token = digitOnly(raw) || raw; allLogprobs[token] = lp.logprob; if (token in labelMap) { const label = labelMap[token]; labelLogprobs[label] = parseFloat(lp.logprob.toFixed(4)); const probability = Math.exp(lp.logprob); probabilities[label] = this.round(probability * 100); if (lp.logprob > bestLogProb) { secondLogProb = bestLogProb; bestLogProb = lp.logprob; bestLabelTok = token; } else if (lp.logprob > secondLogProb) { secondLogProb = lp.logprob; } } } const predicted = bestLabelTok && bestLabelTok in labelMap ? labelMap[bestLabelTok] : "Unknown"; const margin = bestLogProb - secondLogProb; const confidence = this.round(1 / (1 + Math.exp(-margin)) * 100); return { text, predicted_label: predicted, confidence, probabilities, logprobs: labelLogprobs, all_logprobs: allLogprobs, prompt }; } /** * Round number to 3 decimal places */ round(num) { return Math.round((num + Number.EPSILON) * 1e3) / 1e3; } }; // src/index.ts var classify = async (input) => { const { apiKey, texts, images, labels, criteria, additionalInstructions } = input; const client = new ZeroLabelClient({ apiKey }); const classifyInput = { labels }; if (texts !== void 0) classifyInput.texts = texts; if (images !== void 0) classifyInput.images = images; if (criteria !== void 0) classifyInput.criteria = criteria; if (additionalInstructions !== void 0) classifyInput.additionalInstructions = additionalInstructions; return client.classify(classifyInput); }; export { ZeroLabelError, classify };