{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pickle\n", | |
"import numpy as np\n", | |
"import os\n", | |
"import shutil\n", | |
"from tqdm import tqdm\n", | |
"import os\n", | |
"import shutil\n", | |
"import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def retrieve_joint_connections(dataset):\n", | |
" if dataset == \"Human36M\":\n", | |
" joint_connections = [[14, 15], [15, 16], [13, 12], [12, 11], [9, 8], [8, 7], [4, 5], [5, 6], [3, 2], [2, 1], [7, 0], [0, 4], [0, 1], [8, 11], [8, 14], [9, 10]] \n", | |
" defined_kpts = 17\n", | |
"\n", | |
" elif dataset == \"face\":\n", | |
" joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4],[5, 6], [6, 7], [7, 8], [8, 9], [10, 11], [11, 12], [12, 13], [14, 15], [15, 16], [16, 17], [17, 18], [19, 20], [20, 21], [21, 22], [22, 23], [23, 24], [24, 19], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 25], [31, 32], [32, 33], [33, 34], [34, 35], [35, 36], [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 42], [42, 31], [43, 44], [44, 45], [45, 46], [46, 47], [47, 48], [48, 49], [49, 50], [50, 43]]\n", | |
" defined_kpts = 53\n", | |
"\n", | |
" # elif dataset == \"cheetah\":\n", | |
" # joint_connections = [[0,2], [0,1], [0,3], [2,3], [1,3], [3,4], [4,5], [5,6], [6,7],\n", | |
" # [3,8], [8,9], [9,10], [10,11],\n", | |
" # [3,12], [12,13], [13,14], [14,15],\n", | |
" # [5, 16], [16,17], [17,18], [18,19],\n", | |
" # [5,20], [20,21], [21,22], [22,23]]\n", | |
"\n", | |
" elif dataset == \"cheetah\":\n", | |
" joint_connections = [[3, 1], [1, 2], [3, 4], [4, 5], \n", | |
" [5, 6], [6, 7], [8, 9], [9, 10], [3, 8], \n", | |
" [3, 11], [11, 12], [12, 13], [5, 17], \n", | |
" [17, 18], [18, 19], [5, 14], [14, 15], \n", | |
" [15, 16]]\n", | |
" defined_kpts = 20\n", | |
"\n", | |
" elif dataset == \"cheetahtr\":\n", | |
" joint_connections = [[3, 1], [1, 2], [3, 4], [4, 5], \n", | |
" [5, 6], [6, 7], [8, 9], [9, 10], [3, 8], \n", | |
" [3, 11], [11, 12], [12, 13], [5, 17], \n", | |
" [17, 18], [18, 19], [5, 14], [14, 15], \n", | |
" [15, 16]]\n", | |
" defined_kpts = 20\n", | |
"\n", | |
"\n", | |
" elif dataset == \"cheetahsub\":\n", | |
" joint_connections = [[2, 0], [0, 1], [2, 3], [3, 4],\n", | |
" [4, 5], [6, 7], [7, 8], [2, 6],\n", | |
" [2, 9], [9, 10], [10, 11], [3, 15],\n", | |
" [15, 16], [16, 17], [3, 12], [12, 13], [13, 14]]\n", | |
" defined_kpts = 18\n", | |
"\n", | |
" elif dataset == \"hands\":\n", | |
" joint_connections = [[0,1],[1,2],[2,3],[3,4],[0,5],[5,6],[6,7],[7,8],[0,9],[9,10],[10,11],[11,12],[0,13],[13,14],[14,15],[15,16],[0,17],[17,18],[18,19],[19,20]]\n", | |
" defined_kpts = 21\n", | |
"\n", | |
" elif dataset == \"amass\":\n", | |
" joint_connections = [[9, 13], [13, 16], [16, 18], [18, 20], [6,9], [3, 6], [0, 3], [0, 1], [0, 2], [1, 4] , [4, 7], [7, 10], [2, 5], [5, 8], [8, 11], \n", | |
" [9, 14], [14, 17], [17, 19], [19, 21], [6, 12], [12, 15], \n", | |
" [20, 34], [34, 35], [35, 36], \n", | |
" [20, 22], [22, 23], [23, 24], \n", | |
" [20, 25], [25, 26], [26, 27], \n", | |
" [20, 31], [31, 32], [32, 33],\n", | |
" [20, 28], [28, 29], [29, 30], \n", | |
" # [22, 25], [25, 31], [31, 28]\n", | |
" [21, 49], [49, 50], [50, 51],\n", | |
" [21, 37], [37, 38], [38, 39],\n", | |
" [21, 40], [40, 41], [41, 42],\n", | |
" [21, 46], [46, 47], [47, 48],\n", | |
" [21, 43], [43, 44], [44, 45]]\n", | |
" defined_kpts = 52\n", | |
"\n", | |
" elif dataset == \"openmonkey\":\n", | |
" joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [2, 5], [5, 6], [2, 7], [7, 8], [8, 9], [7, 10], [10, 11], [7, 12]] \n", | |
" defined_kpts = 13\n", | |
"\n", | |
" elif dataset == \"wholebodyh36m\":\n", | |
" joint_connections = [[0, 1], [1, 3], [0, 2], [2, 4], [59, 64], [65, 70], [71, 82], \n", | |
" [71, 83], [77, 87], [77, 88], [88, 89], [89, 90], [71, 90],\n", | |
" [5, 7], [7, 9], [9, 91], [91, 92], [93, 96], [96, 100], [100, 104], [104, 108], [91, 108],\n", | |
" [6, 8], [8, 10], [10, 112], [112, 113], [114, 117], [117, 121], [121, 125], [125, 129], [112, 129],\n", | |
" [5, 6], [6, 12], [11, 12], [5, 11], [12, 14], [14, 16], [16, 20], [16, 21], [16, 22], [11, 13], [13, 15], [15, 17], [15, 18], [15, 19]] \n", | |
"\n", | |
" joint_connections = [\n", | |
" # Connect points as defined by the tuples\n", | |
" [0, 1], [1, 3], [0, 2], [2, 4], \n", | |
" [5, 7], [7, 9], [9, 91], #[91, 92], [93, 96], [96, 100], [100, 104], [104, 108], [91, 108],\n", | |
" [6, 8], [8, 10], [10, 112], #[112, 113], [114, 117], [117, 121], [121, 125], [125, 129], [112, 129],\n", | |
" [5, 6], [6, 12], [11, 12], [5, 11], \n", | |
" [12, 14], [14, 16], [16, 20], [16, 21], [16, 22], \n", | |
" [11, 13], [13, 15], [15, 17], [15, 18], [15, 19],\n", | |
"\n", | |
" # Face\n", | |
" [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [31, 32], [32, 33], [33, 34],\n", | |
" [34, 35], [35, 36], [36, 37], [37, 38], [38, 39], [40, 41], [41, 42], [42, 43], [43, 44], [59, 60], [60, 61],\n", | |
" [61, 62], [62, 63], [63, 64], [59, 64], [45, 46], [46, 47], [47, 48], [48, 49], [65, 66], [66, 67], [67, 68],\n", | |
" [68, 69], [69, 70], [65, 70], [50, 51], [51, 52], [52, 53], [54, 55], [55, 56], [56, 57], [57, 58], [71, 72],\n", | |
" [72, 73], [73, 74], [74, 75], [75, 76], [76, 77], [77, 78], [78, 79], [79, 80], [80, 81], [81, 82], [82, 83],\n", | |
" [83, 84], [84, 85], [85, 86], [86, 87], [87, 88], [88, 89], [89, 90], [91, 92],\n", | |
"\n", | |
" # Left hand\n", | |
" [91, 92], [92, 93], [93, 94], [94, 95], [91, 96], [96, 97], [97, 98], [98, 99], [91, 100], [100, 101], [101, 102],\n", | |
" [102, 103], [91, 104], [104, 105], [105, 106], [106, 107], [91, 108], [108, 109], [109, 110], [110, 111],\n", | |
"\n", | |
"\n", | |
" # Right hand\n", | |
" [112, 113], [113, 114], [114, 115], [115, 116], [112, 117], [117, 118], [118, 119], [119, 120], [112, 121],\n", | |
" [121, 122], [122, 123], [123, 124], [112, 125], [125, 126], [126, 127], [127, 128], [112, 129], [129, 130],\n", | |
" [130, 131], [131, 132]\n", | |
"\n", | |
" ] \n", | |
"\n", | |
" defined_kpts = 133\n", | |
"\n", | |
" elif dataset == \"bp4d+\":\n", | |
" joint_connections = [\n", | |
" # Left eyebrow (viewed from the model's perspective)\n", | |
" [15, 16], [16, 17], [17, 18], [18, 19],\n", | |
" [10, 11], [11, 12], [12, 13], [13, 14],\n", | |
" \n", | |
" # Right eyebrow\n", | |
" [0, 1], [1, 2], [2, 3], [3, 4],\n", | |
" [5, 6], [6, 7], [7, 8], [8, 9],\n", | |
" \n", | |
" # Bridge of the nose (from between the eyebrows to the tip)\n", | |
" [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 42], [42, 43], [43, 44], [44, 45], [45, 46], [46, 47],\n", | |
" \n", | |
" # Left eye\n", | |
" [28, 29], [29, 30], [30, 31], [31, 32], [32, 33], [33, 34], [34, 35], [35, 28],\n", | |
" \n", | |
" # Right eye\n", | |
" [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 20],\n", | |
" \n", | |
" # Outer part of the lips (outline of the lips)\n", | |
" [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],\n", | |
" \n", | |
" # Inner part of the lips (detail within the lips)\n", | |
" [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], [66, 67], [67, 60],\n", | |
" \n", | |
" # Jawline (from left ear, around the chin, to right ear)\n", | |
" [68, 69], [69, 70], [70, 71], [71, 72], [72, 73], [73, 74], [74, 75], [75, 76], [76, 77], [77, 78], [78, 79], [79, 80], [80, 81], [81, 82]\n", | |
" ]\n", | |
" defined_kpts = 83\n", | |
"\n", | |
" \n", | |
" elif dataset == \"panoptic\":\n", | |
" joint_connections = [[0, 1], [0, 3], [3, 4], [4, 5], [0, 2], [2, 6], [6, 7], [7, 8], [2, 12], [12, 13], [13, 14], [0, 9], [9, 10], [10, 11]]\n", | |
" defined_kpts = 15\n", | |
"\n", | |
" elif dataset == 'aeroplane':\n", | |
" joint_connections = [[2, 5], [1, 4], [5, 3], [3, 7], [7, 0], [0, 5], [5, 7], [5, 6], [6, 0], [6, 3], [2, 4], [2, 1]]\n", | |
" defined_kpts = 8\n", | |
"\n", | |
" elif dataset == 'bicycle':\n", | |
" joint_connections = [[0,3], [0,7], [0, 2], [0, 6], [0, 10], [9, 10], [4, 10], [8, 10], [1, 9], [5, 9]]\n", | |
" defined_kpts = 11\n", | |
"\n", | |
" elif dataset == \"tiger\" or dataset == \"cow\" or dataset == \"horse\" or dataset == \"hippo\" or dataset == \"dog\":\n", | |
" joint_connections = [[0,24], [0, 20], [1, 21], [1, 24], [7, 25], [19, 25], [6, 17],\n", | |
" [4, 15], [3, 14], [9, 15], [8, 14], [9, 13], [8, 12],\n", | |
" [2, 23], [2, 22], [2, 24], [11, 17], [10, 16], [5, 16],\n", | |
" [7, 10], [7, 11], [13,18], [12, 18], [7, 18], [24,18]]\n", | |
" defined_kpts = 26\n", | |
"\n", | |
" elif dataset == \"tigersubset\" or dataset == \"cowsubset\" or dataset == \"horsesubset\" or dataset == \"hipposubset\" or dataset == \"dogsubset\":\n", | |
" joint_connections = [[3, 17], [15, 17], [2, 13], [5, 9], [4, 8], [7, 13], [6, 12], [1, 12], [3, 6], [3, 7], [9, 14], [8, 14], [3, 14], [5, 11], [4, 10], [0, 14], [0, 16]] \n", | |
" defined_kpts = 18\n", | |
"\n", | |
" elif dataset == 'boat':\n", | |
" joint_connections = [[0, 2], [0, 3], [0, 1], [1, 2], [1, 3], [2, 4], [3, 5], [4, 5], [1, 5], [1, 4]]\n", | |
" defined_kpts = 6\n", | |
"\n", | |
" elif dataset == 'bottle':\n", | |
" joint_connections = [[0, 1], [1, 2], [0, 2], [3, 4], [3, 5], [4, 5], [1, 4], [0, 3], [2, 5], [1, 6], [0, 6], [2, 6]]\n", | |
" defined_kpts = 8\n", | |
"\n", | |
" elif dataset == 'busfull' or dataset == 'bus':\n", | |
" joint_connections = [[5, 7], [4, 5], [6, 7], [4, 6], [1, 5], [1, 3], [3, 7], [0, 1], [2, 3], [0, 2], [2, 10], [0, 8], [8, 9], [10, 11], [6, 11], [4, 9]]\n", | |
" defined_kpts = 12\n", | |
"\n", | |
" elif dataset == 'car':\n", | |
" joint_connections = [[0, 8], [0, 4], [4, 10], [8, 10],\n", | |
" [10, 9], [9, 11], [8, 11], [11, 6], \n", | |
" [9, 2], [2, 6], [4, 1], [5, 1], \n", | |
" [0, 5], [5, 7], [1, 3], [7, 3], [3, 2], [7, 6]] \n", | |
" defined_kpts = 12\n", | |
"\n", | |
" elif dataset == 'busmissing':\n", | |
" joint_connections = [[5, 7], [4, 5], [6, 7], [4, 6], [1, 5], [1, 3], [3, 7], [0, 1], [2, 3], [0, 2], [2, 6], [0, 4]]\n", | |
" defined_kpts = 8\n", | |
"\n", | |
" elif dataset == 'diningtable':\n", | |
" joint_connections = [[0, 2], [4, 6], [1, 3], [5, 7], [1, 5], [3, 7], [0, 4], [2, 6], [0, 1], [2, 3], [4, 5], [6, 7]]\n", | |
" defined_kpts = 8\n", | |
"\n", | |
" elif dataset == 'tvmonitor':\n", | |
" joint_connections = [[5, 7], [4, 5], [4, 6], [6, 7], [0, 1], [0, 2], [2, 3], [1, 3], [3, 7], [1, 5], [2, 6], [0, 4]]\n", | |
" defined_kpts = 8\n", | |
"\n", | |
" elif dataset == 'train':\n", | |
" joint_connections = [[4, 5], [4, 6], [6, 7], [5, 7], [0, 1], [1, 3], [2, 3], [0, 2], [1, 5], [0, 4], [2, 6], [3, 7], [1, 5]]\n", | |
" defined_kpts = 8\n", | |
"\n", | |
"\n", | |
" elif dataset == 'train16':\n", | |
" joint_connections = [[0, 1], [1,5], [5, 9], [9, 15], [3, 7], [7, 11], [11, 13], [2, 3], [2, 6], [6, 10], [10, 12], [1, 3], [0, 2], \n", | |
" [0, 4], [4, 8], [8, 14], [15, 13], [13, 12], [12, 14], [14, 15]]\n", | |
" defined_kpts = 16\n", | |
"\n", | |
" elif dataset == 'motorbike':\n", | |
" joint_connections = [[6, 2], [2, 9], [2, 3], [3, 8], [5, 8],\n", | |
" [3, 5], [2, 1], [1, 0], [0, 7], [0, 4],\n", | |
" [4, 7], [1, 4], [1, 7], [1, 5], [1, 8]]\n", | |
" defined_kpts = 10\n", | |
"\n", | |
" elif dataset == 'sofa':\n", | |
" joint_connections = [[1, 5], [5, 4], [4, 6], [6, 2], [2, 0], \n", | |
" [1, 0], [0, 4], [1, 3], [7, 5], [2, 3], \n", | |
" [3, 7], [9, 7], [7, 6], [6, 8], [8, 9]]\n", | |
" defined_kpts = 10\n", | |
"\n", | |
" elif dataset == 'chair':\n", | |
" joint_connections = [[7, 3], [6, 2], [9, 5], [8, 4], [7, 9], \n", | |
" [8, 6], [6, 7], [9, 8], [9, 1], [8, 0], [1, 0]]\n", | |
" defined_kpts = 10\n", | |
"\n", | |
" # MBW datasets\n", | |
" elif dataset == 'colobusmonkey':\n", | |
" joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [10, 11], [11, 12], [9, 13], [13, 14], [14, 15]]\n", | |
" defined_kpts = 16\n", | |
"\n", | |
" elif dataset == 'chimpanzee':\n", | |
" joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [10, 11], [11, 12], [9, 13], [13, 14], [14, 15]]\n", | |
" defined_kpts = 16\n", | |
"\n", | |
"\n", | |
" elif dataset == 'tigerzoo':\n", | |
" joint_connections = [[0, 1], [1, 2], [2, 3], [3,4], [4,5], [1,6], [6,7], [1,8], [8,9], [3,10], [10,11], [3,12], [12,13]]\n", | |
" defined_kpts = 14\n", | |
"\n", | |
" elif dataset == 'clownfish':\n", | |
" joint_connections = [[0, 1], [1, 2], [2, 3], [1, 4], [1, 5]]\n", | |
" defined_kpts = 6\n", | |
"\n", | |
" elif dataset == 'fish':\n", | |
" joint_connections = [[0, 1], [1, 2], [2, 3], [1, 3], [3, 4], [4, 5], [5, 6], [6, 7], [5, 7], [5, 8], [8, 9], [9, 10], [8, 10], [10, 11], [11, 0]]\n", | |
" defined_kpts = 12\n", | |
"\n", | |
" elif dataset == 'seahorse':\n", | |
" joint_connections = [[0, 1], [1, 2], [2,3], [1,3], [3,4], [4,5]]\n", | |
" defined_kpts = 6\n", | |
" \n", | |
" return joint_connections" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Load pickle file\n", | |
"\n", | |
"category_name = 'aeroplane'\n", | |
"suffix = 'val'\n", | |
"\n", | |
"pickle_name = category_name + '_' + suffix\n", | |
"\n", | |
"## Load input data\n", | |
"input_data_path = pickle_name + '.pkl'\n", | |
"with open(input_data_path, 'rb') as f:\n", | |
" input_data = pickle.load(f)\n", | |
"\n", | |
"input_2d = input_data['W_GT']\n", | |
"image_path = input_data['image_path']\n", | |
"\n", | |
"## Load predictions pickle file\n", | |
"pred_data_path = category_name + '_3dlfm.pkl'\n", | |
"with open(pred_data_path, 'rb') as f:\n", | |
" pred_data = pickle.load(f)\n", | |
"labels_3d = pred_data['labels_3d']\n", | |
"outputs_3d = pred_data['outputs_3d']\n", | |
"\n", | |
"\n", | |
"joint_connections = retrieve_joint_connections(category_name)\n", | |
"\n", | |
"## Print the statistics\n", | |
"print(\"Number of images: \", len(image_path))\n", | |
"print(\"Input 2D shape: \", input_2d.shape)\n", | |
"print(\"Labels 3D shape: \", labels_3d.shape)\n", | |
"print(\"Outputs 3D shape: \", outputs_3d.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def plot_3d_skeleton(predictions_3d, labels_3d, joint_connections, range_scale=2500, masks=None):\n", | |
" \"\"\"Visualize 3D skeletons for predicted and ground truth data.\"\"\" \n", | |
"\n", | |
" # Extract 3D coordinates and masks for the given sample index\n", | |
" pred_coordinates = predictions_3d\n", | |
" label_coordinates = labels_3d\n", | |
" \n", | |
" # Extract X, Y, Z coordinates after filtering\n", | |
" label_x, label_y, label_z = label_coordinates.T \n", | |
"\n", | |
" # Filter joint connections based on the mask\n", | |
" if masks is not None:\n", | |
" updated_connections = [connection for connection in joint_connections if masks[connection[0]] == 1.0 and masks[connection[1]] == 1.0]\n", | |
" print(\"Updated connections: {}\".format(updated_connections))\n", | |
" else:\n", | |
" updated_connections = joint_connections \n", | |
"\n", | |
" \n", | |
" \n", | |
" # Plotly Traces\n", | |
" traces = []\n", | |
" # Predicted skeleton\n", | |
" traces.extend(get_trace3d(updated_connections, pred_coordinates, 'blue', 'blue', \"Predicted KP\", masks=masks))\n", | |
" # Ground truth skeleton\n", | |
" traces.extend(get_trace3d(updated_connections, label_coordinates, 'red', 'red', \"Groundtruth KP\", masks=masks))\n", | |
"\n", | |
" # Define layout\n", | |
" layout = go.Layout(\n", | |
" scene=dict(\n", | |
" aspectratio=dict(x=1, y=1, z=2),\n", | |
" xaxis=dict(range=[-label_x.max() * range_scale, label_x.max() * range_scale, ], showticklabels=False),\n", | |
" yaxis=dict(range=[-label_z.max() * range_scale, label_z.max() * range_scale], showticklabels=False),\n", | |
" zaxis=dict(range=[-label_y.max() * range_scale, label_y.max() * range_scale], showticklabels=False),\n", | |
" ),\n", | |
" width=700,\n", | |
" margin=dict(r=20, l=10, b=10, t=10),\n", | |
" scene_camera=dict(\n", | |
" up=dict(x=0, y=0, z=1),\n", | |
" center=dict(x=0, y=0, z=0),\n", | |
" eye=dict(x=0, y=-1.5, z=1.25),\n", | |
" )\n", | |
" )\n", | |
"\n", | |
" # Create and display the plot\n", | |
" # fig = go.Figure(data=traces, layout=layout)\n", | |
" fig = go.Figure(data=traces)\n", | |
" fig.update_layout(scene=dict(aspectmode=\"data\")) \n", | |
" fig.update_layout(\n", | |
" scene=dict(\n", | |
" xaxis=dict(title='', showticklabels=False),\n", | |
" yaxis=dict(title='', showticklabels=False),\n", | |
" zaxis=dict(title='', showticklabels=False)\n", | |
" )\n", | |
" )\n", | |
" fig.show() \n", | |
"\n", | |
"\n", | |
"def get_trace3d(joint_connections, points3d, point_color, line_color, name, masks=None):\n", | |
" \"\"\"Generate plotly traces for 3D points and connections.\"\"\"\n", | |
"\n", | |
" # Filter 3D coordinates based on the mask\n", | |
" if masks is not None:\n", | |
" masked_coordinates = points3d[masks == 1.0]\n", | |
" else:\n", | |
" masked_coordinates = points3d\n", | |
" \n", | |
" x, z, y = masked_coordinates.T # Swap Y and Z here\n", | |
" x_trace, z_trace, y_trace = points3d.T # Swap Y and Z here\n", | |
"\n", | |
" # Trace of points\n", | |
" trace_pts = go.Scatter3d(\n", | |
" x=x, y=y, z=z,\n", | |
" mode='markers',\n", | |
" name=name,\n", | |
" marker=dict(symbol='circle', size=6, color=point_color)\n", | |
" )\n", | |
"\n", | |
" # Trace of lines\n", | |
" x_lines = []\n", | |
" y_lines = []\n", | |
" z_lines = []\n", | |
"\n", | |
" for start, end in joint_connections:\n", | |
" x_lines.extend([x_trace[start], x_trace[end], None])\n", | |
" y_lines.extend([y_trace[start], y_trace[end], None])\n", | |
" z_lines.extend([z_trace[start], z_trace[end], None])\n", | |
"\n", | |
" trace_lines = go.Scatter3d(\n", | |
" x=x_lines, y=y_lines, z=z_lines,\n", | |
" mode='lines',\n", | |
" name=name,\n", | |
" line=dict(width=6, color=line_color)\n", | |
" )\n", | |
"\n", | |
" return [trace_pts, trace_lines]\n", | |
"\n", | |
"\n", | |
"import plotly.graph_objects as go\n", | |
"\n", | |
"def plot_2d_skeleton(predictions_2d, labels_2d, joint_connections, masks=None):\n", | |
" \"\"\"Visualize 2D skeletons for predicted and ground truth data.\"\"\"\n", | |
" \n", | |
" # Extract 2D coordinates and masks for the given sample index\n", | |
" pred_coordinates = predictions_2d\n", | |
" label_coordinates = labels_2d\n", | |
"\n", | |
" # Filter joint connections based on the mask\n", | |
" if masks is not None:\n", | |
" updated_connections = [connection for connection in joint_connections if masks[connection[0]] == 1.0 and masks[connection[1]] == 1.0]\n", | |
" print(\"Updated connections: {}\".format(updated_connections))\n", | |
" else:\n", | |
" updated_connections = joint_connections\n", | |
"\n", | |
" \n", | |
" \n", | |
" \n", | |
" # Plotly Traces\n", | |
" traces = []\n", | |
" # Predicted skeleton\n", | |
" traces.extend(get_trace2d(updated_connections, pred_coordinates, 'blue', 'blue', \"Predicted KP\", masks=masks))\n", | |
" # Ground truth skeleton\n", | |
" traces.extend(get_trace2d(updated_connections, label_coordinates, 'red', 'red', \"Groundtruth KP\", masks=masks))\n", | |
"\n", | |
" # Define layout\n", | |
" layout = go.Layout(\n", | |
" width=700,\n", | |
" height=700,\n", | |
" margin=dict(r=20, l=10, b=10, t=10)\n", | |
" )\n", | |
"\n", | |
" # Create and display the plot\n", | |
" fig = go.Figure(data=traces, layout=layout)\n", | |
" fig.show()\n", | |
"\n", | |
"def get_trace2d(joint_connections, points2d, point_color, line_color, name, masks=None, get_lines=None):\n", | |
" \"\"\"Generate plotly traces for 2D points and connections.\"\"\"\n", | |
" \n", | |
" # Filter 2D coordinates based on the mask\n", | |
" if masks is not None:\n", | |
" masked_coordinates = points2d[masks == 1.0]\n", | |
" else:\n", | |
" masked_coordinates = points2d\n", | |
" \n", | |
" x, y = masked_coordinates.T # Swap Y and Z here\n", | |
" x_trace, y_trace = points2d.T # Swap Y and Z here\n", | |
"\n", | |
" # Trace of points\n", | |
" trace_pts = go.Scatter(\n", | |
" x=x, y=y,\n", | |
" mode='markers',\n", | |
" name=name,\n", | |
" marker=dict(symbol='circle', size=6, color=point_color)\n", | |
" )\n", | |
"\n", | |
" # Trace of lines\n", | |
" x_lines = []\n", | |
" y_lines = []\n", | |
"\n", | |
" for start, end in joint_connections:\n", | |
" x_lines.extend([x_trace[start], x_trace[end], None])\n", | |
" y_lines.extend([y_trace[start], y_trace[end], None])\n", | |
"\n", | |
" trace_lines = go.Scatter(\n", | |
" x=x_lines, y=y_lines,\n", | |
" mode='lines',\n", | |
" name=name,\n", | |
" line=dict(width=2, color=line_color)\n", | |
" )\n", | |
"\n", | |
" if get_lines is not None:\n", | |
" if get_lines:\n", | |
" return [trace_pts, trace_lines]\n", | |
" else:\n", | |
" return [trace_pts] \n", | |
" else:\n", | |
" return [trace_pts, trace_lines]\n", | |
"\n", | |
"\n", | |
"import plotly.graph_objs as go\n", | |
"from PIL import Image\n", | |
"import numpy as np\n", | |
"\n", | |
"from PIL import Image\n", | |
"import numpy as np\n", | |
"def plot_2d_skeleton_on_image(predictions_2d, labels_2d, joint_connections, image_path, masks=None, get_lines=None):\n", | |
" \"\"\"Visualize 2D skeletons for predicted and ground truth data on top of an image.\"\"\"\n", | |
" \n", | |
" # Load the image\n", | |
" image = Image.open(image_path)\n", | |
" width, height = image.size\n", | |
"\n", | |
" # Extract 2D coordinates and masks for the given sample index\n", | |
" pred_coordinates = predictions_2d\n", | |
" label_coordinates = labels_2d\n", | |
" masks_ = masks\n", | |
" \n", | |
" # Filter joint connections based on the mask\n", | |
" if masks is not None:\n", | |
" updated_connections = [connection for connection in joint_connections if masks_[connection[0]] == 1.0 and masks_[connection[1]] == 1.0]\n", | |
" else:\n", | |
" updated_connections = joint_connections\n", | |
"\n", | |
" print(\"updated connections: {}\".format(updated_connections))\n", | |
"\n", | |
" # Plotly Traces\n", | |
" traces = []\n", | |
" # Image as background\n", | |
" traces.append(go.Scatter(\n", | |
" x=[0, width],\n", | |
" y=[0, height],\n", | |
" mode=\"markers\",\n", | |
" marker_opacity=0,\n", | |
" hoverinfo=\"none\",\n", | |
" showlegend=False\n", | |
" ))\n", | |
"\n", | |
" # Predicted skeleton\n", | |
" traces.extend(get_trace2d(updated_connections, pred_coordinates, 'blue', 'blue', None, masks_, get_lines=get_lines))\n", | |
" # Ground truth skeleton\n", | |
" traces.extend(get_trace2d(updated_connections, label_coordinates, 'red', 'red', None, masks_, get_lines=get_lines))\n", | |
"\n", | |
" # Define layout\n", | |
" layout = go.Layout(\n", | |
" width=width,\n", | |
" height=height,\n", | |
" xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[0, width]),\n", | |
" yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[height, 0], scaleanchor=\"x\"),\n", | |
" images=[go.layout.Image(source=image, xref=\"x\", yref=\"y\", x=0, y=0, sizex=width, sizey=height, sizing=\"stretch\", opacity=1.0, layer=\"below\")],\n", | |
" margin=dict(r=10, l=10, b=10, t=10),\n", | |
" hovermode=\"closest\",\n", | |
" showlegend=False, # Hide legend\n", | |
" )\n", | |
" \n", | |
" # Create and display the plot\n", | |
" fig = go.Figure(data=traces, layout=layout)\n", | |
" fig.show()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"get_lines = False; frame_idx = 227\n", | |
"plot_3d_skeleton(-outputs_3d[frame_idx], -labels_3d[frame_idx], joint_connections, range_scale=2500, masks=None)\n", | |
"plot_2d_skeleton_on_image(input_2d[frame_idx], input_2d[frame_idx], joint_connections, image_path[frame_idx], masks=None, get_lines=get_lines)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#### Final storage #####\n", | |
"\n", | |
"## Randomly choose 10 frames and store them as a final pickle file\n", | |
"# random_indices = random.sample(range(len(image_path)), 10)\n", | |
"random_indices = [167, 29, 1, 3, 4, 7, 10, 18, 123, 227]\n", | |
"final_data = {}\n", | |
"final_data['image_path'] = np.asarray(image_path)[random_indices].tolist()\n", | |
"final_data['inputs_2d'] = input_2d[random_indices]\n", | |
"final_data['labels_3d'] = labels_3d[random_indices]\n", | |
"final_data['outputs_3d'] = outputs_3d[random_indices]\n", | |
"\n", | |
"# Create final directory\n", | |
"if not os.path.exists('final'):\n", | |
" os.makedirs('final')\n", | |
"\n", | |
"# if the below directory already exists, then delete it\n", | |
"if os.path.exists('final/' + category_name + '_images'):\n", | |
" shutil.rmtree('final/' + category_name + '_images')\n", | |
"os.makedirs('final/' + category_name + '_images', exist_ok=True)\n", | |
"# Copying images to the new directory and updating the paths\n", | |
"new_image_paths = []\n", | |
"for path in tqdm(final_data['image_path']):\n", | |
" original_image_path = path # Save the original path\n", | |
" new_path = os.path.join('final/' + category_name + '_images', path)\n", | |
"\n", | |
" # Create the directory if it doesn't exist\n", | |
" os.makedirs(os.path.dirname(new_path), exist_ok=True)\n", | |
"\n", | |
" shutil.copy(original_image_path, new_path)\n", | |
" new_image_paths.append(new_path)\n", | |
"\n", | |
"# Update the image_path in data\n", | |
"final_data['image_path'] = new_image_paths\n", | |
"\n", | |
"# Save the new pickle file\n", | |
"final_pickle_name = 'final/' + category_name + '.pkl'\n", | |
"with open(final_pickle_name, 'wb') as f:\n", | |
" pickle.dump(final_data, f)\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# print(random_indices)\n", | |
"random_indices = [167, 29, 1, 3, 4, 7, 10, 18, 123, 227]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "lifting", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} | |