harshaUwm163
added link to the model page with the notebook for training it
fd1f777
raw
history blame
2.53 kB
import numpy as np
import tensorflow as tf
import gradio as gr
from huggingface_hub import from_pretrained_keras
import cv2
# import matplotlib.pyplot as plt
model = from_pretrained_keras("keras-io/CutMix_data_augmentation_for_image_classification")
# functions for inference
IMG_SIZE = 32
class_names = [
"Airplane",
"Automobile",
"Bird",
"Cat",
"Deer",
"Dog",
"Frog",
"Horse",
"Ship",
"Truck",
]
# resize the image and it to a float between 0,1
def preprocess_image(image, label):
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
image = tf.image.convert_image_dtype(image, tf.float32) / 255.0
return image, label
def read_image(image):
image = tf.convert_to_tensor(image)
image.set_shape([None, None, 3])
print('$$$$$$$$$$$$$$$$$$$$$ in read image $$$$$$$$$$$$$$$$$$$$$$')
print(image.shape)
# plt.imshow(image)
# plt.show()
# image = tf.image.resize(images=image, size=[IMG_SIZE, IMG_SIZE])
# image = image / 127.5 - 1
image, _ = preprocess_image(image, 1) # 1 here is a temporary label
return image
def infer(input_image):
print('#$$$$$$$$$$$$$$$$$$$$$$$$$ IN INFER $$$$$$$$$$$$$$$$$$$$$$$')
image_tensor = read_image(input_image)
print(image_tensor.shape)
predictions = model.predict(np.expand_dims((image_tensor), axis=0))
predictions = np.squeeze(predictions)
predictions = np.argmax(predictions) # , axis=2
predicted_label = class_names[predictions.item()]
return str(predicted_label)
# get the inputs
input = gr.inputs.Image(shape=(IMG_SIZE, IMG_SIZE))
# the app outputs two segmented images
output = [gr.outputs.Label()]
# it's good practice to pass examples, description and a title to guide users
examples = [["./content/examples/Frog.jpg"], ["./content/examples/Truck.jpg"]]
title = "Image classification"
description = "Upload an image or select from examples to classify it. The allowed classes are - Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck <p><b>Space author: Harshavardhan</b> <br><b> Keras example author: <a href=\"https://twitter.com/sayannath2350\"> Sayan Nath </a> </b> <a href=\"https://huggingface.co/keras-io/CutMix_data_augmentation_for_image_classification\">link to the model page<> </p>"
gr_interface = gr.Interface(infer, input, output, examples=examples, allow_flagging=False, analytics_enabled=False, title=title, description=description).launch(enable_queue=True, debug=False)
gr_interface.launch()