| import gradio as gr | |
| from PIL import Image | |
| import tensorflow as tf | |
| import os | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| model = tf.keras.models.load_model('model.hdf5') | |
| LABELS = ['NORMAL', 'TUBERCULOSIS', 'PNEUMONIA', 'COVID19'] | |
| def predict_input_image(img): | |
| try: | |
| img = Image.open(BytesIO(base64.b64decode(img))).convert('RGB').resize((128,128)) | |
| img = np.array(img) | |
| img_4d=img.reshape(-1,128,128,3)/img.max() | |
| print(img_4d.min()) | |
| print(img_4d.max()) | |
| prediction=model.predict(img_4d)[0] | |
| return {LABELS[i]: float(prediction[i]) for i in range(4)} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| with gr.Blocks(title="Chest X-Ray Disease Classification", css="") as demo: | |
| with gr.Row(): | |
| textmd = gr.Markdown(''' | |
| # Chest X-Ray Disease Classification | |
| View the full training code at <a href="https://www.kaggle.com/code/mushfirat/chest-x-ray-disease-classification"><b>kaggle</b></a> | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=600): | |
| image = gr.inputs.Image(shape=(128,128)) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary') | |
| label = gr.outputs.Label(num_top_classes=4) | |
| submit_btn.click(predict_input_image, inputs=image, outputs=label, api_name="prediction_place") | |
| demo.launch() |