Image classification with TensorFlow Lite on Android

As I've already listed in my recent blog post there are lots of advantages of making inference directly on a mobile device instead of using cloud solutions. Because of mobile devices' computation limitations, we can't migrate all of the available models to work on mobile. Unfortunately, plenty of them won't work on mobile devices but that's fine because we often don't need these heavy models on mobile devices. In this blog post, we will create a simple Android application that will take advantage of MobileNetV2 that was pre-trained on ImageNet.


Let's make our hands dirty...

In our Android app project we need to add TFLite dependency to build.gradle file.

implementation 'org.tensorflow:tensorflow-lite:1.13.1'

and the undermentioned snippet to prevent compressing the model.

aaptOptions {  
    noCompress "tflite"  
    noCompress "lite"  

The next step is to get a model for the image classification problem. One way is to create your own or take pre-trained one from here and put it to the assets folder. We will be using customized pre-trained MobileNetV2 that I've created for the sake of this demo. Our model will be able to or at least it should distinguish 🌊 kitesurfing, windsurfing, and surfing 🏄‍♂️. You can download this model as well as labels from my git repository.


Dive into the code

In order to make use of the prepared model we need to somehow import it into code. Let's use tf.lite.Interpreter interface for the model.

You can set up an interpreter in many ways, one recommended on TF website is to make use of MappedByteBuffer.

private fun getModelByteBuffer(assetManager: AssetManager, modelPath: String): MappedByteBuffer {  
    val fileDescriptor = assetManager.openFd(modelPath)  
    val inputStream = FileInputStream(fileDescriptor.fileDescriptor)  
    val fileChannel =  
    val startOffset = fileDescriptor.startOffset  
    val declaredLength = fileDescriptor.declaredLength  
    return, startOffset, declaredLength) 

and then...

model = Interpreter(loadModelFile(activity))

Unfortunately using MappedByteBuffer as an argument is already deprecated and will be deleted in future releases but you can solve this problem by delivering ByteBuffer instead and it is as simple as invoking .asReadOnlyBuffer() on loadModelFile method.

Next step is to read the file with labels. You can easily grab them with:

private fun getLabels(assetManager: AssetManager, labelPath: String): List<String> {  
    val labels = ArrayList<String>()  
    val reader = BufferedReader(InputStreamReader(  
    while (true) {
        val label = reader.readLine() ?: break
    return labels  

The last thing is to create a method that will take an image as an argument and return a list of labels with assigned probabilities to them.

fun recognize(bitmap: Bitmap): List<Recognition>{

Because our model expects the exact input shape (224x224 pixels) we need to rescale a delivered bitmap to fit into these constraints.

	val scaledBitmap =  Bitmap.createScaledBitmap(bitmap, MODEL_INPUT_SIZE, MODEL_INPUT_SIZE, false)

Next, we need to create byteBuffer of appropriate size that will be passed as an argument to the model.

	val byteBuffer = ByteBuffer  
	                BATCH_SIZE *  		 // amount of images per single processing
	                MODEL_INPUT_SIZE *   // img height
	                MODEL_INPUT_SIZE *   // img width
	                BYTES_PER_CHANNEL *  // size of float = 4
	                PIXEL_SIZE           // r+g+b = 1+1+1
	    .apply { order(ByteOrder.nativeOrder()) } // force device's native order (BIG_ENDIAN or LITTLE_ENDIAN)

And load byteByffer with the image data as floating point numbers. In order to decode color (ignoring alpha) in each pixel on a bitmap, we need to mask the least significant 8 bits and its multiple.

	val pixelValues = IntArray(MODEL_INPUT_SIZE * MODEL_INPUT_SIZE)  
	bitmap.getPixels(pixelValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)  
	var pixel = 0  
	for (i in 0 until MODEL_INPUT_SIZE) {  
	    for (j in 0 until MODEL_INPUT_SIZE) {  
	        val pixelValue = pixelValues[pixel++]  
	        byteBuffer.putFloat((pixelValue shr 16 and 0xFF) / 255f)  
	        byteBuffer.putFloat((pixelValue shr 8 and 0xFF) / 255f)  
	        byteBuffer.putFloat((pixelValue and 0xFF) / 255f)  

Finally, we can pass byteBuffer to the model. The interpreter expects for the second argument container for results and it is array of float arrays (array for each image and each one will contain float array of probabilities).

	val results = Array(BATCH_SIZE) { FloatArray(labels.size) }, results)
	return parseResults(results)

The last step is to bond probability with a proper class.

private fun parseResults(result: Array<FloatArray>): List<Recognition> {  
    val recognitions = mutableListOf<Recognition>()  
    labels.forEachIndexed { index, label ->  
        val probability = result[0][index]  
        recognitions.add(Recognition(label, probability))  
  return recognitions.sortedByDescending { it.probability }  

where Recognition is our humble result data class.

data class Recognition(  
    val name: String,  
    val probability: Float  
) {  
    override fun toString() =  
        "$name : ${probability*100}%"  
Don't forget that there are many things you should consider to make it work well i.e. handling a camera orientation or using post-training quantization if you need higher speed with a bit lower accuracy and lighter.

It’s showtime!

The above code is a minimalistic version for getting TFLite solving for us image classification problem. With the provided model you can successfully classify all photos that are in this blog post. 📸 You can find the demo here.