File size: 3,915 Bytes
ec387c2 270b631 ec387c2 64183cb ec387c2 448426c 64183cb ec387c2 448426c 755c3f6 27b7bc3 ec387c2 82768ad b31004e ec387c2 755c3f6 27b7bc3 755c3f6 448426c ec387c2 270b631 448426c ec387c2 64183cb ec387c2 448426c 755c3f6 27b7bc3 448426c 270b631 448426c 270b631 448426c 755c3f6 64183cb 448426c 755c3f6 27b7bc3 ec387c2 448426c ec387c2 755c3f6 448426c 64183cb 448426c 64183cb ec387c2 448426c ec387c2 64183cb 448426c 64183cb 448426c 64183cb ec387c2 755c3f6 448426c 270b631 755c3f6 ec387c2 755c3f6 270b631 ec387c2 27b7bc3 270b631 755c3f6 27b7bc3 ec387c2 448426c eb80dfb ec387c2 b31004e 270b631 b31004e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests
import time
from langchain_groq import ChatGroq
from langchain.agents import initialize_agent
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain.tools import StructuredTool
IMG_HEIGHT = 256
IMG_WIDTH = 256
model_path = "unet_model.h5"
if not os.path.exists(model_path):
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
print(f"Downloading model from {hf_url}...")
with requests.get(hf_url, stream=True) as r:
r.raise_for_status()
with open(model_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("Loading model...")
model = tf.keras.models.load_model(model_path, compile=False)
def classify_image_and_stats(image_input):
img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
img_norm = img / 255.0
img_batch = np.expand_dims(img_norm, axis=0)
prediction = model.predict(img_batch)[0]
mask = (prediction > 0.5).astype(np.uint8)
if mask.ndim == 3 and mask.shape[-1] == 1:
mask = np.squeeze(mask, axis=-1)
tumor_area = np.sum(mask)
total_area = IMG_HEIGHT * IMG_WIDTH
tumor_ratio = tumor_area / total_area
tumor_label = "Tumor Detected" if tumor_ratio > 0.00385 else "No Tumor Detected"
overlay = np.array(img)
red_mask = np.zeros_like(overlay)
red_mask[..., 0] = mask * 255
overlay_img = np.clip(0.6 * overlay + 0.4 * red_mask, 0, 255).astype(np.uint8)
stats = {
"tumor_area": int(tumor_area),
"total_area": total_area,
"tumor_ratio": tumor_ratio,
"tumor_label": tumor_label
}
return overlay_img, stats
def rishigpt_handler(image_input, groq_api_key):
os.environ["GROQ_API_KEY"] = groq_api_key
overlay_img, stats = classify_image_and_stats(image_input)
def segment_brain_tool(input_text: str) -> str:
return (
f"Tumor label: {stats['tumor_label']}. "
f"Tumor area: {stats['tumor_area']}. "
f"Ratio: {stats['tumor_ratio']:.4f}."
)
tool = StructuredTool.from_function(
segment_brain_tool,
name="segment_brain",
description="Provide tumor segmentation stats for the MRI image."
)
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0.4
)
agent = initialize_agent(
tools=[tool],
llm=llm,
agent="zero-shot-react-description",
verbose=False
)
user_query = "Give me the segmentation details"
classification = agent.run(user_query)
prompt = PromptTemplate(
input_variables=["result"],
template=(
"You are a compassionate AI radiologist. "
"Read this tumor analysis result: {result}. "
"Summarize the situation for the patient in natural paragraphs, calm, clear tone, with next steps."
)
)
chain = prompt | llm
final_text = chain.invoke({"result": classification}).content.strip()
displayed_text = ""
for char in final_text:
displayed_text += char
time.sleep(0.015)
yield overlay_img, displayed_text
inputs = [
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
gr.Textbox(type="password", label="Groq API Key")
]
outputs = [
gr.Image(type="numpy", label="Overlay: Brain MRI + Tumor Mask"),
gr.Textbox(label="Doctor's Explanation")
]
if __name__ == "__main__":
gr.Interface(
fn=rishigpt_handler,
inputs=inputs,
outputs=outputs,
title="RishiGPT Medical Brain Segmentation",
description="UNet++ Brain Tumor Segmentation with live mask overlay, detailed stats, and human-like typing explanation."
).launch()
|