|
import gradio as gr |
|
import os |
|
import tensorflow as tf |
|
import numpy as np |
|
import requests |
|
|
|
from langchain_groq import ChatGroq |
|
from langchain.agents import initialize_agent |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
from langchain.tools import StructuredTool |
|
from tensorflow.keras.preprocessing import image |
|
|
|
|
|
model_path = "unet_model.h5" |
|
if not os.path.exists(model_path): |
|
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5" |
|
r = requests.get(hf_url) |
|
with open(model_path, "wb") as f: |
|
f.write(r.content) |
|
|
|
model = tf.keras.models.load_model(model_path, compile=False) |
|
|
|
|
|
def classify_image(image_input): |
|
img = tf.image.resize(image_input, (256, 256)) |
|
img = img / 255.0 |
|
img = np.expand_dims(img, axis=0) |
|
|
|
prediction = model.predict(img)[0] |
|
mask = (prediction > 0.5).astype(np.uint8) * 255 |
|
|
|
return mask |
|
|
|
|
|
def rishigpt_handler(image_input, groq_api_key): |
|
os.environ["GROQ_API_KEY"] = groq_api_key |
|
|
|
mask = classify_image(image_input) |
|
|
|
def classify_image_tool(img_path): |
|
return "Brain tumor mask generated." |
|
|
|
tool = StructuredTool.from_function( |
|
classify_image_tool, |
|
name="segment_brain", |
|
description="Segment brain MRI for tumor detection." |
|
) |
|
|
|
llm = ChatGroq( |
|
model="meta-llama/llama-4-scout-17b-16e-instruct", |
|
temperature=0.3 |
|
) |
|
|
|
agent = initialize_agent( |
|
tools=[tool], |
|
llm=llm, |
|
agent="zero-shot-react-description", |
|
verbose=True |
|
) |
|
|
|
user_query = "I uploaded a brain MRI. What does the segmentation say?" |
|
classification = agent.run(user_query) |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["result"], |
|
template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis." |
|
) |
|
|
|
llm_chain = LLMChain( |
|
llm=llm, |
|
prompt=prompt |
|
) |
|
|
|
description = llm_chain.run({"result": classification}) |
|
|
|
return mask, description |
|
|
|
|
|
inputs = [ |
|
gr.Image(type="numpy", label="Upload Brain MRI Slice"), |
|
gr.Textbox(type="password", label="Groq API Key") |
|
] |
|
|
|
outputs = [ |
|
gr.Image(type="numpy", label="Tumor Segmentation Mask"), |
|
gr.Textbox(label="Medical Explanation") |
|
] |
|
|
|
gr.Interface( |
|
fn=rishigpt_handler, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="RishiGPT Medical Brain Segmentation", |
|
description="UNet++ Brain Tumor Segmentation" |
|
).launch() |
|
|