Paras Shah commited on
Commit
0466118
·
1 Parent(s): 0d17b56

Add cache optimization

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -14,11 +14,17 @@ from SingleTreePointCloudLoader import SingleTreePointCloudLoader
14
  gc.enable()
15
 
16
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
- with st.spinner("Loading PointNet++ model..."):
18
- checkpoint = torch.load('checkpoints/best_model.pth', map_location=torch.device(device))
19
- classifier = pn2.get_model(num_class=4, normal_channel=False)
20
- classifier.load_state_dict(checkpoint['model_state_dict'])
21
- classifier.eval()
 
 
 
 
 
 
22
 
23
  side_bg = "static/sidebar.png"
24
  side_bg_ext = "png"
@@ -160,6 +166,7 @@ if uploaded_file:
160
  proceed = st.button("Run model")
161
  except Exception as e:
162
  st.error(f"An error occured: {str(e)}")
 
163
 
164
  if proceed:
165
  try:
@@ -259,6 +266,7 @@ if proceed:
259
  st.write(f"**Height of tree: {height:.2f}m**")
260
  st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
261
  st.write(f"**DBH: {dbh:.2f}m**")
 
262
 
263
  except Exception as e:
264
  st.error(f"An error occured: {str(e)}")
 
14
  gc.enable()
15
 
16
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+ @st.cache_resource
19
+ def load_model():
20
+ with st.spinner("Loading PointNet++ model..."):
21
+ checkpoint = torch.load('checkpoints/best_model.pth', map_location=torch.device(device))
22
+ classifier = pn2.get_model(num_class=4, normal_channel=False)
23
+ classifier.load_state_dict(checkpoint['model_state_dict'])
24
+ classifier.eval()
25
+ return classifier
26
+
27
+ classifier = load_model()
28
 
29
  side_bg = "static/sidebar.png"
30
  side_bg_ext = "png"
 
166
  proceed = st.button("Run model")
167
  except Exception as e:
168
  st.error(f"An error occured: {str(e)}")
169
+ gc.collect() # Optimize after file is loaded
170
 
171
  if proceed:
172
  try:
 
266
  st.write(f"**Height of tree: {height:.2f}m**")
267
  st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
268
  st.write(f"**DBH: {dbh:.2f}m**")
269
+ gc.collect() # Optimize after inference is done
270
 
271
  except Exception as e:
272
  st.error(f"An error occured: {str(e)}")