Update app.py
Browse files
app.py
CHANGED
@@ -9,16 +9,21 @@ from langchain.agents import initialize_agent
|
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
from langchain.chains import LLMChain
|
11 |
from langchain.tools import StructuredTool
|
12 |
-
from tensorflow.keras.preprocessing import image
|
13 |
|
14 |
|
15 |
model_path = "unet_model.h5"
|
|
|
|
|
16 |
if not os.path.exists(model_path):
|
17 |
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
|
18 |
-
|
19 |
-
with
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
model = tf.keras.models.load_model(model_path, compile=False)
|
23 |
|
24 |
|
@@ -38,11 +43,12 @@ def rishigpt_handler(image_input, groq_api_key):
|
|
38 |
|
39 |
mask = classify_image(image_input)
|
40 |
|
41 |
-
|
42 |
-
|
|
|
43 |
|
44 |
tool = StructuredTool.from_function(
|
45 |
-
|
46 |
name="segment_brain",
|
47 |
description="Segment brain MRI for tumor detection."
|
48 |
)
|
@@ -87,10 +93,11 @@ outputs = [
|
|
87 |
gr.Textbox(label="Medical Explanation")
|
88 |
]
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
from langchain.chains import LLMChain
|
11 |
from langchain.tools import StructuredTool
|
|
|
12 |
|
13 |
|
14 |
model_path = "unet_model.h5"
|
15 |
+
|
16 |
+
# Safe download with streaming to avoid incomplete file
|
17 |
if not os.path.exists(model_path):
|
18 |
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
|
19 |
+
print(f"Downloading model from {hf_url}...")
|
20 |
+
with requests.get(hf_url, stream=True) as r:
|
21 |
+
r.raise_for_status()
|
22 |
+
with open(model_path, "wb") as f:
|
23 |
+
for chunk in r.iter_content(chunk_size=8192):
|
24 |
+
f.write(chunk)
|
25 |
+
|
26 |
+
print("Loading model...")
|
27 |
model = tf.keras.models.load_model(model_path, compile=False)
|
28 |
|
29 |
|
|
|
43 |
|
44 |
mask = classify_image(image_input)
|
45 |
|
46 |
+
# The LLM tool just reports a dummy text here for now
|
47 |
+
def segment_brain_tool():
|
48 |
+
return "A brain tumor mask was generated."
|
49 |
|
50 |
tool = StructuredTool.from_function(
|
51 |
+
segment_brain_tool,
|
52 |
name="segment_brain",
|
53 |
description="Segment brain MRI for tumor detection."
|
54 |
)
|
|
|
93 |
gr.Textbox(label="Medical Explanation")
|
94 |
]
|
95 |
|
96 |
+
if __name__ == "__main__":
|
97 |
+
gr.Interface(
|
98 |
+
fn=rishigpt_handler,
|
99 |
+
inputs=inputs,
|
100 |
+
outputs=outputs,
|
101 |
+
title="RishiGPT Medical Brain Segmentation",
|
102 |
+
description="UNet++ Brain Tumor Segmentation with LangChain integration"
|
103 |
+
).launch()
|