File size: 6,007 Bytes
753fd9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

import torch
from torch.utils.data import DataLoader
import cv2
import glob
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src'))


def get_evaluation_dataset(cfg_data_dataset, cfg_data_val_opt, cfg_data_V12, cfg_optim_batch_size, args_workers, drop_last=False):
    # cfg_data_dataset = cfg.data.DATASET
    # cfg_data_val_opt = cfg.data.VAL_OPT 
    # cfg_data_V12 = cfg.data.V12
    # cfg_optim_batch_size = cfg.optim.BATCH_SIZE
    # args_workers = args.workers
    assert cfg_data_dataset in ['stanext24_easy', 'stanext24', 'stanext24_withgc', 'stanext24_withgc_big']
    assert cfg_data_val_opt in ['train', 'test', 'val']

    if cfg_data_dataset == 'stanext24_easy':
        from stacked_hourglass.datasets.stanext24_easy import StanExtEasy as StanExt 
        dataset_mode = 'complete'
    elif cfg_data_dataset == 'stanext24':
        from stacked_hourglass.datasets.stanext24 import StanExt 
        dataset_mode = 'complete'
    elif cfg_data_dataset == 'stanext24_withgc':
        from stacked_hourglass.datasets.stanext24_withgc import StanExtGC as StanExt 
        dataset_mode = 'complete_with_gc'
    elif cfg_data_dataset == 'stanext24_withgc_big':
        from stacked_hourglass.datasets.stanext24_withgc_v2 import StanExtGC as StanExt 
        dataset_mode = 'complete_with_gc'

    # Initialise the validation set dataloader
    if cfg_data_val_opt == 'test':
        val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg_data_V12, val_opt='test')
        test_name_list = val_dataset.test_name_list
    elif cfg_data_val_opt == 'val':
        val_dataset = StanExt(image_path=None, is_train=False, dataset_mode=dataset_mode, V12=cfg_data_V12, val_opt='val')
        test_name_list = val_dataset.test_name_list
    elif cfg_data_val_opt == 'train':
        val_dataset = StanExt(image_path=None, is_train=True, do_augment='no', dataset_mode=dataset_mode, V12=cfg_data_V12)
        test_name_list = val_dataset.train_name_list
    else:
        raise ValueError
    val_loader = DataLoader(val_dataset, batch_size=cfg_optim_batch_size, shuffle=False,
                            num_workers=args_workers, pin_memory=True, drop_last=drop_last) # False)  # , drop_last=True    args.batch_size
    len_val_dataset = len(val_dataset)
    stanext_data_info = StanExt.DATA_INFO
    stanext_acc_joints = StanExt.ACC_JOINTS
    return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints


def get_sketchfab_evaluation_dataset(cfg_optim_batch_size, args_workers):
    # cfg_optim_batch_size = cfg.optim.BATCH_SIZE
    # args_workers = args.workers
    from stacked_hourglass.datasets.sketchfab import SketchfabScans
    val_dataset = SketchfabScans(image_path=None, is_train=False, dataset_mode='complete')
    test_name_list = val_dataset.test_name_list
    val_loader = DataLoader(val_dataset, batch_size=cfg_optim_batch_size, shuffle=False,
                            num_workers=args_workers, pin_memory=True, drop_last=False)     # drop_last=True)
    from stacked_hourglass.datasets.stanext24 import StanExt
    len_val_dataset = len(val_dataset)
    stanext_data_info = StanExt.DATA_INFO
    stanext_acc_joints = StanExt.ACC_JOINTS
    return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints

def get_crop_evaluation_dataset(cfg_optim_batch_size, args_workers, input_folder):
    from stacked_hourglass.datasets.imgcropslist import ImgCrops
    image_list_paths = glob.glob(os.path.join(input_folder, '*.jpg')) + glob.glob(os.path.join(input_folder, '*.png'))
    image_list = []
    test_name_list = []
    for image_path in image_list_paths:
        test_name_list.append(os.path.basename(image_path).split('.')[0])
        img = cv2.imread(image_path)
        image_list.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    val_dataset = ImgCrops(image_list=image_list, bbox_list=None)
    val_loader = DataLoader(val_dataset, batch_size=cfg_optim_batch_size, shuffle=False,
                            num_workers=args_workers, pin_memory=True, drop_last=False)     # drop_last=True)
    from stacked_hourglass.datasets.stanext24 import StanExt
    len_val_dataset = len(val_dataset)
    stanext_data_info = StanExt.DATA_INFO
    stanext_acc_joints = StanExt.ACC_JOINTS
    return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints

def get_single_crop_dataset_from_image(input_image, bbox=None):
    from stacked_hourglass.datasets.imgcropslist import ImgCrops
    input_image_list = [input_image]
    if bbox is not None:
        input_bbox_list = [bbox]
    else:
        input_bbox_list = None
    # prepare data loader
    val_dataset = ImgCrops(image_list=input_image_list, bbox_list=input_bbox_list, dataset_mode='complete')
    test_name_list = val_dataset.test_name_list
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
                            num_workers=0, pin_memory=True, drop_last=False)   
    from stacked_hourglass.datasets.stanext24 import StanExt
    len_val_dataset = len(val_dataset)
    stanext_data_info = StanExt.DATA_INFO
    stanext_acc_joints = StanExt.ACC_JOINTS
    return val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints




def get_norm_dict(data_info=None, device="cuda"):
    if data_info is None:
        from stacked_hourglass.datasets.stanext24 import StanExt
        data_info = StanExt.DATA_INFO
    norm_dict = {
        'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
        'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
        'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
        'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
        'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
    return norm_dict