File size: 4,129 Bytes
ec387c2
 
 
 
 
 
 
64183cb
ec387c2
448426c
64183cb
ec387c2
448426c
755c3f6
 
 
64183cb
ec387c2
 
82768ad
b31004e
 
 
 
 
 
 
 
ec387c2
 
755c3f6
448426c
755c3f6
 
448426c
 
ec387c2
448426c
 
ec387c2
64183cb
 
ec387c2
755c3f6
448426c
755c3f6
 
 
448426c
 
 
 
 
 
 
 
755c3f6
 
 
 
 
 
 
64183cb
448426c
755c3f6
 
 
ec387c2
 
 
448426c
ec387c2
755c3f6
448426c
 
 
 
 
64183cb
 
 
 
448426c
64183cb
 
ec387c2
 
448426c
ec387c2
 
64183cb
 
 
 
448426c
64183cb
 
448426c
64183cb
 
448426c
ec387c2
 
755c3f6
448426c
 
 
 
755c3f6
ec387c2
 
755c3f6
448426c
ec387c2
448426c
ec387c2
755c3f6
64183cb
ec387c2
 
 
 
 
 
448426c
 
ec387c2
 
b31004e
 
 
 
 
448426c
 
b31004e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests

from langchain_groq import ChatGroq
from langchain.agents import initialize_agent
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain.tools import StructuredTool


IMG_HEIGHT = 256
IMG_WIDTH = 256

# === Download model if not exists ===
model_path = "unet_model.h5"
if not os.path.exists(model_path):
    hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
    print(f"Downloading model from {hf_url}...")
    with requests.get(hf_url, stream=True) as r:
        r.raise_for_status()
        with open(model_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)

print("Loading model...")
model = tf.keras.models.load_model(model_path, compile=False)


# === Segmentation + Stats + Overlay ===
def classify_image_and_stats(image_input):
    img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
    img_norm = img / 255.0
    img_batch = np.expand_dims(img_norm, axis=0)

    prediction = model.predict(img_batch)[0]  # (256, 256, 1)
    mask = (prediction > 0.5).astype(np.uint8)

    if mask.ndim == 3 and mask.shape[-1] == 1:
        mask = np.squeeze(mask, axis=-1)

    # Tumor stats
    tumor_area = np.sum(mask)
    total_area = IMG_HEIGHT * IMG_WIDTH
    tumor_ratio = tumor_area / total_area

    tumor_label = "Tumor Detected" if tumor_ratio > 0.005 else "No Tumor Detected"

    # === Overlay mask on original ===
    overlay = np.array(img)  # original resized input
    red_mask = np.zeros_like(overlay)
    red_mask[..., 0] = mask * 255  # Red channel

    overlay_img = np.clip(0.6 * overlay + 0.4 * red_mask, 0, 255).astype(np.uint8)

    stats = {
        "tumor_area": int(tumor_area),
        "total_area": total_area,
        "tumor_ratio": tumor_ratio,
        "tumor_label": tumor_label
    }

    return overlay_img, stats


# === Gradio handler ===
def rishigpt_handler(image_input, groq_api_key):
    os.environ["GROQ_API_KEY"] = groq_api_key

    overlay_img, stats = classify_image_and_stats(image_input)

    def segment_brain_tool(input_text: str) -> str:
        return (
            f"Tumor label: {stats['tumor_label']}. "
            f"Tumor area: {stats['tumor_area']}. "
            f"Ratio: {stats['tumor_ratio']:.4f}."
        )

    tool = StructuredTool.from_function(
        segment_brain_tool,
        name="segment_brain",
        description="Provide tumor segmentation stats for the MRI image."
    )

    llm = ChatGroq(
        model="meta-llama/llama-4-scout-17b-16e-instruct",
        temperature=0.4
    )

    agent = initialize_agent(
        tools=[tool],
        llm=llm,
        agent="zero-shot-react-description",
        verbose=False
    )

    user_query = "Give me the segmentation details"
    classification = agent.run(user_query)

    # Better prompt + output parser
    prompt = PromptTemplate(
        input_variables=["result"],
        template=(
            "You are a compassionate AI radiologist. "
            "Read this tumor analysis result: {result}. "
            "Summarize the situation like you're talking to the patient in calm, clear language. "
            "Add any recommendations for next steps too, but keep it easy to understand."
        )
    )

    chain = prompt | llm
    description = chain.invoke({"result": classification}).content.strip()

    return overlay_img, description


# === Gradio UI ===
inputs = [
    gr.Image(type="numpy", label="Upload Brain MRI Slice"),
    gr.Textbox(type="password", label="Groq API Key")
]

outputs = [
    gr.Image(type="numpy", label="Overlay: Brain MRI + Tumor Mask"),
    gr.Textbox(label="Doctor's Explanation")
]

if __name__ == "__main__":
    gr.Interface(
        fn=rishigpt_handler,
        inputs=inputs,
        outputs=outputs,
        title="🧠 RishiGPT Medical Brain Segmentation",
        description="UNet++ Brain Tumor Segmentation with mask overlay, detailed stats, and human-like explanation."
    ).launch()