rishirajbal's picture
Create app.py
ec387c2 verified
raw
history blame
2.5 kB
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests
from langchain_groq import ChatGroq
from langchain.agents import initialize_agent
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.tools import StructuredTool
from tensorflow.keras.preprocessing import image
model_path = "unet_model.h5"
if not os.path.exists(model_path):
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
r = requests.get(hf_url)
with open(model_path, "wb") as f:
f.write(r.content)
model = tf.keras.models.load_model(model_path, compile=False)
def classify_image(image_input):
img = tf.image.resize(image_input, (256, 256))
img = img / 255.0
img = np.expand_dims(img, axis=0)
prediction = model.predict(img)[0]
mask = (prediction > 0.5).astype(np.uint8) * 255
return mask
def rishigpt_handler(image_input, groq_api_key):
os.environ["GROQ_API_KEY"] = groq_api_key
mask = classify_image(image_input)
def classify_image_tool(img_path):
return "Brain tumor mask generated."
tool = StructuredTool.from_function(
classify_image_tool,
name="segment_brain",
description="Segment brain MRI for tumor detection."
)
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0.3
)
agent = initialize_agent(
tools=[tool],
llm=llm,
agent="zero-shot-react-description",
verbose=True
)
user_query = "I uploaded a brain MRI. What does the segmentation say?"
classification = agent.run(user_query)
prompt = PromptTemplate(
input_variables=["result"],
template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis."
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt
)
description = llm_chain.run({"result": classification})
return mask, description
inputs = [
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
gr.Textbox(type="password", label="Groq API Key")
]
outputs = [
gr.Image(type="numpy", label="Tumor Segmentation Mask"),
gr.Textbox(label="Medical Explanation")
]
gr.Interface(
fn=rishigpt_handler,
inputs=inputs,
outputs=outputs,
title="RishiGPT Medical Brain Segmentation",
description="UNet++ Brain Tumor Segmentation"
).launch()