File size: 3,945 Bytes
e7f38cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub
from PIL import Image

# Load models
model_initial = keras.models.load_model(
    "models/initial_model.h5", custom_objects={'KerasLayer': hub.KerasLayer}
)
model_tumor = keras.models.load_model(
    "models/model_tumor.h5", custom_objects={'KerasLayer': hub.KerasLayer}
)
model_stroke = keras.models.load_model(
    "models/model_stroke.h5", custom_objects={'KerasLayer': hub.KerasLayer}
)
model_alzheimer = keras.models.load_model(
    "models/model_alzheimer.h5", custom_objects={'KerasLayer': hub.KerasLayer}
)

class CombinedDiseaseModel(tf.keras.Model):
    def __init__(self, model_initial, model_alzheimer, model_tumor, model_stroke):
        super(CombinedDiseaseModel, self).__init__()
        self.model_initial = model_initial
        self.model_alzheimer = model_alzheimer
        self.model_tumor = model_tumor
        self.model_stroke = model_stroke
        self.disease_labels = ["Alzheimer's", 'No Disease', 'Stroke', 'Tumor']

        self.sub_models = {
            "Alzheimer's": model_alzheimer,
            'Tumor': model_tumor,
            'Stroke': model_stroke
        }

    def call(self, inputs):
        initial_probs = self.model_initial(inputs, training=False)
        main_disease_idx = tf.argmax(initial_probs, axis=1)
        main_disease = self.disease_labels[main_disease_idx[0].numpy()]
        main_disease_prob = initial_probs[0, main_disease_idx[0]].numpy()

        if main_disease == 'No Disease':
            sub_category = "No Disease"
            sub_category_prob = main_disease_prob
        else:
            sub_model = self.sub_models[main_disease]
            sub_category_pred = sub_model(inputs, training=False)
            sub_category = tf.argmax(sub_category_pred, axis=1).numpy()[0]
            sub_category_prob = sub_category_pred[0, sub_category].numpy()

            if main_disease == "Alzheimer's":
                sub_category_label = ['Very Mild', 'Mild', 'Moderate']
            elif main_disease == 'Tumor':
                sub_category_label = ['Glioma', 'Meningioma', 'Pituitary']
            elif main_disease == 'Stroke':
                sub_category_label = ['Ischemic', 'Hemorrhagic']

            sub_category = sub_category_label[sub_category]

        return f"The MRI image shows {main_disease} with a probability of {main_disease_prob*100:.2f}%.\nThe subcategory of {main_disease} is {sub_category} with a probability of {sub_category_prob*100:.2f}%."


# Initialize the combined model
cnn_model = CombinedDiseaseModel(
    model_initial=model_initial,
    model_alzheimer=model_alzheimer,
    model_tumor=model_tumor,
    model_stroke=model_stroke
)


def process_image(image):
    image = image.resize((256, 256))
    image.convert("RGB")
    image_array = np.array(image) / 255.0
    image_array = np.expand_dims(image_array, axis=0)
    predictions = cnn_model(image_array)
    return predictions


def gradio_interface(patient_info, query_type, image):
    if image is not None:
        image_response = process_image(image)
        response = f"Patient Info: {patient_info}\nQuery Type: {query_type}\n{image_response}"
        return response
    else:
        return "Please upload an image."


# Create Gradio app
iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Textbox(
            label="Patient Information",
            placeholder="Enter patient details here...",
            lines=5,
            max_lines=10
        ),
        gr.Textbox(
            label="Query Type"
        ),
        gr.Image(
            type="pil",
            label="Upload an Image",
        )
    ],
    outputs=gr.Textbox(label="Response", placeholder="The response will appear here..."),
    title="Medical Diagnosis with MRI",
    description="Upload MRI images and provide patient information for diagnosis.",
)

iface.launch()