Inference with Tensorflow in Java
We talk about the struggles I had with running inferences in Java using a Keras model in tensorflow.
This is the train.py used in this example.
import tensorflow as tf
#import tensorflowjs as tfjs
import os
from datetime import datetime
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Layer
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense, Input, InputLayer, Reshape
IMG_WIDTH = 600
IMG_HEIGHT = 400
def tf_parse(record):
keys_to_features = {
"width": tf.io.FixedLenFeature([], tf.int64, default_value=0),
"height": tf.io.FixedLenFeature([], tf.int64, default_value=0),
"depth": tf.io.FixedLenFeature([], tf.int64, default_value=0),
"label": tf.io.FixedLenFeature([], tf.int64, default_value=0),
"image_raw": tf.io.FixedLenFeature([], tf.string, default_value='')
}
parsed = tf.io.parse_example(record[tf.newaxis], keys_to_features)
img = tf.image.decode_jpeg(parsed["image_raw"][0], channels=3)
img = tf.reshape(img, shape=[parsed["height"][0], parsed["width"][0], parsed["depth"][0]])
img = tf.image.resize_with_crop_or_pad(img, IMG_HEIGHT, IMG_WIDTH)
label = tf.cast(parsed["label"][0], tf.int64)
return img, label
# define a function to list tfrecord files.
def list_tfrecord_file(file_list):
tfrecord_list = []
for i in range(len(file_list)):
current_file_abs_path = os.path.abspath("records/" + file_list[i])
if current_file_abs_path.endswith(".tfrecords"):
tfrecord_list.append(current_file_abs_path)
print("Found %s successfully!" % file_list[i])
else:
pass
return tfrecord_list
# Traverse current directory
def tfrecord_auto_traversal():
current_folder_filename_list = os.listdir("./records/") # Change this PATH to traverse other directories if you want.
if current_folder_filename_list != None:
print("%s files were found under current folder. " % len(current_folder_filename_list))
print("Please be noted that only files end with '*.tfrecords' will be load!")
tfrecord_list = list_tfrecord_file(current_folder_filename_list)
if len(tfrecord_list) != 0:
for list_index in range(len(tfrecord_list)):
print(tfrecord_list[list_index])
else:
print("Cannot find any tfrecords files, please check the path.")
return tfrecord_list
logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
dataset = tf.data.TFRecordDataset(filenames = tfrecord_auto_traversal())
parsed_dataset = dataset.map(tf_parse).shuffle(buffer_size=50000)
val_dataset = parsed_dataset.take(1024).repeat().batch(16, drop_remainder=True)
train_dataset = parsed_dataset.skip(1024).repeat().batch(16, drop_remainder=True)
model = Sequential()
model.add(Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())
model.add(Conv2D(96, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())
model.add(Conv2D(32, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(5, activation='softmax'))
model.add(Layer(name='main_output', dtype='float32'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_dataset, steps_per_epoch=1000, epochs=50,
validation_data=val_dataset, validation_steps=10,
callbacks=[tensorboard_callback],
)
#model.evaluate(val_dataset)
tf.saved_model.save(model, "my_model")
model.save('trained_model.h5')
#tfjs.converters.save_keras_model(model, 'tensorjs')
This is the plot_train.py used in this example.
import tensorflow as tf
new_model = tf.keras.models.load_model('my_model')
new_model.summary()
tf.keras.utils.plot_model(new_model, to_file='model.png', dpi=150)