File size: 4,129 Bytes
ec387c2 64183cb ec387c2 448426c 64183cb ec387c2 448426c 755c3f6 64183cb ec387c2 82768ad b31004e ec387c2 755c3f6 448426c 755c3f6 448426c ec387c2 448426c ec387c2 64183cb ec387c2 755c3f6 448426c 755c3f6 448426c 755c3f6 64183cb 448426c 755c3f6 ec387c2 448426c ec387c2 755c3f6 448426c 64183cb 448426c 64183cb ec387c2 448426c ec387c2 64183cb 448426c 64183cb 448426c 64183cb 448426c ec387c2 755c3f6 448426c 755c3f6 ec387c2 755c3f6 448426c ec387c2 448426c ec387c2 755c3f6 64183cb ec387c2 448426c ec387c2 b31004e 448426c b31004e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests
from langchain_groq import ChatGroq
from langchain.agents import initialize_agent
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain.tools import StructuredTool
IMG_HEIGHT = 256
IMG_WIDTH = 256
# === Download model if not exists ===
model_path = "unet_model.h5"
if not os.path.exists(model_path):
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
print(f"Downloading model from {hf_url}...")
with requests.get(hf_url, stream=True) as r:
r.raise_for_status()
with open(model_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("Loading model...")
model = tf.keras.models.load_model(model_path, compile=False)
# === Segmentation + Stats + Overlay ===
def classify_image_and_stats(image_input):
img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
img_norm = img / 255.0
img_batch = np.expand_dims(img_norm, axis=0)
prediction = model.predict(img_batch)[0] # (256, 256, 1)
mask = (prediction > 0.5).astype(np.uint8)
if mask.ndim == 3 and mask.shape[-1] == 1:
mask = np.squeeze(mask, axis=-1)
# Tumor stats
tumor_area = np.sum(mask)
total_area = IMG_HEIGHT * IMG_WIDTH
tumor_ratio = tumor_area / total_area
tumor_label = "Tumor Detected" if tumor_ratio > 0.005 else "No Tumor Detected"
# === Overlay mask on original ===
overlay = np.array(img) # original resized input
red_mask = np.zeros_like(overlay)
red_mask[..., 0] = mask * 255 # Red channel
overlay_img = np.clip(0.6 * overlay + 0.4 * red_mask, 0, 255).astype(np.uint8)
stats = {
"tumor_area": int(tumor_area),
"total_area": total_area,
"tumor_ratio": tumor_ratio,
"tumor_label": tumor_label
}
return overlay_img, stats
# === Gradio handler ===
def rishigpt_handler(image_input, groq_api_key):
os.environ["GROQ_API_KEY"] = groq_api_key
overlay_img, stats = classify_image_and_stats(image_input)
def segment_brain_tool(input_text: str) -> str:
return (
f"Tumor label: {stats['tumor_label']}. "
f"Tumor area: {stats['tumor_area']}. "
f"Ratio: {stats['tumor_ratio']:.4f}."
)
tool = StructuredTool.from_function(
segment_brain_tool,
name="segment_brain",
description="Provide tumor segmentation stats for the MRI image."
)
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0.4
)
agent = initialize_agent(
tools=[tool],
llm=llm,
agent="zero-shot-react-description",
verbose=False
)
user_query = "Give me the segmentation details"
classification = agent.run(user_query)
# Better prompt + output parser
prompt = PromptTemplate(
input_variables=["result"],
template=(
"You are a compassionate AI radiologist. "
"Read this tumor analysis result: {result}. "
"Summarize the situation like you're talking to the patient in calm, clear language. "
"Add any recommendations for next steps too, but keep it easy to understand."
)
)
chain = prompt | llm
description = chain.invoke({"result": classification}).content.strip()
return overlay_img, description
# === Gradio UI ===
inputs = [
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
gr.Textbox(type="password", label="Groq API Key")
]
outputs = [
gr.Image(type="numpy", label="Overlay: Brain MRI + Tumor Mask"),
gr.Textbox(label="Doctor's Explanation")
]
if __name__ == "__main__":
gr.Interface(
fn=rishigpt_handler,
inputs=inputs,
outputs=outputs,
title="🧠 RishiGPT Medical Brain Segmentation",
description="UNet++ Brain Tumor Segmentation with mask overlay, detailed stats, and human-like explanation."
).launch()
|