Recently in some interview I have been asked about experience of implementing trained tensorflow models in android platform. I have tried one android project cloned from github which embedded a tflite model in it. However, I have not yet tried implementing my own model in an Android application. Thus I did such an exercise today and I successfully made my CNN model work on my Redmi Note 8 pro.
CNN model
Here is the code for training a cnn model with mnist data set. This model then is converted as tflite model and shall be implemented in Android application for recognizing hand-write digits.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics,models
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print(x_train[0,:,:])
## x_train.shape => (60000, 28, 28), y_train.shape => (60000,)
## x_test.shape => (10000, 28, 28), y_test.shape => (60000,)
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
yt = tf.squeeze(y_train)
y_train = tf.squeeze(y_train)
y_test = tf.squeeze(y_test)
print("Dataset info: ", x_train.shape, y_train.shape, x_test.shape, y_test.shape)
batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(10000).batch(batch_size)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.batch(batch_size)
train_iter = iter(train_db)
sample = next(train_iter)
print(sample[0].shape, sample[1].shape)
## build a standard cnn model
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))
model.summary()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
train_history = model.fit(train_db, epochs=10, validation_data=test_db)
## once the model has been trained, convert it to tflite model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('qiu_mnist_model.tflite', 'wb') as f:
f.write(tflite_model)
Implementation in Android app
I refereed to this post for obtaining the original android project. I imported the kotlin version into my Android Studio. However, there were some bugs initially when I loaded my model into it.
My own model is located to asset
repository:
The most important thing for this work is the following Gradle setting:
After about 15min of debugging and code modifications, I successfully made my model work.
Check out the video (there is still accuracy issue):
I will upload the android project src code to my github repo once I finish cleaning the code and improve the performance.
Top comments (0)