File size: 3,645 Bytes
ec387c2 64183cb ec387c2 755c3f6 64183cb ec387c2 755c3f6 64183cb ec387c2 82768ad b31004e ec387c2 755c3f6 ec387c2 755c3f6 ec387c2 64183cb ec387c2 755c3f6 64183cb 755c3f6 ec387c2 755c3f6 ec387c2 755c3f6 64183cb 755c3f6 64183cb ec387c2 64183cb 755c3f6 64183cb ec387c2 755c3f6 ec387c2 755c3f6 64183cb ec387c2 755c3f6 64183cb ec387c2 b31004e 755c3f6 b31004e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests
from langchain_groq import ChatGroq
from langchain.agents import initialize_agent
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence # modern
from langchain.tools import StructuredTool
IMG_HEIGHT = 256
IMG_WIDTH = 256
# === Download model if not exists ===
model_path = "unet_model.h5"
if not os.path.exists(model_path):
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
print(f"Downloading model from {hf_url}...")
with requests.get(hf_url, stream=True) as r:
r.raise_for_status()
with open(model_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("Loading model...")
model = tf.keras.models.load_model(model_path, compile=False)
# === Segmentation + Stats ===
def classify_image_and_stats(image_input):
img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
img = img / 255.0
img = np.expand_dims(img, axis=0)
prediction = model.predict(img)[0] # (256, 256, 1)
mask = (prediction > 0.5).astype(np.uint8) * 255
if mask.ndim == 3 and mask.shape[-1] == 1:
mask = np.squeeze(mask, axis=-1)
# Tumor stats
tumor_area = np.sum(prediction > 0.5)
total_area = IMG_HEIGHT * IMG_WIDTH
tumor_ratio = tumor_area / total_area
if tumor_ratio > 0.01:
tumor_label = "Tumor Detected"
else:
tumor_label = "No Tumor Detected"
stats = {
"tumor_area": int(tumor_area),
"total_area": total_area,
"tumor_ratio": tumor_ratio,
"tumor_label": tumor_label
}
return mask, stats
# === Gradio handler ===
def rishigpt_handler(image_input, groq_api_key):
os.environ["GROQ_API_KEY"] = groq_api_key
mask, stats = classify_image_and_stats(image_input)
def segment_brain_tool(input_text: str) -> str:
return f"Tumor label: {stats['tumor_label']}. Tumor area: {stats['tumor_area']}. Ratio: {stats['tumor_ratio']:.4f}."
tool = StructuredTool.from_function(
segment_brain_tool,
name="segment_brain",
description="Provide tumor segmentation stats for the MRI image. Takes dummy input text."
)
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0.3
)
agent = initialize_agent(
tools=[tool],
llm=llm,
agent="zero-shot-react-description",
verbose=True
)
user_query = "Get segmentation details"
classification = agent.run(user_query)
prompt = PromptTemplate(
input_variables=["result"],
template=(
"You are a medical imaging expert. Based on this tumor analysis result: {result}, "
"explain what this means for the patient in simple language."
)
)
chain = prompt | llm
description = chain.invoke({"result": classification})
return mask, description
# === Gradio UI ===
inputs = [
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
gr.Textbox(type="password", label="Groq API Key")
]
outputs = [
gr.Image(type="numpy", label="Tumor Segmentation Mask"),
gr.Textbox(label="Medical Explanation")
]
if __name__ == "__main__":
gr.Interface(
fn=rishigpt_handler,
inputs=inputs,
outputs=outputs,
title="RishiGPT Medical Brain Segmentation",
description="UNet++ Brain Tumor Segmentation with LangChain integration. Includes tumor stats and medical explanation."
).launch()
|