3dlfm / demo-data /validation /boat /.goutputstream-5O3CF2
dylanebert's picture
dylanebert HF staff
initial commit
8df98cb
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"import numpy as np\n",
"import os\n",
"import shutil\n",
"from tqdm import tqdm\n",
"import os\n",
"import shutil\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def retrieve_joint_connections(dataset):\n",
" if dataset == \"Human36M\":\n",
" joint_connections = [[14, 15], [15, 16], [13, 12], [12, 11], [9, 8], [8, 7], [4, 5], [5, 6], [3, 2], [2, 1], [7, 0], [0, 4], [0, 1], [8, 11], [8, 14], [9, 10]] \n",
" defined_kpts = 17\n",
"\n",
" elif dataset == \"face\":\n",
" joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4],[5, 6], [6, 7], [7, 8], [8, 9], [10, 11], [11, 12], [12, 13], [14, 15], [15, 16], [16, 17], [17, 18], [19, 20], [20, 21], [21, 22], [22, 23], [23, 24], [24, 19], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 25], [31, 32], [32, 33], [33, 34], [34, 35], [35, 36], [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 42], [42, 31], [43, 44], [44, 45], [45, 46], [46, 47], [47, 48], [48, 49], [49, 50], [50, 43]]\n",
" defined_kpts = 53\n",
"\n",
" # elif dataset == \"cheetah\":\n",
" # joint_connections = [[0,2], [0,1], [0,3], [2,3], [1,3], [3,4], [4,5], [5,6], [6,7],\n",
" # [3,8], [8,9], [9,10], [10,11],\n",
" # [3,12], [12,13], [13,14], [14,15],\n",
" # [5, 16], [16,17], [17,18], [18,19],\n",
" # [5,20], [20,21], [21,22], [22,23]]\n",
"\n",
" elif dataset == \"cheetah\":\n",
" joint_connections = [[3, 1], [1, 2], [3, 4], [4, 5], \n",
" [5, 6], [6, 7], [8, 9], [9, 10], [3, 8], \n",
" [3, 11], [11, 12], [12, 13], [5, 17], \n",
" [17, 18], [18, 19], [5, 14], [14, 15], \n",
" [15, 16]]\n",
" defined_kpts = 20\n",
"\n",
" elif dataset == \"cheetahtr\":\n",
" joint_connections = [[3, 1], [1, 2], [3, 4], [4, 5], \n",
" [5, 6], [6, 7], [8, 9], [9, 10], [3, 8], \n",
" [3, 11], [11, 12], [12, 13], [5, 17], \n",
" [17, 18], [18, 19], [5, 14], [14, 15], \n",
" [15, 16]]\n",
" defined_kpts = 20\n",
"\n",
"\n",
" elif dataset == \"cheetahsub\":\n",
" joint_connections = [[2, 0], [0, 1], [2, 3], [3, 4],\n",
" [4, 5], [6, 7], [7, 8], [2, 6],\n",
" [2, 9], [9, 10], [10, 11], [3, 15],\n",
" [15, 16], [16, 17], [3, 12], [12, 13], [13, 14]]\n",
" defined_kpts = 18\n",
"\n",
" elif dataset == \"hands\":\n",
" joint_connections = [[0,1],[1,2],[2,3],[3,4],[0,5],[5,6],[6,7],[7,8],[0,9],[9,10],[10,11],[11,12],[0,13],[13,14],[14,15],[15,16],[0,17],[17,18],[18,19],[19,20]]\n",
" defined_kpts = 21\n",
"\n",
" elif dataset == \"amass\":\n",
" joint_connections = [[9, 13], [13, 16], [16, 18], [18, 20], [6,9], [3, 6], [0, 3], [0, 1], [0, 2], [1, 4] , [4, 7], [7, 10], [2, 5], [5, 8], [8, 11], \n",
" [9, 14], [14, 17], [17, 19], [19, 21], [6, 12], [12, 15], \n",
" [20, 34], [34, 35], [35, 36], \n",
" [20, 22], [22, 23], [23, 24], \n",
" [20, 25], [25, 26], [26, 27], \n",
" [20, 31], [31, 32], [32, 33],\n",
" [20, 28], [28, 29], [29, 30], \n",
" # [22, 25], [25, 31], [31, 28]\n",
" [21, 49], [49, 50], [50, 51],\n",
" [21, 37], [37, 38], [38, 39],\n",
" [21, 40], [40, 41], [41, 42],\n",
" [21, 46], [46, 47], [47, 48],\n",
" [21, 43], [43, 44], [44, 45]]\n",
" defined_kpts = 52\n",
"\n",
" elif dataset == \"openmonkey\":\n",
" joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [2, 5], [5, 6], [2, 7], [7, 8], [8, 9], [7, 10], [10, 11], [7, 12]] \n",
" defined_kpts = 13\n",
"\n",
" elif dataset == \"wholebodyh36m\":\n",
" joint_connections = [[0, 1], [1, 3], [0, 2], [2, 4], [59, 64], [65, 70], [71, 82], \n",
" [71, 83], [77, 87], [77, 88], [88, 89], [89, 90], [71, 90],\n",
" [5, 7], [7, 9], [9, 91], [91, 92], [93, 96], [96, 100], [100, 104], [104, 108], [91, 108],\n",
" [6, 8], [8, 10], [10, 112], [112, 113], [114, 117], [117, 121], [121, 125], [125, 129], [112, 129],\n",
" [5, 6], [6, 12], [11, 12], [5, 11], [12, 14], [14, 16], [16, 20], [16, 21], [16, 22], [11, 13], [13, 15], [15, 17], [15, 18], [15, 19]] \n",
"\n",
" joint_connections = [\n",
" # Connect points as defined by the tuples\n",
" [0, 1], [1, 3], [0, 2], [2, 4], \n",
" [5, 7], [7, 9], [9, 91], #[91, 92], [93, 96], [96, 100], [100, 104], [104, 108], [91, 108],\n",
" [6, 8], [8, 10], [10, 112], #[112, 113], [114, 117], [117, 121], [121, 125], [125, 129], [112, 129],\n",
" [5, 6], [6, 12], [11, 12], [5, 11], \n",
" [12, 14], [14, 16], [16, 20], [16, 21], [16, 22], \n",
" [11, 13], [13, 15], [15, 17], [15, 18], [15, 19],\n",
"\n",
" # Face\n",
" [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [31, 32], [32, 33], [33, 34],\n",
" [34, 35], [35, 36], [36, 37], [37, 38], [38, 39], [40, 41], [41, 42], [42, 43], [43, 44], [59, 60], [60, 61],\n",
" [61, 62], [62, 63], [63, 64], [59, 64], [45, 46], [46, 47], [47, 48], [48, 49], [65, 66], [66, 67], [67, 68],\n",
" [68, 69], [69, 70], [65, 70], [50, 51], [51, 52], [52, 53], [54, 55], [55, 56], [56, 57], [57, 58], [71, 72],\n",
" [72, 73], [73, 74], [74, 75], [75, 76], [76, 77], [77, 78], [78, 79], [79, 80], [80, 81], [81, 82], [82, 83],\n",
" [83, 84], [84, 85], [85, 86], [86, 87], [87, 88], [88, 89], [89, 90], [91, 92],\n",
"\n",
" # Left hand\n",
" [91, 92], [92, 93], [93, 94], [94, 95], [91, 96], [96, 97], [97, 98], [98, 99], [91, 100], [100, 101], [101, 102],\n",
" [102, 103], [91, 104], [104, 105], [105, 106], [106, 107], [91, 108], [108, 109], [109, 110], [110, 111],\n",
"\n",
"\n",
" # Right hand\n",
" [112, 113], [113, 114], [114, 115], [115, 116], [112, 117], [117, 118], [118, 119], [119, 120], [112, 121],\n",
" [121, 122], [122, 123], [123, 124], [112, 125], [125, 126], [126, 127], [127, 128], [112, 129], [129, 130],\n",
" [130, 131], [131, 132]\n",
"\n",
" ] \n",
"\n",
" defined_kpts = 133\n",
"\n",
" elif dataset == \"bp4d+\":\n",
" joint_connections = [\n",
" # Left eyebrow (viewed from the model's perspective)\n",
" [15, 16], [16, 17], [17, 18], [18, 19],\n",
" [10, 11], [11, 12], [12, 13], [13, 14],\n",
" \n",
" # Right eyebrow\n",
" [0, 1], [1, 2], [2, 3], [3, 4],\n",
" [5, 6], [6, 7], [7, 8], [8, 9],\n",
" \n",
" # Bridge of the nose (from between the eyebrows to the tip)\n",
" [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 42], [42, 43], [43, 44], [44, 45], [45, 46], [46, 47],\n",
" \n",
" # Left eye\n",
" [28, 29], [29, 30], [30, 31], [31, 32], [32, 33], [33, 34], [34, 35], [35, 28],\n",
" \n",
" # Right eye\n",
" [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 20],\n",
" \n",
" # Outer part of the lips (outline of the lips)\n",
" [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],\n",
" \n",
" # Inner part of the lips (detail within the lips)\n",
" [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], [66, 67], [67, 60],\n",
" \n",
" # Jawline (from left ear, around the chin, to right ear)\n",
" [68, 69], [69, 70], [70, 71], [71, 72], [72, 73], [73, 74], [74, 75], [75, 76], [76, 77], [77, 78], [78, 79], [79, 80], [80, 81], [81, 82]\n",
" ]\n",
" defined_kpts = 83\n",
"\n",
" \n",
" elif dataset == \"panoptic\":\n",
" joint_connections = [[0, 1], [0, 3], [3, 4], [4, 5], [0, 2], [2, 6], [6, 7], [7, 8], [2, 12], [12, 13], [13, 14], [0, 9], [9, 10], [10, 11]]\n",
" defined_kpts = 15\n",
"\n",
" elif dataset == 'aeroplane':\n",
" joint_connections = [[2, 5], [1, 4], [5, 3], [3, 7], [7, 0], [0, 5], [5, 7], [5, 6], [6, 0], [6, 3], [2, 4], [2, 1]]\n",
" defined_kpts = 8\n",
"\n",
" elif dataset == 'bicycle':\n",
" joint_connections = [[0,3], [0,7], [0, 2], [0, 6], [0, 10], [9, 10], [4, 10], [8, 10], [1, 9], [5, 9]]\n",
" defined_kpts = 11\n",
"\n",
" elif dataset == \"tiger\" or dataset == \"cow\" or dataset == \"horse\" or dataset == \"hippo\" or dataset == \"dog\":\n",
" joint_connections = [[0,24], [0, 20], [1, 21], [1, 24], [7, 25], [19, 25], [6, 17],\n",
" [4, 15], [3, 14], [9, 15], [8, 14], [9, 13], [8, 12],\n",
" [2, 23], [2, 22], [2, 24], [11, 17], [10, 16], [5, 16],\n",
" [7, 10], [7, 11], [13,18], [12, 18], [7, 18], [24,18]]\n",
" defined_kpts = 26\n",
"\n",
" elif dataset == \"tigersubset\" or dataset == \"cowsubset\" or dataset == \"horsesubset\" or dataset == \"hipposubset\" or dataset == \"dogsubset\":\n",
" joint_connections = [[3, 17], [15, 17], [2, 13], [5, 9], [4, 8], [7, 13], [6, 12], [1, 12], [3, 6], [3, 7], [9, 14], [8, 14], [3, 14], [5, 11], [4, 10], [0, 14], [0, 16]] \n",
" defined_kpts = 18\n",
"\n",
" elif dataset == 'boat':\n",
" joint_connections = [[0, 2], [0, 3], [0, 1], [1, 2], [1, 3], [2, 4], [3, 5], [4, 5], [1, 5], [1, 4]]\n",
" defined_kpts = 6\n",
"\n",
" elif dataset == 'bottle':\n",
" joint_connections = [[0, 1], [1, 2], [0, 2], [3, 4], [3, 5], [4, 5], [1, 4], [0, 3], [2, 5], [1, 6], [0, 6], [2, 6]]\n",
" defined_kpts = 8\n",
"\n",
" elif dataset == 'busfull' or dataset == 'bus':\n",
" joint_connections = [[5, 7], [4, 5], [6, 7], [4, 6], [1, 5], [1, 3], [3, 7], [0, 1], [2, 3], [0, 2], [2, 10], [0, 8], [8, 9], [10, 11], [6, 11], [4, 9]]\n",
" defined_kpts = 12\n",
"\n",
" elif dataset == 'car':\n",
" joint_connections = [[0, 8], [0, 4], [4, 10], [8, 10],\n",
" [10, 9], [9, 11], [8, 11], [11, 6], \n",
" [9, 2], [2, 6], [4, 1], [5, 1], \n",
" [0, 5], [5, 7], [1, 3], [7, 3], [3, 2], [7, 6]] \n",
" defined_kpts = 12\n",
"\n",
" elif dataset == 'busmissing':\n",
" joint_connections = [[5, 7], [4, 5], [6, 7], [4, 6], [1, 5], [1, 3], [3, 7], [0, 1], [2, 3], [0, 2], [2, 6], [0, 4]]\n",
" defined_kpts = 8\n",
"\n",
" elif dataset == 'diningtable':\n",
" joint_connections = [[0, 2], [4, 6], [1, 3], [5, 7], [1, 5], [3, 7], [0, 4], [2, 6], [0, 1], [2, 3], [4, 5], [6, 7]]\n",
" defined_kpts = 8\n",
"\n",
" elif dataset == 'tvmonitor':\n",
" joint_connections = [[5, 7], [4, 5], [4, 6], [6, 7], [0, 1], [0, 2], [2, 3], [1, 3], [3, 7], [1, 5], [2, 6], [0, 4]]\n",
" defined_kpts = 8\n",
"\n",
" elif dataset == 'train':\n",
" joint_connections = [[4, 5], [4, 6], [6, 7], [5, 7], [0, 1], [1, 3], [2, 3], [0, 2], [1, 5], [0, 4], [2, 6], [3, 7], [1, 5]]\n",
" defined_kpts = 8\n",
"\n",
"\n",
" elif dataset == 'train16':\n",
" joint_connections = [[0, 1], [1,5], [5, 9], [9, 15], [3, 7], [7, 11], [11, 13], [2, 3], [2, 6], [6, 10], [10, 12], [1, 3], [0, 2], \n",
" [0, 4], [4, 8], [8, 14], [15, 13], [13, 12], [12, 14], [14, 15]]\n",
" defined_kpts = 16\n",
"\n",
" elif dataset == 'motorbike':\n",
" joint_connections = [[6, 2], [2, 9], [2, 3], [3, 8], [5, 8],\n",
" [3, 5], [2, 1], [1, 0], [0, 7], [0, 4],\n",
" [4, 7], [1, 4], [1, 7], [1, 5], [1, 8]]\n",
" defined_kpts = 10\n",
"\n",
" elif dataset == 'sofa':\n",
" joint_connections = [[1, 5], [5, 4], [4, 6], [6, 2], [2, 0], \n",
" [1, 0], [0, 4], [1, 3], [7, 5], [2, 3], \n",
" [3, 7], [9, 7], [7, 6], [6, 8], [8, 9]]\n",
" defined_kpts = 10\n",
"\n",
" elif dataset == 'chair':\n",
" joint_connections = [[7, 3], [6, 2], [9, 5], [8, 4], [7, 9], \n",
" [8, 6], [6, 7], [9, 8], [9, 1], [8, 0], [1, 0]]\n",
" defined_kpts = 10\n",
"\n",
" # MBW datasets\n",
" elif dataset == 'colobusmonkey':\n",
" joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [10, 11], [11, 12], [9, 13], [13, 14], [14, 15]]\n",
" defined_kpts = 16\n",
"\n",
" elif dataset == 'chimpanzee':\n",
" joint_connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [10, 11], [11, 12], [9, 13], [13, 14], [14, 15]]\n",
" defined_kpts = 16\n",
"\n",
"\n",
" elif dataset == 'tigerzoo':\n",
" joint_connections = [[0, 1], [1, 2], [2, 3], [3,4], [4,5], [1,6], [6,7], [1,8], [8,9], [3,10], [10,11], [3,12], [12,13]]\n",
" defined_kpts = 14\n",
"\n",
" elif dataset == 'clownfish':\n",
" joint_connections = [[0, 1], [1, 2], [2, 3], [1, 4], [1, 5]]\n",
" defined_kpts = 6\n",
"\n",
" elif dataset == 'fish':\n",
" joint_connections = [[0, 1], [1, 2], [2, 3], [1, 3], [3, 4], [4, 5], [5, 6], [6, 7], [5, 7], [5, 8], [8, 9], [9, 10], [8, 10], [10, 11], [11, 0]]\n",
" defined_kpts = 12\n",
"\n",
" elif dataset == 'seahorse':\n",
" joint_connections = [[0, 1], [1, 2], [2,3], [1,3], [3,4], [4,5]]\n",
" defined_kpts = 6\n",
" \n",
" return joint_connections"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load pickle file\n",
"\n",
"category_name = 'aeroplane'\n",
"suffix = 'val'\n",
"\n",
"pickle_name = category_name + '_' + suffix\n",
"\n",
"## Load input data\n",
"input_data_path = pickle_name + '.pkl'\n",
"with open(input_data_path, 'rb') as f:\n",
" input_data = pickle.load(f)\n",
"\n",
"input_2d = input_data['W_GT']\n",
"image_path = input_data['image_path']\n",
"\n",
"## Load predictions pickle file\n",
"pred_data_path = category_name + '_3dlfm.pkl'\n",
"with open(pred_data_path, 'rb') as f:\n",
" pred_data = pickle.load(f)\n",
"labels_3d = pred_data['labels_3d']\n",
"outputs_3d = pred_data['outputs_3d']\n",
"\n",
"\n",
"joint_connections = retrieve_joint_connections(category_name)\n",
"\n",
"## Print the statistics\n",
"print(\"Number of images: \", len(image_path))\n",
"print(\"Input 2D shape: \", input_2d.shape)\n",
"print(\"Labels 3D shape: \", labels_3d.shape)\n",
"print(\"Outputs 3D shape: \", outputs_3d.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_3d_skeleton(predictions_3d, labels_3d, joint_connections, range_scale=2500, masks=None):\n",
" \"\"\"Visualize 3D skeletons for predicted and ground truth data.\"\"\" \n",
"\n",
" # Extract 3D coordinates and masks for the given sample index\n",
" pred_coordinates = predictions_3d\n",
" label_coordinates = labels_3d\n",
" \n",
" # Extract X, Y, Z coordinates after filtering\n",
" label_x, label_y, label_z = label_coordinates.T \n",
"\n",
" # Filter joint connections based on the mask\n",
" if masks is not None:\n",
" updated_connections = [connection for connection in joint_connections if masks[connection[0]] == 1.0 and masks[connection[1]] == 1.0]\n",
" print(\"Updated connections: {}\".format(updated_connections))\n",
" else:\n",
" updated_connections = joint_connections \n",
"\n",
" \n",
" \n",
" # Plotly Traces\n",
" traces = []\n",
" # Predicted skeleton\n",
" traces.extend(get_trace3d(updated_connections, pred_coordinates, 'blue', 'blue', \"Predicted KP\", masks=masks))\n",
" # Ground truth skeleton\n",
" traces.extend(get_trace3d(updated_connections, label_coordinates, 'red', 'red', \"Groundtruth KP\", masks=masks))\n",
"\n",
" # Define layout\n",
" layout = go.Layout(\n",
" scene=dict(\n",
" aspectratio=dict(x=1, y=1, z=2),\n",
" xaxis=dict(range=[-label_x.max() * range_scale, label_x.max() * range_scale, ], showticklabels=False),\n",
" yaxis=dict(range=[-label_z.max() * range_scale, label_z.max() * range_scale], showticklabels=False),\n",
" zaxis=dict(range=[-label_y.max() * range_scale, label_y.max() * range_scale], showticklabels=False),\n",
" ),\n",
" width=700,\n",
" margin=dict(r=20, l=10, b=10, t=10),\n",
" scene_camera=dict(\n",
" up=dict(x=0, y=0, z=1),\n",
" center=dict(x=0, y=0, z=0),\n",
" eye=dict(x=0, y=-1.5, z=1.25),\n",
" )\n",
" )\n",
"\n",
" # Create and display the plot\n",
" # fig = go.Figure(data=traces, layout=layout)\n",
" fig = go.Figure(data=traces)\n",
" fig.update_layout(scene=dict(aspectmode=\"data\")) \n",
" fig.update_layout(\n",
" scene=dict(\n",
" xaxis=dict(title='', showticklabels=False),\n",
" yaxis=dict(title='', showticklabels=False),\n",
" zaxis=dict(title='', showticklabels=False)\n",
" )\n",
" )\n",
" fig.show() \n",
"\n",
"\n",
"def get_trace3d(joint_connections, points3d, point_color, line_color, name, masks=None):\n",
" \"\"\"Generate plotly traces for 3D points and connections.\"\"\"\n",
"\n",
" # Filter 3D coordinates based on the mask\n",
" if masks is not None:\n",
" masked_coordinates = points3d[masks == 1.0]\n",
" else:\n",
" masked_coordinates = points3d\n",
" \n",
" x, z, y = masked_coordinates.T # Swap Y and Z here\n",
" x_trace, z_trace, y_trace = points3d.T # Swap Y and Z here\n",
"\n",
" # Trace of points\n",
" trace_pts = go.Scatter3d(\n",
" x=x, y=y, z=z,\n",
" mode='markers',\n",
" name=name,\n",
" marker=dict(symbol='circle', size=6, color=point_color)\n",
" )\n",
"\n",
" # Trace of lines\n",
" x_lines = []\n",
" y_lines = []\n",
" z_lines = []\n",
"\n",
" for start, end in joint_connections:\n",
" x_lines.extend([x_trace[start], x_trace[end], None])\n",
" y_lines.extend([y_trace[start], y_trace[end], None])\n",
" z_lines.extend([z_trace[start], z_trace[end], None])\n",
"\n",
" trace_lines = go.Scatter3d(\n",
" x=x_lines, y=y_lines, z=z_lines,\n",
" mode='lines',\n",
" name=name,\n",
" line=dict(width=6, color=line_color)\n",
" )\n",
"\n",
" return [trace_pts, trace_lines]\n",
"\n",
"\n",
"import plotly.graph_objects as go\n",
"\n",
"def plot_2d_skeleton(predictions_2d, labels_2d, joint_connections, masks=None):\n",
" \"\"\"Visualize 2D skeletons for predicted and ground truth data.\"\"\"\n",
" \n",
" # Extract 2D coordinates and masks for the given sample index\n",
" pred_coordinates = predictions_2d\n",
" label_coordinates = labels_2d\n",
"\n",
" # Filter joint connections based on the mask\n",
" if masks is not None:\n",
" updated_connections = [connection for connection in joint_connections if masks[connection[0]] == 1.0 and masks[connection[1]] == 1.0]\n",
" print(\"Updated connections: {}\".format(updated_connections))\n",
" else:\n",
" updated_connections = joint_connections\n",
"\n",
" \n",
" \n",
" \n",
" # Plotly Traces\n",
" traces = []\n",
" # Predicted skeleton\n",
" traces.extend(get_trace2d(updated_connections, pred_coordinates, 'blue', 'blue', \"Predicted KP\", masks=masks))\n",
" # Ground truth skeleton\n",
" traces.extend(get_trace2d(updated_connections, label_coordinates, 'red', 'red', \"Groundtruth KP\", masks=masks))\n",
"\n",
" # Define layout\n",
" layout = go.Layout(\n",
" width=700,\n",
" height=700,\n",
" margin=dict(r=20, l=10, b=10, t=10)\n",
" )\n",
"\n",
" # Create and display the plot\n",
" fig = go.Figure(data=traces, layout=layout)\n",
" fig.show()\n",
"\n",
"def get_trace2d(joint_connections, points2d, point_color, line_color, name, masks=None, get_lines=None):\n",
" \"\"\"Generate plotly traces for 2D points and connections.\"\"\"\n",
" \n",
" # Filter 2D coordinates based on the mask\n",
" if masks is not None:\n",
" masked_coordinates = points2d[masks == 1.0]\n",
" else:\n",
" masked_coordinates = points2d\n",
" \n",
" x, y = masked_coordinates.T # Swap Y and Z here\n",
" x_trace, y_trace = points2d.T # Swap Y and Z here\n",
"\n",
" # Trace of points\n",
" trace_pts = go.Scatter(\n",
" x=x, y=y,\n",
" mode='markers',\n",
" name=name,\n",
" marker=dict(symbol='circle', size=6, color=point_color)\n",
" )\n",
"\n",
" # Trace of lines\n",
" x_lines = []\n",
" y_lines = []\n",
"\n",
" for start, end in joint_connections:\n",
" x_lines.extend([x_trace[start], x_trace[end], None])\n",
" y_lines.extend([y_trace[start], y_trace[end], None])\n",
"\n",
" trace_lines = go.Scatter(\n",
" x=x_lines, y=y_lines,\n",
" mode='lines',\n",
" name=name,\n",
" line=dict(width=2, color=line_color)\n",
" )\n",
"\n",
" if get_lines is not None:\n",
" if get_lines:\n",
" return [trace_pts, trace_lines]\n",
" else:\n",
" return [trace_pts] \n",
" else:\n",
" return [trace_pts, trace_lines]\n",
"\n",
"\n",
"import plotly.graph_objs as go\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"from PIL import Image\n",
"import numpy as np\n",
"def plot_2d_skeleton_on_image(predictions_2d, labels_2d, joint_connections, image_path, masks=None, get_lines=None):\n",
" \"\"\"Visualize 2D skeletons for predicted and ground truth data on top of an image.\"\"\"\n",
" \n",
" # Load the image\n",
" image = Image.open(image_path)\n",
" width, height = image.size\n",
"\n",
" # Extract 2D coordinates and masks for the given sample index\n",
" pred_coordinates = predictions_2d\n",
" label_coordinates = labels_2d\n",
" masks_ = masks\n",
" \n",
" # Filter joint connections based on the mask\n",
" if masks is not None:\n",
" updated_connections = [connection for connection in joint_connections if masks_[connection[0]] == 1.0 and masks_[connection[1]] == 1.0]\n",
" else:\n",
" updated_connections = joint_connections\n",
"\n",
" print(\"updated connections: {}\".format(updated_connections))\n",
"\n",
" # Plotly Traces\n",
" traces = []\n",
" # Image as background\n",
" traces.append(go.Scatter(\n",
" x=[0, width],\n",
" y=[0, height],\n",
" mode=\"markers\",\n",
" marker_opacity=0,\n",
" hoverinfo=\"none\",\n",
" showlegend=False\n",
" ))\n",
"\n",
" # Predicted skeleton\n",
" traces.extend(get_trace2d(updated_connections, pred_coordinates, 'blue', 'blue', None, masks_, get_lines=get_lines))\n",
" # Ground truth skeleton\n",
" traces.extend(get_trace2d(updated_connections, label_coordinates, 'red', 'red', None, masks_, get_lines=get_lines))\n",
"\n",
" # Define layout\n",
" layout = go.Layout(\n",
" width=width,\n",
" height=height,\n",
" xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[0, width]),\n",
" yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[height, 0], scaleanchor=\"x\"),\n",
" images=[go.layout.Image(source=image, xref=\"x\", yref=\"y\", x=0, y=0, sizex=width, sizey=height, sizing=\"stretch\", opacity=1.0, layer=\"below\")],\n",
" margin=dict(r=10, l=10, b=10, t=10),\n",
" hovermode=\"closest\",\n",
" showlegend=False, # Hide legend\n",
" )\n",
" \n",
" # Create and display the plot\n",
" fig = go.Figure(data=traces, layout=layout)\n",
" fig.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"get_lines = False; frame_idx = 227\n",
"plot_3d_skeleton(-outputs_3d[frame_idx], -labels_3d[frame_idx], joint_connections, range_scale=2500, masks=None)\n",
"plot_2d_skeleton_on_image(input_2d[frame_idx], input_2d[frame_idx], joint_connections, image_path[frame_idx], masks=None, get_lines=get_lines)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#### Final storage #####\n",
"\n",
"## Randomly choose 10 frames and store them as a final pickle file\n",
"# random_indices = random.sample(range(len(image_path)), 10)\n",
"random_indices = [167, 29, 1, 3, 4, 7, 10, 18, 123, 227]\n",
"final_data = {}\n",
"final_data['image_path'] = np.asarray(image_path)[random_indices].tolist()\n",
"final_data['inputs_2d'] = input_2d[random_indices]\n",
"final_data['labels_3d'] = labels_3d[random_indices]\n",
"final_data['outputs_3d'] = outputs_3d[random_indices]\n",
"\n",
"# Create final directory\n",
"if not os.path.exists('final'):\n",
" os.makedirs('final')\n",
"\n",
"# if the below directory already exists, then delete it\n",
"if os.path.exists('final/' + category_name + '_images'):\n",
" shutil.rmtree('final/' + category_name + '_images')\n",
"os.makedirs('final/' + category_name + '_images', exist_ok=True)\n",
"# Copying images to the new directory and updating the paths\n",
"new_image_paths = []\n",
"for path in tqdm(final_data['image_path']):\n",
" original_image_path = path # Save the original path\n",
" new_path = os.path.join('final/' + category_name + '_images', path)\n",
"\n",
" # Create the directory if it doesn't exist\n",
" os.makedirs(os.path.dirname(new_path), exist_ok=True)\n",
"\n",
" shutil.copy(original_image_path, new_path)\n",
" new_image_paths.append(new_path)\n",
"\n",
"# Update the image_path in data\n",
"final_data['image_path'] = new_image_paths\n",
"\n",
"# Save the new pickle file\n",
"final_pickle_name = 'final/' + category_name + '.pkl'\n",
"with open(final_pickle_name, 'wb') as f:\n",
" pickle.dump(final_data, f)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print(random_indices)\n",
"random_indices = [167, 29, 1, 3, 4, 7, 10, 18, 123, 227]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "lifting",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}