|
import gradio as gr |
|
import os |
|
import tensorflow as tf |
|
import numpy as np |
|
import requests |
|
import time |
|
|
|
from langchain_groq import ChatGroq |
|
from langchain.agents import initialize_agent |
|
from langchain.prompts import PromptTemplate |
|
from langchain_core.runnables import RunnableSequence |
|
from langchain.tools import StructuredTool |
|
|
|
|
|
IMG_HEIGHT = 256 |
|
IMG_WIDTH = 256 |
|
|
|
|
|
model_path = "unet_model.h5" |
|
if not os.path.exists(model_path): |
|
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5" |
|
print(f"Downloading model from {hf_url}...") |
|
with requests.get(hf_url, stream=True) as r: |
|
r.raise_for_status() |
|
with open(model_path, "wb") as f: |
|
for chunk in r.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
|
|
print("Loading model...") |
|
model = tf.keras.models.load_model(model_path, compile=False) |
|
|
|
|
|
|
|
def classify_image_and_stats(image_input): |
|
img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH]) |
|
img_norm = img / 255.0 |
|
img_batch = np.expand_dims(img_norm, axis=0) |
|
|
|
prediction = model.predict(img_batch)[0] |
|
mask = (prediction > 0.5).astype(np.uint8) |
|
|
|
if mask.ndim == 3 and mask.shape[-1] == 1: |
|
mask = np.squeeze(mask, axis=-1) |
|
|
|
tumor_area = np.sum(mask) |
|
total_area = IMG_HEIGHT * IMG_WIDTH |
|
tumor_ratio = tumor_area / total_area |
|
|
|
tumor_label = "Tumor Detected" if tumor_ratio > 0.00385 else "No Tumor Detected" |
|
|
|
overlay = np.array(img) |
|
red_mask = np.zeros_like(overlay) |
|
red_mask[..., 0] = mask * 255 |
|
|
|
overlay_img = np.clip(0.6 * overlay + 0.4 * red_mask, 0, 255).astype(np.uint8) |
|
|
|
stats = { |
|
"tumor_area": int(tumor_area), |
|
"total_area": total_area, |
|
"tumor_ratio": tumor_ratio, |
|
"tumor_label": tumor_label |
|
} |
|
|
|
return overlay_img, stats |
|
|
|
|
|
|
|
def rishigpt_handler(image_input, groq_api_key): |
|
os.environ["GROQ_API_KEY"] = groq_api_key |
|
|
|
overlay_img, stats = classify_image_and_stats(image_input) |
|
|
|
def segment_brain_tool(input_text: str) -> str: |
|
return ( |
|
f"Tumor label: {stats['tumor_label']}. " |
|
f"Tumor area: {stats['tumor_area']}. " |
|
f"Ratio: {stats['tumor_ratio']:.4f}." |
|
) |
|
|
|
tool = StructuredTool.from_function( |
|
segment_brain_tool, |
|
name="segment_brain", |
|
description="Provide tumor segmentation stats for the MRI image." |
|
) |
|
|
|
llm = ChatGroq( |
|
model="meta-llama/llama-4-scout-17b-16e-instruct", |
|
temperature=0.4 |
|
) |
|
|
|
agent = initialize_agent( |
|
tools=[tool], |
|
llm=llm, |
|
agent="zero-shot-react-description", |
|
verbose=False |
|
) |
|
|
|
user_query = "Give me the segmentation details" |
|
classification = agent.run(user_query) |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["result"], |
|
template=( |
|
"You are a compassionate AI radiologist. " |
|
"Read this tumor analysis result: {result}. " |
|
"Summarize the situation for the patient in natural paragraphs, calm, clear tone, with next steps." |
|
) |
|
) |
|
|
|
chain = prompt | llm |
|
final_text = chain.invoke({"result": classification}).content.strip() |
|
|
|
|
|
displayed_text = "" |
|
for char in final_text: |
|
displayed_text += char |
|
time.sleep(0.015) |
|
yield overlay_img, displayed_text |
|
|
|
|
|
inputs = [ |
|
gr.Image(type="numpy", label="Upload Brain MRI Slice"), |
|
gr.Textbox(type="password", label="Groq API Key") |
|
] |
|
|
|
outputs = [ |
|
gr.Image(type="numpy", label="Overlay: Brain MRI + Tumor Mask"), |
|
gr.Textbox(label="Doctor's Explanation") |
|
] |
|
|
|
if __name__ == "__main__": |
|
gr.Interface( |
|
fn=rishigpt_handler, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="RishiGPT Medical Brain Segmentation", |
|
description="UNet++ Brain Tumor Segmentation with live mask overlay, detailed stats, and human-like typing explanation." |
|
).launch() |
|
|