Spaces:
Sleeping
Sleeping
import random | |
import collections | |
import gradio as gr | |
import numpy as np | |
import psutil | |
import torch | |
from PIL import ImageDraw, Image, ImageEnhance | |
from matplotlib import pyplot as plt | |
from mmcv import Config | |
from mmcv.runner import load_checkpoint | |
from mmpose.core import wrap_fp16_model | |
from mmpose.models import build_posenet | |
from torchvision import transforms | |
import matplotlib.patheffects as mpe | |
from demo import Resize_Pad | |
from EdgeCape.models import * | |
def process_img(support_image, global_state): | |
global_state['images']['image_orig'] = support_image | |
global_state['images']['image_kp'] = support_image | |
reset_kp(global_state) | |
return support_image, global_state | |
def adj_mx_from_edges(num_pts, skeleton, device='cuda', normalization_fix=True): | |
adj_mx = torch.empty(0, device=device) | |
batch_size = len(skeleton) | |
for b in range(batch_size): | |
edges = torch.tensor(skeleton[b]) | |
adj = torch.zeros(num_pts, num_pts, device=device) | |
adj[edges[:, 0], edges[:, 1]] = 1 | |
adj_mx = torch.concatenate((adj_mx, adj.unsqueeze(0)), dim=0) | |
trans_adj_mx = torch.transpose(adj_mx, 1, 2) | |
cond = (trans_adj_mx > adj_mx).float() | |
adj = adj_mx + trans_adj_mx * cond - adj_mx * cond | |
return adj | |
def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w, | |
skeleton=None, prediction=None, radius=6, in_color=None, | |
original_skeleton=None, img_alpha=0.6, target_keypoints=None): | |
h, w, c = support_img.shape | |
prediction = prediction[-1] * h | |
if isinstance(prediction, torch.Tensor): | |
prediction = prediction.cpu().numpy() | |
if isinstance(skeleton, list): | |
skeleton = adj_mx_from_edges(num_pts=100, skeleton=[skeleton]).cpu().numpy()[0] | |
original_skeleton = skeleton | |
support_img = (support_img - np.min(support_img)) / (np.max(support_img) - np.min(support_img)) | |
query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img)) | |
error_mask = None | |
for id, (img, w, keypoint, adj) in enumerate(zip([support_img, support_img, query_img], | |
[support_w, support_w, query_w], | |
# [support_kp, query_kp])): | |
[support_kp, support_kp, prediction], | |
[original_skeleton, skeleton, skeleton])): | |
color = in_color | |
f, axes = plt.subplots() | |
plt.imshow(img, alpha=img_alpha) | |
# On qeury image plot | |
if id == 2 and target_keypoints is not None: | |
error = np.linalg.norm(keypoint - target_keypoints, axis=-1) | |
error_mask = error > (256 * 0.05) | |
for k in range(keypoint.shape[0]): | |
if w[k] > 0: | |
kp = keypoint[k, :2] | |
c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6) | |
if error_mask is not None and error_mask[k]: | |
c = (1, 1, 0, 0.75) | |
patch = plt.Circle(kp, | |
radius, | |
color=c, | |
path_effects=[mpe.withStroke(linewidth=8, foreground='black'), | |
mpe.withStroke(linewidth=4, foreground='white'), | |
mpe.withStroke(linewidth=2, foreground='black'), | |
], | |
zorder=260) | |
axes.add_patch(patch) | |
axes.text(kp[0], kp[1], k, fontsize=10, color='black', ha="center", va="center", zorder=320, ) | |
else: | |
patch = plt.Circle(kp, | |
radius, | |
color=c, | |
path_effects=[mpe.withStroke(linewidth=2, foreground='black')], | |
zorder=200) | |
axes.add_patch(patch) | |
axes.text(kp[0], kp[1], k, fontsize=(radius + 4), color='white', ha="center", va="center", | |
zorder=300, | |
path_effects=[ | |
mpe.withStroke(linewidth=max(1, int((radius + 4) / 5)), foreground='black')]) | |
# axes.text(kp[0], kp[1], k) | |
plt.draw() | |
if adj is not None: | |
# Make max value 6 | |
draw_skeleton = adj ** 1 | |
max_skel_val = np.max(draw_skeleton) | |
draw_skeleton = draw_skeleton / max_skel_val * 6 | |
for i in range(1, keypoint.shape[0]): | |
for j in range(0, i): | |
if w[i] > 0 and w[j] > 0 and original_skeleton[i][j] > 0: | |
if color is None: | |
num_colors = int((skeleton > 0.05).sum() / 2) | |
color = iter(plt.cm.rainbow(np.linspace(0, 1, num_colors + 1))) | |
c = next(color) | |
elif isinstance(color, str): | |
c = color | |
elif isinstance(color, collections.Iterable): | |
c = next(color) | |
else: | |
raise ValueError("Color must be a string or an iterable") | |
if w[i] > 0 and w[j] > 0 and skeleton[i][j] > 0: | |
width = draw_skeleton[i][j] | |
stroke_width = width + (width / 3) | |
patch = plt.Line2D([keypoint[i, 0], keypoint[j, 0]], | |
[keypoint[i, 1], keypoint[j, 1]], | |
linewidth=width, color=c, alpha=0.6, | |
path_effects=[mpe.withStroke(linewidth=stroke_width, foreground='black')], | |
zorder=1) | |
axes.add_artist(patch) | |
plt.axis('off') # command for hiding the axis. | |
plt.subplots_adjust(0, 0, 1, 1, 0, 0) | |
return plt | |
def process(query_img, state, | |
cfg_path='configs/test/1shot_split1.py', | |
checkpoint_path='ckpt/1shot_split1.pth'): | |
cfg = Config.fromfile(cfg_path) | |
width, height, _ = state['original_support_image'].shape | |
kp_src_np = np.array(state['kp_src']).copy().astype(np.float32) | |
kp_src_np[:, 0] = kp_src_np[:, 0] / (width // 4) * cfg.model.encoder_config.img_size | |
kp_src_np[:, 1] = kp_src_np[:, 1] / (height // 4) * cfg.model.encoder_config.img_size | |
kp_src_np = np.flip(kp_src_np, 1).copy() | |
kp_src_tensor = torch.tensor(kp_src_np).float() | |
preprocess = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
Resize_Pad(cfg.model.encoder_config.img_size, | |
cfg.model.encoder_config.img_size)]) | |
if len(state['skeleton']) == 0: | |
state['skeleton'] = [(0, 0)] | |
support_img = preprocess(state['images']['image_orig']).flip(0)[None] | |
np_query = np.array(query_img)[:, :, ::-1].copy() | |
q_img = preprocess(np_query).flip(0)[None] | |
# Create heatmap from keypoints | |
genHeatMap = TopDownGenerateTargetFewShot() | |
data_cfg = cfg.data_cfg | |
data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, | |
cfg.model.encoder_config.img_size]) | |
data_cfg['joint_weights'] = None | |
data_cfg['use_different_joint_weights'] = False | |
kp_src_3d = torch.cat( | |
(kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) | |
kp_src_3d_weight = torch.cat( | |
(torch.ones_like(kp_src_tensor), | |
torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) | |
target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, | |
kp_src_3d, | |
kp_src_3d_weight, | |
sigma=1) | |
target_s = torch.tensor(target_s).float()[None] | |
target_weight_s = torch.ones_like( | |
torch.tensor(target_weight_s).float()[None]) | |
data = { | |
'img_s': [support_img], | |
'img_q': q_img, | |
'target_s': [target_s], | |
'target_weight_s': [target_weight_s], | |
'target_q': None, | |
'target_weight_q': None, | |
'return_loss': False, | |
'img_metas': [{'sample_skeleton': [state['skeleton']], | |
'query_skeleton': state['skeleton'], | |
'sample_joints_3d': [kp_src_3d], | |
'query_joints_3d': kp_src_3d, | |
'sample_center': [kp_src_tensor.mean(dim=0)], | |
'query_center': kp_src_tensor.mean(dim=0), | |
'sample_scale': [ | |
kp_src_tensor.max(dim=0)[0] - | |
kp_src_tensor.min(dim=0)[0]], | |
'query_scale': kp_src_tensor.max(dim=0)[0] - | |
kp_src_tensor.min(dim=0)[0], | |
'sample_rotation': [0], | |
'query_rotation': 0, | |
'sample_bbox_score': [1], | |
'query_bbox_score': 1, | |
'query_image_file': '', | |
'sample_image_file': [''], | |
}] | |
} | |
# Load model | |
model = build_posenet(cfg.model) | |
fp16_cfg = cfg.get('fp16', None) | |
if fp16_cfg is not None: | |
wrap_fp16_model(model) | |
load_checkpoint(model, checkpoint_path, map_location='cpu') | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**data) | |
# visualize results | |
vis_s_weight = target_weight_s[0] | |
vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0) | |
vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0) | |
support_kp = kp_src_3d | |
out = plot_results(vis_s_image, | |
vis_q_image, | |
support_kp, | |
vis_s_weight, | |
None, | |
vis_s_weight, | |
outputs['skeleton'], | |
torch.tensor(outputs['points']).squeeze(), | |
original_skeleton=state['skeleton'], | |
img_alpha=1.0, | |
) | |
return out, state | |
def update_examples(support_img, posed_support, query_img, state, r=0.015, width=0.02): | |
state['color_idx'] = 0 | |
state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy() | |
support_img, posed_support, _ = set_query(support_img, state, example=True) | |
w, h = support_img.size | |
draw_pose = ImageDraw.Draw(support_img) | |
draw_limb = ImageDraw.Draw(posed_support) | |
r = int(r * w) | |
width = int(width * w) | |
for pixel in state['kp_src']: | |
leftUpPoint = (pixel[1] - r, pixel[0] - r) | |
rightDownPoint = (pixel[1] + r, pixel[0] + r) | |
twoPointList = [leftUpPoint, rightDownPoint] | |
draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
for limb in state['skeleton']: | |
point_a = state['kp_src'][limb[0]][::-1] | |
point_b = state['kp_src'][limb[1]][::-1] | |
if state['color_idx'] < len(COLORS): | |
c = COLORS[state['color_idx']] | |
state['color_idx'] += 1 | |
else: | |
c = random.choices(range(256), k=3) | |
draw_limb.line([point_a, point_b], fill=tuple(c), width=width) | |
return support_img, posed_support, query_img, state | |
def get_select_coords(global_state, | |
evt: gr.SelectData | |
): | |
"""This function only support click for point selection | |
""" | |
xy = evt.index | |
global_state["points"].append(xy) | |
# point_idx = get_latest_points_pair(points) | |
# if point_idx is None: | |
# points[0] = {'start': xy, 'target': None} | |
# print(f'Click Image - Start - {xy}') | |
# elif points[point_idx].get('target', None) is None: | |
# points[point_idx]['target'] = xy | |
# print(f'Click Image - Target - {xy}') | |
# else: | |
# points[point_idx + 1] = {'start': xy, 'target': None} | |
# print(f'Click Image - Start - {xy}') | |
image_raw = global_state['images']['image_kp'] | |
image_draw = update_image_draw( | |
image_raw, | |
xy, | |
global_state | |
) | |
global_state['images']['image_kp'] = image_draw | |
return global_state, image_draw | |
def get_closest_point_idx(pts_list, xy): | |
x, y = xy | |
closest_point = min(pts_list, key=lambda p: (p[0] - x) ** 2 + (p[1] - y) ** 2) | |
closest_point_index = pts_list.index(closest_point) | |
return closest_point_index | |
def reset_skeleton(global_state): | |
image = global_state["images"]["image_kp"] | |
global_state["images"]["image_skel"] = image | |
global_state["skeleton"] = [] | |
global_state["curr_type_point"] = "start" | |
global_state["prev_point"] = None | |
return image | |
def reset_kp(global_state): | |
image = global_state["images"]["image_orig"] | |
global_state["images"]["image_kp"] = image | |
global_state["images"]["image_skel"] = image | |
global_state["skeleton"] = [] | |
global_state["points"] = [] | |
global_state["curr_type_point"] = "start" | |
global_state["prev_point"] = None | |
return image, image | |
def select_skeleton(global_state, | |
evt: gr.SelectData, | |
): | |
xy = evt.index | |
pts_list = global_state["points"] | |
closest_point_idx = get_closest_point_idx(pts_list, xy) | |
image_raw = global_state['images']['image_skel'] | |
if global_state["curr_type_point"] == "end": | |
prev_point_idx = global_state["prev_point_idx"] | |
prev_point = pts_list[prev_point_idx] | |
points = [prev_point, xy] | |
image_draw = draw_limbs_on_image(image_raw, | |
points | |
) | |
global_state['images']['image_skel'] = image_draw | |
global_state['skeleton'].append([prev_point_idx, closest_point_idx]) | |
global_state["curr_type_point"] = "start" | |
global_state["prev_point_idx"] = None | |
else: | |
global_state["prev_point_idx"] = closest_point_idx | |
global_state["curr_type_point"] = "end" | |
return global_state, global_state['images']['image_skel'] | |
def reverse_point_pairs(points): | |
new_points = [] | |
for p in points: | |
new_points.append([p[1], p[0]]) | |
return new_points | |
def update_image_draw(image, points, global_state): | |
if len(global_state["points"]) < 2: | |
alpha = 0.5 | |
else: | |
alpha = 1.0 | |
image_draw = draw_points_on_image(image, points, alpha=alpha) | |
return image_draw | |
def print_memory_usage(): | |
# Print system memory usage | |
print(f"System memory usage: {psutil.virtual_memory().percent}%") | |
# Print GPU memory usage | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB") | |
print( | |
f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB") | |
device_properties = torch.cuda.get_device_properties(device) | |
available_memory = device_properties.total_memory - \ | |
torch.cuda.max_memory_allocated() | |
print(f"Available GPU memory: {available_memory / 1e9} GB") | |
else: | |
print("No GPU available") | |
def draw_limbs_on_image(image, | |
points,): | |
color = tuple(random.choices(range(256), k=3)) | |
overlay_rgba = Image.new("RGBA", image.size, 0) | |
overlay_draw = ImageDraw.Draw(overlay_rgba) | |
p_start, p_target = points | |
if p_start is not None and p_target is not None: | |
p_draw = int(p_start[0]), int(p_start[1]) | |
t_draw = int(p_target[0]), int(p_target[1]) | |
overlay_draw.line( | |
(p_draw[0], p_draw[1], t_draw[0], t_draw[1]), | |
fill=color, | |
width=10, | |
) | |
return Image.alpha_composite(image.convert("RGBA"), | |
overlay_rgba).convert("RGB") | |
def draw_points_on_image(image, | |
points, | |
radius_scale=0.01, | |
alpha=1.): | |
if alpha < 1: | |
enhancer = ImageEnhance.Brightness(image) | |
image = enhancer.enhance(1.1) | |
overlay_rgba = Image.new("RGBA", image.size, 0) | |
overlay_draw = ImageDraw.Draw(overlay_rgba) | |
p_color = (255, 0, 0) | |
rad_draw = int(image.size[0] * radius_scale) | |
if points is not None: | |
p_draw = int(points[0]), int(points[1]) | |
overlay_draw.ellipse( | |
( | |
p_draw[0] - rad_draw, | |
p_draw[1] - rad_draw, | |
p_draw[0] + rad_draw, | |
p_draw[1] + rad_draw, | |
), | |
fill=p_color, | |
) | |
return Image.alpha_composite(image.convert("RGBA"), overlay_rgba).convert("RGB") | |