Inference with Tensorflow in Java

We talk about the struggles I had with running inferences in Java using a Keras model in tensorflow.

This is the train.py used in this example.

import tensorflow as tf
#import tensorflowjs as tfjs
import os
from datetime import datetime

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Layer
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense, Input, InputLayer, Reshape

IMG_WIDTH = 600
IMG_HEIGHT = 400

def tf_parse(record):
    keys_to_features = {
        "width":     tf.io.FixedLenFeature([], tf.int64, default_value=0),
        "height":     tf.io.FixedLenFeature([], tf.int64, default_value=0),
        "depth":     tf.io.FixedLenFeature([], tf.int64, default_value=0),
        "label":     tf.io.FixedLenFeature([], tf.int64, default_value=0),
        "image_raw": tf.io.FixedLenFeature([], tf.string, default_value='')
    }
    parsed = tf.io.parse_example(record[tf.newaxis], keys_to_features)

    img = tf.image.decode_jpeg(parsed["image_raw"][0], channels=3)
    img = tf.reshape(img, shape=[parsed["height"][0], parsed["width"][0], parsed["depth"][0]])
    img = tf.image.resize_with_crop_or_pad(img, IMG_HEIGHT, IMG_WIDTH)

    label = tf.cast(parsed["label"][0], tf.int64)

    return img, label

# define a function to list tfrecord files.
def list_tfrecord_file(file_list):
    tfrecord_list = []
    for i in range(len(file_list)):
        current_file_abs_path = os.path.abspath("records/" + file_list[i])
        if current_file_abs_path.endswith(".tfrecords"):
            tfrecord_list.append(current_file_abs_path)
            print("Found %s successfully!" % file_list[i])
        else:
            pass
    return tfrecord_list

# Traverse current directory
def tfrecord_auto_traversal():
    current_folder_filename_list = os.listdir("./records/") # Change this PATH to traverse other directories if you want.
    if current_folder_filename_list != None:
        print("%s files were found under current folder. " % len(current_folder_filename_list))
        print("Please be noted that only files end with '*.tfrecords' will be load!")
        tfrecord_list = list_tfrecord_file(current_folder_filename_list)
        if len(tfrecord_list) != 0:
            for list_index in range(len(tfrecord_list)):
                print(tfrecord_list[list_index])
        else:
            print("Cannot find any tfrecords files, please check the path.")
    return tfrecord_list

logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

dataset = tf.data.TFRecordDataset(filenames = tfrecord_auto_traversal())
parsed_dataset = dataset.map(tf_parse).shuffle(buffer_size=50000)

val_dataset = parsed_dataset.take(1024).repeat().batch(16, drop_remainder=True)
train_dataset = parsed_dataset.skip(1024).repeat().batch(16, drop_remainder=True)

model = Sequential()
model.add(Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())

model.add(Conv2D(64, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())

model.add(Conv2D(64, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())

model.add(Conv2D(96, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())

model.add(Conv2D(32, kernel_size=(3,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(BatchNormalization())

model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(5, activation='softmax'))
model.add(Layer(name='main_output', dtype='float32'))

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_dataset, steps_per_epoch=1000, epochs=50,
          validation_data=val_dataset, validation_steps=10,
          callbacks=[tensorboard_callback],
          )

#model.evaluate(val_dataset)
tf.saved_model.save(model, "my_model")
model.save('trained_model.h5')

#tfjs.converters.save_keras_model(model, 'tensorjs')

This is the plot_train.py used in this example.

import tensorflow as tf

new_model = tf.keras.models.load_model('my_model')
new_model.summary()
tf.keras.utils.plot_model(new_model, to_file='model.png', dpi=150)

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.