Paras Shah
commited on
Commit
·
0466118
1
Parent(s):
0d17b56
Add cache optimization
Browse files
app.py
CHANGED
@@ -14,11 +14,17 @@ from SingleTreePointCloudLoader import SingleTreePointCloudLoader
|
|
14 |
gc.enable()
|
15 |
|
16 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
side_bg = "static/sidebar.png"
|
24 |
side_bg_ext = "png"
|
@@ -160,6 +166,7 @@ if uploaded_file:
|
|
160 |
proceed = st.button("Run model")
|
161 |
except Exception as e:
|
162 |
st.error(f"An error occured: {str(e)}")
|
|
|
163 |
|
164 |
if proceed:
|
165 |
try:
|
@@ -259,6 +266,7 @@ if proceed:
|
|
259 |
st.write(f"**Height of tree: {height:.2f}m**")
|
260 |
st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
|
261 |
st.write(f"**DBH: {dbh:.2f}m**")
|
|
|
262 |
|
263 |
except Exception as e:
|
264 |
st.error(f"An error occured: {str(e)}")
|
|
|
14 |
gc.enable()
|
15 |
|
16 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
17 |
+
|
18 |
+
@st.cache_resource
|
19 |
+
def load_model():
|
20 |
+
with st.spinner("Loading PointNet++ model..."):
|
21 |
+
checkpoint = torch.load('checkpoints/best_model.pth', map_location=torch.device(device))
|
22 |
+
classifier = pn2.get_model(num_class=4, normal_channel=False)
|
23 |
+
classifier.load_state_dict(checkpoint['model_state_dict'])
|
24 |
+
classifier.eval()
|
25 |
+
return classifier
|
26 |
+
|
27 |
+
classifier = load_model()
|
28 |
|
29 |
side_bg = "static/sidebar.png"
|
30 |
side_bg_ext = "png"
|
|
|
166 |
proceed = st.button("Run model")
|
167 |
except Exception as e:
|
168 |
st.error(f"An error occured: {str(e)}")
|
169 |
+
gc.collect() # Optimize after file is loaded
|
170 |
|
171 |
if proceed:
|
172 |
try:
|
|
|
266 |
st.write(f"**Height of tree: {height:.2f}m**")
|
267 |
st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
|
268 |
st.write(f"**DBH: {dbh:.2f}m**")
|
269 |
+
gc.collect() # Optimize after inference is done
|
270 |
|
271 |
except Exception as e:
|
272 |
st.error(f"An error occured: {str(e)}")
|