import gradio as gr import os import tensorflow as tf import numpy as np import requests import time from langchain_groq import ChatGroq from langchain.agents import initialize_agent from langchain.prompts import PromptTemplate from langchain_core.runnables import RunnableSequence from langchain.tools import StructuredTool IMG_HEIGHT = 256 IMG_WIDTH = 256 model_path = "unet_model.h5" if not os.path.exists(model_path): hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5" print(f"Downloading model from {hf_url}...") with requests.get(hf_url, stream=True) as r: r.raise_for_status() with open(model_path, "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) print("Loading model...") model = tf.keras.models.load_model(model_path, compile=False) def classify_image_and_stats(image_input): img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH]) img_norm = img / 255.0 img_batch = np.expand_dims(img_norm, axis=0) prediction = model.predict(img_batch)[0] mask = (prediction > 0.5).astype(np.uint8) if mask.ndim == 3 and mask.shape[-1] == 1: mask = np.squeeze(mask, axis=-1) tumor_area = np.sum(mask) total_area = IMG_HEIGHT * IMG_WIDTH tumor_ratio = tumor_area / total_area tumor_label = "Tumor Detected" if tumor_ratio > 0.00385 else "No Tumor Detected" overlay = np.array(img) red_mask = np.zeros_like(overlay) red_mask[..., 0] = mask * 255 overlay_img = np.clip(0.6 * overlay + 0.4 * red_mask, 0, 255).astype(np.uint8) stats = { "tumor_area": int(tumor_area), "total_area": total_area, "tumor_ratio": tumor_ratio, "tumor_label": tumor_label } return overlay_img, stats def rishigpt_handler(image_input, groq_api_key): os.environ["GROQ_API_KEY"] = groq_api_key overlay_img, stats = classify_image_and_stats(image_input) def segment_brain_tool(input_text: str) -> str: return ( f"Tumor label: {stats['tumor_label']}. " f"Tumor area: {stats['tumor_area']}. " f"Ratio: {stats['tumor_ratio']:.4f}." ) tool = StructuredTool.from_function( segment_brain_tool, name="segment_brain", description="Provide tumor segmentation stats for the MRI image." ) llm = ChatGroq( model="meta-llama/llama-4-scout-17b-16e-instruct", temperature=0.4 ) agent = initialize_agent( tools=[tool], llm=llm, agent="zero-shot-react-description", verbose=False ) user_query = "Give me the segmentation details" classification = agent.run(user_query) prompt = PromptTemplate( input_variables=["result"], template=( "You are a compassionate AI radiologist. " "Read this tumor analysis result: {result}. " "Summarize the situation for the patient in natural paragraphs, calm, clear tone, with next steps." ) ) chain = prompt | llm final_text = chain.invoke({"result": classification}).content.strip() displayed_text = "" for char in final_text: displayed_text += char time.sleep(0.015) yield overlay_img, displayed_text inputs = [ gr.Image(type="numpy", label="Upload Brain MRI Slice"), gr.Textbox(type="password", label="Groq API Key") ] outputs = [ gr.Image(type="numpy", label="Overlay: Brain MRI + Tumor Mask"), gr.Textbox(label="Doctor's Explanation") ] if __name__ == "__main__": gr.Interface( fn=rishigpt_handler, inputs=inputs, outputs=outputs, title="RishiGPT Medical Brain Segmentation", description="UNet++ Brain Tumor Segmentation with live mask overlay, detailed stats, and human-like typing explanation." ).launch()