rishirajbal commited on
Commit
b31004e
·
verified ·
1 Parent(s): d7e726d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -9,16 +9,21 @@ from langchain.agents import initialize_agent
9
  from langchain.prompts import PromptTemplate
10
  from langchain.chains import LLMChain
11
  from langchain.tools import StructuredTool
12
- from tensorflow.keras.preprocessing import image
13
 
14
 
15
  model_path = "unet_model.h5"
 
 
16
  if not os.path.exists(model_path):
17
  hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
18
- r = requests.get(hf_url)
19
- with open(model_path, "wb") as f:
20
- f.write(r.content)
21
-
 
 
 
 
22
  model = tf.keras.models.load_model(model_path, compile=False)
23
 
24
 
@@ -38,11 +43,12 @@ def rishigpt_handler(image_input, groq_api_key):
38
 
39
  mask = classify_image(image_input)
40
 
41
- def classify_image_tool(img_path):
42
- return "Brain tumor mask generated."
 
43
 
44
  tool = StructuredTool.from_function(
45
- classify_image_tool,
46
  name="segment_brain",
47
  description="Segment brain MRI for tumor detection."
48
  )
@@ -87,10 +93,11 @@ outputs = [
87
  gr.Textbox(label="Medical Explanation")
88
  ]
89
 
90
- gr.Interface(
91
- fn=rishigpt_handler,
92
- inputs=inputs,
93
- outputs=outputs,
94
- title="RishiGPT Medical Brain Segmentation",
95
- description="UNet++ Brain Tumor Segmentation"
96
- ).launch()
 
 
9
  from langchain.prompts import PromptTemplate
10
  from langchain.chains import LLMChain
11
  from langchain.tools import StructuredTool
 
12
 
13
 
14
  model_path = "unet_model.h5"
15
+
16
+ # Safe download with streaming to avoid incomplete file
17
  if not os.path.exists(model_path):
18
  hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
19
+ print(f"Downloading model from {hf_url}...")
20
+ with requests.get(hf_url, stream=True) as r:
21
+ r.raise_for_status()
22
+ with open(model_path, "wb") as f:
23
+ for chunk in r.iter_content(chunk_size=8192):
24
+ f.write(chunk)
25
+
26
+ print("Loading model...")
27
  model = tf.keras.models.load_model(model_path, compile=False)
28
 
29
 
 
43
 
44
  mask = classify_image(image_input)
45
 
46
+ # The LLM tool just reports a dummy text here for now
47
+ def segment_brain_tool():
48
+ return "A brain tumor mask was generated."
49
 
50
  tool = StructuredTool.from_function(
51
+ segment_brain_tool,
52
  name="segment_brain",
53
  description="Segment brain MRI for tumor detection."
54
  )
 
93
  gr.Textbox(label="Medical Explanation")
94
  ]
95
 
96
+ if __name__ == "__main__":
97
+ gr.Interface(
98
+ fn=rishigpt_handler,
99
+ inputs=inputs,
100
+ outputs=outputs,
101
+ title="RishiGPT Medical Brain Segmentation",
102
+ description="UNet++ Brain Tumor Segmentation with LangChain integration"
103
+ ).launch()