File size: 5,241 Bytes
26e5c1d c4592bc 26e5c1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gc
import laspy
import torch
import tempfile
import numpy as np
import open3d as o3d
import streamlit as st
import plotly.graph_objs as go
import pointnet2_cls_msg as pn2
from utils import calculate_dbh, calc_canopy_volume, CLASSES
from SingleTreePointCloudLoader import SingleTreePointCloudLoader
gc.enable()
with st.spinner("Loading PointNet++ model..."):
checkpoint = torch.load('checkpoints/best_model.pth', map_location=torch.device('cpu'))
classifier = pn2.get_model(num_class=4, normal_channel=False)
classifier.load_state_dict(checkpoint['model_state_dict'])
classifier.eval()
st.title("Tree Species Identification")
uploaded_file = st.file_uploader(
label="Upload Point Cloud Data",
type=['laz', 'las', 'pcd'],
help="Please upload trees with ground points removed"
)
Z_THRESHOLD = st.slider(
label="Z-Threshold(%)",
min_value=5,
max_value=100,
value=50,
step=1,
help="Please select a Z-Threshold for canopy volume calculation"
)
DBH_HEIGHT = st.slider(
label="DBH Height(m)",
min_value=1.3,
max_value=1.4,
value=1.4,
step=0.01,
help="Enter height used for DBH calculation"
)
proceed = None
if uploaded_file:
try:
with st.spinner("Reading point cloud file..."):
file_type = uploaded_file.name.split('.')[-1].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
tmp.write(uploaded_file.read())
temp_file_path = tmp.name
if file_type == 'pcd':
pcd = o3d.io.read_point_cloud(temp_file_path)
points = np.asarray(pcd.points)
else:
point_cloud = laspy.read(temp_file_path)
points = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose()
proceed = st.button("Run model")
except Exception as e:
st.error(f"An error occured: {str(e)}")
if proceed:
try:
with st.spinner("Calculating tree inventory..."):
dbh, trunk_points = calculate_dbh(points, DBH_HEIGHT)
z_min = np.min(points[:, 2])
z_max = np.max(points[:, 2])
height = z_max - z_min
canopy_volume, canopy_points = calc_canopy_volume(points, Z_THRESHOLD, height, z_min)
with st.spinner("Visualizing point cloud..."):
fig = go.Figure()
fig.add_trace(go.Scatter3d(
x=points[:, 0],
y=points[:, 1],
z=points[:, 2],
mode='markers',
marker=dict(
size=0.5,
color=points[:, 2],
colorscale='Viridis',
opacity=1.0,
),
name='Tree'
))
fig.add_trace(go.Scatter3d(
x=canopy_points[:, 0],
y=canopy_points[:, 1],
z=canopy_points[:, 2],
mode='markers',
marker=dict(
size=2,
color='blue',
opacity=0.8,
),
name='Canopy points'
))
fig.add_trace(go.Scatter3d(
x=trunk_points[:, 0],
y=trunk_points[:, 1],
z=trunk_points[:, 2],
mode='markers',
marker=dict(
size=2,
color='red',
opacity=0.9,
),
name='DBH'
))
fig.update_layout(
margin=dict(l=0, r=0, b=0, t=0),
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
aspectmode='data'
)
)
st.plotly_chart(fig, use_container_width=True)
with st.spinner("Running inference..."):
testFile = SingleTreePointCloudLoader(temp_file_path, file_type)
testFileLoader = torch.utils.data.DataLoader(testFile, batch_size=8, shuffle=False, num_workers=0)
point_set, _ = next(iter(testFileLoader))
point_set = point_set.transpose(2, 1)
with torch.no_grad():
logits, _ = classifier(point_set)
probabilities = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
confidence_score = (probabilities.numpy().tolist())[0][predicted_class] * 100
predicted_label = CLASSES[predicted_class]
st.write(f"**Predicted class: {predicted_label}**")
# st.write(f"Class Probabilities: {probabilities.numpy().tolist()}")
st.write(f"**Confidence score: {confidence_score:.2f}%**")
st.write(f"**Height of tree: {height:.2f}m**")
st.write(f"**Canopy volume: {canopy_volume:.2f}m\u00b3**")
st.write(f"**DBH: {dbh:.2f}m**")
except Exception as e:
st.error(f"An error occured: {str(e)}")
|