rishirajbal's picture
Update app.py
27b7bc3 verified
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests
import time
from langchain_groq import ChatGroq
from langchain.agents import initialize_agent
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain.tools import StructuredTool
IMG_HEIGHT = 256
IMG_WIDTH = 256
model_path = "unet_model.h5"
if not os.path.exists(model_path):
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
print(f"Downloading model from {hf_url}...")
with requests.get(hf_url, stream=True) as r:
r.raise_for_status()
with open(model_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("Loading model...")
model = tf.keras.models.load_model(model_path, compile=False)
def classify_image_and_stats(image_input):
img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
img_norm = img / 255.0
img_batch = np.expand_dims(img_norm, axis=0)
prediction = model.predict(img_batch)[0]
mask = (prediction > 0.5).astype(np.uint8)
if mask.ndim == 3 and mask.shape[-1] == 1:
mask = np.squeeze(mask, axis=-1)
tumor_area = np.sum(mask)
total_area = IMG_HEIGHT * IMG_WIDTH
tumor_ratio = tumor_area / total_area
tumor_label = "Tumor Detected" if tumor_ratio > 0.00385 else "No Tumor Detected"
overlay = np.array(img)
red_mask = np.zeros_like(overlay)
red_mask[..., 0] = mask * 255
overlay_img = np.clip(0.6 * overlay + 0.4 * red_mask, 0, 255).astype(np.uint8)
stats = {
"tumor_area": int(tumor_area),
"total_area": total_area,
"tumor_ratio": tumor_ratio,
"tumor_label": tumor_label
}
return overlay_img, stats
def rishigpt_handler(image_input, groq_api_key):
os.environ["GROQ_API_KEY"] = groq_api_key
overlay_img, stats = classify_image_and_stats(image_input)
def segment_brain_tool(input_text: str) -> str:
return (
f"Tumor label: {stats['tumor_label']}. "
f"Tumor area: {stats['tumor_area']}. "
f"Ratio: {stats['tumor_ratio']:.4f}."
)
tool = StructuredTool.from_function(
segment_brain_tool,
name="segment_brain",
description="Provide tumor segmentation stats for the MRI image."
)
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0.4
)
agent = initialize_agent(
tools=[tool],
llm=llm,
agent="zero-shot-react-description",
verbose=False
)
user_query = "Give me the segmentation details"
classification = agent.run(user_query)
prompt = PromptTemplate(
input_variables=["result"],
template=(
"You are a compassionate AI radiologist. "
"Read this tumor analysis result: {result}. "
"Summarize the situation for the patient in natural paragraphs, calm, clear tone, with next steps."
)
)
chain = prompt | llm
final_text = chain.invoke({"result": classification}).content.strip()
displayed_text = ""
for char in final_text:
displayed_text += char
time.sleep(0.015)
yield overlay_img, displayed_text
inputs = [
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
gr.Textbox(type="password", label="Groq API Key")
]
outputs = [
gr.Image(type="numpy", label="Overlay: Brain MRI + Tumor Mask"),
gr.Textbox(label="Doctor's Explanation")
]
if __name__ == "__main__":
gr.Interface(
fn=rishigpt_handler,
inputs=inputs,
outputs=outputs,
title="RishiGPT Medical Brain Segmentation",
description="UNet++ Brain Tumor Segmentation with live mask overlay, detailed stats, and human-like typing explanation."
).launch()