UNPKG

@tensorify.io/sdk

Version:

TypeScript SDK for developing Tensorify plugins with V2-Alpha definition/execution pattern and legacy compatibility

84 lines 2.67 kB
"use strict"; Object.defineProperty(exports, "__esModule", { value: true }); exports.MLPlugin = void 0; const base_plugin_1 = require("./base-plugin"); /** * Base class for machine learning plugins * * Provides common functionality for ML/AI workflow nodes like * device management, tensor operations, and common imports. */ class MLPlugin extends base_plugin_1.BasePlugin { /** * Get common PyTorch imports */ getCommonImports() { return ["import torch", "import torch.nn as nn"]; } /** * Generate device setup code */ generateDeviceSetup() { return `# Get the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using {device} device")`; } /** * Helper to move a variable to device */ moveToDevice(varName) { return `${varName} = ${varName}.to(device)`; } /** * Get settings with type safety */ getSettings(ctx) { return ctx.node.config; } /** * Format Python code with proper indentation */ formatCode(code, indent = 0) { const spaces = " ".repeat(indent); return code .split("\n") .map((line) => (line.length > 0 ? spaces + line : line)) .join("\n"); } /** * Combine imports from multiple sources and deduplicate */ mergeImports(...importLists) { const allImports = importLists.flat(); return [...new Set(allImports)]; } /** * Generate a Python function definition */ generateFunction(name, params, body, returnType) { const signature = returnType ? `def ${name}(${params.join(", ")}) -> ${returnType}:` : `def ${name}(${params.join(", ")}):`; return `${signature}\n${this.formatCode(body, 4)}`; } /** * Generate a Python class definition */ generateClass(name, baseClass, methods, initBody) { let code = `class ${name}(${baseClass}):\n`; if (initBody) { code += ` def __init__(self${methods.length > 0 || initBody ? ", " : ""}${methods.find((m) => m.name === "__init__")?.params.join(", ") || ""}):\n`; code += ` super().__init__()\n`; code += this.formatCode(initBody, 8) + "\n\n"; } for (const method of methods) { if (method.name === "__init__") continue; code += ` def ${method.name}(${method.params.join(", ")}):\n`; code += this.formatCode(method.body, 8) + "\n\n"; } return code.trimEnd(); } } exports.MLPlugin = MLPlugin; //# sourceMappingURL=ml-plugin.js.map