UNPKG

@juspay/neurolink

Version:

Universal AI Development Platform with working MCP integration, multi-provider support, and professional CLI. Built-in tools operational, 58+ external MCP servers discoverable. Connect to filesystem, GitHub, database operations, and more. Build, test, and

756 lines (755 loc) 30.6 kB
/** * SageMaker Language Model Implementation * * This module implements the LanguageModelV1 interface for Amazon SageMaker * integration with the Vercel AI SDK. */ import { randomUUID } from "crypto"; import { SageMakerRuntimeClient } from "./client.js"; import { handleSageMakerError } from "./errors.js"; import { estimateTokenUsage, createSageMakerStream } from "./streaming.js"; import { createAdaptiveSemaphore, } from "./adaptive-semaphore.js"; import { logger } from "../../utils/logger.js"; /** * Base synthetic streaming delay in milliseconds for simulating real-time response * Can be configured via SAGEMAKER_BASE_STREAMING_DELAY_MS environment variable */ const BASE_SYNTHETIC_STREAMING_DELAY_MS = process.env .SAGEMAKER_BASE_STREAMING_DELAY_MS ? parseInt(process.env.SAGEMAKER_BASE_STREAMING_DELAY_MS, 10) : 50; /** * Maximum synthetic streaming delay in milliseconds to prevent excessively slow streaming * Can be configured via SAGEMAKER_MAX_STREAMING_DELAY_MS environment variable */ const MAX_SYNTHETIC_STREAMING_DELAY_MS = process.env .SAGEMAKER_MAX_STREAMING_DELAY_MS ? parseInt(process.env.SAGEMAKER_MAX_STREAMING_DELAY_MS, 10) : 200; /** * Calculate adaptive delay based on text size to avoid slow streaming for large texts * Smaller texts get longer delays for realistic feel, larger texts get shorter delays for performance */ function calculateAdaptiveDelay(textLength, chunkCount) { // Base calculation: smaller delay for larger texts const adaptiveDelay = Math.max(10, // Minimum 10ms delay Math.min(MAX_SYNTHETIC_STREAMING_DELAY_MS, BASE_SYNTHETIC_STREAMING_DELAY_MS * (1000 / Math.max(textLength, 100)))); // Further reduce delay if there are many chunks to process if (chunkCount > 20) { return Math.max(10, adaptiveDelay * 0.5); // Half delay for many chunks } else if (chunkCount > 10) { return Math.max(15, adaptiveDelay * 0.7); // Reduced delay for moderate chunks } return adaptiveDelay; } /** * Create an async iterator for text chunks with adaptive delay between chunks * Used for synthetic streaming simulation with performance optimization for large texts */ async function* createTextChunkIterator(text) { if (!text) { return; // No text to emit } const words = text.split(/\s+/); const chunkSize = Math.max(1, Math.floor(words.length / 10)); const totalChunks = Math.ceil(words.length / chunkSize); // Calculate adaptive delay based on text size and chunk count const adaptiveDelay = calculateAdaptiveDelay(text.length, totalChunks); for (let i = 0; i < words.length; i += chunkSize) { const chunk = words.slice(i, i + chunkSize).join(" "); const deltaText = i === 0 ? chunk : " " + chunk; // Add adaptive delay between chunks for realistic streaming simulation // Delay is shorter for larger texts to improve performance if (i > 0) { await new Promise((resolve) => setTimeout(resolve, adaptiveDelay)); } yield deltaText; } } /** * Batch processing concurrency constants */ const DEFAULT_INITIAL_CONCURRENCY = 5; const DEFAULT_MAX_CONCURRENCY = 10; const DEFAULT_MIN_CONCURRENCY = 1; /** * SageMaker Language Model implementing LanguageModelV1 interface */ export class SageMakerLanguageModel { specificationVersion = "v1"; provider = "sagemaker"; modelId; supportsStreaming = true; defaultObjectGenerationMode = "json"; client; config; modelConfig; constructor(modelId, config, modelConfig) { this.modelId = modelId; this.config = config; this.modelConfig = modelConfig; this.client = new SageMakerRuntimeClient(config); logger.debug("SageMaker Language Model initialized", { modelId: this.modelId, endpointName: this.modelConfig.endpointName, provider: this.provider, specificationVersion: this.specificationVersion, }); } /** * Generate text synchronously using SageMaker endpoint */ async doGenerate(options) { const startTime = Date.now(); try { const promptText = this.extractPromptText(options); logger.debug("SageMaker doGenerate called", { endpointName: this.modelConfig.endpointName, promptLength: promptText.length, maxTokens: options.maxTokens, temperature: options.temperature, }); // Convert AI SDK options to SageMaker request format const sagemakerRequest = this.convertToSageMakerRequest(options); // Invoke SageMaker endpoint const response = await this.client.invokeEndpoint({ EndpointName: this.modelConfig.endpointName, Body: JSON.stringify(sagemakerRequest), ContentType: "application/json", Accept: "application/json", }); // Parse SageMaker response const responseBody = JSON.parse(new TextDecoder().decode(response.Body)); const generatedText = this.extractTextFromResponse(responseBody); // Extract tool calls if present (Phase 4 enhancement) const toolCalls = this.extractToolCallsFromResponse(responseBody); // Calculate token usage const usage = estimateTokenUsage(promptText, generatedText); // Determine finish reason based on response content let finishReason = "stop"; if (toolCalls && toolCalls.length > 0) { finishReason = "tool-calls"; } else if (responseBody.finish_reason) { finishReason = this.mapSageMakerFinishReason(responseBody.finish_reason); } const duration = Date.now() - startTime; logger.debug("SageMaker doGenerate completed", { duration, outputLength: generatedText.length, usage, toolCallsCount: toolCalls?.length || 0, finishReason, }); const result = { text: generatedText, usage: { promptTokens: usage.promptTokens, completionTokens: usage.completionTokens, totalTokens: usage.totalTokens, }, finishReason, rawCall: { rawPrompt: options.prompt, rawSettings: { maxTokens: options.maxTokens, temperature: options.temperature, topP: options.topP, endpointName: this.modelConfig.endpointName, }, }, rawResponse: { headers: { "content-type": response.ContentType || "application/json", "invoked-variant": response.InvokedProductionVariant || "", }, }, request: { body: JSON.stringify(sagemakerRequest), }, }; // Add tool calls to result if present if (toolCalls && toolCalls.length > 0) { result.toolCalls = toolCalls; } // Add structured data if response format was specified (Phase 4) const responseFormat = sagemakerRequest .response_format; if (responseFormat && (responseFormat.type === "json_object" || responseFormat.type === "json_schema")) { try { const parsedData = JSON.parse(generatedText); result.object = parsedData; logger.debug("Extracted structured data from response", { responseFormat: responseFormat.type, hasObject: !!result.object, }); } catch (parseError) { logger.warn("Failed to parse structured response as JSON", { error: parseError instanceof Error ? parseError.message : String(parseError), responseText: generatedText.substring(0, 200), }); // Keep the text response as fallback } } return result; } catch (error) { const duration = Date.now() - startTime; logger.error("SageMaker doGenerate failed", { duration, error: error instanceof Error ? error.message : String(error), }); throw handleSageMakerError(error, this.modelConfig.endpointName); } } /** * Generate text with streaming using SageMaker endpoint */ async doStream(options) { try { const promptText = this.extractPromptText(options); logger.debug("SageMaker doStream called", { endpointName: this.modelConfig.endpointName, promptLength: promptText.length, }); // Phase 2: Full streaming implementation with automatic detection const sagemakerRequest = this.convertToSageMakerRequest(options); // Add streaming parameter if model supports it const requestWithStreaming = { ...sagemakerRequest, parameters: { ...(typeof sagemakerRequest.parameters === "object" && sagemakerRequest.parameters !== null ? sagemakerRequest.parameters : {}), stream: true, // Will be validated by detection system }, }; logger.debug("Attempting streaming generation", { endpointName: this.modelConfig.endpointName, hasStreamingFlag: true, }); try { // First, try to invoke with streaming const response = await this.client.invokeEndpointWithStreaming({ EndpointName: this.modelConfig.endpointName, Body: JSON.stringify(requestWithStreaming), ContentType: this.modelConfig.contentType || "application/json", Accept: this.modelConfig.accept || "application/json", }); // Create intelligent streaming response const stream = await createSageMakerStream(response.Body, this.modelConfig.endpointName, this.config, { prompt: promptText, onChunk: (chunk) => { logger.debug("Streaming chunk received", { contentLength: chunk.content?.length || 0, done: chunk.done, }); }, onComplete: (usage) => { logger.debug("Streaming completed", { usage, endpointName: this.modelConfig.endpointName, }); }, onError: (error) => { logger.error("Streaming error", { error: error.message, endpointName: this.modelConfig.endpointName, }); }, }); return { stream: stream, rawCall: { rawPrompt: sagemakerRequest, rawSettings: this.modelConfig, }, rawResponse: { headers: { "Content-Type": response.ContentType || "application/json", "X-Invoked-Production-Variant": response.InvokedProductionVariant || "unknown", }, }, }; } catch (streamingError) { logger.warn("Streaming failed, falling back to non-streaming", { endpointName: this.modelConfig.endpointName, error: streamingError instanceof Error ? streamingError.message : String(streamingError), }); // Fallback: Generate normally and create synthetic stream const result = await this.doGenerate(options); // Create synthetic stream from complete result using async iterator pattern const syntheticStream = new ReadableStream({ async start(controller) { try { // Create async iterator for text chunks const textChunks = createTextChunkIterator(result.text); // Process chunks with async iterator pattern for await (const deltaText of textChunks) { controller.enqueue({ type: "text-delta", textDelta: deltaText, }); } // Emit completion controller.enqueue({ type: "finish", finishReason: result.finishReason, usage: result.usage, }); controller.close(); } catch (error) { controller.error(error); } }, }); return { stream: syntheticStream, rawCall: result.rawCall, rawResponse: result.rawResponse, request: result.request, warnings: [ ...(result.warnings || []), { type: "other", message: "Streaming not supported, using synthetic stream", }, ], }; } } catch (error) { logger.error("SageMaker doStream failed", { error: error instanceof Error ? error.message : String(error), }); throw handleSageMakerError(error, this.modelConfig.endpointName); } } /** * Convert AI SDK options to SageMaker request format */ convertToSageMakerRequest(options) { const promptText = this.extractPromptText(options); // Enhanced SageMaker request format with tool support (Phase 4) const request = { inputs: promptText, parameters: { max_new_tokens: options.maxTokens || 512, temperature: options.temperature || 0.7, top_p: options.topP || 0.9, stop: options.stopSequences || [], }, }; // Add tool support if tools are present const tools = options.tools; if (tools && Array.isArray(tools) && tools.length > 0) { request.tools = this.convertToolsToSageMakerFormat(tools); // Add tool choice if specified const toolChoice = options.toolChoice; if (toolChoice) { request.tool_choice = this.convertToolChoiceToSageMakerFormat(toolChoice); } logger.debug("Added tool support to SageMaker request", { toolCount: tools.length, toolChoice: toolChoice, }); } // Add structured output support (Phase 4) const responseFormat = options .responseFormat; if (responseFormat) { request.response_format = this.convertResponseFormatToSageMakerFormat(responseFormat); logger.debug("Added structured output support to SageMaker request", { responseFormat: responseFormat.type, }); } logger.debug("Converted to SageMaker request format", { inputLength: promptText.length, parameters: request.parameters, hasTools: !!request.tools, }); return request; } /** * Convert Vercel AI SDK tools to SageMaker format */ convertToolsToSageMakerFormat(tools) { return tools.map((tool) => { if (tool.type === "function") { return { type: "function", function: { name: tool.function.name, description: tool.function.description || "", parameters: tool.function.parameters || {}, }, }; } return tool; // Pass through other tool types }); } /** * Convert Vercel AI SDK tool choice to SageMaker format */ convertToolChoiceToSageMakerFormat(toolChoice) { if (typeof toolChoice === "string") { return toolChoice; // 'auto', 'none', etc. } if (toolChoice?.type === "function") { return { type: "function", function: { name: toolChoice.function.name, }, }; } return toolChoice; } /** * Convert Vercel AI SDK response format to SageMaker format (Phase 4) */ convertResponseFormatToSageMakerFormat(responseFormat) { if (responseFormat.type === "json_object") { return { type: "json_object", schema: responseFormat.schema || undefined, }; } if (responseFormat.type === "json_schema") { return { type: "json_schema", json_schema: { name: responseFormat.json_schema?.name || "response", description: responseFormat.json_schema?.description || "Generated response", schema: responseFormat.json_schema?.schema || {}, }, }; } // Default to text return { type: "text", }; } /** * Extract text content from AI SDK prompt format */ extractPromptText(options) { // Check for messages first (like Ollama) const messages = options.messages; if (messages && Array.isArray(messages)) { return messages .filter((msg) => msg.role && msg.content) .map((msg) => { if (typeof msg.content === "string") { return `${msg.role}: ${msg.content}`; } return `${msg.role}: ${JSON.stringify(msg.content)}`; }) .join("\n"); } // Fallback to prompt property const prompt = options.prompt; if (typeof prompt === "string") { return prompt; } if (Array.isArray(prompt)) { return prompt .filter((msg) => msg.role && msg.content) .map((msg) => { if (typeof msg.content === "string") { return `${msg.role}: ${msg.content}`; } return `${msg.role}: ${JSON.stringify(msg.content)}`; }) .join("\n"); } return String(prompt); } /** * Extract generated text from SageMaker response */ extractTextFromResponse(responseBody) { // Handle common SageMaker response formats if (typeof responseBody === "string") { return responseBody; } if (responseBody.generated_text) { return responseBody.generated_text; } if (responseBody.outputs) { return responseBody.outputs; } if (responseBody.text) { return responseBody.text; } if (Array.isArray(responseBody) && responseBody[0]?.generated_text) { return responseBody[0].generated_text; } // Handle response with tool calls if (responseBody.choices && Array.isArray(responseBody.choices)) { const choice = responseBody.choices[0]; if (choice?.message?.content) { return choice.message.content; } } // Fallback: stringify the entire response return JSON.stringify(responseBody); } /** * Extract tool calls from SageMaker response (Phase 4) */ extractToolCallsFromResponse(responseBody) { // Handle OpenAI-compatible format (common for many SageMaker models) if (responseBody.choices && Array.isArray(responseBody.choices)) { const choice = responseBody.choices[0]; if (choice?.message?.tool_calls) { return choice.message.tool_calls.map((toolCall) => ({ type: "function", id: String(toolCall.id || `call_${randomUUID()}`), function: { name: String(toolCall.function.name), arguments: String(toolCall.function.arguments), }, })); } } // Handle custom SageMaker tool call format if (responseBody.tool_calls && Array.isArray(responseBody.tool_calls)) { return responseBody.tool_calls; } // Handle Anthropic-style tool use if (responseBody.content && Array.isArray(responseBody.content)) { const toolUses = responseBody.content.filter((item) => item.type === "tool_use"); if (toolUses.length > 0) { return toolUses.map((toolUse) => ({ type: "function", id: String(toolUse.id || `call_${randomUUID()}`), function: { name: String(toolUse.name), arguments: JSON.stringify(toolUse.input || {}), }, })); } } return undefined; } /** * Map SageMaker finish reason to standardized format */ mapSageMakerFinishReason(sagemakerReason) { switch (sagemakerReason?.toLowerCase()) { case "stop": case "end_turn": case "stop_sequence": return "stop"; case "length": case "max_tokens": case "max_length": return "length"; case "content_filter": case "content_filtered": return "content-filter"; case "tool_calls": case "function_call": return "tool-calls"; case "error": return "error"; default: return "unknown"; } } /** * Get model configuration summary for debugging */ getModelInfo() { return { modelId: this.modelId, provider: this.provider, specificationVersion: this.specificationVersion, endpointName: this.modelConfig.endpointName, modelType: this.modelConfig.modelType, region: this.config.region, }; } /** * Test basic connectivity to the SageMaker endpoint */ async testConnectivity() { try { // Use the same pattern as Ollama - pass messages directly const result = await this.doGenerate({ inputFormat: "messages", mode: { type: "regular" }, prompt: [ { role: "user", content: [{ type: "text", text: "Hello" }] }, ], maxTokens: 10, }); return { success: !!result.text, }; } catch (error) { return { success: false, error: error instanceof Error ? error.message : String(error), }; } } /** * Batch inference support (Phase 4) * Process multiple prompts in a single request for efficiency */ async doBatchGenerate(prompts, options) { try { logger.debug("SageMaker batch generate called", { batchSize: prompts.length, endpointName: this.modelConfig.endpointName, }); // Advanced parallel processing with dynamic concurrency and error handling const results = await this.processPromptsInParallel(prompts, options); logger.debug("SageMaker batch generate completed", { batchSize: prompts.length, successCount: results.length, }); return results; } catch (error) { logger.error("SageMaker batch generate failed", { error: error instanceof Error ? error.message : String(error), batchSize: prompts.length, }); throw handleSageMakerError(error, this.modelConfig.endpointName); } } /** * Process prompts in parallel with advanced concurrency control and error handling */ async processPromptsInParallel(prompts, options) { // Dynamic concurrency based on batch size and endpoint capacity const INITIAL_CONCURRENCY = Math.min(this.modelConfig.initialConcurrency ?? DEFAULT_INITIAL_CONCURRENCY, prompts.length); const MAX_CONCURRENCY = this.modelConfig.maxConcurrency ?? DEFAULT_MAX_CONCURRENCY; const MIN_CONCURRENCY = this.modelConfig.minConcurrency ?? DEFAULT_MIN_CONCURRENCY; const results = new Array(prompts.length); const errors = []; // Use adaptive semaphore utility for concurrency control const semaphore = createAdaptiveSemaphore(INITIAL_CONCURRENCY, MAX_CONCURRENCY, MIN_CONCURRENCY); // Process each prompt with adaptive concurrency const processPrompt = async (prompt, index) => { await semaphore.acquire(); const startTime = Date.now(); try { const result = await this.doGenerate({ inputFormat: "messages", mode: { type: "regular" }, prompt: [ { role: "user", content: [{ type: "text", text: prompt }], }, ], maxTokens: options?.maxTokens, temperature: options?.temperature, topP: options?.topP, }); const duration = Date.now() - startTime; results[index] = { text: result.text || "", usage: { promptTokens: result.usage.promptTokens, completionTokens: result.usage.completionTokens, totalTokens: result.usage.totalTokens || result.usage.promptTokens + result.usage.completionTokens, }, finishReason: result.finishReason, index, }; // Record successful completion for adaptive concurrency adjustment semaphore.recordSuccess(duration); } catch (error) { errors.push({ index, error: error instanceof Error ? error : new Error(String(error)), }); // Record error for adaptive concurrency adjustment const duration = Date.now() - startTime; semaphore.recordError(duration); // Create error result results[index] = { text: "", usage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 }, finishReason: "error", index, }; } finally { semaphore.release(); } }; // Start all requests with concurrency control const allPromises = prompts.map((prompt, index) => processPrompt(prompt, index)); // Wait for all requests to complete await Promise.all(allPromises); // Log final statistics using semaphore metrics const metrics = semaphore.getMetrics(); logger.debug("Parallel batch processing completed", { totalPrompts: prompts.length, successCount: metrics.completedCount, errorCount: metrics.errorCount, finalConcurrency: metrics.currentConcurrency, errorRate: metrics.errorCount / prompts.length, averageResponseTime: metrics.averageResponseTime, }); // If we have too many errors, log them for debugging if (errors.length > 0) { logger.warn("Batch processing encountered errors", { errorCount: errors.length, sampleErrors: errors.slice(0, 3).map((e) => ({ index: e.index, message: e.error.message, })), }); } // Return results in original order (already sorted by index) return results.map(({ text, usage, finishReason }) => ({ text, usage, finishReason, })); } /** * Enhanced model information with batch capabilities */ getModelCapabilities() { return { ...this.getModelInfo(), capabilities: { streaming: true, toolCalling: true, structuredOutput: true, batchInference: true, supportedResponseFormats: ["text", "json_object", "json_schema"], supportedToolTypes: ["function"], maxBatchSize: 100, // Increased limit with parallel processing adaptiveConcurrency: true, errorRecovery: true, }, }; } } export default SageMakerLanguageModel;