Update app.py
Browse files
app.py
CHANGED
|
@@ -7,9 +7,12 @@ import requests
|
|
| 7 |
from langchain_groq import ChatGroq
|
| 8 |
from langchain.agents import initialize_agent
|
| 9 |
from langchain.prompts import PromptTemplate
|
| 10 |
-
from langchain_core.runnables import RunnableSequence #
|
| 11 |
from langchain.tools import StructuredTool
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
# === Download model if not exists ===
|
| 14 |
model_path = "unet_model.h5"
|
| 15 |
if not os.path.exists(model_path):
|
|
@@ -24,36 +27,52 @@ if not os.path.exists(model_path):
|
|
| 24 |
print("Loading model...")
|
| 25 |
model = tf.keras.models.load_model(model_path, compile=False)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 30 |
img = img / 255.0
|
| 31 |
img = np.expand_dims(img, axis=0)
|
| 32 |
|
| 33 |
-
prediction = model.predict(img)[0] # (256, 256, 1)
|
| 34 |
-
|
| 35 |
-
mask = (prediction > 0.5).astype(np.uint8) * 255 # binary mask
|
| 36 |
|
| 37 |
-
# Squeeze to (H, W) if needed
|
| 38 |
if mask.ndim == 3 and mask.shape[-1] == 1:
|
| 39 |
mask = np.squeeze(mask, axis=-1)
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
| 44 |
def rishigpt_handler(image_input, groq_api_key):
|
| 45 |
os.environ["GROQ_API_KEY"] = groq_api_key
|
| 46 |
|
| 47 |
-
mask =
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
return "A brain tumor mask was generated."
|
| 52 |
|
| 53 |
tool = StructuredTool.from_function(
|
| 54 |
segment_brain_tool,
|
| 55 |
name="segment_brain",
|
| 56 |
-
description="
|
| 57 |
)
|
| 58 |
|
| 59 |
llm = ChatGroq(
|
|
@@ -68,20 +87,23 @@ def rishigpt_handler(image_input, groq_api_key):
|
|
| 68 |
verbose=True
|
| 69 |
)
|
| 70 |
|
| 71 |
-
user_query = "
|
| 72 |
classification = agent.run(user_query)
|
| 73 |
|
| 74 |
-
# New style: RunnableSequence
|
| 75 |
prompt = PromptTemplate(
|
| 76 |
input_variables=["result"],
|
| 77 |
-
template=
|
|
|
|
|
|
|
|
|
|
| 78 |
)
|
| 79 |
-
chain = prompt | llm
|
| 80 |
|
|
|
|
| 81 |
description = chain.invoke({"result": classification})
|
| 82 |
|
| 83 |
return mask, description
|
| 84 |
|
|
|
|
| 85 |
# === Gradio UI ===
|
| 86 |
inputs = [
|
| 87 |
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
|
|
@@ -99,5 +121,5 @@ if __name__ == "__main__":
|
|
| 99 |
inputs=inputs,
|
| 100 |
outputs=outputs,
|
| 101 |
title="RishiGPT Medical Brain Segmentation",
|
| 102 |
-
description="UNet++ Brain Tumor Segmentation with LangChain integration"
|
| 103 |
).launch()
|
|
|
|
| 7 |
from langchain_groq import ChatGroq
|
| 8 |
from langchain.agents import initialize_agent
|
| 9 |
from langchain.prompts import PromptTemplate
|
| 10 |
+
from langchain_core.runnables import RunnableSequence # modern
|
| 11 |
from langchain.tools import StructuredTool
|
| 12 |
|
| 13 |
+
IMG_HEIGHT = 256
|
| 14 |
+
IMG_WIDTH = 256
|
| 15 |
+
|
| 16 |
# === Download model if not exists ===
|
| 17 |
model_path = "unet_model.h5"
|
| 18 |
if not os.path.exists(model_path):
|
|
|
|
| 27 |
print("Loading model...")
|
| 28 |
model = tf.keras.models.load_model(model_path, compile=False)
|
| 29 |
|
| 30 |
+
|
| 31 |
+
# === Segmentation + Stats ===
|
| 32 |
+
def classify_image_and_stats(image_input):
|
| 33 |
+
img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
|
| 34 |
img = img / 255.0
|
| 35 |
img = np.expand_dims(img, axis=0)
|
| 36 |
|
| 37 |
+
prediction = model.predict(img)[0] # (256, 256, 1)
|
| 38 |
+
mask = (prediction > 0.5).astype(np.uint8) * 255
|
|
|
|
| 39 |
|
|
|
|
| 40 |
if mask.ndim == 3 and mask.shape[-1] == 1:
|
| 41 |
mask = np.squeeze(mask, axis=-1)
|
| 42 |
|
| 43 |
+
# Tumor stats
|
| 44 |
+
tumor_area = np.sum(prediction > 0.5)
|
| 45 |
+
total_area = IMG_HEIGHT * IMG_WIDTH
|
| 46 |
+
tumor_ratio = tumor_area / total_area
|
| 47 |
+
|
| 48 |
+
if tumor_ratio > 0.01:
|
| 49 |
+
tumor_label = "Tumor Detected"
|
| 50 |
+
else:
|
| 51 |
+
tumor_label = "No Tumor Detected"
|
| 52 |
+
|
| 53 |
+
stats = {
|
| 54 |
+
"tumor_area": int(tumor_area),
|
| 55 |
+
"total_area": total_area,
|
| 56 |
+
"tumor_ratio": tumor_ratio,
|
| 57 |
+
"tumor_label": tumor_label
|
| 58 |
+
}
|
| 59 |
|
| 60 |
+
return mask, stats
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# === Gradio handler ===
|
| 64 |
def rishigpt_handler(image_input, groq_api_key):
|
| 65 |
os.environ["GROQ_API_KEY"] = groq_api_key
|
| 66 |
|
| 67 |
+
mask, stats = classify_image_and_stats(image_input)
|
| 68 |
|
| 69 |
+
def segment_brain_tool(input_text: str) -> str:
|
| 70 |
+
return f"Tumor label: {stats['tumor_label']}. Tumor area: {stats['tumor_area']}. Ratio: {stats['tumor_ratio']:.4f}."
|
|
|
|
| 71 |
|
| 72 |
tool = StructuredTool.from_function(
|
| 73 |
segment_brain_tool,
|
| 74 |
name="segment_brain",
|
| 75 |
+
description="Provide tumor segmentation stats for the MRI image. Takes dummy input text."
|
| 76 |
)
|
| 77 |
|
| 78 |
llm = ChatGroq(
|
|
|
|
| 87 |
verbose=True
|
| 88 |
)
|
| 89 |
|
| 90 |
+
user_query = "Get segmentation details"
|
| 91 |
classification = agent.run(user_query)
|
| 92 |
|
|
|
|
| 93 |
prompt = PromptTemplate(
|
| 94 |
input_variables=["result"],
|
| 95 |
+
template=(
|
| 96 |
+
"You are a medical imaging expert. Based on this tumor analysis result: {result}, "
|
| 97 |
+
"explain what this means for the patient in simple language."
|
| 98 |
+
)
|
| 99 |
)
|
|
|
|
| 100 |
|
| 101 |
+
chain = prompt | llm
|
| 102 |
description = chain.invoke({"result": classification})
|
| 103 |
|
| 104 |
return mask, description
|
| 105 |
|
| 106 |
+
|
| 107 |
# === Gradio UI ===
|
| 108 |
inputs = [
|
| 109 |
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
|
|
|
|
| 121 |
inputs=inputs,
|
| 122 |
outputs=outputs,
|
| 123 |
title="RishiGPT Medical Brain Segmentation",
|
| 124 |
+
description="UNet++ Brain Tumor Segmentation with LangChain integration. Includes tumor stats and medical explanation."
|
| 125 |
).launch()
|