@tensorify.io/sdk
Version:
TypeScript SDK for developing Tensorify plugins with V2-Alpha definition/execution pattern and legacy compatibility
110 lines (97 loc) • 2.8 kB
text/typescript
import { BasePlugin } from "./base-plugin";
import {
GenerateCodeContext,
GenerateCodeResult,
NodeRegistry,
} from "./v2-alpha-types";
/**
* Base class for machine learning plugins
*
* Provides common functionality for ML/AI workflow nodes like
* device management, tensor operations, and common imports.
*/
export abstract class MLPlugin extends BasePlugin {
/**
* Get common PyTorch imports
*/
protected getCommonImports(): string[] {
return ["import torch", "import torch.nn as nn"];
}
/**
* Generate device setup code
*/
protected generateDeviceSetup(): string {
return `# Get the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")`;
}
/**
* Helper to move a variable to device
*/
protected moveToDevice(varName: string): string {
return `${varName} = ${varName}.to(device)`;
}
/**
* Get settings with type safety
*/
protected getSettings<T extends Record<string, any>>(
ctx: GenerateCodeContext
): T {
return ctx.node.config as T;
}
/**
* Format Python code with proper indentation
*/
protected formatCode(code: string, indent: number = 0): string {
const spaces = " ".repeat(indent);
return code
.split("\n")
.map((line) => (line.length > 0 ? spaces + line : line))
.join("\n");
}
/**
* Combine imports from multiple sources and deduplicate
*/
protected mergeImports(...importLists: string[][]): string[] {
const allImports = importLists.flat();
return [...new Set(allImports)];
}
/**
* Generate a Python function definition
*/
protected generateFunction(
name: string,
params: string[],
body: string,
returnType?: string
): string {
const signature = returnType
? `def ${name}(${params.join(", ")}) -> ${returnType}:`
: `def ${name}(${params.join(", ")}):`;
return `${signature}\n${this.formatCode(body, 4)}`;
}
/**
* Generate a Python class definition
*/
protected generateClass(
name: string,
baseClass: string,
methods: { name: string; params: string[]; body: string }[],
initBody?: string
): string {
let code = `class ${name}(${baseClass}):\n`;
if (initBody) {
code += ` def __init__(self${methods.length > 0 || initBody ? ", " : ""}${
methods.find((m) => m.name === "__init__")?.params.join(", ") || ""
}):\n`;
code += ` super().__init__()\n`;
code += this.formatCode(initBody, 8) + "\n\n";
}
for (const method of methods) {
if (method.name === "__init__") continue;
code += ` def ${method.name}(${method.params.join(", ")}):\n`;
code += this.formatCode(method.body, 8) + "\n\n";
}
return code.trimEnd();
}
}