@bigin/ns-firebase
Version:
139 lines (138 loc) • 6.42 kB
JavaScript
import { ImageSource, knownFolders } from '@nativescript/core';
import { getLabelsFromAppFolder, MLKitCustomModel as MLKitCustomModelBase } from './custommodel-common';
const gmsTasks = com.google.android.gms.tasks;
export class MLKitCustomModel extends MLKitCustomModelBase {
createDetector() {
this.detector = getInterpreter(this.localModelFile);
return this.detector;
}
runDetector(imageByteBuffer, previewWidth, previewHeight) {
if (this.detectorBusy) {
return;
}
this.detectorBusy = true;
if (!this.onFailureListener) {
this.onFailureListener = new gmsTasks.OnFailureListener({
onFailure: (exception) => {
console.log(exception.getMessage());
this.detectorBusy = false;
},
});
}
const modelExpectsWidth = this.modelInputShape[1];
const modelExpectsHeight = this.modelInputShape[2];
const isQuantized = this.modelInputType !== 'FLOAT32';
if (!this.inputOutputOptions) {
let intArrayIn = Array.create('int', 4);
intArrayIn[0] = this.modelInputShape[0];
intArrayIn[1] = modelExpectsWidth;
intArrayIn[2] = modelExpectsHeight;
intArrayIn[3] = this.modelInputShape[3];
const inputType = isQuantized ? com.google.firebase.ml.custom.FirebaseModelDataType.BYTE : com.google.firebase.ml.custom.FirebaseModelDataType.FLOAT32;
let intArrayOut = Array.create('int', 2);
intArrayOut[0] = 1;
intArrayOut[1] = this.labels.length;
this.inputOutputOptions = new com.google.firebase.ml.custom.FirebaseModelInputOutputOptions.Builder().setInputFormat(0, inputType, intArrayIn).setOutputFormat(0, inputType, intArrayOut).build();
}
const input = org.nativescript.plugins.firebase.mlkit.BitmapUtil.byteBufferToByteBuffer(imageByteBuffer, previewWidth, previewHeight, modelExpectsWidth, modelExpectsHeight, isQuantized);
const inputs = new com.google.firebase.ml.custom.FirebaseModelInputs.Builder().add(input).build();
this.detector.run(inputs, this.inputOutputOptions).addOnSuccessListener(this.onSuccessListener).addOnFailureListener(this.onFailureListener);
}
createSuccessListener() {
this.onSuccessListener = new gmsTasks.OnSuccessListener({
onSuccess: (output) => {
const probabilities = output.getOutput(0)[0];
if (this.labels.length !== probabilities.length) {
console.log(`The number of labels (${this.labels.length}) is not equal to the interpretation result (${probabilities.length})!`);
return;
}
const result = {
result: getSortedResult(this.labels, probabilities, this.maxResults),
};
this.notify({
eventName: MLKitCustomModel.scanResultEvent,
object: this,
value: result,
});
this.detectorBusy = false;
},
});
return this.onSuccessListener;
}
}
const registeredModels = [];
function getInterpreter(localModelFile) {
const localModelName = localModelFile.lastIndexOf('/') === -1 ? localModelFile : localModelFile.substring(localModelFile.lastIndexOf('/') + 1);
let localModelRegistrationSuccess = false;
if (localModelFile) {
const localModelBuilder = new com.google.firebase.ml.custom.FirebaseCustomLocalModel.Builder();
if (localModelFile.indexOf('~/') === 0) {
localModelBuilder.setFilePath(knownFolders.currentApp().path + localModelFile.substring(1));
} else {
localModelBuilder.setAssetFilePath(localModelFile);
}
const firModelOptions = new com.google.firebase.ml.custom.FirebaseModelInterpreterOptions.Builder(localModelBuilder.build()).build();
return com.google.firebase.ml.custom.FirebaseModelInterpreter.getInstance(firModelOptions);
}
return null;
}
export function useCustomModel(options) {
return new Promise((resolve, reject) => {
try {
const interpreter = getInterpreter(options.localModelFile);
let labels;
if (options.labelsFile.indexOf('~/') === 0) {
labels = getLabelsFromAppFolder(options.labelsFile);
} else {
reject('Use the ~/ prefix for now..');
return;
}
const onSuccessListener = new gmsTasks.OnSuccessListener({
onSuccess: (output) => {
const probabilities = output.getOutput(0)[0];
if (labels.length !== probabilities.length) {
console.log(`The number of labels in ${options.labelsFile} (${labels.length}) is not equal to the interpretation result (${probabilities.length})!`);
return;
}
const result = {
result: getSortedResult(labels, probabilities, options.maxResults),
};
resolve(result);
interpreter.close();
},
});
const onFailureListener = new gmsTasks.OnFailureListener({
onFailure: (exception) => reject(exception.getMessage()),
});
let intArrayIn = Array.create('int', 4);
intArrayIn[0] = options.modelInput[0].shape[0];
intArrayIn[1] = options.modelInput[0].shape[1];
intArrayIn[2] = options.modelInput[0].shape[2];
intArrayIn[3] = options.modelInput[0].shape[3];
const isQuantized = options.modelInput[0].type !== 'FLOAT32';
const inputType = isQuantized ? com.google.firebase.ml.custom.FirebaseModelDataType.BYTE : com.google.firebase.ml.custom.FirebaseModelDataType.FLOAT32;
let intArrayOut = Array.create('int', 2);
intArrayOut[0] = 1;
intArrayOut[1] = labels.length;
const inputOutputOptions = new com.google.firebase.ml.custom.FirebaseModelInputOutputOptions.Builder().setInputFormat(0, inputType, intArrayIn).setOutputFormat(0, inputType, intArrayOut).build();
const image = options.image instanceof ImageSource ? options.image.android : options.image.imageSource.android;
const input = org.nativescript.plugins.firebase.mlkit.BitmapUtil.bitmapToByteBuffer(image, options.modelInput[0].shape[1], options.modelInput[0].shape[2], isQuantized);
const inputs = new com.google.firebase.ml.custom.FirebaseModelInputs.Builder().add(input).build();
interpreter.run(inputs, inputOutputOptions).addOnSuccessListener(onSuccessListener).addOnFailureListener(onFailureListener);
} catch (ex) {
console.log('Error in firebase.mlkit.useCustomModel: ' + ex);
reject(ex);
}
});
}
function getSortedResult(labels, probabilities, maxResults = 5) {
const result = [];
labels.forEach((text, i) => result.push({ text, confidence: probabilities[i] }));
result.sort((a, b) => (a.confidence < b.confidence ? 1 : a.confidence === b.confidence ? 0 : -1));
if (result.length > maxResults) {
result.splice(maxResults);
}
result.map((r) => (r.confidence = (r.confidence & 0xff) / 255.0));
return result;
}
//# sourceMappingURL=index.android.js.map