@juspay/neurolink
Version:
Universal AI Development Platform with working MCP integration, multi-provider support, and professional CLI. Built-in tools operational, 58+ external MCP servers discoverable. Connect to filesystem, GitHub, database operations, and more. Build, test, and
597 lines (596 loc) • 22.5 kB
JavaScript
/**
* SageMaker Model Detection and Streaming Capability Discovery
*
* This module provides intelligent detection of SageMaker endpoint capabilities
* including model type identification and streaming protocol support.
*/
import { SageMakerRuntimeClient } from "./client.js";
import { logger } from "../../utils/logger.js";
/**
* Configurable constants for detection timing and performance
*/
const DETECTION_TEST_DELAY_MS = 100; // Base delay between detection tests (ms)
const DETECTION_STAGGER_DELAY_MS = 25; // Delay between staggered test starts (ms)
const DETECTION_RATE_LIMIT_BACKOFF_MS = 200; // Initial backoff on rate limit detection (ms)
/**
* SageMaker Model Detection and Capability Discovery Service
*/
export class SageMakerDetector {
client;
config;
constructor(config) {
this.client = new SageMakerRuntimeClient(config);
this.config = config;
}
/**
* Detect streaming capabilities for a given endpoint
*/
async detectStreamingCapability(endpointName) {
logger.debug("Starting streaming capability detection", { endpointName });
try {
// Step 1: Check endpoint health and gather metadata
const health = await this.checkEndpointHealth(endpointName);
if (health.status !== "healthy") {
return this.createNoStreamingCapability("custom", "Endpoint not healthy");
}
// Step 2: Detect model type
const modelDetection = await this.detectModelType(endpointName);
logger.debug("Model type detection result", {
endpointName,
type: modelDetection.type,
confidence: modelDetection.confidence,
});
// Step 3: Test streaming support based on model type
const streamingSupport = await this.testStreamingSupport(endpointName, modelDetection.type);
// Step 4: Determine streaming protocol
const protocol = await this.detectStreamingProtocol(modelDetection.type);
return {
supported: streamingSupport.supported,
protocol,
modelType: modelDetection.type,
confidence: Math.min(modelDetection.confidence, streamingSupport.confidence),
parameters: streamingSupport.parameters,
metadata: {
modelName: health.modelInfo?.name,
framework: health.modelInfo?.framework,
version: health.modelInfo?.version,
},
};
}
catch (error) {
logger.warn("Streaming capability detection failed", {
endpointName,
error: error instanceof Error ? error.message : String(error),
});
return this.createNoStreamingCapability("custom", "Detection failed, assuming custom model");
}
}
/**
* Detect the model type/framework for an endpoint
*/
async detectModelType(endpointName) {
const evidence = [];
const detectionTests = [
() => this.testHuggingFaceSignature(endpointName, evidence),
() => this.testLlamaSignature(endpointName, evidence),
() => this.testPyTorchSignature(endpointName, evidence),
() => this.testTensorFlowSignature(endpointName, evidence),
];
// Run detection tests in parallel with intelligent rate limiting
const testNames = ["HuggingFace", "LLaMA", "PyTorch", "TensorFlow"];
const results = await this.runDetectionTestsInParallel(detectionTests, testNames, endpointName);
// Analyze results and determine most likely model type
const scores = {
huggingface: 0,
llama: 0,
pytorch: 0,
tensorflow: 0,
custom: 0.1, // Base score for custom models
};
// Process evidence and calculate scores
evidence.forEach((item) => {
if (item.includes("huggingface") || item.includes("transformers")) {
scores.huggingface += 0.3;
}
if (item.includes("llama") || item.includes("openai-compatible")) {
scores.llama += 0.3;
}
if (item.includes("pytorch") || item.includes("torch")) {
scores.pytorch += 0.2;
}
if (item.includes("tensorflow") || item.includes("serving")) {
scores.tensorflow += 0.2;
}
});
// Find highest scoring model type
const maxScore = Math.max(...Object.values(scores));
const detectedType = Object.entries(scores).find(([, score]) => score === maxScore)?.[0] || "custom";
return {
type: detectedType,
confidence: maxScore,
evidence,
suggestedConfig: this.getSuggestedConfig(detectedType),
};
}
/**
* Check endpoint health and gather metadata
*/
async checkEndpointHealth(endpointName) {
const startTime = Date.now();
try {
// Simple health check with minimal payload
const testPayload = JSON.stringify({ inputs: "test" });
const response = await this.client.invokeEndpoint({
EndpointName: endpointName,
Body: testPayload,
ContentType: "application/json",
});
const responseTime = Date.now() - startTime;
return {
status: "healthy",
responseTime,
metadata: response.CustomAttributes
? JSON.parse(response.CustomAttributes)
: undefined,
modelInfo: this.extractModelInfo(response),
};
}
catch (error) {
const responseTime = Date.now() - startTime;
logger.warn("Endpoint health check failed", {
endpointName,
responseTime,
error: error instanceof Error ? error.message : String(error),
});
return {
status: "unhealthy",
responseTime,
};
}
}
/**
* Test if endpoint supports streaming for given model type
*/
async testStreamingSupport(endpointName, modelType) {
const testCases = this.getStreamingTestCases(modelType);
for (const testCase of testCases) {
try {
const response = await this.client.invokeEndpoint({
EndpointName: endpointName,
Body: JSON.stringify(testCase.payload),
ContentType: "application/json",
Accept: testCase.acceptHeader,
});
// Check response headers for streaming indicators
if (this.indicatesStreamingSupport(response)) {
return {
supported: true,
confidence: testCase.confidence,
parameters: testCase.parameters,
};
}
}
catch (error) {
// Streaming test failed, continue to next test case
logger.debug("Streaming test failed", {
endpointName,
error: error instanceof Error ? error.message : String(error),
});
}
}
return { supported: false, confidence: 0.9 };
}
/**
* Detect streaming protocol used by endpoint
*/
async detectStreamingProtocol(modelType) {
// Protocol mapping based on model type
const protocolMap = {
huggingface: "sse", // Server-Sent Events
llama: "jsonl", // JSON Lines
pytorch: "none", // Usually no streaming
tensorflow: "none", // Usually no streaming
custom: "chunked", // Generic chunked transfer
};
return protocolMap[modelType] || "none";
}
/**
* Test for HuggingFace Transformers signature
*/
async testHuggingFaceSignature(endpointName, evidence) {
try {
const testPayload = {
inputs: "test",
parameters: { return_full_text: false, max_new_tokens: 1 },
};
const response = await this.client.invokeEndpoint({
EndpointName: endpointName,
Body: JSON.stringify(testPayload),
ContentType: "application/json",
});
const responseText = new TextDecoder().decode(response.Body);
const parsedResponse = JSON.parse(responseText);
if (parsedResponse[0]?.generated_text !== undefined) {
evidence.push("huggingface: generated_text field found");
}
if (parsedResponse.error?.includes("transformers")) {
evidence.push("huggingface: transformers error message");
}
}
catch (error) {
// Test failed, no evidence
}
}
/**
* Test for LLaMA model signature
*/
async testLlamaSignature(endpointName, evidence) {
try {
const testPayload = {
prompt: "test",
max_tokens: 1,
temperature: 0,
};
const response = await this.client.invokeEndpoint({
EndpointName: endpointName,
Body: JSON.stringify(testPayload),
ContentType: "application/json",
});
const responseText = new TextDecoder().decode(response.Body);
const parsedResponse = JSON.parse(responseText);
if (parsedResponse.choices) {
evidence.push("llama: openai-compatible choices field");
}
if (parsedResponse.object === "text_completion") {
evidence.push("llama: openai text_completion object");
}
}
catch (error) {
// Test failed, no evidence
}
}
/**
* Test for PyTorch model signature
*/
async testPyTorchSignature(endpointName, evidence) {
try {
const testPayload = { input: "test" };
const response = await this.client.invokeEndpoint({
EndpointName: endpointName,
Body: JSON.stringify(testPayload),
ContentType: "application/json",
});
const responseText = new TextDecoder().decode(response.Body);
if (responseText.includes("prediction") ||
responseText.includes("output")) {
evidence.push("pytorch: prediction/output field pattern");
}
}
catch (error) {
// Test failed, no evidence
}
}
/**
* Test for TensorFlow Serving signature
*/
async testTensorFlowSignature(endpointName, evidence) {
try {
const testPayload = {
instances: [{ input: "test" }],
signature_name: "serving_default",
};
const response = await this.client.invokeEndpoint({
EndpointName: endpointName,
Body: JSON.stringify(testPayload),
ContentType: "application/json",
});
const responseText = new TextDecoder().decode(response.Body);
const parsedResponse = JSON.parse(responseText);
if (parsedResponse.predictions) {
evidence.push("tensorflow: serving predictions field");
}
}
catch (error) {
// Test failed, no evidence
}
}
/**
* Get streaming test cases for a model type
*/
getStreamingTestCases(modelType) {
const testCases = {
huggingface: [
{
name: "HF streaming test",
payload: {
inputs: "test",
parameters: { stream: true, max_new_tokens: 5 },
},
acceptHeader: "text/event-stream",
confidence: 0.8,
parameters: { stream: true },
},
],
llama: [
{
name: "LLaMA streaming test",
payload: { prompt: "test", stream: true, max_tokens: 5 },
acceptHeader: "application/x-ndjson",
confidence: 0.8,
parameters: { stream: true },
},
],
pytorch: [],
tensorflow: [],
custom: [
{
name: "Generic streaming test",
payload: { input: "test", stream: true },
acceptHeader: "application/json",
confidence: 0.3,
parameters: { stream: true },
},
],
};
return testCases[modelType] || [];
}
/**
* Check if response indicates streaming support
*/
indicatesStreamingSupport(response) {
// Check content type for streaming indicators
const contentType = response.ContentType || "";
if (contentType.includes("event-stream") ||
contentType.includes("x-ndjson") ||
contentType.includes("chunked")) {
return true;
}
// Note: InvokeEndpointResponse doesn't include headers
// Streaming detection is based on ContentType only
logger.debug("Testing streaming support", {
contentType,
});
return false;
}
/**
* Extract model information from response
*/
extractModelInfo(response) {
try {
const customAttributes = response.CustomAttributes
? JSON.parse(response.CustomAttributes)
: {};
return {
name: customAttributes.model_name,
version: customAttributes.model_version,
framework: customAttributes.framework,
architecture: customAttributes.architecture,
};
}
catch {
return undefined;
}
}
/**
* Get suggested configuration for detected model type
*/
getSuggestedConfig(modelType) {
const configs = {
huggingface: {
modelType: "huggingface",
inputFormat: "huggingface",
outputFormat: "huggingface",
contentType: "application/json",
accept: "text/event-stream",
},
llama: {
modelType: "llama",
contentType: "application/json",
accept: "application/x-ndjson",
},
pytorch: {
modelType: "custom",
contentType: "application/json",
accept: "application/json",
},
tensorflow: {
modelType: "custom",
contentType: "application/json",
accept: "application/json",
},
custom: {
modelType: "custom",
contentType: "application/json",
accept: "application/json",
},
};
return configs[modelType] || configs.custom;
}
/**
* Run detection tests in parallel with intelligent rate limiting and circuit breaker
* Now uses configuration object for better parameter management
*/
async runDetectionTestsInParallel(detectionTests, testNames, endpointName, config = {
maxConcurrentTests: 2,
maxRateLimitRetries: 2,
initialRateLimitCount: 0,
}) {
// Use configurable concurrency limit from config
const semaphore = this.createDetectionSemaphore(config.maxConcurrentTests);
// Use mutable object to prevent closure stale state issues
const rateLimitState = { count: config.initialRateLimitCount };
const wrappedTests = detectionTests.map((test, index) => this.wrapDetectionTest({
test,
index,
testName: testNames[index],
endpointName,
semaphore,
incrementRateLimit: () => rateLimitState.count++,
maxRateLimitRetries: config.maxRateLimitRetries,
rateLimitState,
}));
const results = await this.executeTestsWithConcurrencyControl(wrappedTests);
this.logDetectionResults(endpointName, testNames, results, rateLimitState.count > 0);
return results;
}
/**
* Create a semaphore for detection test concurrency control
*/
createDetectionSemaphore(maxConcurrent) {
return {
count: maxConcurrent,
waiters: [],
async acquire() {
return new Promise((resolve) => {
if (this.count > 0) {
this.count--;
resolve();
}
else {
this.waiters.push(() => {
this.count--;
resolve();
});
}
});
},
release() {
if (this.waiters.length > 0) {
const waiter = this.waiters.shift();
waiter();
}
else {
this.count++;
}
},
};
}
/**
* Wrap a detection test with error handling, rate limiting, and retry logic
* Now uses configuration object instead of multiple parameters
*/
wrapDetectionTest(config) {
return async () => {
await config.semaphore.acquire();
try {
await this.executeWithStaggeredStart(config.test, config.index);
return { status: "fulfilled", value: undefined };
}
catch (error) {
const result = await this.handleDetectionTestError(error, config.test, config.testName, config.endpointName, config.incrementRateLimit, config.maxRateLimitRetries, config.rateLimitState.count);
return result;
}
finally {
config.semaphore.release();
}
};
}
/**
* Execute a test with staggered start to spread load
*/
async executeWithStaggeredStart(test, index) {
const staggerDelay = index * DETECTION_STAGGER_DELAY_MS;
if (staggerDelay > 0) {
await new Promise((resolve) => setTimeout(resolve, staggerDelay));
}
await test();
}
/**
* Handle detection test errors with rate limiting and retry logic
*/
async handleDetectionTestError(error, test, testName, endpointName, incrementRateLimit, maxRateLimitRetries, rateLimitCount) {
const isRateLimit = this.isRateLimitError(error);
if (isRateLimit && rateLimitCount < maxRateLimitRetries) {
return await this.retryWithBackoff(test, testName, endpointName, incrementRateLimit, rateLimitCount);
}
this.logDetectionTestFailure(testName, endpointName, error);
return { status: "rejected", reason: error };
}
/**
* Check if an error indicates rate limiting
*/
isRateLimitError(error) {
return (error instanceof Error &&
(error.message.toLowerCase().includes("throttl") ||
error.message.toLowerCase().includes("rate limit") ||
error.message.toLowerCase().includes("too many requests")));
}
/**
* Retry a test with exponential backoff
*/
async retryWithBackoff(test, testName, endpointName, incrementRateLimit, rateLimitCount) {
incrementRateLimit();
logger.debug(`Rate limit detected for ${testName}, applying backoff`, {
endpointName,
attempt: rateLimitCount + 1,
});
await new Promise((resolve) => setTimeout(resolve, DETECTION_RATE_LIMIT_BACKOFF_MS * Math.pow(2, rateLimitCount)));
try {
await test();
return { status: "fulfilled", value: undefined };
}
catch (retryError) {
this.logDetectionTestRetryFailure(testName, endpointName, retryError);
return { status: "rejected", reason: retryError };
}
}
/**
* Execute wrapped tests with concurrency control
*/
async executeTestsWithConcurrencyControl(wrappedTests) {
const testPromises = wrappedTests.map((wrappedTest) => wrappedTest());
return await Promise.all(testPromises);
}
/**
* Log detection test failure
*/
logDetectionTestFailure(testName, endpointName, error) {
logger.debug(`${testName} detection test failed`, {
endpointName,
error: error instanceof Error ? error.message : String(error),
});
}
/**
* Log detection test retry failure
*/
logDetectionTestRetryFailure(testName, endpointName, error) {
logger.debug(`${testName} detection test retry failed`, {
endpointName,
error: error instanceof Error ? error.message : String(error),
});
}
/**
* Log final detection results
*/
logDetectionResults(endpointName, testNames, results, rateLimitEncountered) {
logger.debug("Parallel detection tests completed", {
endpointName,
totalTests: testNames.length,
successCount: results.filter((r) => r.status === "fulfilled").length,
rateLimitEncountered,
});
}
/**
* Create a no-streaming capability result
*/
createNoStreamingCapability(modelType, reason) {
logger.debug("No streaming capability detected", { modelType, reason });
return {
supported: false,
protocol: "none",
modelType,
confidence: 0.9,
metadata: {
// reason property not supported in interface
// Store reason in framework field for debugging
framework: reason,
},
};
}
}
/**
* Create a detector instance with configuration
*/
export function createSageMakerDetector(config) {
return new SageMakerDetector(config);
}