zhuotao.tian commited on
Commit
29680af
·
1 Parent(s): de08d89

add data processing demo script

Browse files
Files changed (1) hide show
  1. utils/data_proc_demo.py +83 -0
utils/data_proc_demo.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import json
4
+ import cv2
5
+ import glob
6
+
7
+ def get_mask_from_json(json_path, img):
8
+ try:
9
+ with open(json_path, 'r') as r:
10
+ anno = json.loads(r.read())
11
+ except:
12
+ with open(json_path, 'r', encoding="cp1252") as r:
13
+ anno = json.loads(r.read())
14
+
15
+ inform = anno['shapes']
16
+ comments = anno['text']
17
+ is_sentence = anno['is_sentence']
18
+
19
+ height, width = img.shape[:2]
20
+
21
+ ### sort polies by area
22
+ area_list = []
23
+ valid_poly_list = []
24
+ for i in inform:
25
+ label_id = i['label']
26
+ points = i['points']
27
+ if 'flag' == label_id.lower(): ## meaningless deprecated annotations
28
+ continue
29
+
30
+ tmp_mask = np.zeros((height, width), dtype=np.uint8)
31
+ cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1)
32
+ cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1)
33
+ tmp_area = tmp_mask.sum()
34
+
35
+ area_list.append(tmp_area)
36
+ valid_poly_list.append(i)
37
+
38
+ ### ground-truth mask
39
+ sort_index = np.argsort(area_list)[::-1].astype(np.int32)
40
+ sort_index = list(sort_index)
41
+ sort_inform = []
42
+ for s_idx in sort_index:
43
+ sort_inform.append(valid_poly_list[s_idx])
44
+
45
+ mask = np.zeros((height, width), dtype=np.uint8)
46
+ for i in sort_inform:
47
+ label_id = i['label']
48
+ points = i['points']
49
+
50
+ if 'ignore' in label_id.lower():
51
+ label_value = 255 # ignored during evaluation
52
+ else:
53
+ label_value = 1 # target
54
+
55
+ cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1)
56
+ cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value)
57
+
58
+ return mask, comments, is_sentence
59
+
60
+
61
+ if __name__ == '__main__':
62
+ data_dir = './train'
63
+ vis_dir = './vis'
64
+
65
+ if not os.path.exists(vis_dir):
66
+ os.makedirs(vis_dir)
67
+
68
+ json_path_list = sorted(glob.glob(data_dir + '/*.json'))
69
+ for json_path in json_path_list:
70
+ img_path = json_path.replace('.json', '.jpg')
71
+ img = cv2.imread(img_path)[:,:,::-1]
72
+
73
+ # In generated mask, value 1 denotes valid target region, and value 255 stands for region ignored during evaluaiton.
74
+ mask, comments, is_sentence = get_mask_from_json(json_path, img)
75
+
76
+ ## visualization. Green for target, and red for ignore.
77
+ valid_mask = (mask == 1).astype(np.float32)[:,:,None]
78
+ ignore_mask = (mask == 255).astype(np.float32)[:,:,None]
79
+ vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + ((np.array([0,255,0]) * 0.6 + img * 0.4) * valid_mask + (np.array([255,0,0]) * 0.6 + img * 0.4) * ignore_mask)
80
+ vis_img = np.concatenate([img, vis_img], 1)
81
+ vis_path = os.path.join(vis_dir, json_path.split('/')[-1].replace('.json', '.jpg'))
82
+ cv2.imwrite(vis_path, vis_img[:,:,::-1])
83
+ print('Visualization has been saved to: ', vis_path)