alwaysai
Version:
The alwaysAI command-line interface (CLI)
410 lines (392 loc) • 12.2 kB
text/typescript
import * as logSymbols from 'log-symbols';
import {
CliFlagInput,
CliLeaf,
CliNumberArrayInput,
CliNumberInput,
CliOneOfInput,
CliStringArrayInput,
CliStringInput,
CliTerseError,
CliUsageError
} from '@alwaysai/alwayscli';
import { echo, keyMirror, logger } from '../../util';
import {
ModelJson,
ModelJsonParameters,
modelPurposeValues,
HailoArchitecture,
hailoArchitectureEnum,
hailoArchitectureValues,
HailoFormat,
hailoFormatValues,
HailoPurpose,
hailoPurposeEnum,
modelFrameworkValues,
OnnxArchitectureObjectDetection,
onnxArchitectureObjectDetectionEnum,
onnxArchitectureObjectDetectionValues,
TensorrtArchitectureObjectDetection,
tensorrtArchitectureObjectDetectionEnum,
tensorrtArchitectureObjectDetectionValues,
TensorrtDevice,
tensorrtDeviceEnum,
tensorrtDeviceValues,
hailoPurposeValues,
validateModel
} from '@alwaysai/model-configuration-schemas';
import { ModelPackageJsonFile } from '../../core/model/model-package-json-file';
import { CliAuthenticationClient } from '../../infrastructure';
const ModelDeviceValues = Object.values(keyMirror(tensorrtDeviceEnum));
const ModelArchitectureValues = Object.values(
keyMirror({
...tensorrtArchitectureObjectDetectionEnum,
...hailoArchitectureEnum,
...onnxArchitectureObjectDetectionEnum
})
);
export const modelConfigure = CliLeaf({
name: 'configure',
description: 'Generate or modify the model configuration of the application.',
namedInputs: {
name: CliStringInput({
description: 'The model name to be used to generate the model ID',
required: true
}),
// Required fields for all frameworks
framework: CliOneOfInput({
description: '',
required: true,
values: modelFrameworkValues
}),
model_file: CliStringInput({
description: 'Path to model binary file.'
}),
mean: CliNumberArrayInput({
description:
'The average pixel intensity in the red, green, and blue channels of the training dataset.'
}),
scalefactor: CliNumberInput({
description: 'Factor to scale pixel intensities by.'
}),
size: CliNumberArrayInput({
description: 'The input size of the neural network.'
}),
purpose: CliOneOfInput({
description: 'Computer vision purpose of the model.',
required: true,
values: modelPurposeValues
}),
crop: CliFlagInput({
description: 'Crop before resize?'
}),
// Optional fields
config_file: CliStringInput({
description: 'Path to model structure.'
}),
label_file: CliStringInput({
description: 'File containing labels for each class index.'
}),
colors_file: CliStringInput({
description: 'File containing colors to be used by each class index.'
}),
swaprb: CliFlagInput({
description: 'Swap red and blue channels after image blob generation'
}),
softmax: CliFlagInput({
description:
'Apply softmax to the output of the neural network? Boolean true/false'
}),
batch_size: CliNumberInput({
description: 'The inference batch size of the model'
}),
output_layer_names: CliStringArrayInput({
description: 'List of output layers provided in advance',
placeholder: '<>'
}),
device: CliOneOfInput({
description: 'Define the device model is intended to be used on.',
values: ModelDeviceValues
}),
architecture: CliOneOfInput({
description: 'Define the architecture type intended to be used.',
values: ModelArchitectureValues
}),
quantize_input: CliFlagInput({
description: 'Quantize input? Boolean true/false.'
}),
quantize_output: CliFlagInput({
description: 'Quantize output? Boolean true/false.'
}),
input_format: CliOneOfInput({
description: 'Define the input format of the data.',
values: hailoFormatValues
}),
output_format: CliOneOfInput({
description: 'Define the output format of the data.',
values: hailoFormatValues
})
},
async action(_, opts) {
const {
name,
framework,
model_file,
config_file,
mean,
scalefactor,
size,
purpose,
crop,
label_file,
colors_file,
swaprb,
softmax,
batch_size,
output_layer_names,
device,
architecture,
quantize_input,
quantize_output,
input_format,
output_format
} = opts;
let modelParameters: ModelJsonParameters;
switch (framework) {
case 'tensorflow': {
modelParameters = {
framework_type: 'tensorflow',
model_file: model_file || '',
label_file: label_file || '',
mean: mean || [0, 0, 0],
scalefactor: scalefactor || 1,
size: size || [300, 300],
purpose,
crop,
config_file: config_file || '',
colors_file: colors_file || '',
swaprb,
softmax
};
break;
}
case 'caffe': {
modelParameters = {
framework_type: 'caffe',
config_file: config_file || '',
size: size || [300, 300],
model_file: model_file || '',
label_file: label_file || '',
scalefactor: scalefactor || 1,
mean: mean || [0, 0, 0],
crop,
swaprb,
softmax,
purpose,
output_layer_names: output_layer_names || ['', '']
};
break;
}
case 'enet': {
modelParameters = {
framework_type: 'enet',
size: size || [300, 300],
model_file: model_file || '',
label_file: label_file || '',
colors_file: colors_file || '',
scalefactor: scalefactor || 1,
mean: mean || [0, 0, 0],
crop,
swaprb,
purpose
};
break;
}
case 'darknet': {
modelParameters = {
framework_type: 'darknet',
config_file: config_file || '',
size: size || [300, 300],
model_file: model_file || '',
label_file: label_file || '',
colors_file: colors_file || '',
scalefactor: scalefactor || 1,
mean: mean || [0, 0, 0],
crop,
swaprb,
purpose,
output_layer_names: output_layer_names || null
};
break;
}
case 'onnx': {
if (purpose === 'ObjectDetection') {
if (
architecture &&
!(architecture in onnxArchitectureObjectDetectionEnum)
) {
throw new CliUsageError(
`Architecture not supported! (${onnxArchitectureObjectDetectionValues})`
);
}
modelParameters = {
framework_type: 'onnx',
size: size || [300, 300],
model_file: model_file || '',
label_file: label_file || '',
colors_file: colors_file || '',
scalefactor: scalefactor || 1,
crop,
swaprb,
purpose,
mean: mean || [0, 0, 0],
output_layer_names: output_layer_names || null,
architecture: architecture as OnnxArchitectureObjectDetection
};
} else {
// Purpose other than ObjectDetection
if (architecture) {
throw new CliUsageError(
`Parameter --architecture not supported for purpose ${purpose}`
);
}
modelParameters = {
framework_type: 'onnx',
size: size || [300, 300],
model_file: model_file || '',
label_file: label_file || '',
colors_file: colors_file || '',
scalefactor: scalefactor || 1,
crop,
swaprb,
purpose,
mean: mean || [0, 0, 0],
output_layer_names: output_layer_names || null
};
}
if (batch_size) {
modelParameters.batch_size = batch_size;
}
break;
}
case 'tensor-rt': {
if (!batch_size) {
throw new CliUsageError(`Parameter --batch_size required!`);
}
if (device && !(device in tensorrtDeviceEnum)) {
throw new CliUsageError(
`Device not supported! (${tensorrtDeviceValues})`
);
}
if (purpose === 'ObjectDetection') {
if (
architecture &&
!(architecture in tensorrtArchitectureObjectDetectionEnum)
) {
throw new CliUsageError(
`Architecture not supported! (${tensorrtArchitectureObjectDetectionValues})`
);
}
modelParameters = {
framework_type: 'tensor-rt',
size: size || [300, 300],
model_file: model_file || '',
label_file: label_file || '',
scalefactor: scalefactor || 1,
mean: mean || [0, 0, 0],
crop,
swaprb,
purpose,
batch_size,
colors_file: colors_file || '',
device: device as TensorrtDevice,
architecture: architecture as TensorrtArchitectureObjectDetection
};
} else {
// Purpose other than ObjectDetection
if (architecture) {
throw new CliUsageError(
`Parameter --architecture not supported for purpose ${purpose}`
);
}
modelParameters = {
framework_type: 'tensor-rt',
size: size || [300, 300],
model_file: model_file || '',
label_file: label_file || '',
scalefactor: scalefactor || 1,
mean: mean || [0, 0, 0],
crop,
swaprb,
purpose,
batch_size,
colors_file: colors_file || '',
device: device as TensorrtDevice
};
}
break;
}
case 'hailo': {
if (!architecture || !(architecture in hailoArchitectureEnum)) {
throw new CliUsageError(
`Parameter --architecture required! (${hailoArchitectureValues})`
);
}
if (!purpose || !(purpose in hailoPurposeEnum)) {
throw new CliUsageError(
`Parameter --purpose required! (${hailoPurposeValues})`
);
}
modelParameters = {
framework_type: 'hailo',
architecture: architecture as HailoArchitecture,
quantize_input: quantize_input || true,
quantize_output: quantize_output || true,
input_format: (input_format as HailoFormat) || 'auto',
output_format: (output_format as HailoFormat) || 'auto',
size: size || [300, 300],
model_file: model_file || '',
label_file: label_file || '',
purpose: purpose as HailoPurpose,
crop,
swaprb,
mean: mean || [0, 0, 0],
scalefactor: scalefactor || 1
};
break;
}
default: {
throw new Error('Unsupported framework.');
}
}
const { username } = await CliAuthenticationClient().getInfo();
const newModel: ModelJson = {
accuracy: '',
dataset: '',
description: '',
id: `${username}/${name}`,
inference_time: null,
license: '',
mean_average_precision_top_1: null,
mean_average_precision_top_5: null,
public: false,
website_url: '',
model_parameters: modelParameters
};
validateModel(newModel);
if (validateModel.errors) {
echo(JSON.stringify(validateModel.errors, _, 2));
throw new CliTerseError('Model package contents are invalid!');
}
const message = `Write alwaysai.model.json file`;
const modelPkg = ModelPackageJsonFile(process.cwd());
try {
modelPkg.write(newModel);
echo(`${logSymbols.success} ${message}`);
} catch (exception) {
echo(`${logSymbols.error} ${message}`);
logger.error(exception);
throw new CliTerseError(`Failed to write model package! ${exception}`);
}
}
});