UNPKG

itk-wasm

Version:

High-performance spatial analysis in a web browser, Node.js, and reproducible execution across programming languages and hardware architectures.

343 lines (323 loc) 13.6 kB
import fs from 'fs-extra' import snakeCase from '../../snake-case.js' import canonicalType from '../../canonical-type.js' import functionModuleImports from '../function-module-imports.js' import functionModuleReturnType from '../function-module-return-type.js' import functionModuleDocstring from '../function-module-docstring.js' import functionModuleArgs from '../function-module-args.js' import interfaceJsonTypeToInterfaceType from '../../interface-json-type-to-interface-type.js' import interfaceJsonTypeToPythonType from '../interface-json-type-to-python-type.js' import writeIfOverrideNotPresent from '../../write-if-override-not-present.js' function wasiFunctionModule(interfaceJson, pypackage, modulePath) { const functionName = snakeCase(interfaceJson.name) let moduleContent = `from pathlib import Path, PurePosixPath import os from typing import Dict, Tuple, Optional, List, Any from importlib_resources import files as file_resources _pipeline = None from itkwasm import ( InterfaceTypes, PipelineOutput, PipelineInput, Pipeline,` moduleContent += functionModuleImports(interfaceJson) const functionArgs = functionModuleArgs(interfaceJson) const returnType = functionModuleReturnType(interfaceJson) const docstring = functionModuleDocstring(interfaceJson) let pipelineOutputFilePrep = '' interfaceJson.outputs.forEach((output) => { if (interfaceJsonTypeToInterfaceType.has(output.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type) const isArray = output.itemsExpectedMax > 1 if (interfaceType.includes('File')) { const snake = snakeCase(output.name) if (isArray) { pipelineOutputFilePrep += ` ${snake}_pipeline_outputs = [PipelineOutput(InterfaceTypes.${interfaceType}, ${interfaceType}(PurePosixPath(p))) for p in ${snake}]\n` } } } }) let pipelineOutputs = '' let haveArray = false interfaceJson.outputs.forEach((output) => { if (interfaceJsonTypeToInterfaceType.has(output.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type) const isArray = output.itemsExpectedMax > 1 switch (interfaceType) { case 'TextFile': case 'BinaryFile': if (isArray) { haveArray = true pipelineOutputs += ` *${snakeCase(output.name)}_pipeline_outputs,\n` } else { pipelineOutputs += ` PipelineOutput(InterfaceTypes.${interfaceType}, ${interfaceType}(PurePosixPath(${snakeCase(output.name)}))),\n` } break default: pipelineOutputs += ` PipelineOutput(InterfaceTypes.${interfaceType}),\n` } } }) let pipelineOutputIndices = '' if (haveArray) { pipelineOutputIndices += ` output_index = 0\n` interfaceJson.outputs.forEach((output) => { if (interfaceJsonTypeToInterfaceType.has(output.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type) if (interfaceType.includes('File')) { const snake = snakeCase(output.name) const isArray = output.itemsExpectedMax > 1 if (isArray) { pipelineOutputIndices += ` ${snake}_start = output_index\n` pipelineOutputIndices += ` output_index += len(${snake})\n` pipelineOutputIndices += ` ${snake}_end = output_index\n` } else { pipelineOutputIndices += ` ${snake}_index = output_index\n` pipelineOutputIndices += ` output_index += 1\n` } } else if (!interfaceType.includes('File')) { pipelineOutputIndices += ` ${snake}_index = output_index\n` pipelineOutputIndices += ` output_index += 1\n` } } }) } let pipelineInputs = '' interfaceJson['inputs'].forEach((input) => { if (interfaceJsonTypeToInterfaceType.has(input.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get(input.type) switch (interfaceType) { case 'TextFile': case 'BinaryFile': pipelineInputs += ` PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(PurePosixPath(${snakeCase(input.name)}))),\n` break case 'TextStream': case 'BinaryStream': pipelineInputs += ` PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(${snakeCase(input.name)})),\n` break default: pipelineInputs += ` PipelineInput(InterfaceTypes.${interfaceType}, ${snakeCase(input.name)}),\n` } } }) let args = ` args: List[str] = ['--memory-io',]\n` let inputCount = 0 args += ' # Inputs\n' interfaceJson.inputs.forEach((input) => { const snakeName = snakeCase(input.name) if (interfaceJsonTypeToInterfaceType.has(input.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get(input.type) if (interfaceType.includes('File')) { args += ` if not Path(${snakeName}).exists():\n` args += ` raise FileNotFoundError("${snakeName} does not exist")\n` } const name = interfaceType.includes('File') ? `str(PurePosixPath(${snakeName}))` : `'${inputCount.toString()}'` args += ` args.append(${name})\n` inputCount++ } else { args += ` args.append(str(${snakeName}))\n` } }) let outputCount = 0 args += ' # Outputs\n' interfaceJson.outputs.forEach((output) => { const snake = snakeCase(output.name) if (interfaceJsonTypeToInterfaceType.has(output.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get(output.type) const isArray = output.itemsExpectedMax > 1 let name = ` ${snake}_name = '${outputCount.toString()}'\n` if (interfaceType.includes('File')) { if (isArray) { name = '' } else { name = ` ${snake}_name = str(PurePosixPath(${snake}))\n` } } args += name if (isArray) { args += ` args.extend([str(PurePosixPath(p)) for p in ${snake}])\n` } else { args += ` args.append(${snake}_name)\n` } args += '\n' outputCount++ } else { args += ` args.append(str(${snake}))\n` } }) args += ' # Options\n' args += ` input_count = len(pipeline_inputs)\n` interfaceJson.parameters.forEach((parameter) => { if (parameter.name === 'memory-io' || parameter.name === 'version') { // Internal return } const snake = snakeCase(parameter.name) if (parameter.type === 'BOOL') { args += ` if ${snake}:\n` args += ` args.append('--${parameter.name}')\n` } else if (parameter.itemsExpectedMax > 1) { const notNone = parameter.required ? '' : `${snake} is not None and ` args += ` if ${notNone}len(${snake}) < ${parameter.itemsExpectedMin}:\n` args += ` raise ValueError('"${parameter.name}" kwarg must have a length > ${parameter.itemsExpectedMin}')\n` args += ` if ${notNone}len(${snake}) > 0:\n` args += ` args.append('--${parameter.name}')\n` args += ` for value in ${snake}:\n` if (interfaceJsonTypeToInterfaceType.has(parameter.type)) { const interfaceType = interfaceJsonTypeToInterfaceType.get( parameter.type ) if (interfaceType.includes('File')) { // for files args += ` input_file = str(PurePosixPath(value))\n` args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(value)))\n` args += ` args.append(input_file)\n` } else if (interfaceType.includes('Stream')) { // for streams args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(value)))\n` args += ` args.append(str(input_count))\n` args += ` input_count += 1\n` } else { // Image, Mesh, PolyData, JsonCompatible args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, value))\n` args += ` args.append(str(input_count))\n` args += ` input_count += 1\n` } } else { if (parameter.type.startsWith('TEXT:{')) { const choices = parameter.type.split('{')[1].split('}')[0].split(',') args += ` if ${snake} not in (${choices.map((c) => `'${c}'`).join(', ')}):\n` args += ` raise ValueError(f'${snake} must be one of ${choices.join(', ')}')\n` } args += ` args.append(str(value))\n` } } else { if (interfaceJsonTypeToInterfaceType.has(parameter.type)) { args += ` if ${snake} is not None:\n` const interfaceType = interfaceJsonTypeToInterfaceType.get( parameter.type ) if (interfaceType.includes('File')) { // for files args += ` input_file = str(PurePosixPath(${snakeCase(parameter.name)}))\n` args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(${snake})))\n` args += ` args.append('--${parameter.name}')\n` args += ` args.append(input_file)\n` } else if (interfaceType.includes('Stream')) { // for streams args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${interfaceType}(${snake})))\n` args += ` args.append('--${parameter.name}')\n` args += ` args.append(str(input_count))\n` args += ` input_count += 1\n` } else { // Image, Mesh, PolyData, JsonCompatible args += ` pipeline_inputs.append(PipelineInput(InterfaceTypes.${interfaceType}, ${snake}))\n` args += ` args.append('--${parameter.name}')\n` args += ` args.append(str(input_count))\n` args += ` input_count += 1\n` } } else { args += ` if ${snake}:\n` if (parameter.type.startsWith('TEXT:{')) { const choices = parameter.type.split('{')[1].split('}')[0].split(', ') args += ` if ${snake} not in (${choices.map((c) => `'${c}'`).join(',')}):\n` args += ` raise ValueError(f'${snake} must be one of ${choices.join(', ')}')\n` } args += ` args.append('--${parameter.name}')\n` args += ` args.append(str(${snake}))\n` } } args += `\n` }) let postOutput = '' function toPythonType(type, value) { const canonical = canonicalType(type) const pythonType = interfaceJsonTypeToPythonType.get(canonical) switch (pythonType) { case 'os.PathLike': return `Path(${value}.data.path)` case 'str': if (type === 'TEXT') { return `${value}` } else { return `${value}.data.data` } case 'bytes': return `${value}.data.data` case 'int': return `int(${value})` case 'bool': return `bool(${value})` case 'float': return `float(${value})` case 'Any': return `${value}.data` default: return `${value}.data` } } outputCount = 0 const jsonOutputs = interfaceJson['outputs'] const numOutputs = interfaceJson.outputs.filter( (o) => !o.type.includes('FILE') ).length if (numOutputs > 1) { postOutput += ' result = (\n' } else if (numOutputs === 1) { postOutput = ' result = ' } if (numOutputs > 0) { // const outputValue = "outputs[0]" // postOutput = ` result = ${toPythonType(jsonOutputs[0].type, outputValue)}\n` const indent = numOutputs > 1 ? ' ' : '' const comma = numOutputs > 1 ? ',' : '' jsonOutputs.forEach((value) => { if (value.type.includes('FILE')) { outputCount++ return } const snake = snakeCase(value.name) const outputIndex = haveArray ? `${snake}_index` : outputCount.toString() if (haveArray) { const isArray = value.itemsExpectedMax > 1 if (isArray) { const outputValue = `outputs[${snake}_start:${snake}_end]` postOutput += `${indent}*[${toPythonType(value.type, 'v')} for v in ${outputValue}]${comma}\n` } else { const outputValue = `outputs[${snake}_index]` postOutput += `${indent}${toPythonType(value.type, outputValue)}${comma}\n` } } else { const outputValue = `outputs[${outputIndex}]` postOutput += `${indent}${toPythonType(value.type, outputValue)}${comma}\n` } outputCount++ }) } if (numOutputs > 1) { postOutput += ' )\n' } if (numOutputs > 0) { postOutput += ' return result\n' } moduleContent += `def ${functionName}( ${functionArgs}) -> ${returnType}: ${docstring} global _pipeline if _pipeline is None: _pipeline = Pipeline(file_resources('${pypackage}').joinpath(Path('wasm_modules') / Path('${interfaceJson.name}.wasi.wasm'))) ${pipelineOutputFilePrep} pipeline_outputs: List[PipelineOutput] = [ ${pipelineOutputs} ] ${pipelineOutputIndices} pipeline_inputs: List[PipelineInput] = [ ${pipelineInputs} ] ${args} outputs = _pipeline.run(args, pipeline_outputs, pipeline_inputs) ${postOutput} ` writeIfOverrideNotPresent(modulePath, moduleContent, '#') } export default wasiFunctionModule