Spaces:
Running
Running
Colin Leong
commited on
Commit
·
6dbb7a6
1
Parent(s):
6e290b7
Updates to the app such that downloading and visualization work more intuitively
Browse files
app.py
CHANGED
|
@@ -7,7 +7,6 @@ from pose_format.pose_visualizer import PoseVisualizer
|
|
| 7 |
from pathlib import Path
|
| 8 |
from pyzstd import decompress
|
| 9 |
from PIL import Image
|
| 10 |
-
import cv2
|
| 11 |
import mediapipe as mp
|
| 12 |
|
| 13 |
mp_holistic = mp.solutions.holistic
|
|
@@ -18,6 +17,10 @@ FACEMESH_CONTOURS_POINTS = [
|
|
| 18 |
)
|
| 19 |
]
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def pose_normalization_info(pose_header):
|
| 23 |
if pose_header.components[0].name == "POSE_LANDMARKS":
|
|
@@ -76,12 +79,12 @@ def get_pose_frames(pose: Pose, transparency: bool = False):
|
|
| 76 |
return frames, images
|
| 77 |
|
| 78 |
|
| 79 |
-
def get_pose_gif(pose: Pose, step: int = 1, fps: int = None):
|
| 80 |
if fps is not None:
|
| 81 |
pose.body.fps = fps
|
| 82 |
v = PoseVisualizer(pose)
|
| 83 |
frames = [frame_data for frame_data in v.draw()]
|
| 84 |
-
frames = frames[
|
| 85 |
return v.save_gif(None, frames=frames)
|
| 86 |
|
| 87 |
|
|
@@ -110,28 +113,64 @@ if uploaded_file is not None:
|
|
| 110 |
"How to select components?", options=["manual", "signclip"]
|
| 111 |
)
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
if component_selection == "manual":
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
| 116 |
)
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
elif component_selection == "signclip":
|
| 121 |
st.write("Selected landmarks used for SignCLIP.")
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
|
| 127 |
# Filter button logic
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
with st.expander("Show header"):
|
| 131 |
-
st.write(
|
| 132 |
with st.expander("Show body"):
|
| 133 |
-
st.write(
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
pose_file_out = Path(uploaded_file.name).with_suffix(".pose")
|
| 136 |
with pose_file_out.open("wb") as f:
|
| 137 |
pose.write(f)
|
|
@@ -139,22 +178,19 @@ if uploaded_file is not None:
|
|
| 139 |
with pose_file_out.open("rb") as f:
|
| 140 |
st.download_button("Download Filtered Pose", f, file_name=pose_file_out.name)
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
# Visualization button logic
|
| 143 |
if st.button("Visualize"):
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
st.image(get_pose_gif(pose=
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
# st.write(pose.body.data.shape)
|
| 154 |
-
|
| 155 |
-
# st.write(visualize_pose(pose=pose)) # bunch of ndarrays
|
| 156 |
-
# st.write([Image.fromarray(v.cv2.cvtColor(frame, cv_code)) for frame in frames])
|
| 157 |
-
|
| 158 |
-
# for i, image in enumerate(images[::n]):
|
| 159 |
-
# print(f"i={i}")
|
| 160 |
-
# st.image(image=image, width=width)
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
from pyzstd import decompress
|
| 9 |
from PIL import Image
|
|
|
|
| 10 |
import mediapipe as mp
|
| 11 |
|
| 12 |
mp_holistic = mp.solutions.holistic
|
|
|
|
| 17 |
)
|
| 18 |
]
|
| 19 |
|
| 20 |
+
# Initialize session state
|
| 21 |
+
# if "filtered_pose" not in st.session_state:
|
| 22 |
+
# st.session_state.filtered_pose = None
|
| 23 |
+
|
| 24 |
|
| 25 |
def pose_normalization_info(pose_header):
|
| 26 |
if pose_header.components[0].name == "POSE_LANDMARKS":
|
|
|
|
| 79 |
return frames, images
|
| 80 |
|
| 81 |
|
| 82 |
+
def get_pose_gif(pose: Pose, step: int = 1, start_frame:int=None, end_frame:int=None, fps: int = None):
|
| 83 |
if fps is not None:
|
| 84 |
pose.body.fps = fps
|
| 85 |
v = PoseVisualizer(pose)
|
| 86 |
frames = [frame_data for frame_data in v.draw()]
|
| 87 |
+
frames = frames[start_frame:end_frame:step]
|
| 88 |
return v.save_gif(None, frames=frames)
|
| 89 |
|
| 90 |
|
|
|
|
| 113 |
"How to select components?", options=["manual", "signclip"]
|
| 114 |
)
|
| 115 |
|
| 116 |
+
component_names = [c.name for c in pose.header.components]
|
| 117 |
+
chosen_component_names = []
|
| 118 |
+
points_dict = {}
|
| 119 |
+
hide_legs = False
|
| 120 |
+
|
| 121 |
if component_selection == "manual":
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
chosen_component_names = st.pills(
|
| 125 |
+
"Select components to visualize", options=component_names, default=component_names,selection_mode="multi"
|
| 126 |
)
|
| 127 |
+
|
| 128 |
+
for component in pose.header.components:
|
| 129 |
+
if component.name in chosen_component_names:
|
| 130 |
+
with st.expander(f"Points for {component.name}"):
|
| 131 |
+
selected_points = st.multiselect(
|
| 132 |
+
f"Select points for component {component.name}:",
|
| 133 |
+
options=component.points,
|
| 134 |
+
default=component.points,
|
| 135 |
+
)
|
| 136 |
+
if selected_points != component.points: # Only add entry if not all points are selected
|
| 137 |
+
points_dict[component.name] = selected_points
|
| 138 |
+
|
| 139 |
+
|
| 140 |
|
| 141 |
elif component_selection == "signclip":
|
| 142 |
st.write("Selected landmarks used for SignCLIP.")
|
| 143 |
+
chosen_component_names = ["POSE_LANDMARKS", "FACE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"]
|
| 144 |
+
points_dict = {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
|
| 148 |
# Filter button logic
|
| 149 |
+
# Filter section
|
| 150 |
+
st.write("### Filter .pose File")
|
| 151 |
+
filtered = st.button("Filter")
|
| 152 |
+
if filtered:
|
| 153 |
+
pose = pose.get_components(chosen_component_names, points=points_dict if points_dict else None)
|
| 154 |
+
if hide_legs:
|
| 155 |
+
pose = pose_hide_legs(pose)
|
| 156 |
+
|
| 157 |
+
st.session_state.filtered_pose = pose
|
| 158 |
+
|
| 159 |
+
filtered_pose = st.session_state.get('filtered_pose', pose)
|
| 160 |
+
if filtered_pose:
|
| 161 |
+
filtered_pose = st.session_state.get('filtered_pose', pose)
|
| 162 |
+
st.write(f"#### Filtered .pose file")
|
| 163 |
+
st.write(f"Pose data shape: {filtered_pose.body.data.shape}")
|
| 164 |
with st.expander("Show header"):
|
| 165 |
+
st.write(filtered_pose.header)
|
| 166 |
with st.expander("Show body"):
|
| 167 |
+
st.write(filtered_pose.body)
|
| 168 |
+
# with st.expander("Show data:"):
|
| 169 |
+
# for frame in filtered_pose.body.data:
|
| 170 |
+
# st.write(f"Frame:{frame}")
|
| 171 |
+
# for person in frame:
|
| 172 |
+
# st.write(person)
|
| 173 |
+
|
| 174 |
pose_file_out = Path(uploaded_file.name).with_suffix(".pose")
|
| 175 |
with pose_file_out.open("wb") as f:
|
| 176 |
pose.write(f)
|
|
|
|
| 178 |
with pose_file_out.open("rb") as f:
|
| 179 |
st.download_button("Download Filtered Pose", f, file_name=pose_file_out.name)
|
| 180 |
|
| 181 |
+
|
| 182 |
+
st.write("### Visualization")
|
| 183 |
+
step = st.select_slider("Step value to select every nth image", list(range(1, len(frames))), value=1)
|
| 184 |
+
fps = st.slider("FPS for visualization", min_value=1.0, max_value=filtered_pose.body.fps, value=filtered_pose.body.fps)
|
| 185 |
+
start_frame, end_frame = st.slider(
|
| 186 |
+
"Select Frame Range",
|
| 187 |
+
0,
|
| 188 |
+
len(frames),
|
| 189 |
+
(0, len(frames)), # Default range
|
| 190 |
+
)
|
| 191 |
# Visualization button logic
|
| 192 |
if st.button("Visualize"):
|
| 193 |
+
# Load filtered pose if it exists; otherwise, use the unfiltered pose
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
st.image(get_pose_gif(pose=filtered_pose, step=step, start_frame=start_frame, end_frame=end_frame, fps=fps))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|