File size: 3,645 Bytes
ec387c2
 
 
 
 
 
 
64183cb
ec387c2
755c3f6
64183cb
ec387c2
755c3f6
 
 
64183cb
ec387c2
 
82768ad
b31004e
 
 
 
 
 
 
 
ec387c2
 
755c3f6
 
 
 
ec387c2
 
 
755c3f6
 
ec387c2
64183cb
 
ec387c2
755c3f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64183cb
755c3f6
 
 
 
ec387c2
 
 
755c3f6
ec387c2
755c3f6
 
64183cb
 
 
 
755c3f6
64183cb
 
ec387c2
 
 
 
 
64183cb
 
 
 
 
 
 
755c3f6
64183cb
 
ec387c2
 
755c3f6
 
 
 
ec387c2
 
755c3f6
64183cb
ec387c2
 
 
755c3f6
64183cb
ec387c2
 
 
 
 
 
 
 
 
 
b31004e
 
 
 
 
 
755c3f6
b31004e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests

from langchain_groq import ChatGroq
from langchain.agents import initialize_agent
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence  # modern
from langchain.tools import StructuredTool

IMG_HEIGHT = 256
IMG_WIDTH = 256

# === Download model if not exists ===
model_path = "unet_model.h5"
if not os.path.exists(model_path):
    hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
    print(f"Downloading model from {hf_url}...")
    with requests.get(hf_url, stream=True) as r:
        r.raise_for_status()
        with open(model_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)

print("Loading model...")
model = tf.keras.models.load_model(model_path, compile=False)


# === Segmentation + Stats ===
def classify_image_and_stats(image_input):
    img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
    img = img / 255.0
    img = np.expand_dims(img, axis=0)

    prediction = model.predict(img)[0]  # (256, 256, 1)
    mask = (prediction > 0.5).astype(np.uint8) * 255

    if mask.ndim == 3 and mask.shape[-1] == 1:
        mask = np.squeeze(mask, axis=-1)

    # Tumor stats
    tumor_area = np.sum(prediction > 0.5)
    total_area = IMG_HEIGHT * IMG_WIDTH
    tumor_ratio = tumor_area / total_area

    if tumor_ratio > 0.01:
        tumor_label = "Tumor Detected"
    else:
        tumor_label = "No Tumor Detected"

    stats = {
        "tumor_area": int(tumor_area),
        "total_area": total_area,
        "tumor_ratio": tumor_ratio,
        "tumor_label": tumor_label
    }

    return mask, stats


# === Gradio handler ===
def rishigpt_handler(image_input, groq_api_key):
    os.environ["GROQ_API_KEY"] = groq_api_key

    mask, stats = classify_image_and_stats(image_input)

    def segment_brain_tool(input_text: str) -> str:
        return f"Tumor label: {stats['tumor_label']}. Tumor area: {stats['tumor_area']}. Ratio: {stats['tumor_ratio']:.4f}."

    tool = StructuredTool.from_function(
        segment_brain_tool,
        name="segment_brain",
        description="Provide tumor segmentation stats for the MRI image. Takes dummy input text."
    )

    llm = ChatGroq(
        model="meta-llama/llama-4-scout-17b-16e-instruct",
        temperature=0.3
    )

    agent = initialize_agent(
        tools=[tool],
        llm=llm,
        agent="zero-shot-react-description",
        verbose=True
    )

    user_query = "Get segmentation details"
    classification = agent.run(user_query)

    prompt = PromptTemplate(
        input_variables=["result"],
        template=(
            "You are a medical imaging expert. Based on this tumor analysis result: {result}, "
            "explain what this means for the patient in simple language."
        )
    )

    chain = prompt | llm
    description = chain.invoke({"result": classification})

    return mask, description


# === Gradio UI ===
inputs = [
    gr.Image(type="numpy", label="Upload Brain MRI Slice"),
    gr.Textbox(type="password", label="Groq API Key")
]

outputs = [
    gr.Image(type="numpy", label="Tumor Segmentation Mask"),
    gr.Textbox(label="Medical Explanation")
]

if __name__ == "__main__":
    gr.Interface(
        fn=rishigpt_handler,
        inputs=inputs,
        outputs=outputs,
        title="RishiGPT Medical Brain Segmentation",
        description="UNet++ Brain Tumor Segmentation with LangChain integration. Includes tumor stats and medical explanation."
    ).launch()