NeuroGenAI / app.py
yashbyname's picture
Create app.py
e7f38cd verified
raw
history blame
3.95 kB
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub
from PIL import Image
# Load models
model_initial = keras.models.load_model(
"models/initial_model.h5", custom_objects={'KerasLayer': hub.KerasLayer}
)
model_tumor = keras.models.load_model(
"models/model_tumor.h5", custom_objects={'KerasLayer': hub.KerasLayer}
)
model_stroke = keras.models.load_model(
"models/model_stroke.h5", custom_objects={'KerasLayer': hub.KerasLayer}
)
model_alzheimer = keras.models.load_model(
"models/model_alzheimer.h5", custom_objects={'KerasLayer': hub.KerasLayer}
)
class CombinedDiseaseModel(tf.keras.Model):
def __init__(self, model_initial, model_alzheimer, model_tumor, model_stroke):
super(CombinedDiseaseModel, self).__init__()
self.model_initial = model_initial
self.model_alzheimer = model_alzheimer
self.model_tumor = model_tumor
self.model_stroke = model_stroke
self.disease_labels = ["Alzheimer's", 'No Disease', 'Stroke', 'Tumor']
self.sub_models = {
"Alzheimer's": model_alzheimer,
'Tumor': model_tumor,
'Stroke': model_stroke
}
def call(self, inputs):
initial_probs = self.model_initial(inputs, training=False)
main_disease_idx = tf.argmax(initial_probs, axis=1)
main_disease = self.disease_labels[main_disease_idx[0].numpy()]
main_disease_prob = initial_probs[0, main_disease_idx[0]].numpy()
if main_disease == 'No Disease':
sub_category = "No Disease"
sub_category_prob = main_disease_prob
else:
sub_model = self.sub_models[main_disease]
sub_category_pred = sub_model(inputs, training=False)
sub_category = tf.argmax(sub_category_pred, axis=1).numpy()[0]
sub_category_prob = sub_category_pred[0, sub_category].numpy()
if main_disease == "Alzheimer's":
sub_category_label = ['Very Mild', 'Mild', 'Moderate']
elif main_disease == 'Tumor':
sub_category_label = ['Glioma', 'Meningioma', 'Pituitary']
elif main_disease == 'Stroke':
sub_category_label = ['Ischemic', 'Hemorrhagic']
sub_category = sub_category_label[sub_category]
return f"The MRI image shows {main_disease} with a probability of {main_disease_prob*100:.2f}%.\nThe subcategory of {main_disease} is {sub_category} with a probability of {sub_category_prob*100:.2f}%."
# Initialize the combined model
cnn_model = CombinedDiseaseModel(
model_initial=model_initial,
model_alzheimer=model_alzheimer,
model_tumor=model_tumor,
model_stroke=model_stroke
)
def process_image(image):
image = image.resize((256, 256))
image.convert("RGB")
image_array = np.array(image) / 255.0
image_array = np.expand_dims(image_array, axis=0)
predictions = cnn_model(image_array)
return predictions
def gradio_interface(patient_info, query_type, image):
if image is not None:
image_response = process_image(image)
response = f"Patient Info: {patient_info}\nQuery Type: {query_type}\n{image_response}"
return response
else:
return "Please upload an image."
# Create Gradio app
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(
label="Patient Information",
placeholder="Enter patient details here...",
lines=5,
max_lines=10
),
gr.Textbox(
label="Query Type"
),
gr.Image(
type="pil",
label="Upload an Image",
)
],
outputs=gr.Textbox(label="Response", placeholder="The response will appear here..."),
title="Medical Diagnosis with MRI",
description="Upload MRI images and provide patient information for diagnosis.",
)
iface.launch()