import gradio as gr import os import tensorflow as tf import numpy as np import requests from langchain_groq import ChatGroq from langchain.prompts import PromptTemplate from langchain.chains import LLMChain model_path = "unet_model.h5" # Safe download with streaming to avoid incomplete file if not os.path.exists(model_path): hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5" print(f"Downloading model from {hf_url}...") with requests.get(hf_url, stream=True) as r: r.raise_for_status() with open(model_path, "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) print("Loading model...") model = tf.keras.models.load_model(model_path, compile=False) def classify_image(image_input): img = tf.image.resize(image_input, (256, 256)) img = img / 255.0 img = np.expand_dims(img, axis=0) prediction = model.predict(img)[0] mask = (prediction > 0.5).astype(np.uint8) * 255 return mask def rishigpt_handler(image_input, groq_api_key): os.environ["GROQ_API_KEY"] = groq_api_key mask = classify_image(image_input) llm = ChatGroq( model="meta-llama/llama-4-scout-17b-16e-instruct", temperature=0.3 ) prompt = PromptTemplate( input_variables=["result"], template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis." ) llm_chain = LLMChain( llm=llm, prompt=prompt ) classification = "The brain tumor mask has been generated and segmentation is complete." description = llm_chain.run({"result": classification}) return mask, description inputs = [ gr.Image(type="numpy", label="Upload Brain MRI Slice"), gr.Textbox(type="password", label="Groq API Key") ] outputs = [ gr.Image(type="numpy", label="Tumor Segmentation Mask"), gr.Textbox(label="Medical Explanation") ] if __name__ == "__main__": gr.Interface( fn=rishigpt_handler, inputs=inputs, outputs=outputs, title="RishiGPT Medical Brain Segmentation", description="UNet++ Brain Tumor Segmentation with LangChain integration" ).launch()