@tensorify.io/sdk
Version:
TypeScript SDK for developing Tensorify plugins with V2-Alpha definition/execution pattern and legacy compatibility
84 lines • 2.67 kB
JavaScript
;
Object.defineProperty(exports, "__esModule", { value: true });
exports.MLPlugin = void 0;
const base_plugin_1 = require("./base-plugin");
/**
* Base class for machine learning plugins
*
* Provides common functionality for ML/AI workflow nodes like
* device management, tensor operations, and common imports.
*/
class MLPlugin extends base_plugin_1.BasePlugin {
/**
* Get common PyTorch imports
*/
getCommonImports() {
return ["import torch", "import torch.nn as nn"];
}
/**
* Generate device setup code
*/
generateDeviceSetup() {
return `# Get the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")`;
}
/**
* Helper to move a variable to device
*/
moveToDevice(varName) {
return `${varName} = ${varName}.to(device)`;
}
/**
* Get settings with type safety
*/
getSettings(ctx) {
return ctx.node.config;
}
/**
* Format Python code with proper indentation
*/
formatCode(code, indent = 0) {
const spaces = " ".repeat(indent);
return code
.split("\n")
.map((line) => (line.length > 0 ? spaces + line : line))
.join("\n");
}
/**
* Combine imports from multiple sources and deduplicate
*/
mergeImports(...importLists) {
const allImports = importLists.flat();
return [...new Set(allImports)];
}
/**
* Generate a Python function definition
*/
generateFunction(name, params, body, returnType) {
const signature = returnType
? `def ${name}(${params.join(", ")}) -> ${returnType}:`
: `def ${name}(${params.join(", ")}):`;
return `${signature}\n${this.formatCode(body, 4)}`;
}
/**
* Generate a Python class definition
*/
generateClass(name, baseClass, methods, initBody) {
let code = `class ${name}(${baseClass}):\n`;
if (initBody) {
code += ` def __init__(self${methods.length > 0 || initBody ? ", " : ""}${methods.find((m) => m.name === "__init__")?.params.join(", ") || ""}):\n`;
code += ` super().__init__()\n`;
code += this.formatCode(initBody, 8) + "\n\n";
}
for (const method of methods) {
if (method.name === "__init__")
continue;
code += ` def ${method.name}(${method.params.join(", ")}):\n`;
code += this.formatCode(method.body, 8) + "\n\n";
}
return code.trimEnd();
}
}
exports.MLPlugin = MLPlugin;
//# sourceMappingURL=ml-plugin.js.map