UNPKG

alwaysai

Version:

The alwaysAI command-line interface (CLI)

410 lines (392 loc) 12.2 kB
import * as logSymbols from 'log-symbols'; import { CliFlagInput, CliLeaf, CliNumberArrayInput, CliNumberInput, CliOneOfInput, CliStringArrayInput, CliStringInput, CliTerseError, CliUsageError } from '@alwaysai/alwayscli'; import { echo, keyMirror, logger } from '../../util'; import { ModelJson, ModelJsonParameters, modelPurposeValues, HailoArchitecture, hailoArchitectureEnum, hailoArchitectureValues, HailoFormat, hailoFormatValues, HailoPurpose, hailoPurposeEnum, modelFrameworkValues, OnnxArchitectureObjectDetection, onnxArchitectureObjectDetectionEnum, onnxArchitectureObjectDetectionValues, TensorrtArchitectureObjectDetection, tensorrtArchitectureObjectDetectionEnum, tensorrtArchitectureObjectDetectionValues, TensorrtDevice, tensorrtDeviceEnum, tensorrtDeviceValues, hailoPurposeValues, validateModel } from '@alwaysai/model-configuration-schemas'; import { ModelPackageJsonFile } from '../../core/model/model-package-json-file'; import { CliAuthenticationClient } from '../../infrastructure'; const ModelDeviceValues = Object.values(keyMirror(tensorrtDeviceEnum)); const ModelArchitectureValues = Object.values( keyMirror({ ...tensorrtArchitectureObjectDetectionEnum, ...hailoArchitectureEnum, ...onnxArchitectureObjectDetectionEnum }) ); export const modelConfigure = CliLeaf({ name: 'configure', description: 'Generate or modify the model configuration of the application.', namedInputs: { name: CliStringInput({ description: 'The model name to be used to generate the model ID', required: true }), // Required fields for all frameworks framework: CliOneOfInput({ description: '', required: true, values: modelFrameworkValues }), model_file: CliStringInput({ description: 'Path to model binary file.' }), mean: CliNumberArrayInput({ description: 'The average pixel intensity in the red, green, and blue channels of the training dataset.' }), scalefactor: CliNumberInput({ description: 'Factor to scale pixel intensities by.' }), size: CliNumberArrayInput({ description: 'The input size of the neural network.' }), purpose: CliOneOfInput({ description: 'Computer vision purpose of the model.', required: true, values: modelPurposeValues }), crop: CliFlagInput({ description: 'Crop before resize?' }), // Optional fields config_file: CliStringInput({ description: 'Path to model structure.' }), label_file: CliStringInput({ description: 'File containing labels for each class index.' }), colors_file: CliStringInput({ description: 'File containing colors to be used by each class index.' }), swaprb: CliFlagInput({ description: 'Swap red and blue channels after image blob generation' }), softmax: CliFlagInput({ description: 'Apply softmax to the output of the neural network? Boolean true/false' }), batch_size: CliNumberInput({ description: 'The inference batch size of the model' }), output_layer_names: CliStringArrayInput({ description: 'List of output layers provided in advance', placeholder: '<>' }), device: CliOneOfInput({ description: 'Define the device model is intended to be used on.', values: ModelDeviceValues }), architecture: CliOneOfInput({ description: 'Define the architecture type intended to be used.', values: ModelArchitectureValues }), quantize_input: CliFlagInput({ description: 'Quantize input? Boolean true/false.' }), quantize_output: CliFlagInput({ description: 'Quantize output? Boolean true/false.' }), input_format: CliOneOfInput({ description: 'Define the input format of the data.', values: hailoFormatValues }), output_format: CliOneOfInput({ description: 'Define the output format of the data.', values: hailoFormatValues }) }, async action(_, opts) { const { name, framework, model_file, config_file, mean, scalefactor, size, purpose, crop, label_file, colors_file, swaprb, softmax, batch_size, output_layer_names, device, architecture, quantize_input, quantize_output, input_format, output_format } = opts; let modelParameters: ModelJsonParameters; switch (framework) { case 'tensorflow': { modelParameters = { framework_type: 'tensorflow', model_file: model_file || '', label_file: label_file || '', mean: mean || [0, 0, 0], scalefactor: scalefactor || 1, size: size || [300, 300], purpose, crop, config_file: config_file || '', colors_file: colors_file || '', swaprb, softmax }; break; } case 'caffe': { modelParameters = { framework_type: 'caffe', config_file: config_file || '', size: size || [300, 300], model_file: model_file || '', label_file: label_file || '', scalefactor: scalefactor || 1, mean: mean || [0, 0, 0], crop, swaprb, softmax, purpose, output_layer_names: output_layer_names || ['', ''] }; break; } case 'enet': { modelParameters = { framework_type: 'enet', size: size || [300, 300], model_file: model_file || '', label_file: label_file || '', colors_file: colors_file || '', scalefactor: scalefactor || 1, mean: mean || [0, 0, 0], crop, swaprb, purpose }; break; } case 'darknet': { modelParameters = { framework_type: 'darknet', config_file: config_file || '', size: size || [300, 300], model_file: model_file || '', label_file: label_file || '', colors_file: colors_file || '', scalefactor: scalefactor || 1, mean: mean || [0, 0, 0], crop, swaprb, purpose, output_layer_names: output_layer_names || null }; break; } case 'onnx': { if (purpose === 'ObjectDetection') { if ( architecture && !(architecture in onnxArchitectureObjectDetectionEnum) ) { throw new CliUsageError( `Architecture not supported! (${onnxArchitectureObjectDetectionValues})` ); } modelParameters = { framework_type: 'onnx', size: size || [300, 300], model_file: model_file || '', label_file: label_file || '', colors_file: colors_file || '', scalefactor: scalefactor || 1, crop, swaprb, purpose, mean: mean || [0, 0, 0], output_layer_names: output_layer_names || null, architecture: architecture as OnnxArchitectureObjectDetection }; } else { // Purpose other than ObjectDetection if (architecture) { throw new CliUsageError( `Parameter --architecture not supported for purpose ${purpose}` ); } modelParameters = { framework_type: 'onnx', size: size || [300, 300], model_file: model_file || '', label_file: label_file || '', colors_file: colors_file || '', scalefactor: scalefactor || 1, crop, swaprb, purpose, mean: mean || [0, 0, 0], output_layer_names: output_layer_names || null }; } if (batch_size) { modelParameters.batch_size = batch_size; } break; } case 'tensor-rt': { if (!batch_size) { throw new CliUsageError(`Parameter --batch_size required!`); } if (device && !(device in tensorrtDeviceEnum)) { throw new CliUsageError( `Device not supported! (${tensorrtDeviceValues})` ); } if (purpose === 'ObjectDetection') { if ( architecture && !(architecture in tensorrtArchitectureObjectDetectionEnum) ) { throw new CliUsageError( `Architecture not supported! (${tensorrtArchitectureObjectDetectionValues})` ); } modelParameters = { framework_type: 'tensor-rt', size: size || [300, 300], model_file: model_file || '', label_file: label_file || '', scalefactor: scalefactor || 1, mean: mean || [0, 0, 0], crop, swaprb, purpose, batch_size, colors_file: colors_file || '', device: device as TensorrtDevice, architecture: architecture as TensorrtArchitectureObjectDetection }; } else { // Purpose other than ObjectDetection if (architecture) { throw new CliUsageError( `Parameter --architecture not supported for purpose ${purpose}` ); } modelParameters = { framework_type: 'tensor-rt', size: size || [300, 300], model_file: model_file || '', label_file: label_file || '', scalefactor: scalefactor || 1, mean: mean || [0, 0, 0], crop, swaprb, purpose, batch_size, colors_file: colors_file || '', device: device as TensorrtDevice }; } break; } case 'hailo': { if (!architecture || !(architecture in hailoArchitectureEnum)) { throw new CliUsageError( `Parameter --architecture required! (${hailoArchitectureValues})` ); } if (!purpose || !(purpose in hailoPurposeEnum)) { throw new CliUsageError( `Parameter --purpose required! (${hailoPurposeValues})` ); } modelParameters = { framework_type: 'hailo', architecture: architecture as HailoArchitecture, quantize_input: quantize_input || true, quantize_output: quantize_output || true, input_format: (input_format as HailoFormat) || 'auto', output_format: (output_format as HailoFormat) || 'auto', size: size || [300, 300], model_file: model_file || '', label_file: label_file || '', purpose: purpose as HailoPurpose, crop, swaprb, mean: mean || [0, 0, 0], scalefactor: scalefactor || 1 }; break; } default: { throw new Error('Unsupported framework.'); } } const { username } = await CliAuthenticationClient().getInfo(); const newModel: ModelJson = { accuracy: '', dataset: '', description: '', id: `${username}/${name}`, inference_time: null, license: '', mean_average_precision_top_1: null, mean_average_precision_top_5: null, public: false, website_url: '', model_parameters: modelParameters }; validateModel(newModel); if (validateModel.errors) { echo(JSON.stringify(validateModel.errors, _, 2)); throw new CliTerseError('Model package contents are invalid!'); } const message = `Write alwaysai.model.json file`; const modelPkg = ModelPackageJsonFile(process.cwd()); try { modelPkg.write(newModel); echo(`${logSymbols.success} ${message}`); } catch (exception) { echo(`${logSymbols.error} ${message}`); logger.error(exception); throw new CliTerseError(`Failed to write model package! ${exception}`); } } });