|
import gradio as gr |
|
from PIL import Image |
|
import tensorflow as tf |
|
import os |
|
import numpy as np |
|
import base64 |
|
from io import BytesIO |
|
|
|
model = tf.keras.models.load_model('model.hdf5') |
|
|
|
LABELS = ['NORMAL', 'TUBERCULOSIS', 'PNEUMONIA', 'COVID19'] |
|
|
|
def predict_input_image(img): |
|
try: |
|
img = Image.open(BytesIO(base64.b64decode(img))).convert('RGB').resize((128,128)) |
|
img = np.array(img) |
|
except Exception as e: |
|
return {"error": str(e)} |
|
img_4d=img.reshape(-1,128,128,3)/img.max() |
|
print(img_4d.min()) |
|
print(img_4d.max()) |
|
prediction=model.predict(img_4d)[0] |
|
return {LABELS[i]: float(prediction[i]) for i in range(4)} |
|
|
|
|
|
with gr.Blocks(title="Chest X-Ray Disease Classification", css="") as demo: |
|
with gr.Row(): |
|
textmd = gr.Markdown(''' |
|
# Chest X-Ray Disease Classification |
|
View the full training code at <a href="https://www.kaggle.com/code/mushfirat/chest-x-ray-disease-classification"><b>kaggle</b></a> |
|
''') |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=600): |
|
image = gr.inputs.Image(shape=(128,128)) |
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary') |
|
label = gr.outputs.Label(num_top_classes=4) |
|
|
|
submit_btn.click(predict_input_image, inputs=image, outputs=label, api_name="prediction_place") |
|
|
|
demo.launch() |