zerolabel
Version:
Zero-shot multimodal classification SDK - classify text and images with custom labels, no training required
260 lines (250 loc) • 8.65 kB
JavaScript
// src/classifier.ts
import OpenAI from "openai";
// src/config.ts
var DEFAULT_CONFIG = {
/** Fixed base URL for inference.net API - not customizable */
BASE_URL: "https://api.inference.net/v1",
/** Fixed model for classification - not customizable */
MODEL: "google/gemma-3-27b-instruct/bf-16",
/** Default maximum number of retries for failed requests */
MAX_RETRIES: 3,
/** Default temperature for model inference */
TEMPERATURE: 0.5,
/** Maximum tokens for model response */
MAX_TOKENS: 1,
/** Number of top logprobs to request */
TOP_LOGPROBS: 20
};
// src/types.ts
var ZeroLabelError = class extends Error {
constructor(message, code, statusCode) {
super(message);
this.code = code;
this.statusCode = statusCode;
this.name = "ZeroLabelError";
}
};
// src/classifier.ts
var ZeroLabelClient = class {
constructor(config) {
if (!config.apiKey) {
throw new ZeroLabelError("Inference API key is required", "MISSING_API_KEY");
}
this.config = {
apiKey: config.apiKey,
maxRetries: config.maxRetries ?? DEFAULT_CONFIG.MAX_RETRIES
};
this.client = new OpenAI({
apiKey: this.config.apiKey,
baseURL: DEFAULT_CONFIG.BASE_URL
});
}
/**
* Classify text and/or images against provided labels
*/
async classify(input, options = {}) {
this.validateInput(input);
const { texts = [], images = [], labels } = input;
const maxInputs = Math.max(texts.length, images.length);
if (maxInputs === 0) {
return [];
}
const tasks = Array.from({ length: maxInputs }, async (_, idx) => {
const text = texts[idx] || "";
const image = images[idx];
return this.classifySingle(text, image, labels, input, options);
});
return Promise.all(tasks);
}
/**
* Classify a single text/image pair
*/
async classifySingle(text, image, labels, input, options) {
const isImageOnly = !text && Boolean(image);
const displayText = isImageOnly ? "(image input)" : text;
const prompt = this.constructPrompt(
displayText,
labels,
options.criteria ?? input.criteria ?? "",
options.additionalInstructions ?? input.additionalInstructions ?? ""
);
const response = await this.getCompletion(
prompt,
image,
options.maxRetries ?? this.config.maxRetries
);
return this.processResponse(response, text, labels, prompt);
}
/**
* Validate input parameters
*/
validateInput(input) {
if (!input.labels || input.labels.length === 0) {
throw new ZeroLabelError("Labels array cannot be empty", "INVALID_LABELS");
}
const hasTexts = input.texts && input.texts.length > 0;
const hasImages = input.images && input.images.length > 0;
if (!hasTexts && !hasImages) {
throw new ZeroLabelError("Either texts or images must be provided", "INVALID_INPUT");
}
}
/**
* Construct the classification prompt
*/
constructPrompt(text, labels, criteria, additionalInstructions) {
const numberedLabels = labels.map((label, idx) => `${idx + 1} - ${label}`).join("\n");
const header = `You are a top-tier domain expert and world-class text-classification assistant. Your goal is to read the given CONTENT and decide which single CATEGORY best matches the overall meaning. Even if multiple categories seem to fit, there is always a best match\u2014select the one that most accurately captures the primary intent or theme. Think through the semantics, context, nuance and any implicit information before deciding.`;
const reasoningCue = `Before you answer you may reason internally, but DO NOT output that reasoning. After finishing your reasoning, output **all category numbers in descending order of likelihood** (most likely first), separated by single spaces. Example: "3 1 2 4 5". Output nothing else\u2014no words, punctuation, or code fences.`;
const parts = [
header,
criteria ? `
ADDITIONAL EVALUATION CRITERIA:
${criteria}` : "",
`
CATEGORIES (number \u2014 label):
${numberedLabels}`,
additionalInstructions ? `
EXTRA INSTRUCTIONS:
${additionalInstructions}` : "",
reasoningCue,
`
CONTENT (delimited by triple backticks):
\`\`\`
${text}
\`\`\`
YOUR ANSWER:`
].filter(Boolean);
return parts.join("\n\n");
}
/**
* Make API call to get completion
* Uses fixed inference.net model - no customization allowed
*/
async getCompletion(prompt, imageDataUri, maxRetries) {
let attempt = 0;
while (attempt <= maxRetries) {
try {
const userMessage = imageDataUri ? [
{
type: "image_url",
image_url: { url: imageDataUri }
},
{
type: "text",
text: prompt
}
] : prompt;
const stream = await this.client.chat.completions.create({
model: DEFAULT_CONFIG.MODEL,
// Fixed model - no customization
messages: [{ role: "user", content: userMessage }],
temperature: DEFAULT_CONFIG.TEMPERATURE,
max_tokens: DEFAULT_CONFIG.MAX_TOKENS,
logprobs: true,
top_logprobs: DEFAULT_CONFIG.TOP_LOGPROBS,
stream: true
});
for await (const chunk of stream) {
const hasLogprobs = Boolean(
chunk.choices?.[0]?.logprobs?.content?.[0]?.top_logprobs?.length
);
if (hasLogprobs) {
return chunk;
}
}
throw new ZeroLabelError("Stream ended without data", "NO_LOGPROBS");
} catch (error) {
attempt += 1;
if (attempt > maxRetries) {
if (error instanceof ZeroLabelError) {
throw error;
}
throw new ZeroLabelError(
`Failed to get completion after ${maxRetries} retries: ${error instanceof Error ? error.message : "Unknown error"}`,
"COMPLETION_FAILED"
);
}
const delay = 2 ** (attempt - 1) * 1e3;
await new Promise((resolve) => setTimeout(resolve, delay));
}
}
throw new ZeroLabelError("Failed to obtain completion", "COMPLETION_FAILED");
}
/**
* Process the API response into a classification result
*/
processResponse(response, text, labels, prompt) {
const labelMap = {};
labels.forEach((label, idx) => {
labelMap[String(idx + 1)] = label;
});
const choice0 = response.choices?.[0];
const topLogprobs = choice0?.logprobs?.content?.[0]?.top_logprobs ?? [];
const labelLogprobs = {};
const probabilities = {};
const allLogprobs = {};
let bestLogProb = -Infinity;
let secondLogProb = -Infinity;
let bestLabelTok = null;
labels.forEach((label) => {
probabilities[label] = 0;
labelLogprobs[label] = null;
});
const digitOnly = (str) => str.replace(/[^0-9]/g, "");
for (const lp of topLogprobs) {
const raw = lp.token.trim();
const token = digitOnly(raw) || raw;
allLogprobs[token] = lp.logprob;
if (token in labelMap) {
const label = labelMap[token];
labelLogprobs[label] = parseFloat(lp.logprob.toFixed(4));
const probability = Math.exp(lp.logprob);
probabilities[label] = this.round(probability * 100);
if (lp.logprob > bestLogProb) {
secondLogProb = bestLogProb;
bestLogProb = lp.logprob;
bestLabelTok = token;
} else if (lp.logprob > secondLogProb) {
secondLogProb = lp.logprob;
}
}
}
const predicted = bestLabelTok && bestLabelTok in labelMap ? labelMap[bestLabelTok] : "Unknown";
const margin = bestLogProb - secondLogProb;
const confidence = this.round(1 / (1 + Math.exp(-margin)) * 100);
return {
text,
predicted_label: predicted,
confidence,
probabilities,
logprobs: labelLogprobs,
all_logprobs: allLogprobs,
prompt
};
}
/**
* Round number to 3 decimal places
*/
round(num) {
return Math.round((num + Number.EPSILON) * 1e3) / 1e3;
}
};
// src/index.ts
var classify = async (input) => {
const { apiKey, texts, images, labels, criteria, additionalInstructions } = input;
const client = new ZeroLabelClient({ apiKey });
const classifyInput = { labels };
if (texts !== void 0)
classifyInput.texts = texts;
if (images !== void 0)
classifyInput.images = images;
if (criteria !== void 0)
classifyInput.criteria = criteria;
if (additionalInstructions !== void 0)
classifyInput.additionalInstructions = additionalInstructions;
return client.classify(classifyInput);
};
export {
ZeroLabelError,
classify
};