Mushfi's picture
Update app.py
9324034
raw
history blame
1.35 kB
import gradio as gr
from PIL import Image
import tensorflow as tf
import os
import numpy as np
import base64
from io import BytesIO
model = tf.keras.models.load_model('model.hdf5')
LABELS = ['NORMAL', 'TUBERCULOSIS', 'PNEUMONIA', 'COVID19']
def predict_input_image(img):
try:
img = Image.open(BytesIO(base64.b64decode(img))).convert('RGB').resize((128,128))
img = np.array(img)
img_4d=img.reshape(-1,128,128,3)/img.max()
print(img_4d.min())
print(img_4d.max())
prediction=model.predict(img_4d)[0]
return {LABELS[i]: float(prediction[i]) for i in range(4)}
except Exception as e:
return {"error": str(e)}
with gr.Blocks(title="Chest X-Ray Disease Classification", css="") as demo:
with gr.Row():
textmd = gr.Markdown('''
# Chest X-Ray Disease Classification
View the full training code at <a href="https://www.kaggle.com/code/mushfirat/chest-x-ray-disease-classification"><b>kaggle</b></a>
''')
with gr.Row():
with gr.Column(scale=1, min_width=600):
image = gr.inputs.Image(shape=(128,128))
with gr.Row():
submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary')
label = gr.outputs.Label(num_top_classes=4)
submit_btn.click(predict_input_image, inputs=image, outputs=label, api_name="prediction_place")
demo.launch()