import copy import spaces import json import random import collections.abc import gradio as gr import numpy as np import psutil import torch from PIL import ImageDraw, Image, ImageEnhance from matplotlib import pyplot as plt from mmcv import Config from mmcv.runner import load_checkpoint from mmpose.core import wrap_fp16_model from mmpose.models import build_posenet from torchvision import transforms import matplotlib.patheffects as mpe import pickle from pprint import pformat as pf from EdgeCape import TopDownGenerateTargetFewShot from demo import Resize_Pad from EdgeCape.models import * def process_img(support_image, global_state): global_state['images']['image_orig'] = support_image if global_state["load_example"]: global_state["load_example"] = False return global_state['images']['image_kp'], global_state _, _ = reset_kp(global_state) return support_image, global_state def adj_mx_from_edges(num_pts, skeleton, device='cpu', normalization_fix=True): adj_mx = torch.empty(0, device=device) batch_size = len(skeleton) for b in range(batch_size): edges = torch.tensor(skeleton[b]).long() adj = torch.zeros(num_pts, num_pts, device=device) adj[edges[:, 0], edges[:, 1]] = 1 adj_mx = torch.concatenate((adj_mx, adj.unsqueeze(0)), dim=0) trans_adj_mx = torch.transpose(adj_mx, 1, 2) cond = (trans_adj_mx > adj_mx).float() adj = adj_mx + trans_adj_mx * cond - adj_mx * cond return adj def plot_results(support_img, query_img, query_w, skeleton=None, prediction=None, radius=6, in_color=None, original_skeleton=None, img_alpha=0.6, target_keypoints=None): h, w, c = support_img.shape prediction = prediction[-1] * h if isinstance(prediction, torch.Tensor): prediction = prediction.numpy() if isinstance(original_skeleton, list): original_skeleton = adj_mx_from_edges(num_pts=prediction.shape[0], skeleton=[original_skeleton]).numpy()[0] query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img)) img = query_img w = query_w keypoint = prediction adj = skeleton color = None f, axes = plt.subplots() plt.imshow(img, alpha=img_alpha) for k in range(keypoint.shape[0]): if w[k] > 0: kp = keypoint[k, :2] c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6) patch = plt.Circle(kp, radius, color=c, path_effects=[mpe.withStroke(linewidth=2, foreground='black')], zorder=200) axes.add_patch(patch) axes.text(kp[0], kp[1], k, fontsize=(radius + 4), color='white', ha="center", va="center", zorder=300, path_effects=[ mpe.withStroke(linewidth=max(1, int((radius + 4) / 5)), foreground='black')]) plt.draw() if adj is not None: max_skel_val = np.max(adj) draw_skeleton = adj / max_skel_val * 6 for i in range(1, keypoint.shape[0]): for j in range(0, i): if w[i] > 0 and w[j] > 0 and original_skeleton[i][j] > 0: if color is None: num_colors = int((adj > 0.05).sum() / 2) color = iter(plt.cm.rainbow(np.linspace(0, 1, num_colors + 1))) c = next(color) elif isinstance(color, str): c = color elif isinstance(color, collections.abc.Iterable): c = next(color) else: raise ValueError("Color must be a string or an iterable") if w[i] > 0 and w[j] > 0 and adj[i][j] > 0: width = draw_skeleton[i][j] stroke_width = width + (width / 3) patch = plt.Line2D([keypoint[i, 0], keypoint[j, 0]], [keypoint[i, 1], keypoint[j, 1]], linewidth=width, color=c, alpha=0.6, path_effects=[mpe.withStroke(linewidth=stroke_width, foreground='black')], zorder=1) axes.add_artist(patch) plt.axis('off') # command for hiding the axis. return plt @spaces.GPU(duration=30) def estimate(model, data): model.cuda() data['img_s'] = [s.cuda() for s in data['img_s']] data['img_q'] = data['img_q'].cuda() data['target_s'] = [s.cuda() for s in data['target_s']] data['target_weight_s'] = [s.cuda() for s in data['target_weight_s']] with torch.no_grad(): return model(**data) def process(query_img, state, cfg_path='configs/test/1shot_split1.py', checkpoint_path='ckpt/1shot_split1.pth'): print(state) device = print_memory_usage() cfg = Config.fromfile(cfg_path) width, height, _ = np.array(state['images']['image_orig']).shape kp_src_np = np.array(state['points']).copy().astype(np.float32) kp_src_np[:, 0] = kp_src_np[:, 0] / width * 256 kp_src_np[:, 1] = kp_src_np[:, 1] / height * 256 kp_src_np = kp_src_np.copy() kp_src_tensor = torch.tensor(kp_src_np).float() preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), Resize_Pad(256, 256) ]) if len(state['skeleton']) == 0: state['skeleton'] = [(0, 0)] support_img = preprocess(state['images']['image_orig']).flip(0)[None] np_query = np.array(query_img)[:, :, ::-1].copy() q_img = preprocess(np_query).flip(0)[None] # Create heatmap from keypoints genHeatMap = TopDownGenerateTargetFewShot() data_cfg = cfg.data_cfg data_cfg['image_size'] = np.array([256, 256]) data_cfg['joint_weights'] = None data_cfg['use_different_joint_weights'] = False kp_src_3d = torch.cat( (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) kp_src_3d_weight = torch.cat( (torch.ones_like(kp_src_tensor), torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, kp_src_3d, kp_src_3d_weight, sigma=1) target_s = torch.tensor(target_s).float()[None] target_weight_s = torch.ones_like( torch.tensor(target_weight_s).float()[None]) data = { 'img_s': [support_img], 'img_q': q_img, 'target_s': [target_s], 'target_weight_s': [target_weight_s], 'target_q': None, 'target_weight_q': None, 'return_loss': False, 'img_metas': [{'sample_skeleton': [state['skeleton']], 'query_skeleton': state['skeleton'], 'sample_joints_3d': [kp_src_3d], 'query_joints_3d': kp_src_3d, 'sample_center': [kp_src_tensor.mean(dim=0)], 'query_center': kp_src_tensor.mean(dim=0), 'sample_scale': [ kp_src_tensor.max(dim=0)[0] - kp_src_tensor.min(dim=0)[0] ], 'query_scale': kp_src_tensor.max(dim=0)[0] - kp_src_tensor.min(dim=0)[0], 'sample_rotation': [0], 'query_rotation': 0, 'sample_bbox_score': [1], 'query_bbox_score': 1, 'query_image_file': '', 'sample_image_file': [''], }] } # Load model model = build_posenet(cfg.model) fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) load_checkpoint(model, checkpoint_path, map_location='cpu') model.eval() outputs = estimate(model, data) # visualize results vis_s_weight = target_weight_s[0] vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0) vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0) support_kp = kp_src_3d out = plot_results(vis_s_image, vis_q_image, vis_s_weight, skeleton=outputs['skeleton'][1], prediction=torch.tensor(outputs['points']).squeeze().cpu(), original_skeleton=state['skeleton'], img_alpha=1.0, ) return out def update_examples(support_img, query_image, global_state_str): example_state = json.loads(global_state_str) example_state["load_example"] = True example_state["curr_type_point"] = "start" example_state["prev_point"] = None example_state['images'] = {} example_state['images']['image_orig'] = support_img example_state['images']['image_kp'] = support_img example_state['images']['image_skeleton'] = support_img image_draw = example_state['images']['image_orig'].copy() for xy in example_state['points']: image_draw = update_image_draw( image_draw, xy, example_state ) kp_image = image_draw.copy() example_state['images']['image_kp'] = kp_image pts_list = example_state['points'] for limb in example_state['skeleton']: prev_point = pts_list[limb[0]] curr_point = pts_list[limb[1]] points = [prev_point, curr_point] image_draw = draw_limbs_on_image(image_draw, points ) skel_image = image_draw.copy() example_state['images']['image_skel'] = skel_image return (support_img, kp_image, skel_image, query_image, example_state) def get_select_coords(global_state, evt: gr.SelectData ): """This function only support click for point selection """ xy = evt.index global_state["points"].append(xy) image_raw = global_state['images']['image_kp'] image_draw = update_image_draw( image_raw, xy, global_state ) global_state['images']['image_kp'] = image_draw return global_state, image_draw def get_closest_point_idx(pts_list, xy): x, y = xy closest_point = min(pts_list, key=lambda p: (p[0] - x) ** 2 + (p[1] - y) ** 2) closest_point_index = pts_list.index(closest_point) return closest_point_index def reset_skeleton(global_state): image = global_state["images"]["image_kp"] global_state["images"]["image_skel"] = image global_state["skeleton"] = [] global_state["curr_type_point"] = "start" global_state["prev_point"] = None return image def reset_kp(global_state): image = global_state["images"]["image_orig"] global_state["images"]["image_kp"] = image global_state["images"]["image_skel"] = image global_state["skeleton"] = [] global_state["points"] = [] global_state["curr_type_point"] = "start" global_state["prev_point"] = None return image, image def select_skeleton(global_state, evt: gr.SelectData, ): xy = evt.index pts_list = global_state["points"] closest_point_idx = get_closest_point_idx(pts_list, xy) image_raw = global_state['images']['image_skel'] if global_state["curr_type_point"] == "end": prev_point_idx = global_state["prev_point_idx"] prev_point = pts_list[prev_point_idx] current_point = pts_list[closest_point_idx] points = [prev_point, current_point] image_draw = draw_limbs_on_image(image_raw, points ) global_state['images']['image_skel'] = image_draw global_state['skeleton'].append([prev_point_idx, closest_point_idx]) global_state["curr_type_point"] = "start" global_state["prev_point_idx"] = None else: global_state["prev_point_idx"] = closest_point_idx global_state["curr_type_point"] = "end" return global_state, global_state['images']['image_skel'] def reverse_point_pairs(points): new_points = [] for p in points: new_points.append([p[1], p[0]]) return new_points def update_image_draw(image, points, global_state): if len(global_state["points"]) < 2: alpha = 0.5 else: alpha = 1.0 image_draw = draw_points_on_image(image, points, alpha=alpha) return image_draw def print_memory_usage(): # Print system memory usage print(f"System memory usage: {psutil.virtual_memory().percent}%") # Print GPU memory usage if torch.cuda.is_available(): device = torch.device("cuda") print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB") print( f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB") device_properties = torch.cuda.get_device_properties(device) available_memory = device_properties.total_memory - \ torch.cuda.max_memory_allocated() print(f"Available GPU memory: {available_memory / 1e9} GB") else: device = "cpu" print("No GPU available") return device def draw_limbs_on_image(image, points,): color = tuple(random.choices(range(256), k=3)) overlay_rgba = Image.new("RGBA", image.size, 0) overlay_draw = ImageDraw.Draw(overlay_rgba) p_start, p_target = points if p_start is not None and p_target is not None: p_draw = int(p_start[0]), int(p_start[1]) t_draw = int(p_target[0]), int(p_target[1]) overlay_draw.line( (p_draw[0], p_draw[1], t_draw[0], t_draw[1]), fill=color, width=10, ) return Image.alpha_composite(image.convert("RGBA"), overlay_rgba).convert("RGB") def draw_points_on_image(image, points, radius_scale=0.01, alpha=1.): if alpha < 1: enhancer = ImageEnhance.Brightness(image) image = enhancer.enhance(1.1) overlay_rgba = Image.new("RGBA", image.size, 0) overlay_draw = ImageDraw.Draw(overlay_rgba) p_color = (255, 0, 0) rad_draw = int(image.size[0] * radius_scale) if points is not None: p_draw = int(points[0]), int(points[1]) overlay_draw.ellipse( ( p_draw[0] - rad_draw, p_draw[1] - rad_draw, p_draw[0] + rad_draw, p_draw[1] + rad_draw, ), fill=p_color, ) return Image.alpha_composite(image.convert("RGBA"), overlay_rgba).convert("RGB") def pickle_trick(obj, max_depth=10): output = {} if max_depth <= 0: return output try: pickle.dumps(obj) except (pickle.PicklingError, TypeError) as e: failing_children = [] if hasattr(obj, "__dict__"): for k, v in obj.__dict__.items(): result = pickle_trick(v, max_depth=max_depth - 1) if result: failing_children.append(result) output = { "fail": obj, "err": e, "depth": max_depth, "failing_children": failing_children } return output