Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow import keras | |
import tensorflow_hub as hub | |
from PIL import Image | |
# Load models | |
model_initial = keras.models.load_model( | |
"models/initial_model.h5", custom_objects={'KerasLayer': hub.KerasLayer} | |
) | |
model_tumor = keras.models.load_model( | |
"models/model_tumor.h5", custom_objects={'KerasLayer': hub.KerasLayer} | |
) | |
model_stroke = keras.models.load_model( | |
"models/model_stroke.h5", custom_objects={'KerasLayer': hub.KerasLayer} | |
) | |
model_alzheimer = keras.models.load_model( | |
"models/model_alzheimer.h5", custom_objects={'KerasLayer': hub.KerasLayer} | |
) | |
class CombinedDiseaseModel(tf.keras.Model): | |
def __init__(self, model_initial, model_alzheimer, model_tumor, model_stroke): | |
super(CombinedDiseaseModel, self).__init__() | |
self.model_initial = model_initial | |
self.model_alzheimer = model_alzheimer | |
self.model_tumor = model_tumor | |
self.model_stroke = model_stroke | |
self.disease_labels = ["Alzheimer's", 'No Disease', 'Stroke', 'Tumor'] | |
self.sub_models = { | |
"Alzheimer's": model_alzheimer, | |
'Tumor': model_tumor, | |
'Stroke': model_stroke | |
} | |
def call(self, inputs): | |
initial_probs = self.model_initial(inputs, training=False) | |
main_disease_idx = tf.argmax(initial_probs, axis=1) | |
main_disease = self.disease_labels[main_disease_idx[0].numpy()] | |
main_disease_prob = initial_probs[0, main_disease_idx[0]].numpy() | |
if main_disease == 'No Disease': | |
sub_category = "No Disease" | |
sub_category_prob = main_disease_prob | |
else: | |
sub_model = self.sub_models[main_disease] | |
sub_category_pred = sub_model(inputs, training=False) | |
sub_category = tf.argmax(sub_category_pred, axis=1).numpy()[0] | |
sub_category_prob = sub_category_pred[0, sub_category].numpy() | |
if main_disease == "Alzheimer's": | |
sub_category_label = ['Very Mild', 'Mild', 'Moderate'] | |
elif main_disease == 'Tumor': | |
sub_category_label = ['Glioma', 'Meningioma', 'Pituitary'] | |
elif main_disease == 'Stroke': | |
sub_category_label = ['Ischemic', 'Hemorrhagic'] | |
sub_category = sub_category_label[sub_category] | |
return f"The MRI image shows {main_disease} with a probability of {main_disease_prob*100:.2f}%.\nThe subcategory of {main_disease} is {sub_category} with a probability of {sub_category_prob*100:.2f}%." | |
# Initialize the combined model | |
cnn_model = CombinedDiseaseModel( | |
model_initial=model_initial, | |
model_alzheimer=model_alzheimer, | |
model_tumor=model_tumor, | |
model_stroke=model_stroke | |
) | |
def process_image(image): | |
image = image.resize((256, 256)) | |
image.convert("RGB") | |
image_array = np.array(image) / 255.0 | |
image_array = np.expand_dims(image_array, axis=0) | |
predictions = cnn_model(image_array) | |
return predictions | |
def gradio_interface(patient_info, query_type, image): | |
if image is not None: | |
image_response = process_image(image) | |
response = f"Patient Info: {patient_info}\nQuery Type: {query_type}\n{image_response}" | |
return response | |
else: | |
return "Please upload an image." | |
# Create Gradio app | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox( | |
label="Patient Information", | |
placeholder="Enter patient details here...", | |
lines=5, | |
max_lines=10 | |
), | |
gr.Textbox( | |
label="Query Type" | |
), | |
gr.Image( | |
type="pil", | |
label="Upload an Image", | |
) | |
], | |
outputs=gr.Textbox(label="Response", placeholder="The response will appear here..."), | |
title="Medical Diagnosis with MRI", | |
description="Upload MRI images and provide patient information for diagnosis.", | |
) | |
iface.launch() |