Update app.py
Browse files
app.py
CHANGED
@@ -7,9 +7,12 @@ import requests
|
|
7 |
from langchain_groq import ChatGroq
|
8 |
from langchain.agents import initialize_agent
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
-
from langchain_core.runnables import RunnableSequence #
|
11 |
from langchain.tools import StructuredTool
|
12 |
|
|
|
|
|
|
|
13 |
# === Download model if not exists ===
|
14 |
model_path = "unet_model.h5"
|
15 |
if not os.path.exists(model_path):
|
@@ -24,36 +27,52 @@ if not os.path.exists(model_path):
|
|
24 |
print("Loading model...")
|
25 |
model = tf.keras.models.load_model(model_path, compile=False)
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
30 |
img = img / 255.0
|
31 |
img = np.expand_dims(img, axis=0)
|
32 |
|
33 |
-
prediction = model.predict(img)[0] # (256, 256, 1)
|
34 |
-
|
35 |
-
mask = (prediction > 0.5).astype(np.uint8) * 255 # binary mask
|
36 |
|
37 |
-
# Squeeze to (H, W) if needed
|
38 |
if mask.ndim == 3 and mask.shape[-1] == 1:
|
39 |
mask = np.squeeze(mask, axis=-1)
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
|
|
|
|
|
|
44 |
def rishigpt_handler(image_input, groq_api_key):
|
45 |
os.environ["GROQ_API_KEY"] = groq_api_key
|
46 |
|
47 |
-
mask =
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
return "A brain tumor mask was generated."
|
52 |
|
53 |
tool = StructuredTool.from_function(
|
54 |
segment_brain_tool,
|
55 |
name="segment_brain",
|
56 |
-
description="
|
57 |
)
|
58 |
|
59 |
llm = ChatGroq(
|
@@ -68,20 +87,23 @@ def rishigpt_handler(image_input, groq_api_key):
|
|
68 |
verbose=True
|
69 |
)
|
70 |
|
71 |
-
user_query = "
|
72 |
classification = agent.run(user_query)
|
73 |
|
74 |
-
# New style: RunnableSequence
|
75 |
prompt = PromptTemplate(
|
76 |
input_variables=["result"],
|
77 |
-
template=
|
|
|
|
|
|
|
78 |
)
|
79 |
-
chain = prompt | llm
|
80 |
|
|
|
81 |
description = chain.invoke({"result": classification})
|
82 |
|
83 |
return mask, description
|
84 |
|
|
|
85 |
# === Gradio UI ===
|
86 |
inputs = [
|
87 |
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
|
@@ -99,5 +121,5 @@ if __name__ == "__main__":
|
|
99 |
inputs=inputs,
|
100 |
outputs=outputs,
|
101 |
title="RishiGPT Medical Brain Segmentation",
|
102 |
-
description="UNet++ Brain Tumor Segmentation with LangChain integration"
|
103 |
).launch()
|
|
|
7 |
from langchain_groq import ChatGroq
|
8 |
from langchain.agents import initialize_agent
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
+
from langchain_core.runnables import RunnableSequence # modern
|
11 |
from langchain.tools import StructuredTool
|
12 |
|
13 |
+
IMG_HEIGHT = 256
|
14 |
+
IMG_WIDTH = 256
|
15 |
+
|
16 |
# === Download model if not exists ===
|
17 |
model_path = "unet_model.h5"
|
18 |
if not os.path.exists(model_path):
|
|
|
27 |
print("Loading model...")
|
28 |
model = tf.keras.models.load_model(model_path, compile=False)
|
29 |
|
30 |
+
|
31 |
+
# === Segmentation + Stats ===
|
32 |
+
def classify_image_and_stats(image_input):
|
33 |
+
img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
|
34 |
img = img / 255.0
|
35 |
img = np.expand_dims(img, axis=0)
|
36 |
|
37 |
+
prediction = model.predict(img)[0] # (256, 256, 1)
|
38 |
+
mask = (prediction > 0.5).astype(np.uint8) * 255
|
|
|
39 |
|
|
|
40 |
if mask.ndim == 3 and mask.shape[-1] == 1:
|
41 |
mask = np.squeeze(mask, axis=-1)
|
42 |
|
43 |
+
# Tumor stats
|
44 |
+
tumor_area = np.sum(prediction > 0.5)
|
45 |
+
total_area = IMG_HEIGHT * IMG_WIDTH
|
46 |
+
tumor_ratio = tumor_area / total_area
|
47 |
+
|
48 |
+
if tumor_ratio > 0.01:
|
49 |
+
tumor_label = "Tumor Detected"
|
50 |
+
else:
|
51 |
+
tumor_label = "No Tumor Detected"
|
52 |
+
|
53 |
+
stats = {
|
54 |
+
"tumor_area": int(tumor_area),
|
55 |
+
"total_area": total_area,
|
56 |
+
"tumor_ratio": tumor_ratio,
|
57 |
+
"tumor_label": tumor_label
|
58 |
+
}
|
59 |
|
60 |
+
return mask, stats
|
61 |
+
|
62 |
+
|
63 |
+
# === Gradio handler ===
|
64 |
def rishigpt_handler(image_input, groq_api_key):
|
65 |
os.environ["GROQ_API_KEY"] = groq_api_key
|
66 |
|
67 |
+
mask, stats = classify_image_and_stats(image_input)
|
68 |
|
69 |
+
def segment_brain_tool(input_text: str) -> str:
|
70 |
+
return f"Tumor label: {stats['tumor_label']}. Tumor area: {stats['tumor_area']}. Ratio: {stats['tumor_ratio']:.4f}."
|
|
|
71 |
|
72 |
tool = StructuredTool.from_function(
|
73 |
segment_brain_tool,
|
74 |
name="segment_brain",
|
75 |
+
description="Provide tumor segmentation stats for the MRI image. Takes dummy input text."
|
76 |
)
|
77 |
|
78 |
llm = ChatGroq(
|
|
|
87 |
verbose=True
|
88 |
)
|
89 |
|
90 |
+
user_query = "Get segmentation details"
|
91 |
classification = agent.run(user_query)
|
92 |
|
|
|
93 |
prompt = PromptTemplate(
|
94 |
input_variables=["result"],
|
95 |
+
template=(
|
96 |
+
"You are a medical imaging expert. Based on this tumor analysis result: {result}, "
|
97 |
+
"explain what this means for the patient in simple language."
|
98 |
+
)
|
99 |
)
|
|
|
100 |
|
101 |
+
chain = prompt | llm
|
102 |
description = chain.invoke({"result": classification})
|
103 |
|
104 |
return mask, description
|
105 |
|
106 |
+
|
107 |
# === Gradio UI ===
|
108 |
inputs = [
|
109 |
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
|
|
|
121 |
inputs=inputs,
|
122 |
outputs=outputs,
|
123 |
title="RishiGPT Medical Brain Segmentation",
|
124 |
+
description="UNet++ Brain Tumor Segmentation with LangChain integration. Includes tumor stats and medical explanation."
|
125 |
).launch()
|