File size: 1,330 Bytes
fdadfe7 9324034 fdadfe7 126b8a2 fdadfe7 e49974f 513ce06 c5387fd 613b720 513ce06 fdadfe7 4156c99 fdadfe7 d75235b 2f86dc1 893c61f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import gradio as gr
from PIL import Image
import tensorflow as tf
import os
import numpy as np
import base64
from io import BytesIO
model = tf.keras.models.load_model('model.hdf5')
LABELS = ['NORMAL', 'TUBERCULOSIS', 'PNEUMONIA', 'COVID19']
def predict_input_image(img):
try:
img = Image.open(BytesIO(base64.b64decode(img))).convert('RGB').resize((128,128))
img = np.array(img)
except Exception as e:
return {"error": str(e)}
img_4d=img.reshape(-1,128,128,3)/img.max()
print(img_4d.min())
print(img_4d.max())
prediction=model.predict(img_4d)[0]
return {LABELS[i]: float(prediction[i]) for i in range(4)}
with gr.Blocks(title="Chest X-Ray Disease Classification", css="") as demo:
with gr.Row():
textmd = gr.Markdown('''
# Chest X-Ray Disease Classification
View the full training code at <a href="https://www.kaggle.com/code/mushfirat/chest-x-ray-disease-classification"><b>kaggle</b></a>
''')
with gr.Row():
with gr.Column(scale=1, min_width=600):
image = gr.inputs.Image(shape=(128,128))
with gr.Row():
submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary')
label = gr.outputs.Label(num_top_classes=4)
submit_btn.click(predict_input_image, inputs=image, outputs=label, api_name="prediction_place")
demo.launch() |