react-native-tflite-classification
Version:
Run TensorFlow Lite models in React Native Android apps
289 lines (267 loc) • 10.4 kB
JavaScript
/**
* this contianer is shown to take or select a picture, and then classify the picture.
*/
import React from 'react'
import { StyleSheet, Text, TouchableOpacity, View, Image, ScrollView, BackHandler } from 'react-native'
import { Icon } from 'react-native-elements';
import * as ImagePicker from 'expo-image-picker';
import { ListItem } from 'react-native-elements';
const RNFS = require('react-native-fs');
import { Tflite } from 'react-native-tflite-classification';
let tflite = new Tflite()
export default class ModelUse extends React.Component {
componentDidMount = async () => {
// check if this is the first time app is being opened. If it is,
// move the starter model and image from android assets folder to internal
// storage
if (!(await RNFS.exists(RNFS.DocumentDirectoryPath + '/Model'))) {
await RNFS.mkdir(RNFS.DocumentDirectoryPath + '/Model')
await RNFS.copyFileAssets('Model/graph.lite', RNFS.DocumentDirectoryPath + '/Model/graph.lite')
await RNFS.copyFileAssets('Model/labels.txt', RNFS.DocumentDirectoryPath + '/Model/labels.txt')
await RNFS.copyFileAssets('example.jpg', RNFS.DocumentDirectoryPath + '/example.jpg')
}
// start with example photo to clasify
this.setState(() => ({ image: 'file://' + RNFS.DocumentDirectoryPath + '/example.jpg' }))
// load model
tflite.loadModel({
modelPath: '/Model/graph.lite',
labelsPath: '/Model/labels.txt'
},
(err, res) => {
if (err)
console.log(err);
else
console.log(res);
}
);
}
closeApp = async () => {
// Releases all resources used by the model on the native side of things.
tflite.close()
BackHandler.exitApp();
}
takePicture = async () => {
/**
* Asks the user for permission to use the camera, has the user
* take a picture, and makes sure the picture has a square aspect
* ratio so that it feeds in to the models correctly
*/
const permissionResult = await ImagePicker.requestCameraPermissionsAsync();
if (permissionResult.granted === false) {
alert('Sorry, camera permissions are needed!');
return
}
const result = await ImagePicker.launchCameraAsync({
allowsEditing: true,
aspect: [4, 4]
});
if (result.assets) {
this.setState({ image: result.assets[0].uri });
}
};
selectPicture = async () => {
/**
* Asks the user for permission to obtain an image from the gallery,
* and makes sure the image has a square aspect ration through cropping
*/
const permissionResult = await ImagePicker.requestCameraPermissionsAsync();
if (permissionResult.granted === false) {
alert('Sorry, camera permissions are needed!');
return
}
const result = await ImagePicker.launchImageLibraryAsync({
allowsEditing: true,
aspect: [4, 4]
});
if (result.assets) {
this.setState({ image: result.assets[0].uri });
}
};
classifyPicture = () => {
/**
* If the user has a picture selected, classify it using the selected tflite
* model.
*/
if (this.state.image != null) {
// run the image against the loaded model
tflite.runModelOnImage({
path: this.state.image,
numResults: 10,
threshold: 0
},
(err, res) => {
if (err)
console.log(err + '\n' + res);
else {
this.setState(() => (
{ results: res }
))
}
});
} else {
alert('Please first take an image or select an image from your gallery!');
}
};
state = {
results: [],
image: null,
imageSize: 0,
iconSize: 0,
}
find_image_dimesions(layout) {
/**
* finds the dimensions of the current view to size image correctly
*/
const { x, y, width, height } = layout;
this.setState(() => (
{ imageSize: width > height ? height : width }
))
}
find_icon_dimesions(layout) {
/**
* finds the dimensions of the current view to size icon in buttons correctly
*/
const { x, y, width, height } = layout;
this.setState(() => (
{ iconSize: (width > height ? height : width) - 25 }
))
}
render() {
return (
<View style={styles.container}>
<View style={styles.imageContainer} onLayout={(event) => { this.find_image_dimesions(event.nativeEvent.layout) }}>
<Image
style={[{ height: this.state.imageSize, width: this.state.imageSize }, styles.image]}
source={this.state.image != null ? { uri: this.state.image } : require('./assets/noImage.png')}
/>
</View>
<View style={styles.buttonContainer}>
<View style={styles.buttonAndLabel}>
<View style={styles.singleButtonContainer}>
<TouchableOpacity
style={styles.circleButton}
onPress={this.closeApp}
>
<Icon name={'exit-to-app'} size={this.state.iconSize} color='#01a699' />
</TouchableOpacity>
</View>
<View style={styles.buttonLabel}>
<Text>{'Exit'}</Text>
</View>
</View>
<View style={styles.buttonAndLabel}>
<View style={styles.singleButtonContainer}>
<TouchableOpacity
style={styles.circleButton}
onPress={this.takePicture}
>
<Icon name={'camera-alt'} size={this.state.iconSize} color='#01a699' />
</TouchableOpacity>
</View>
<View style={styles.buttonLabel}>
<Text>{'Take Image'}</Text>
</View>
</View>
<View style={styles.buttonAndLabel}>
<View style={styles.singleButtonContainer}>
<TouchableOpacity
style={styles.circleButton}
onPress={this.selectPicture}
>
<Icon name={'photo-album'} size={this.state.iconSize} color='#01a699' />
</TouchableOpacity>
</View>
<View style={styles.buttonLabel}>
<Text>{'Select Image'}</Text>
</View>
</View>
<View style={styles.buttonAndLabel}>
<View style={styles.singleButtonContainer}>
<TouchableOpacity
onLayout={(event) => { this.find_icon_dimesions(event.nativeEvent.layout) }}
style={styles.circleButton}
onPress={this.classifyPicture}
>
<Icon name={'check'} size={this.state.iconSize} color='#01a699' />
</TouchableOpacity>
</View>
<View style={styles.buttonLabel}>
<Text>{'Classify Image'}</Text>
</View>
</View>
</View>
<View style={styles.classifierList}>
<View style={styles.insideList}>
<ScrollView>
{
this.state.results.map((l, i) => (
<ListItem key={i} topDivider bottomDivider>
<ListItem.Content>
<ListItem.Title>{(i + 1).toString() + '. ' + l.label}</ListItem.Title>
</ListItem.Content>
<ListItem.Content right>
<ListItem.Title right>{'Confidence: ' + (l.confidence * 100).toFixed(2) + '%'}</ListItem.Title>
</ListItem.Content>
</ListItem>
))
}
</ScrollView>
</View>
</View>
</View>
)
}
}
const styles = StyleSheet.create({
container: {
flex: 1
},
imageContainer: {
flex: 3,
justifyContent: 'center',
alignItems: 'center',
margin: 30
},
image: {
backgroundColor: 'gray',
borderRadius: 10
},
buttonLabel: {
justifyContent: 'center',
alignItems: 'center'
},
buttonContainer: {
flexDirection: 'row',
margin: 10,
flex: 1
},
buttonAndLabel: {
flex: 1
},
classifierList: {
flex: 2
},
insideList: {
flex: 1,
margin: 20,
borderColor: 'black',
borderWidth: 1,
backgroundColor: 'white',
borderRadius: 20,
overflow: 'hidden'
},
singleButtonContainer: {
justifyContent: 'center',
alignItems: 'center'
},
circleButton: {
borderWidth: 1,
borderColor: 'rgba(0,0,0,0.2)',
alignItems: 'center',
justifyContent: 'center',
height: '85%',
aspectRatio: 1,
backgroundColor: '#fff',
borderRadius: 50,
}
})