inference-server
Version:
Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.
423 lines • 15.3 kB
JavaScript
import crypto from 'node:crypto';
import { customAlphabet } from 'nanoid';
import { calculateChatContextIdentity } from './lib/calculateChatContextIdentity.js';
import { LogLevels, createLogger, withLogMeta } from './lib/logger.js';
import { elapsedMillis, mergeAbortSignals } from './lib/util.js';
import { getLargestCommonPrefix } from './lib/getLargestCommonPrefix.js';
const idAlphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';
const generateId = customAlphabet(idAlphabet, 8);
export class ModelInstance {
id;
status;
modelId;
config;
fingerprint;
createdAt;
lastUsed = 0;
gpu;
ttl;
log;
engine;
engineRef;
contextIdentity;
needsContextReset = false;
currentRequest;
shutdownController;
constructor(engine, { log, gpu, ...options }) {
this.modelId = options.id;
this.id = this.generateInstanceId();
this.engine = engine;
this.config = options;
this.gpu = gpu;
this.ttl = options.ttl ?? 300;
this.status = 'preparing';
this.createdAt = new Date();
this.log = withLogMeta(log ?? createLogger(LogLevels.warn), {
instance: this.id,
});
this.shutdownController = new AbortController();
// TODO to implement this properly we should only include what changes the "behavior" of the model
this.fingerprint = crypto.createHash('sha1').update(JSON.stringify(options)).digest('hex');
this.log(LogLevels.info, 'Initializing new instance', {
model: this.modelId,
engine: this.config.engine,
device: this.config.device,
hasGpuLock: this.gpu,
});
}
generateInstanceId() {
return this.modelId + ':' + generateId(8);
}
generateTaskId() {
return this.id + '-' + generateId(8);
}
getEngineRef() {
return this.engineRef;
}
async load(signal) {
if (this.engineRef) {
throw new Error('Instance is already loaded');
}
this.status = 'loading';
const loadBegin = process.hrtime.bigint();
const abortSignal = mergeAbortSignals([this.shutdownController.signal, signal]);
try {
this.engineRef = await this.engine.createInstance({
log: withLogMeta(this.log, {
instance: this.id,
}),
config: {
...this.config,
device: {
...this.config.device,
gpu: this.gpu ? this.config.device?.gpu : false,
},
},
}, abortSignal);
this.status = 'idle';
if (this.config.initialMessages?.length) {
this.contextIdentity = calculateChatContextIdentity({
messages: this.config.initialMessages,
});
}
if (this.config.prefix) {
this.contextIdentity = this.config.prefix;
}
this.log(LogLevels.debug, 'Instance loaded', {
elapsed: elapsedMillis(loadBegin),
});
}
catch (error) {
this.status = 'error';
this.log(LogLevels.error, 'Failed to load instance:', {
error,
});
throw error;
}
}
dispose() {
this.status = 'busy';
if (!this.engineRef) {
return Promise.resolve();
}
this.shutdownController.abort();
return this.engine.disposeInstance(this.engineRef);
}
lock(request) {
if (this.status !== 'idle') {
throw new Error(`Cannot lock: Instance ${this.id} is not idle`);
}
this.currentRequest = request;
this.status = 'busy';
}
unlock() {
this.status = 'idle';
this.currentRequest = null;
}
resetContext() {
this.needsContextReset = true;
}
getContextStateIdentity() {
return this.contextIdentity;
}
hasContextState() {
return this.contextIdentity !== undefined;
}
matchesContextState(request) {
if (!this.contextIdentity) {
return false;
}
if ('messages' in request && request.messages?.length) {
const incomingContextIdentity = calculateChatContextIdentity({
messages: request.messages,
dropLastUserMessage: true,
});
return this.contextIdentity === incomingContextIdentity;
}
else if ('prompt' in request && request.prompt) {
const commonPrefix = getLargestCommonPrefix(this.contextIdentity, request.prompt);
return commonPrefix.length > 0;
}
return false;
}
matchesRequirements(request) {
const requiresGpu = !!this.config.device?.gpu && this.config.device?.gpu !== 'auto';
const modelMatches = this.modelId === request.model;
const gpuMatches = requiresGpu ? this.gpu : true;
return modelMatches && gpuMatches;
}
createTaskController(args) {
const cancelController = new AbortController();
const timeoutController = new AbortController();
const abortSignals = [cancelController.signal, this.shutdownController.signal];
if (args.signal) {
abortSignals.push(args.signal);
}
let timeout;
if (args.timeout) {
timeout = setTimeout(() => {
timeoutController.abort('timeout');
}, args.timeout);
abortSignals.push(timeoutController.signal);
}
return {
cancel: () => {
cancelController.abort('cancel');
if (timeout) {
clearTimeout(timeout);
}
},
complete: () => {
if (timeout) {
clearTimeout(timeout);
}
},
signal: mergeAbortSignals(abortSignals),
timeoutSignal: timeoutController.signal,
cancelSignal: cancelController.signal,
};
}
processChatCompletionTask(args) {
if (!('processChatCompletionTask' in this.engine)) {
throw new Error(`Engine "${this.config.engine}" does not implement chat completions`);
}
if (!args.messages?.length) {
throw new Error('Messages are required for chat completions');
}
const id = this.generateTaskId();
this.lastUsed = Date.now();
const taskLogger = withLogMeta(this.log, {
sequence: this.currentRequest.sequence,
task: id,
});
// checking if this instance has been flagged for reset
let resetContext = false;
if (this.needsContextReset) {
this.contextIdentity = undefined;
this.needsContextReset = false;
resetContext = true;
}
const controller = this.createTaskController({
timeout: args?.timeout,
signal: args?.signal,
});
// start completion processing
taskLogger(LogLevels.info, 'Processing chat completion task');
const taskBegin = process.hrtime.bigint();
const taskContext = {
instance: this.engineRef,
config: this.config,
resetContext,
log: taskLogger,
};
const completionPromise = this.engine.processChatCompletionTask(args, taskContext, controller.signal)
.then((result) => {
if (controller.timeoutSignal.aborted) {
result.finishReason = 'timeout';
}
else if (controller.cancelSignal.aborted) {
result.finishReason = 'cancel';
}
this.contextIdentity = calculateChatContextIdentity({
messages: [...args.messages, result.message],
});
return result;
})
.catch((error) => {
if (error.name === 'AbortError') {
const emptyResponse = {
finishReason: 'abort',
message: {
role: 'assistant',
content: '',
},
promptTokens: 0,
completionTokens: 0,
contextTokens: 0,
};
if (controller.timeoutSignal.aborted) {
emptyResponse.finishReason = 'timeout';
return emptyResponse;
}
if (controller.cancelSignal.aborted) {
emptyResponse.finishReason = 'cancel';
return emptyResponse;
}
return emptyResponse;
}
taskLogger(LogLevels.error, 'Error while processing task - ', {
error,
});
throw error;
})
.finally(() => {
const elapsedTime = elapsedMillis(taskBegin);
controller.complete();
taskLogger(LogLevels.info, 'Chat completion task done', {
elapsed: elapsedTime,
});
});
return {
id,
model: this.modelId,
createdAt: new Date(),
result: completionPromise,
cancel: controller.cancel,
};
}
processTextCompletionTask(args) {
if (!('processTextCompletionTask' in this.engine)) {
throw new Error(`Engine "${this.config.engine}" does not implement text completion`);
}
if (!args.prompt) {
throw new Error('Prompt is required for text completion');
}
this.lastUsed = Date.now();
const id = this.generateTaskId();
const taskLogger = withLogMeta(this.log, {
sequence: this.currentRequest.sequence,
task: id,
});
const controller = this.createTaskController({
timeout: args?.timeout,
signal: args?.signal,
});
taskLogger(LogLevels.info, 'Processing text completion task');
// pass on resetContext if this instance has been flagged for reset
let resetContext = false;
if (this.needsContextReset) {
this.contextIdentity = undefined;
this.needsContextReset = false;
resetContext = true;
}
const taskBegin = process.hrtime.bigint();
const taskContext = {
instance: this.engineRef,
config: this.config,
resetContext,
log: taskLogger,
};
const completionPromise = this.engine.processTextCompletionTask(args, taskContext, controller.signal)
.then((result) => {
if (controller.timeoutSignal.aborted) {
result.finishReason = 'timeout';
}
else if (controller.cancelSignal.aborted) {
result.finishReason = 'cancel';
}
this.contextIdentity = args.prompt + result.text;
return result;
})
.catch((error) => {
if (error.name === 'AbortError') {
const emptyResponse = {
finishReason: 'abort',
text: '',
promptTokens: 0,
completionTokens: 0,
contextTokens: 0,
};
if (controller.timeoutSignal.aborted) {
emptyResponse.finishReason = 'timeout';
return emptyResponse;
}
if (controller.cancelSignal.aborted) {
emptyResponse.finishReason = 'cancel';
return emptyResponse;
}
return emptyResponse;
}
taskLogger(LogLevels.error, 'Error while processing task - ', {
error,
});
throw error;
})
.finally(() => {
const elapsedTime = elapsedMillis(taskBegin);
controller.complete();
taskLogger(LogLevels.info, 'Text completion task done', {
elapsed: elapsedTime,
});
});
return {
id,
model: this.modelId,
createdAt: new Date(),
cancel: controller.cancel,
result: completionPromise,
};
}
processTask(taskType, processorName, args) {
if (!(processorName in this.engine)) {
throw new Error(`Engine "${this.config.engine}" does not implement ${taskType}`);
}
this.lastUsed = Date.now();
const id = this.generateTaskId();
const taskLogger = withLogMeta(this.log, {
sequence: this.currentRequest.sequence,
task: id,
});
const controller = this.createTaskController({
timeout: args?.timeout,
signal: args?.signal,
});
const taskBegin = process.hrtime.bigint();
const taskContext = {
instance: this.engineRef,
config: this.config,
log: taskLogger,
};
taskLogger(LogLevels.info, `Processing ${taskType} task`);
const processor = this.engine[processorName];
const result = processor(args, taskContext, controller.signal)
.then((result) => {
const timeElapsed = elapsedMillis(taskBegin);
controller.complete();
if (controller.timeoutSignal.aborted) {
taskLogger(LogLevels.warn, `${taskType} task timed out`);
}
taskLogger(LogLevels.info, `${taskType} task done`, {
elapsed: timeElapsed,
});
return result;
})
.catch((error) => {
taskLogger(LogLevels.error, 'Task failed - ', {
error,
});
throw error;
});
return {
id,
model: this.modelId,
createdAt: new Date(),
cancel: controller.cancel,
result: result,
};
}
processEmbeddingTask(args) {
return this.processTask('embedding', 'processEmbeddingTask', args);
}
processImageToTextTask(args) {
return this.processTask('image-to-text', 'processImageToTextTask', args);
}
processImageToImageTask(args) {
return this.processTask('image-to-image', 'processImageToImageTask', args);
}
processSpeechToTextTask(args) {
return this.processTask('speech-to-text', 'processSpeechToTextTask', args);
}
processTextToSpeechTask(args) {
return this.processTask('text-to-speech', 'processTextToSpeechTask', args);
}
processTextToImageTask(args) {
return this.processTask('text-to-image', 'processTextToImageTask', args);
}
processTextClassificationTask(args) {
return this.processTask('text-classification', 'processTextClassificationTask', args);
}
processObjectDetectionTask(args) {
return this.processTask('object-detection', 'processObjectDetectionTask', args);
}
}
//# sourceMappingURL=instance.js.map