UNPKG

react-native-tflite-classification

Version:

Run TensorFlow Lite models in React Native Android apps

289 lines (267 loc) 10.4 kB
/** * this contianer is shown to take or select a picture, and then classify the picture. */ import React from 'react' import { StyleSheet, Text, TouchableOpacity, View, Image, ScrollView, BackHandler } from 'react-native' import { Icon } from 'react-native-elements'; import * as ImagePicker from 'expo-image-picker'; import { ListItem } from 'react-native-elements'; const RNFS = require('react-native-fs'); import { Tflite } from 'react-native-tflite-classification'; let tflite = new Tflite() export default class ModelUse extends React.Component { componentDidMount = async () => { // check if this is the first time app is being opened. If it is, // move the starter model and image from android assets folder to internal // storage if (!(await RNFS.exists(RNFS.DocumentDirectoryPath + '/Model'))) { await RNFS.mkdir(RNFS.DocumentDirectoryPath + '/Model') await RNFS.copyFileAssets('Model/graph.lite', RNFS.DocumentDirectoryPath + '/Model/graph.lite') await RNFS.copyFileAssets('Model/labels.txt', RNFS.DocumentDirectoryPath + '/Model/labels.txt') await RNFS.copyFileAssets('example.jpg', RNFS.DocumentDirectoryPath + '/example.jpg') } // start with example photo to clasify this.setState(() => ({ image: 'file://' + RNFS.DocumentDirectoryPath + '/example.jpg' })) // load model tflite.loadModel({ modelPath: '/Model/graph.lite', labelsPath: '/Model/labels.txt' }, (err, res) => { if (err) console.log(err); else console.log(res); } ); } closeApp = async () => { // Releases all resources used by the model on the native side of things. tflite.close() BackHandler.exitApp(); } takePicture = async () => { /** * Asks the user for permission to use the camera, has the user * take a picture, and makes sure the picture has a square aspect * ratio so that it feeds in to the models correctly */ const permissionResult = await ImagePicker.requestCameraPermissionsAsync(); if (permissionResult.granted === false) { alert('Sorry, camera permissions are needed!'); return } const result = await ImagePicker.launchCameraAsync({ allowsEditing: true, aspect: [4, 4] }); if (result.assets) { this.setState({ image: result.assets[0].uri }); } }; selectPicture = async () => { /** * Asks the user for permission to obtain an image from the gallery, * and makes sure the image has a square aspect ration through cropping */ const permissionResult = await ImagePicker.requestCameraPermissionsAsync(); if (permissionResult.granted === false) { alert('Sorry, camera permissions are needed!'); return } const result = await ImagePicker.launchImageLibraryAsync({ allowsEditing: true, aspect: [4, 4] }); if (result.assets) { this.setState({ image: result.assets[0].uri }); } }; classifyPicture = () => { /** * If the user has a picture selected, classify it using the selected tflite * model. */ if (this.state.image != null) { // run the image against the loaded model tflite.runModelOnImage({ path: this.state.image, numResults: 10, threshold: 0 }, (err, res) => { if (err) console.log(err + '\n' + res); else { this.setState(() => ( { results: res } )) } }); } else { alert('Please first take an image or select an image from your gallery!'); } }; state = { results: [], image: null, imageSize: 0, iconSize: 0, } find_image_dimesions(layout) { /** * finds the dimensions of the current view to size image correctly */ const { x, y, width, height } = layout; this.setState(() => ( { imageSize: width > height ? height : width } )) } find_icon_dimesions(layout) { /** * finds the dimensions of the current view to size icon in buttons correctly */ const { x, y, width, height } = layout; this.setState(() => ( { iconSize: (width > height ? height : width) - 25 } )) } render() { return ( <View style={styles.container}> <View style={styles.imageContainer} onLayout={(event) => { this.find_image_dimesions(event.nativeEvent.layout) }}> <Image style={[{ height: this.state.imageSize, width: this.state.imageSize }, styles.image]} source={this.state.image != null ? { uri: this.state.image } : require('./assets/noImage.png')} /> </View> <View style={styles.buttonContainer}> <View style={styles.buttonAndLabel}> <View style={styles.singleButtonContainer}> <TouchableOpacity style={styles.circleButton} onPress={this.closeApp} > <Icon name={'exit-to-app'} size={this.state.iconSize} color='#01a699' /> </TouchableOpacity> </View> <View style={styles.buttonLabel}> <Text>{'Exit'}</Text> </View> </View> <View style={styles.buttonAndLabel}> <View style={styles.singleButtonContainer}> <TouchableOpacity style={styles.circleButton} onPress={this.takePicture} > <Icon name={'camera-alt'} size={this.state.iconSize} color='#01a699' /> </TouchableOpacity> </View> <View style={styles.buttonLabel}> <Text>{'Take Image'}</Text> </View> </View> <View style={styles.buttonAndLabel}> <View style={styles.singleButtonContainer}> <TouchableOpacity style={styles.circleButton} onPress={this.selectPicture} > <Icon name={'photo-album'} size={this.state.iconSize} color='#01a699' /> </TouchableOpacity> </View> <View style={styles.buttonLabel}> <Text>{'Select Image'}</Text> </View> </View> <View style={styles.buttonAndLabel}> <View style={styles.singleButtonContainer}> <TouchableOpacity onLayout={(event) => { this.find_icon_dimesions(event.nativeEvent.layout) }} style={styles.circleButton} onPress={this.classifyPicture} > <Icon name={'check'} size={this.state.iconSize} color='#01a699' /> </TouchableOpacity> </View> <View style={styles.buttonLabel}> <Text>{'Classify Image'}</Text> </View> </View> </View> <View style={styles.classifierList}> <View style={styles.insideList}> <ScrollView> { this.state.results.map((l, i) => ( <ListItem key={i} topDivider bottomDivider> <ListItem.Content> <ListItem.Title>{(i + 1).toString() + '. ' + l.label}</ListItem.Title> </ListItem.Content> <ListItem.Content right> <ListItem.Title right>{'Confidence: ' + (l.confidence * 100).toFixed(2) + '%'}</ListItem.Title> </ListItem.Content> </ListItem> )) } </ScrollView> </View> </View> </View> ) } } const styles = StyleSheet.create({ container: { flex: 1 }, imageContainer: { flex: 3, justifyContent: 'center', alignItems: 'center', margin: 30 }, image: { backgroundColor: 'gray', borderRadius: 10 }, buttonLabel: { justifyContent: 'center', alignItems: 'center' }, buttonContainer: { flexDirection: 'row', margin: 10, flex: 1 }, buttonAndLabel: { flex: 1 }, classifierList: { flex: 2 }, insideList: { flex: 1, margin: 20, borderColor: 'black', borderWidth: 1, backgroundColor: 'white', borderRadius: 20, overflow: 'hidden' }, singleButtonContainer: { justifyContent: 'center', alignItems: 'center' }, circleButton: { borderWidth: 1, borderColor: 'rgba(0,0,0,0.2)', alignItems: 'center', justifyContent: 'center', height: '85%', aspectRatio: 1, backgroundColor: '#fff', borderRadius: 50, } })