UNPKG

@juspay/neurolink

Version:

Universal AI Development Platform with working MCP integration, multi-provider support, and professional CLI. Built-in tools operational, 58+ external MCP servers discoverable. Connect to filesystem, GitHub, database operations, and more. Build, test, and

597 lines (596 loc) 22.5 kB
/** * SageMaker Model Detection and Streaming Capability Discovery * * This module provides intelligent detection of SageMaker endpoint capabilities * including model type identification and streaming protocol support. */ import { SageMakerRuntimeClient } from "./client.js"; import { logger } from "../../utils/logger.js"; /** * Configurable constants for detection timing and performance */ const DETECTION_TEST_DELAY_MS = 100; // Base delay between detection tests (ms) const DETECTION_STAGGER_DELAY_MS = 25; // Delay between staggered test starts (ms) const DETECTION_RATE_LIMIT_BACKOFF_MS = 200; // Initial backoff on rate limit detection (ms) /** * SageMaker Model Detection and Capability Discovery Service */ export class SageMakerDetector { client; config; constructor(config) { this.client = new SageMakerRuntimeClient(config); this.config = config; } /** * Detect streaming capabilities for a given endpoint */ async detectStreamingCapability(endpointName) { logger.debug("Starting streaming capability detection", { endpointName }); try { // Step 1: Check endpoint health and gather metadata const health = await this.checkEndpointHealth(endpointName); if (health.status !== "healthy") { return this.createNoStreamingCapability("custom", "Endpoint not healthy"); } // Step 2: Detect model type const modelDetection = await this.detectModelType(endpointName); logger.debug("Model type detection result", { endpointName, type: modelDetection.type, confidence: modelDetection.confidence, }); // Step 3: Test streaming support based on model type const streamingSupport = await this.testStreamingSupport(endpointName, modelDetection.type); // Step 4: Determine streaming protocol const protocol = await this.detectStreamingProtocol(modelDetection.type); return { supported: streamingSupport.supported, protocol, modelType: modelDetection.type, confidence: Math.min(modelDetection.confidence, streamingSupport.confidence), parameters: streamingSupport.parameters, metadata: { modelName: health.modelInfo?.name, framework: health.modelInfo?.framework, version: health.modelInfo?.version, }, }; } catch (error) { logger.warn("Streaming capability detection failed", { endpointName, error: error instanceof Error ? error.message : String(error), }); return this.createNoStreamingCapability("custom", "Detection failed, assuming custom model"); } } /** * Detect the model type/framework for an endpoint */ async detectModelType(endpointName) { const evidence = []; const detectionTests = [ () => this.testHuggingFaceSignature(endpointName, evidence), () => this.testLlamaSignature(endpointName, evidence), () => this.testPyTorchSignature(endpointName, evidence), () => this.testTensorFlowSignature(endpointName, evidence), ]; // Run detection tests in parallel with intelligent rate limiting const testNames = ["HuggingFace", "LLaMA", "PyTorch", "TensorFlow"]; const results = await this.runDetectionTestsInParallel(detectionTests, testNames, endpointName); // Analyze results and determine most likely model type const scores = { huggingface: 0, llama: 0, pytorch: 0, tensorflow: 0, custom: 0.1, // Base score for custom models }; // Process evidence and calculate scores evidence.forEach((item) => { if (item.includes("huggingface") || item.includes("transformers")) { scores.huggingface += 0.3; } if (item.includes("llama") || item.includes("openai-compatible")) { scores.llama += 0.3; } if (item.includes("pytorch") || item.includes("torch")) { scores.pytorch += 0.2; } if (item.includes("tensorflow") || item.includes("serving")) { scores.tensorflow += 0.2; } }); // Find highest scoring model type const maxScore = Math.max(...Object.values(scores)); const detectedType = Object.entries(scores).find(([, score]) => score === maxScore)?.[0] || "custom"; return { type: detectedType, confidence: maxScore, evidence, suggestedConfig: this.getSuggestedConfig(detectedType), }; } /** * Check endpoint health and gather metadata */ async checkEndpointHealth(endpointName) { const startTime = Date.now(); try { // Simple health check with minimal payload const testPayload = JSON.stringify({ inputs: "test" }); const response = await this.client.invokeEndpoint({ EndpointName: endpointName, Body: testPayload, ContentType: "application/json", }); const responseTime = Date.now() - startTime; return { status: "healthy", responseTime, metadata: response.CustomAttributes ? JSON.parse(response.CustomAttributes) : undefined, modelInfo: this.extractModelInfo(response), }; } catch (error) { const responseTime = Date.now() - startTime; logger.warn("Endpoint health check failed", { endpointName, responseTime, error: error instanceof Error ? error.message : String(error), }); return { status: "unhealthy", responseTime, }; } } /** * Test if endpoint supports streaming for given model type */ async testStreamingSupport(endpointName, modelType) { const testCases = this.getStreamingTestCases(modelType); for (const testCase of testCases) { try { const response = await this.client.invokeEndpoint({ EndpointName: endpointName, Body: JSON.stringify(testCase.payload), ContentType: "application/json", Accept: testCase.acceptHeader, }); // Check response headers for streaming indicators if (this.indicatesStreamingSupport(response)) { return { supported: true, confidence: testCase.confidence, parameters: testCase.parameters, }; } } catch (error) { // Streaming test failed, continue to next test case logger.debug("Streaming test failed", { endpointName, error: error instanceof Error ? error.message : String(error), }); } } return { supported: false, confidence: 0.9 }; } /** * Detect streaming protocol used by endpoint */ async detectStreamingProtocol(modelType) { // Protocol mapping based on model type const protocolMap = { huggingface: "sse", // Server-Sent Events llama: "jsonl", // JSON Lines pytorch: "none", // Usually no streaming tensorflow: "none", // Usually no streaming custom: "chunked", // Generic chunked transfer }; return protocolMap[modelType] || "none"; } /** * Test for HuggingFace Transformers signature */ async testHuggingFaceSignature(endpointName, evidence) { try { const testPayload = { inputs: "test", parameters: { return_full_text: false, max_new_tokens: 1 }, }; const response = await this.client.invokeEndpoint({ EndpointName: endpointName, Body: JSON.stringify(testPayload), ContentType: "application/json", }); const responseText = new TextDecoder().decode(response.Body); const parsedResponse = JSON.parse(responseText); if (parsedResponse[0]?.generated_text !== undefined) { evidence.push("huggingface: generated_text field found"); } if (parsedResponse.error?.includes("transformers")) { evidence.push("huggingface: transformers error message"); } } catch (error) { // Test failed, no evidence } } /** * Test for LLaMA model signature */ async testLlamaSignature(endpointName, evidence) { try { const testPayload = { prompt: "test", max_tokens: 1, temperature: 0, }; const response = await this.client.invokeEndpoint({ EndpointName: endpointName, Body: JSON.stringify(testPayload), ContentType: "application/json", }); const responseText = new TextDecoder().decode(response.Body); const parsedResponse = JSON.parse(responseText); if (parsedResponse.choices) { evidence.push("llama: openai-compatible choices field"); } if (parsedResponse.object === "text_completion") { evidence.push("llama: openai text_completion object"); } } catch (error) { // Test failed, no evidence } } /** * Test for PyTorch model signature */ async testPyTorchSignature(endpointName, evidence) { try { const testPayload = { input: "test" }; const response = await this.client.invokeEndpoint({ EndpointName: endpointName, Body: JSON.stringify(testPayload), ContentType: "application/json", }); const responseText = new TextDecoder().decode(response.Body); if (responseText.includes("prediction") || responseText.includes("output")) { evidence.push("pytorch: prediction/output field pattern"); } } catch (error) { // Test failed, no evidence } } /** * Test for TensorFlow Serving signature */ async testTensorFlowSignature(endpointName, evidence) { try { const testPayload = { instances: [{ input: "test" }], signature_name: "serving_default", }; const response = await this.client.invokeEndpoint({ EndpointName: endpointName, Body: JSON.stringify(testPayload), ContentType: "application/json", }); const responseText = new TextDecoder().decode(response.Body); const parsedResponse = JSON.parse(responseText); if (parsedResponse.predictions) { evidence.push("tensorflow: serving predictions field"); } } catch (error) { // Test failed, no evidence } } /** * Get streaming test cases for a model type */ getStreamingTestCases(modelType) { const testCases = { huggingface: [ { name: "HF streaming test", payload: { inputs: "test", parameters: { stream: true, max_new_tokens: 5 }, }, acceptHeader: "text/event-stream", confidence: 0.8, parameters: { stream: true }, }, ], llama: [ { name: "LLaMA streaming test", payload: { prompt: "test", stream: true, max_tokens: 5 }, acceptHeader: "application/x-ndjson", confidence: 0.8, parameters: { stream: true }, }, ], pytorch: [], tensorflow: [], custom: [ { name: "Generic streaming test", payload: { input: "test", stream: true }, acceptHeader: "application/json", confidence: 0.3, parameters: { stream: true }, }, ], }; return testCases[modelType] || []; } /** * Check if response indicates streaming support */ indicatesStreamingSupport(response) { // Check content type for streaming indicators const contentType = response.ContentType || ""; if (contentType.includes("event-stream") || contentType.includes("x-ndjson") || contentType.includes("chunked")) { return true; } // Note: InvokeEndpointResponse doesn't include headers // Streaming detection is based on ContentType only logger.debug("Testing streaming support", { contentType, }); return false; } /** * Extract model information from response */ extractModelInfo(response) { try { const customAttributes = response.CustomAttributes ? JSON.parse(response.CustomAttributes) : {}; return { name: customAttributes.model_name, version: customAttributes.model_version, framework: customAttributes.framework, architecture: customAttributes.architecture, }; } catch { return undefined; } } /** * Get suggested configuration for detected model type */ getSuggestedConfig(modelType) { const configs = { huggingface: { modelType: "huggingface", inputFormat: "huggingface", outputFormat: "huggingface", contentType: "application/json", accept: "text/event-stream", }, llama: { modelType: "llama", contentType: "application/json", accept: "application/x-ndjson", }, pytorch: { modelType: "custom", contentType: "application/json", accept: "application/json", }, tensorflow: { modelType: "custom", contentType: "application/json", accept: "application/json", }, custom: { modelType: "custom", contentType: "application/json", accept: "application/json", }, }; return configs[modelType] || configs.custom; } /** * Run detection tests in parallel with intelligent rate limiting and circuit breaker * Now uses configuration object for better parameter management */ async runDetectionTestsInParallel(detectionTests, testNames, endpointName, config = { maxConcurrentTests: 2, maxRateLimitRetries: 2, initialRateLimitCount: 0, }) { // Use configurable concurrency limit from config const semaphore = this.createDetectionSemaphore(config.maxConcurrentTests); // Use mutable object to prevent closure stale state issues const rateLimitState = { count: config.initialRateLimitCount }; const wrappedTests = detectionTests.map((test, index) => this.wrapDetectionTest({ test, index, testName: testNames[index], endpointName, semaphore, incrementRateLimit: () => rateLimitState.count++, maxRateLimitRetries: config.maxRateLimitRetries, rateLimitState, })); const results = await this.executeTestsWithConcurrencyControl(wrappedTests); this.logDetectionResults(endpointName, testNames, results, rateLimitState.count > 0); return results; } /** * Create a semaphore for detection test concurrency control */ createDetectionSemaphore(maxConcurrent) { return { count: maxConcurrent, waiters: [], async acquire() { return new Promise((resolve) => { if (this.count > 0) { this.count--; resolve(); } else { this.waiters.push(() => { this.count--; resolve(); }); } }); }, release() { if (this.waiters.length > 0) { const waiter = this.waiters.shift(); waiter(); } else { this.count++; } }, }; } /** * Wrap a detection test with error handling, rate limiting, and retry logic * Now uses configuration object instead of multiple parameters */ wrapDetectionTest(config) { return async () => { await config.semaphore.acquire(); try { await this.executeWithStaggeredStart(config.test, config.index); return { status: "fulfilled", value: undefined }; } catch (error) { const result = await this.handleDetectionTestError(error, config.test, config.testName, config.endpointName, config.incrementRateLimit, config.maxRateLimitRetries, config.rateLimitState.count); return result; } finally { config.semaphore.release(); } }; } /** * Execute a test with staggered start to spread load */ async executeWithStaggeredStart(test, index) { const staggerDelay = index * DETECTION_STAGGER_DELAY_MS; if (staggerDelay > 0) { await new Promise((resolve) => setTimeout(resolve, staggerDelay)); } await test(); } /** * Handle detection test errors with rate limiting and retry logic */ async handleDetectionTestError(error, test, testName, endpointName, incrementRateLimit, maxRateLimitRetries, rateLimitCount) { const isRateLimit = this.isRateLimitError(error); if (isRateLimit && rateLimitCount < maxRateLimitRetries) { return await this.retryWithBackoff(test, testName, endpointName, incrementRateLimit, rateLimitCount); } this.logDetectionTestFailure(testName, endpointName, error); return { status: "rejected", reason: error }; } /** * Check if an error indicates rate limiting */ isRateLimitError(error) { return (error instanceof Error && (error.message.toLowerCase().includes("throttl") || error.message.toLowerCase().includes("rate limit") || error.message.toLowerCase().includes("too many requests"))); } /** * Retry a test with exponential backoff */ async retryWithBackoff(test, testName, endpointName, incrementRateLimit, rateLimitCount) { incrementRateLimit(); logger.debug(`Rate limit detected for ${testName}, applying backoff`, { endpointName, attempt: rateLimitCount + 1, }); await new Promise((resolve) => setTimeout(resolve, DETECTION_RATE_LIMIT_BACKOFF_MS * Math.pow(2, rateLimitCount))); try { await test(); return { status: "fulfilled", value: undefined }; } catch (retryError) { this.logDetectionTestRetryFailure(testName, endpointName, retryError); return { status: "rejected", reason: retryError }; } } /** * Execute wrapped tests with concurrency control */ async executeTestsWithConcurrencyControl(wrappedTests) { const testPromises = wrappedTests.map((wrappedTest) => wrappedTest()); return await Promise.all(testPromises); } /** * Log detection test failure */ logDetectionTestFailure(testName, endpointName, error) { logger.debug(`${testName} detection test failed`, { endpointName, error: error instanceof Error ? error.message : String(error), }); } /** * Log detection test retry failure */ logDetectionTestRetryFailure(testName, endpointName, error) { logger.debug(`${testName} detection test retry failed`, { endpointName, error: error instanceof Error ? error.message : String(error), }); } /** * Log final detection results */ logDetectionResults(endpointName, testNames, results, rateLimitEncountered) { logger.debug("Parallel detection tests completed", { endpointName, totalTests: testNames.length, successCount: results.filter((r) => r.status === "fulfilled").length, rateLimitEncountered, }); } /** * Create a no-streaming capability result */ createNoStreamingCapability(modelType, reason) { logger.debug("No streaming capability detected", { modelType, reason }); return { supported: false, protocol: "none", modelType, confidence: 0.9, metadata: { // reason property not supported in interface // Store reason in framework field for debugging framework: reason, }, }; } } /** * Create a detector instance with configuration */ export function createSageMakerDetector(config) { return new SageMakerDetector(config); }