File size: 3,915 Bytes
ec387c2
 
 
 
 
270b631
ec387c2
 
64183cb
ec387c2
448426c
64183cb
ec387c2
448426c
755c3f6
 
 
27b7bc3
ec387c2
 
82768ad
b31004e
 
 
 
 
 
 
 
ec387c2
 
755c3f6
27b7bc3
755c3f6
 
448426c
 
ec387c2
270b631
448426c
ec387c2
64183cb
 
ec387c2
448426c
755c3f6
 
 
27b7bc3
448426c
270b631
448426c
270b631
448426c
 
755c3f6
 
 
 
 
 
 
64183cb
448426c
755c3f6
 
27b7bc3
ec387c2
 
 
448426c
ec387c2
755c3f6
448426c
 
 
 
 
64183cb
 
 
 
448426c
64183cb
 
ec387c2
 
448426c
ec387c2
 
64183cb
 
 
 
448426c
64183cb
 
448426c
64183cb
 
ec387c2
 
755c3f6
448426c
 
270b631
755c3f6
ec387c2
 
755c3f6
270b631
ec387c2
27b7bc3
270b631
 
 
 
 
755c3f6
27b7bc3
ec387c2
 
 
 
 
 
448426c
eb80dfb
ec387c2
 
b31004e
 
 
 
 
270b631
 
b31004e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests
import time

from langchain_groq import ChatGroq
from langchain.agents import initialize_agent
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain.tools import StructuredTool


IMG_HEIGHT = 256
IMG_WIDTH = 256


model_path = "unet_model.h5"
if not os.path.exists(model_path):
    hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
    print(f"Downloading model from {hf_url}...")
    with requests.get(hf_url, stream=True) as r:
        r.raise_for_status()
        with open(model_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)

print("Loading model...")
model = tf.keras.models.load_model(model_path, compile=False)



def classify_image_and_stats(image_input):
    img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
    img_norm = img / 255.0
    img_batch = np.expand_dims(img_norm, axis=0)

    prediction = model.predict(img_batch)[0]
    mask = (prediction > 0.5).astype(np.uint8)

    if mask.ndim == 3 and mask.shape[-1] == 1:
        mask = np.squeeze(mask, axis=-1)

    tumor_area = np.sum(mask)
    total_area = IMG_HEIGHT * IMG_WIDTH
    tumor_ratio = tumor_area / total_area

    tumor_label = "Tumor Detected" if tumor_ratio > 0.00385 else "No Tumor Detected"

    overlay = np.array(img)
    red_mask = np.zeros_like(overlay)
    red_mask[..., 0] = mask * 255

    overlay_img = np.clip(0.6 * overlay + 0.4 * red_mask, 0, 255).astype(np.uint8)

    stats = {
        "tumor_area": int(tumor_area),
        "total_area": total_area,
        "tumor_ratio": tumor_ratio,
        "tumor_label": tumor_label
    }

    return overlay_img, stats



def rishigpt_handler(image_input, groq_api_key):
    os.environ["GROQ_API_KEY"] = groq_api_key

    overlay_img, stats = classify_image_and_stats(image_input)

    def segment_brain_tool(input_text: str) -> str:
        return (
            f"Tumor label: {stats['tumor_label']}. "
            f"Tumor area: {stats['tumor_area']}. "
            f"Ratio: {stats['tumor_ratio']:.4f}."
        )

    tool = StructuredTool.from_function(
        segment_brain_tool,
        name="segment_brain",
        description="Provide tumor segmentation stats for the MRI image."
    )

    llm = ChatGroq(
        model="meta-llama/llama-4-scout-17b-16e-instruct",
        temperature=0.4
    )

    agent = initialize_agent(
        tools=[tool],
        llm=llm,
        agent="zero-shot-react-description",
        verbose=False
    )

    user_query = "Give me the segmentation details"
    classification = agent.run(user_query)

    prompt = PromptTemplate(
        input_variables=["result"],
        template=(
            "You are a compassionate AI radiologist. "
            "Read this tumor analysis result: {result}. "
            "Summarize the situation for the patient in natural paragraphs, calm, clear tone, with next steps."
        )
    )

    chain = prompt | llm
    final_text = chain.invoke({"result": classification}).content.strip()


    displayed_text = ""
    for char in final_text:
        displayed_text += char
        time.sleep(0.015)
        yield overlay_img, displayed_text


inputs = [
    gr.Image(type="numpy", label="Upload Brain MRI Slice"),
    gr.Textbox(type="password", label="Groq API Key")
]

outputs = [
    gr.Image(type="numpy", label="Overlay: Brain MRI + Tumor Mask"),
    gr.Textbox(label="Doctor's Explanation")
]

if __name__ == "__main__":
    gr.Interface(
        fn=rishigpt_handler,
        inputs=inputs,
        outputs=outputs,
        title="RishiGPT Medical Brain Segmentation",
        description="UNet++ Brain Tumor Segmentation with live mask overlay, detailed stats, and human-like typing explanation."
    ).launch()