inference-server
Version:
Libraries and server to build AI applications. Adapters to various native bindings allowing local inference. Integrate it with your application, or use as a microservice.
511 lines (472 loc) • 14.1 kB
text/typescript
import crypto from 'node:crypto'
import { customAlphabet } from 'nanoid'
import {
ModelEngine,
ModelConfig,
ModelInstanceRequest,
ChatCompletionTaskResult,
TextCompletionTaskResult,
ChatCompletionTaskArgs,
TextCompletionTaskArgs,
EmbeddingTaskArgs,
ImageToTextTaskArgs,
ImageToImageTaskArgs,
SpeechToTextTaskArgs,
TextToSpeechTaskArgs,
TextToImageTaskArgs,
ObjectDetectionTaskArgs,
TextClassificationTaskArgs,
TaskKind,
TaskProcessor,
TaskProcessorName,
TaskArgs,
EmbeddingTaskResult,
TaskResult,
InferenceTask,
ImageToTextTaskResult,
ImageToImageTaskResult,
SpeechToTextTaskResult,
TextToSpeechTaskResult,
TextToImageTaskResult,
TextClassificationTaskResult,
ObjectDetectionTaskResult,
} from '#package/types/index.js'
import { calculateChatContextIdentity } from '#package/lib/calculateChatContextIdentity.js'
import { LogLevels, Logger, createLogger, withLogMeta } from '#package/lib/logger.js'
import { elapsedMillis, mergeAbortSignals } from '#package/lib/util.js'
import { getLargestCommonPrefix } from '#package/lib/getLargestCommonPrefix.js'
const idAlphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
const generateId = customAlphabet(idAlphabet, 8)
type ModelInstanceStatus = 'idle' | 'busy' | 'error' | 'loading' | 'preparing'
interface ModelInstanceOptions extends ModelConfig {
log?: Logger
gpu: boolean
}
export class ModelInstance<TEngineRef = unknown> {
id: string
status: ModelInstanceStatus
modelId: string
config: ModelConfig
fingerprint: string
createdAt: Date
lastUsed: number = 0
gpu: boolean
ttl: number
log: Logger
private engine: ModelEngine
private engineRef?: TEngineRef | unknown
private contextIdentity?: string
private needsContextReset: boolean = false
private currentRequest?: ModelInstanceRequest | null
private shutdownController: AbortController
constructor(engine: ModelEngine, { log, gpu, ...options }: ModelInstanceOptions) {
this.modelId = options.id
this.id = this.generateInstanceId()
this.engine = engine
this.config = options
this.gpu = gpu
this.ttl = options.ttl ?? 300
this.status = 'preparing'
this.createdAt = new Date()
this.log = withLogMeta(log ?? createLogger(LogLevels.warn), {
instance: this.id,
})
this.shutdownController = new AbortController()
// TODO to implement this properly we should only include what changes the "behavior" of the model
this.fingerprint = crypto.createHash('sha1').update(JSON.stringify(options)).digest('hex')
this.log(LogLevels.info, 'Initializing new instance', {
model: this.modelId,
engine: this.config.engine,
device: this.config.device,
hasGpuLock: this.gpu,
})
}
private generateInstanceId() {
return this.modelId + ':' + generateId(8)
}
private generateTaskId() {
return this.id + '-' + generateId(8)
}
getEngineRef() {
return this.engineRef
}
async load(signal?: AbortSignal) {
if (this.engineRef) {
throw new Error('Instance is already loaded')
}
this.status = 'loading'
const loadBegin = process.hrtime.bigint()
const abortSignal = mergeAbortSignals([this.shutdownController.signal, signal])
try {
this.engineRef = await this.engine.createInstance(
{
log: withLogMeta(this.log, {
instance: this.id,
}),
config: {
...this.config,
device: {
...this.config.device,
gpu: this.gpu ? this.config.device?.gpu : false,
},
},
},
abortSignal,
)
this.status = 'idle'
if (this.config.initialMessages?.length) {
this.contextIdentity = calculateChatContextIdentity({
messages: this.config.initialMessages,
})
}
if (this.config.prefix) {
this.contextIdentity = this.config.prefix
}
this.log(LogLevels.debug, 'Instance loaded', {
elapsed: elapsedMillis(loadBegin),
})
} catch (error: any) {
this.status = 'error'
this.log(LogLevels.error, 'Failed to load instance:', {
error,
})
throw error
}
}
dispose() {
this.status = 'busy'
if (!this.engineRef) {
return Promise.resolve()
}
this.shutdownController.abort()
return this.engine.disposeInstance(this.engineRef)
}
lock(request: ModelInstanceRequest) {
if (this.status !== 'idle') {
throw new Error(`Cannot lock: Instance ${this.id} is not idle`)
}
this.currentRequest = request
this.status = 'busy'
}
unlock() {
this.status = 'idle'
this.currentRequest = null
}
resetContext() {
this.needsContextReset = true
}
getContextStateIdentity() {
return this.contextIdentity
}
hasContextState() {
return this.contextIdentity !== undefined
}
matchesContextState(request: ModelInstanceRequest) {
if (!this.contextIdentity) {
return false
}
if ('messages' in request && request.messages?.length) {
const incomingContextIdentity = calculateChatContextIdentity({
messages: request.messages,
dropLastUserMessage: true,
})
return this.contextIdentity === incomingContextIdentity
} else if ('prompt' in request && request.prompt) {
const commonPrefix = getLargestCommonPrefix(this.contextIdentity, request.prompt)
return commonPrefix.length > 0
}
return false
}
matchesRequirements(request: ModelInstanceRequest) {
const requiresGpu = !!this.config.device?.gpu && this.config.device?.gpu !== 'auto'
const modelMatches = this.modelId === request.model
const gpuMatches = requiresGpu ? this.gpu : true
return modelMatches && gpuMatches
}
private createTaskController(args: { timeout?: number; signal?: AbortSignal }) {
const cancelController = new AbortController()
const timeoutController = new AbortController()
const abortSignals = [cancelController.signal, this.shutdownController.signal]
if (args.signal) {
abortSignals.push(args.signal)
}
let timeout: NodeJS.Timeout | undefined
if (args.timeout) {
timeout = setTimeout(() => {
timeoutController.abort('timeout')
}, args.timeout)
abortSignals.push(timeoutController.signal)
}
return {
cancel: () => {
cancelController.abort('cancel')
if (timeout) {
clearTimeout(timeout)
}
},
complete: () => {
if (timeout) {
clearTimeout(timeout)
}
},
signal: mergeAbortSignals(abortSignals),
timeoutSignal: timeoutController.signal,
cancelSignal: cancelController.signal,
}
}
processChatCompletionTask(args: ChatCompletionTaskArgs) {
if (!('processChatCompletionTask' in this.engine)) {
throw new Error(`Engine "${this.config.engine}" does not implement chat completions`)
}
if (!args.messages?.length) {
throw new Error('Messages are required for chat completions')
}
const id = this.generateTaskId()
this.lastUsed = Date.now()
const taskLogger = withLogMeta(this.log, {
sequence: this.currentRequest!.sequence,
task: id,
})
// checking if this instance has been flagged for reset
let resetContext = false
if (this.needsContextReset) {
this.contextIdentity = undefined
this.needsContextReset = false
resetContext = true
}
const controller = this.createTaskController({
timeout: args?.timeout,
signal: args?.signal,
})
// start completion processing
taskLogger(LogLevels.info, 'Processing chat completion task')
const taskBegin = process.hrtime.bigint()
const taskContext = {
instance: this.engineRef,
config: this.config,
resetContext,
log: taskLogger,
}
const completionPromise = this.engine.processChatCompletionTask!(
args,
taskContext,
controller.signal,
)
.then((result) => {
if (controller.timeoutSignal.aborted) {
result.finishReason = 'timeout'
} else if (controller.cancelSignal.aborted) {
result.finishReason = 'cancel'
}
this.contextIdentity = calculateChatContextIdentity({
messages: [...args.messages, result.message],
})
return result
})
.catch((error) => {
if (error.name === 'AbortError') {
const emptyResponse: ChatCompletionTaskResult = {
finishReason: 'abort',
message: {
role: 'assistant',
content: '',
},
promptTokens: 0,
completionTokens: 0,
contextTokens: 0,
}
if (controller.timeoutSignal.aborted) {
emptyResponse.finishReason = 'timeout'
return emptyResponse
}
if (controller.cancelSignal.aborted) {
emptyResponse.finishReason = 'cancel'
return emptyResponse
}
return emptyResponse
}
taskLogger(LogLevels.error, 'Error while processing task - ', {
error,
})
throw error
})
.finally(() => {
const elapsedTime = elapsedMillis(taskBegin)
controller.complete()
taskLogger(LogLevels.info, 'Chat completion task done', {
elapsed: elapsedTime,
})
})
return {
id,
model: this.modelId,
createdAt: new Date(),
result: completionPromise,
cancel: controller.cancel,
}
}
processTextCompletionTask(args: TextCompletionTaskArgs) {
if (!('processTextCompletionTask' in this.engine)) {
throw new Error(`Engine "${this.config.engine}" does not implement text completion`)
}
if (!args.prompt) {
throw new Error('Prompt is required for text completion')
}
this.lastUsed = Date.now()
const id = this.generateTaskId()
const taskLogger = withLogMeta(this.log, {
sequence: this.currentRequest!.sequence,
task: id,
})
const controller = this.createTaskController({
timeout: args?.timeout,
signal: args?.signal,
})
taskLogger(LogLevels.info, 'Processing text completion task')
// pass on resetContext if this instance has been flagged for reset
let resetContext = false
if (this.needsContextReset) {
this.contextIdentity = undefined
this.needsContextReset = false
resetContext = true
}
const taskBegin = process.hrtime.bigint()
const taskContext = {
instance: this.engineRef,
config: this.config,
resetContext,
log: taskLogger,
}
const completionPromise = this.engine.processTextCompletionTask!(args,
taskContext,
controller.signal,
)
.then((result) => {
if (controller.timeoutSignal.aborted) {
result.finishReason = 'timeout'
} else if (controller.cancelSignal.aborted) {
result.finishReason = 'cancel'
}
this.contextIdentity = args.prompt + result.text
return result
})
.catch((error) => {
if (error.name === 'AbortError') {
const emptyResponse: TextCompletionTaskResult = {
finishReason: 'abort',
text: '',
promptTokens: 0,
completionTokens: 0,
contextTokens: 0,
}
if (controller.timeoutSignal.aborted) {
emptyResponse.finishReason = 'timeout'
return emptyResponse
}
if (controller.cancelSignal.aborted) {
emptyResponse.finishReason = 'cancel'
return emptyResponse
}
return emptyResponse
}
taskLogger(LogLevels.error, 'Error while processing task - ', {
error,
})
throw error
})
.finally(() => {
const elapsedTime = elapsedMillis(taskBegin)
controller.complete()
taskLogger(LogLevels.info, 'Text completion task done', {
elapsed: elapsedTime,
})
})
return {
id,
model: this.modelId,
createdAt: new Date(),
cancel: controller.cancel,
result: completionPromise,
}
}
private processTask<T extends TaskResult>(
taskType: TaskKind,
processorName: TaskProcessorName,
args: TaskArgs,
): InferenceTask<T> {
if (!(processorName in this.engine)) {
throw new Error(`Engine "${this.config.engine}" does not implement ${taskType}`)
}
this.lastUsed = Date.now()
const id = this.generateTaskId()
const taskLogger = withLogMeta(this.log, {
sequence: this.currentRequest!.sequence,
task: id,
})
const controller = this.createTaskController({
timeout: args?.timeout,
signal: args?.signal,
})
const taskBegin = process.hrtime.bigint()
const taskContext = {
instance: this.engineRef,
config: this.config,
log: taskLogger,
}
taskLogger(LogLevels.info, `Processing ${taskType} task`)
const processor = this.engine[processorName] as TaskProcessor<any, any, any>
const result: Promise<TaskResult> = processor(
args,
taskContext,
controller.signal
)
.then((result) => {
const timeElapsed = elapsedMillis(taskBegin)
controller.complete()
if (controller.timeoutSignal.aborted) {
taskLogger(LogLevels.warn, `${taskType} task timed out`)
}
taskLogger(LogLevels.info, `${taskType} task done`, {
elapsed: timeElapsed,
})
return result
})
.catch((error) => {
taskLogger(LogLevels.error, 'Task failed - ', {
error,
})
throw error
})
return {
id,
model: this.modelId,
createdAt: new Date(),
cancel: controller.cancel,
result: result as Promise<T>,
}
}
processEmbeddingTask(args: EmbeddingTaskArgs) {
return this.processTask<EmbeddingTaskResult>('embedding', 'processEmbeddingTask', args)
}
processImageToTextTask(args: ImageToTextTaskArgs) {
return this.processTask<ImageToTextTaskResult>('image-to-text', 'processImageToTextTask', args)
}
processImageToImageTask(args: ImageToImageTaskArgs) {
return this.processTask<ImageToImageTaskResult>('image-to-image', 'processImageToImageTask', args)
}
processSpeechToTextTask(args: SpeechToTextTaskArgs) {
return this.processTask<SpeechToTextTaskResult>('speech-to-text', 'processSpeechToTextTask', args)
}
processTextToSpeechTask(args: TextToSpeechTaskArgs) {
return this.processTask<TextToSpeechTaskResult>('text-to-speech', 'processTextToSpeechTask', args)
}
processTextToImageTask(args: TextToImageTaskArgs) {
return this.processTask<TextToImageTaskResult>('text-to-image', 'processTextToImageTask', args)
}
processTextClassificationTask(args: TextClassificationTaskArgs) {
return this.processTask<TextClassificationTaskResult>('text-classification', 'processTextClassificationTask', args)
}
processObjectDetectionTask(args: ObjectDetectionTaskArgs) {
return this.processTask<ObjectDetectionTaskResult>('object-detection', 'processObjectDetectionTask', args)
}
}