UNPKG

@tensorify.io/sdk

Version:

TypeScript SDK for developing Tensorify plugins with V2-Alpha definition/execution pattern and legacy compatibility

110 lines (97 loc) 2.8 kB
import { BasePlugin } from "./base-plugin"; import { GenerateCodeContext, GenerateCodeResult, NodeRegistry, } from "./v2-alpha-types"; /** * Base class for machine learning plugins * * Provides common functionality for ML/AI workflow nodes like * device management, tensor operations, and common imports. */ export abstract class MLPlugin extends BasePlugin { /** * Get common PyTorch imports */ protected getCommonImports(): string[] { return ["import torch", "import torch.nn as nn"]; } /** * Generate device setup code */ protected generateDeviceSetup(): string { return `# Get the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using {device} device")`; } /** * Helper to move a variable to device */ protected moveToDevice(varName: string): string { return `${varName} = ${varName}.to(device)`; } /** * Get settings with type safety */ protected getSettings<T extends Record<string, any>>( ctx: GenerateCodeContext ): T { return ctx.node.config as T; } /** * Format Python code with proper indentation */ protected formatCode(code: string, indent: number = 0): string { const spaces = " ".repeat(indent); return code .split("\n") .map((line) => (line.length > 0 ? spaces + line : line)) .join("\n"); } /** * Combine imports from multiple sources and deduplicate */ protected mergeImports(...importLists: string[][]): string[] { const allImports = importLists.flat(); return [...new Set(allImports)]; } /** * Generate a Python function definition */ protected generateFunction( name: string, params: string[], body: string, returnType?: string ): string { const signature = returnType ? `def ${name}(${params.join(", ")}) -> ${returnType}:` : `def ${name}(${params.join(", ")}):`; return `${signature}\n${this.formatCode(body, 4)}`; } /** * Generate a Python class definition */ protected generateClass( name: string, baseClass: string, methods: { name: string; params: string[]; body: string }[], initBody?: string ): string { let code = `class ${name}(${baseClass}):\n`; if (initBody) { code += ` def __init__(self${methods.length > 0 || initBody ? ", " : ""}${ methods.find((m) => m.name === "__init__")?.params.join(", ") || "" }):\n`; code += ` super().__init__()\n`; code += this.formatCode(initBody, 8) + "\n\n"; } for (const method of methods) { if (method.name === "__init__") continue; code += ` def ${method.name}(${method.params.join(", ")}):\n`; code += this.formatCode(method.body, 8) + "\n\n"; } return code.trimEnd(); } }