Spaces:
Runtime error
Runtime error
sunshineatnoon
commited on
Commit
·
1b2a9b1
1
Parent(s):
1d90a68
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- data/___init__.py +0 -0
- data/color150.mat +0 -0
- data/images/108073.jpg +0 -0
- data/images/12003.jpg +0 -0
- data/images/12074.jpg +0 -0
- data/images/134008.jpg +0 -0
- data/images/134052.jpg +0 -0
- data/images/138032.jpg +0 -0
- data/images/145053.jpg +0 -0
- data/images/164074.jpg +0 -0
- data/images/169012.jpg +0 -0
- data/images/198023.jpg +0 -0
- data/images/25098.jpg +0 -0
- data/images/277095.jpg +0 -0
- data/images/45077.jpg +0 -0
- data/palette.txt +256 -0
- data/test_images/100039.jpg +0 -0
- data/test_images/108004.jpg +0 -0
- data/test_images/130014.jpg +0 -0
- data/test_images/130066.jpg +0 -0
- data/test_images/16068.jpg +0 -0
- data/test_images/2018.jpg +0 -0
- data/test_images/208078.jpg +0 -0
- data/test_images/223060.jpg +0 -0
- data/test_images/226033.jpg +0 -0
- data/test_images/388006.jpg +0 -0
- data/test_images/78098.jpg +0 -0
- libs/__init__.py +0 -0
- libs/__pycache__/__init__.cpython-37.pyc +0 -0
- libs/__pycache__/__init__.cpython-38.pyc +0 -0
- libs/__pycache__/flow_transforms.cpython-37.pyc +0 -0
- libs/__pycache__/flow_transforms.cpython-38.pyc +0 -0
- libs/__pycache__/nnutils.cpython-37.pyc +0 -0
- libs/__pycache__/nnutils.cpython-38.pyc +0 -0
- libs/__pycache__/options.cpython-37.pyc +0 -0
- libs/__pycache__/options.cpython-38.pyc +0 -0
- libs/__pycache__/test_base.cpython-37.pyc +0 -0
- libs/__pycache__/test_base.cpython-38.pyc +0 -0
- libs/__pycache__/utils.cpython-37.pyc +0 -0
- libs/__pycache__/utils.cpython-38.pyc +0 -0
- libs/blocks.py +739 -0
- libs/custom_transform.py +249 -0
- libs/data_coco_stuff.py +166 -0
- libs/data_coco_stuff_geo_pho.py +145 -0
- libs/data_geo.py +176 -0
- libs/data_geo_pho.py +130 -0
- libs/data_slic.py +175 -0
- libs/discriminator.py +60 -0
- libs/flow_transforms.py +393 -0
- libs/losses.py +416 -0
data/___init__.py
ADDED
File without changes
|
data/color150.mat
ADDED
Binary file (502 Bytes). View file
|
|
data/images/108073.jpg
ADDED
![]() |
data/images/12003.jpg
ADDED
![]() |
data/images/12074.jpg
ADDED
![]() |
data/images/134008.jpg
ADDED
![]() |
data/images/134052.jpg
ADDED
![]() |
data/images/138032.jpg
ADDED
![]() |
data/images/145053.jpg
ADDED
![]() |
data/images/164074.jpg
ADDED
![]() |
data/images/169012.jpg
ADDED
![]() |
data/images/198023.jpg
ADDED
![]() |
data/images/25098.jpg
ADDED
![]() |
data/images/277095.jpg
ADDED
![]() |
data/images/45077.jpg
ADDED
![]() |
data/palette.txt
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
0 0 0
|
2 |
+
128 0 0
|
3 |
+
0 128 0
|
4 |
+
128 128 0
|
5 |
+
0 0 128
|
6 |
+
128 0 128
|
7 |
+
0 128 128
|
8 |
+
128 128 128
|
9 |
+
64 0 0
|
10 |
+
191 0 0
|
11 |
+
64 128 0
|
12 |
+
191 128 0
|
13 |
+
64 0 128
|
14 |
+
191 0 128
|
15 |
+
64 128 128
|
16 |
+
191 128 128
|
17 |
+
0 64 0
|
18 |
+
128 64 0
|
19 |
+
0 191 0
|
20 |
+
128 191 0
|
21 |
+
0 64 128
|
22 |
+
128 64 128
|
23 |
+
22 22 22
|
24 |
+
23 23 23
|
25 |
+
24 24 24
|
26 |
+
25 25 25
|
27 |
+
26 26 26
|
28 |
+
27 27 27
|
29 |
+
28 28 28
|
30 |
+
29 29 29
|
31 |
+
30 30 30
|
32 |
+
31 31 31
|
33 |
+
32 32 32
|
34 |
+
33 33 33
|
35 |
+
34 34 34
|
36 |
+
35 35 35
|
37 |
+
36 36 36
|
38 |
+
37 37 37
|
39 |
+
38 38 38
|
40 |
+
39 39 39
|
41 |
+
40 40 40
|
42 |
+
41 41 41
|
43 |
+
42 42 42
|
44 |
+
43 43 43
|
45 |
+
44 44 44
|
46 |
+
45 45 45
|
47 |
+
46 46 46
|
48 |
+
47 47 47
|
49 |
+
48 48 48
|
50 |
+
49 49 49
|
51 |
+
50 50 50
|
52 |
+
51 51 51
|
53 |
+
52 52 52
|
54 |
+
53 53 53
|
55 |
+
54 54 54
|
56 |
+
55 55 55
|
57 |
+
56 56 56
|
58 |
+
57 57 57
|
59 |
+
58 58 58
|
60 |
+
59 59 59
|
61 |
+
60 60 60
|
62 |
+
61 61 61
|
63 |
+
62 62 62
|
64 |
+
63 63 63
|
65 |
+
64 64 64
|
66 |
+
65 65 65
|
67 |
+
66 66 66
|
68 |
+
67 67 67
|
69 |
+
68 68 68
|
70 |
+
69 69 69
|
71 |
+
70 70 70
|
72 |
+
71 71 71
|
73 |
+
72 72 72
|
74 |
+
73 73 73
|
75 |
+
74 74 74
|
76 |
+
75 75 75
|
77 |
+
76 76 76
|
78 |
+
77 77 77
|
79 |
+
78 78 78
|
80 |
+
79 79 79
|
81 |
+
80 80 80
|
82 |
+
81 81 81
|
83 |
+
82 82 82
|
84 |
+
83 83 83
|
85 |
+
84 84 84
|
86 |
+
85 85 85
|
87 |
+
86 86 86
|
88 |
+
87 87 87
|
89 |
+
88 88 88
|
90 |
+
89 89 89
|
91 |
+
90 90 90
|
92 |
+
91 91 91
|
93 |
+
92 92 92
|
94 |
+
93 93 93
|
95 |
+
94 94 94
|
96 |
+
95 95 95
|
97 |
+
96 96 96
|
98 |
+
97 97 97
|
99 |
+
98 98 98
|
100 |
+
99 99 99
|
101 |
+
100 100 100
|
102 |
+
101 101 101
|
103 |
+
102 102 102
|
104 |
+
103 103 103
|
105 |
+
104 104 104
|
106 |
+
105 105 105
|
107 |
+
106 106 106
|
108 |
+
107 107 107
|
109 |
+
108 108 108
|
110 |
+
109 109 109
|
111 |
+
110 110 110
|
112 |
+
111 111 111
|
113 |
+
112 112 112
|
114 |
+
113 113 113
|
115 |
+
114 114 114
|
116 |
+
115 115 115
|
117 |
+
116 116 116
|
118 |
+
117 117 117
|
119 |
+
118 118 118
|
120 |
+
119 119 119
|
121 |
+
120 120 120
|
122 |
+
121 121 121
|
123 |
+
122 122 122
|
124 |
+
123 123 123
|
125 |
+
124 124 124
|
126 |
+
125 125 125
|
127 |
+
126 126 126
|
128 |
+
127 127 127
|
129 |
+
128 128 128
|
130 |
+
129 129 129
|
131 |
+
130 130 130
|
132 |
+
131 131 131
|
133 |
+
132 132 132
|
134 |
+
133 133 133
|
135 |
+
134 134 134
|
136 |
+
135 135 135
|
137 |
+
136 136 136
|
138 |
+
137 137 137
|
139 |
+
138 138 138
|
140 |
+
139 139 139
|
141 |
+
140 140 140
|
142 |
+
141 141 141
|
143 |
+
142 142 142
|
144 |
+
143 143 143
|
145 |
+
144 144 144
|
146 |
+
145 145 145
|
147 |
+
146 146 146
|
148 |
+
147 147 147
|
149 |
+
148 148 148
|
150 |
+
149 149 149
|
151 |
+
150 150 150
|
152 |
+
151 151 151
|
153 |
+
152 152 152
|
154 |
+
153 153 153
|
155 |
+
154 154 154
|
156 |
+
155 155 155
|
157 |
+
156 156 156
|
158 |
+
157 157 157
|
159 |
+
158 158 158
|
160 |
+
159 159 159
|
161 |
+
160 160 160
|
162 |
+
161 161 161
|
163 |
+
162 162 162
|
164 |
+
163 163 163
|
165 |
+
164 164 164
|
166 |
+
165 165 165
|
167 |
+
166 166 166
|
168 |
+
167 167 167
|
169 |
+
168 168 168
|
170 |
+
169 169 169
|
171 |
+
170 170 170
|
172 |
+
171 171 171
|
173 |
+
172 172 172
|
174 |
+
173 173 173
|
175 |
+
174 174 174
|
176 |
+
175 175 175
|
177 |
+
176 176 176
|
178 |
+
177 177 177
|
179 |
+
178 178 178
|
180 |
+
179 179 179
|
181 |
+
180 180 180
|
182 |
+
181 181 181
|
183 |
+
182 182 182
|
184 |
+
183 183 183
|
185 |
+
184 184 184
|
186 |
+
185 185 185
|
187 |
+
186 186 186
|
188 |
+
187 187 187
|
189 |
+
188 188 188
|
190 |
+
189 189 189
|
191 |
+
190 190 190
|
192 |
+
191 191 191
|
193 |
+
192 192 192
|
194 |
+
193 193 193
|
195 |
+
194 194 194
|
196 |
+
195 195 195
|
197 |
+
196 196 196
|
198 |
+
197 197 197
|
199 |
+
198 198 198
|
200 |
+
199 199 199
|
201 |
+
200 200 200
|
202 |
+
201 201 201
|
203 |
+
202 202 202
|
204 |
+
203 203 203
|
205 |
+
204 204 204
|
206 |
+
205 205 205
|
207 |
+
206 206 206
|
208 |
+
207 207 207
|
209 |
+
208 208 208
|
210 |
+
209 209 209
|
211 |
+
210 210 210
|
212 |
+
211 211 211
|
213 |
+
212 212 212
|
214 |
+
213 213 213
|
215 |
+
214 214 214
|
216 |
+
215 215 215
|
217 |
+
216 216 216
|
218 |
+
217 217 217
|
219 |
+
218 218 218
|
220 |
+
219 219 219
|
221 |
+
220 220 220
|
222 |
+
221 221 221
|
223 |
+
222 222 222
|
224 |
+
223 223 223
|
225 |
+
224 224 224
|
226 |
+
225 225 225
|
227 |
+
226 226 226
|
228 |
+
227 227 227
|
229 |
+
228 228 228
|
230 |
+
229 229 229
|
231 |
+
230 230 230
|
232 |
+
231 231 231
|
233 |
+
232 232 232
|
234 |
+
233 233 233
|
235 |
+
234 234 234
|
236 |
+
235 235 235
|
237 |
+
236 236 236
|
238 |
+
237 237 237
|
239 |
+
238 238 238
|
240 |
+
239 239 239
|
241 |
+
240 240 240
|
242 |
+
241 241 241
|
243 |
+
242 242 242
|
244 |
+
243 243 243
|
245 |
+
244 244 244
|
246 |
+
245 245 245
|
247 |
+
246 246 246
|
248 |
+
247 247 247
|
249 |
+
248 248 248
|
250 |
+
249 249 249
|
251 |
+
250 250 250
|
252 |
+
251 251 251
|
253 |
+
252 252 252
|
254 |
+
253 253 253
|
255 |
+
254 254 254
|
256 |
+
255 255 255
|
data/test_images/100039.jpg
ADDED
![]() |
data/test_images/108004.jpg
ADDED
![]() |
data/test_images/130014.jpg
ADDED
![]() |
data/test_images/130066.jpg
ADDED
![]() |
data/test_images/16068.jpg
ADDED
![]() |
data/test_images/2018.jpg
ADDED
![]() |
data/test_images/208078.jpg
ADDED
![]() |
data/test_images/223060.jpg
ADDED
![]() |
data/test_images/226033.jpg
ADDED
![]() |
data/test_images/388006.jpg
ADDED
![]() |
data/test_images/78098.jpg
ADDED
![]() |
libs/__init__.py
ADDED
File without changes
|
libs/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (151 Bytes). View file
|
|
libs/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (155 Bytes). View file
|
|
libs/__pycache__/flow_transforms.cpython-37.pyc
ADDED
Binary file (14.1 kB). View file
|
|
libs/__pycache__/flow_transforms.cpython-38.pyc
ADDED
Binary file (13.7 kB). View file
|
|
libs/__pycache__/nnutils.cpython-37.pyc
ADDED
Binary file (3.39 kB). View file
|
|
libs/__pycache__/nnutils.cpython-38.pyc
ADDED
Binary file (3.4 kB). View file
|
|
libs/__pycache__/options.cpython-37.pyc
ADDED
Binary file (5.43 kB). View file
|
|
libs/__pycache__/options.cpython-38.pyc
ADDED
Binary file (5.49 kB). View file
|
|
libs/__pycache__/test_base.cpython-37.pyc
ADDED
Binary file (4.01 kB). View file
|
|
libs/__pycache__/test_base.cpython-38.pyc
ADDED
Binary file (4.07 kB). View file
|
|
libs/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (4.51 kB). View file
|
|
libs/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.53 kB). View file
|
|
libs/blocks.py
ADDED
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Network Modules
|
2 |
+
- encoder3: vgg encoder up to relu31
|
3 |
+
- decoder3: mirror decoder to encoder3
|
4 |
+
- encoder4: vgg encoder up to relu41
|
5 |
+
- decoder4: mirror decoder to encoder4
|
6 |
+
- encoder5: vgg encoder up to relu51
|
7 |
+
- styleLoss: gram matrix loss for all style layers
|
8 |
+
- styleLossMask: gram matrix loss for all style layers, compare between each part defined by a mask
|
9 |
+
- GramMatrix: compute gram matrix for one layer
|
10 |
+
- LossCriterion: style transfer loss that include both content & style losses
|
11 |
+
- LossCriterionMask: style transfer loss that include both content & style losses, use the styleLossMask
|
12 |
+
- VQEmbedding: codebook class for VQVAE
|
13 |
+
"""
|
14 |
+
import os
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from .vq_functions import vq, vq_st
|
19 |
+
from collections import OrderedDict
|
20 |
+
|
21 |
+
class MetaModule(nn.Module):
|
22 |
+
"""
|
23 |
+
Base class for PyTorch meta-learning modules. These modules accept an
|
24 |
+
additional argument `params` in their `forward` method.
|
25 |
+
|
26 |
+
Notes
|
27 |
+
-----
|
28 |
+
Objects inherited from `MetaModule` are fully compatible with PyTorch
|
29 |
+
modules from `torch.nn.Module`. The argument `params` is a dictionary of
|
30 |
+
tensors, with full support of the computation graph (for differentiation).
|
31 |
+
"""
|
32 |
+
def meta_named_parameters(self, prefix='', recurse=True):
|
33 |
+
gen = self._named_members(
|
34 |
+
lambda module: module._parameters.items()
|
35 |
+
if isinstance(module, MetaModule) else [],
|
36 |
+
prefix=prefix, recurse=recurse)
|
37 |
+
for elem in gen:
|
38 |
+
yield elem
|
39 |
+
|
40 |
+
def meta_parameters(self, recurse=True):
|
41 |
+
for name, param in self.meta_named_parameters(recurse=recurse):
|
42 |
+
yield param
|
43 |
+
|
44 |
+
class BatchLinear(nn.Linear, MetaModule):
|
45 |
+
'''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
|
46 |
+
hypernetwork.'''
|
47 |
+
__doc__ = nn.Linear.__doc__
|
48 |
+
|
49 |
+
def forward(self, input, params=None):
|
50 |
+
if params is None:
|
51 |
+
params = OrderedDict(self.named_parameters())
|
52 |
+
|
53 |
+
bias = params.get('bias', None)
|
54 |
+
weight = params['weight']
|
55 |
+
|
56 |
+
output = input.matmul(weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2))
|
57 |
+
output += bias.unsqueeze(-2)
|
58 |
+
return output
|
59 |
+
|
60 |
+
class decoder1(nn.Module):
|
61 |
+
def __init__(self):
|
62 |
+
super(decoder1,self).__init__()
|
63 |
+
self.reflecPad2 = nn.ReflectionPad2d((1,1,1,1))
|
64 |
+
# 226 x 226
|
65 |
+
self.conv3 = nn.Conv2d(64,3,3,1,0)
|
66 |
+
# 224 x 224
|
67 |
+
|
68 |
+
def forward(self,x):
|
69 |
+
out = self.reflecPad2(x)
|
70 |
+
out = self.conv3(out)
|
71 |
+
return out
|
72 |
+
|
73 |
+
|
74 |
+
class decoder2(nn.Module):
|
75 |
+
def __init__(self):
|
76 |
+
super(decoder2,self).__init__()
|
77 |
+
# decoder
|
78 |
+
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
|
79 |
+
self.conv5 = nn.Conv2d(128,64,3,1,0)
|
80 |
+
self.relu5 = nn.ReLU(inplace=True)
|
81 |
+
# 112 x 112
|
82 |
+
|
83 |
+
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
|
84 |
+
# 224 x 224
|
85 |
+
|
86 |
+
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
|
87 |
+
self.conv6 = nn.Conv2d(64,64,3,1,0)
|
88 |
+
self.relu6 = nn.ReLU(inplace=True)
|
89 |
+
# 224 x 224
|
90 |
+
|
91 |
+
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
|
92 |
+
self.conv7 = nn.Conv2d(64,3,3,1,0)
|
93 |
+
|
94 |
+
def forward(self,x):
|
95 |
+
out = self.reflecPad5(x)
|
96 |
+
out = self.conv5(out)
|
97 |
+
out = self.relu5(out)
|
98 |
+
out = self.unpool(out)
|
99 |
+
out = self.reflecPad6(out)
|
100 |
+
out = self.conv6(out)
|
101 |
+
out = self.relu6(out)
|
102 |
+
out = self.reflecPad7(out)
|
103 |
+
out = self.conv7(out)
|
104 |
+
return out
|
105 |
+
|
106 |
+
class encoder3(nn.Module):
|
107 |
+
def __init__(self):
|
108 |
+
super(encoder3,self).__init__()
|
109 |
+
# vgg
|
110 |
+
# 224 x 224
|
111 |
+
self.conv1 = nn.Conv2d(3,3,1,1,0)
|
112 |
+
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
|
113 |
+
# 226 x 226
|
114 |
+
|
115 |
+
self.conv2 = nn.Conv2d(3,64,3,1,0)
|
116 |
+
self.relu2 = nn.ReLU(inplace=True)
|
117 |
+
# 224 x 224
|
118 |
+
|
119 |
+
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
|
120 |
+
self.conv3 = nn.Conv2d(64,64,3,1,0)
|
121 |
+
self.relu3 = nn.ReLU(inplace=True)
|
122 |
+
# 224 x 224
|
123 |
+
|
124 |
+
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
|
125 |
+
# 112 x 112
|
126 |
+
|
127 |
+
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
|
128 |
+
self.conv4 = nn.Conv2d(64,128,3,1,0)
|
129 |
+
self.relu4 = nn.ReLU(inplace=True)
|
130 |
+
# 112 x 112
|
131 |
+
|
132 |
+
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
|
133 |
+
self.conv5 = nn.Conv2d(128,128,3,1,0)
|
134 |
+
self.relu5 = nn.ReLU(inplace=True)
|
135 |
+
# 112 x 112
|
136 |
+
|
137 |
+
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
|
138 |
+
# 56 x 56
|
139 |
+
|
140 |
+
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
|
141 |
+
self.conv6 = nn.Conv2d(128,256,3,1,0)
|
142 |
+
self.relu6 = nn.ReLU(inplace=True)
|
143 |
+
# 56 x 56
|
144 |
+
def forward(self,x):
|
145 |
+
out = self.conv1(x)
|
146 |
+
out = self.reflecPad1(out)
|
147 |
+
out = self.conv2(out)
|
148 |
+
out = self.relu2(out)
|
149 |
+
out = self.reflecPad3(out)
|
150 |
+
out = self.conv3(out)
|
151 |
+
pool1 = self.relu3(out)
|
152 |
+
out,pool_idx = self.maxPool(pool1)
|
153 |
+
out = self.reflecPad4(out)
|
154 |
+
out = self.conv4(out)
|
155 |
+
out = self.relu4(out)
|
156 |
+
out = self.reflecPad5(out)
|
157 |
+
out = self.conv5(out)
|
158 |
+
pool2 = self.relu5(out)
|
159 |
+
out,pool_idx2 = self.maxPool2(pool2)
|
160 |
+
out = self.reflecPad6(out)
|
161 |
+
out = self.conv6(out)
|
162 |
+
out = self.relu6(out)
|
163 |
+
return out
|
164 |
+
|
165 |
+
class decoder3(nn.Module):
|
166 |
+
def __init__(self):
|
167 |
+
super(decoder3,self).__init__()
|
168 |
+
# decoder
|
169 |
+
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
|
170 |
+
self.conv7 = nn.Conv2d(256,128,3,1,0)
|
171 |
+
self.relu7 = nn.ReLU(inplace=True)
|
172 |
+
# 56 x 56
|
173 |
+
|
174 |
+
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
|
175 |
+
# 112 x 112
|
176 |
+
|
177 |
+
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
|
178 |
+
self.conv8 = nn.Conv2d(128,128,3,1,0)
|
179 |
+
self.relu8 = nn.ReLU(inplace=True)
|
180 |
+
# 112 x 112
|
181 |
+
|
182 |
+
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
|
183 |
+
self.conv9 = nn.Conv2d(128,64,3,1,0)
|
184 |
+
self.relu9 = nn.ReLU(inplace=True)
|
185 |
+
|
186 |
+
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
|
187 |
+
# 224 x 224
|
188 |
+
|
189 |
+
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
|
190 |
+
self.conv10 = nn.Conv2d(64,64,3,1,0)
|
191 |
+
self.relu10 = nn.ReLU(inplace=True)
|
192 |
+
|
193 |
+
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
|
194 |
+
self.conv11 = nn.Conv2d(64,3,3,1,0)
|
195 |
+
|
196 |
+
def forward(self,x):
|
197 |
+
output = {}
|
198 |
+
out = self.reflecPad7(x)
|
199 |
+
out = self.conv7(out)
|
200 |
+
out = self.relu7(out)
|
201 |
+
out = self.unpool(out)
|
202 |
+
out = self.reflecPad8(out)
|
203 |
+
out = self.conv8(out)
|
204 |
+
out = self.relu8(out)
|
205 |
+
out = self.reflecPad9(out)
|
206 |
+
out = self.conv9(out)
|
207 |
+
out_relu9 = self.relu9(out)
|
208 |
+
out = self.unpool2(out_relu9)
|
209 |
+
out = self.reflecPad10(out)
|
210 |
+
out = self.conv10(out)
|
211 |
+
out = self.relu10(out)
|
212 |
+
out = self.reflecPad11(out)
|
213 |
+
out = self.conv11(out)
|
214 |
+
return out
|
215 |
+
|
216 |
+
class encoder4(nn.Module):
|
217 |
+
def __init__(self):
|
218 |
+
super(encoder4,self).__init__()
|
219 |
+
# vgg
|
220 |
+
# 224 x 224
|
221 |
+
self.conv1 = nn.Conv2d(3,3,1,1,0)
|
222 |
+
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
|
223 |
+
# 226 x 226
|
224 |
+
|
225 |
+
self.conv2 = nn.Conv2d(3,64,3,1,0)
|
226 |
+
self.relu2 = nn.ReLU(inplace=True)
|
227 |
+
# 224 x 224
|
228 |
+
|
229 |
+
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
|
230 |
+
self.conv3 = nn.Conv2d(64,64,3,1,0)
|
231 |
+
self.relu3 = nn.ReLU(inplace=True)
|
232 |
+
# 224 x 224
|
233 |
+
|
234 |
+
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
|
235 |
+
# 112 x 112
|
236 |
+
|
237 |
+
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
|
238 |
+
self.conv4 = nn.Conv2d(64,128,3,1,0)
|
239 |
+
self.relu4 = nn.ReLU(inplace=True)
|
240 |
+
# 112 x 112
|
241 |
+
|
242 |
+
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
|
243 |
+
self.conv5 = nn.Conv2d(128,128,3,1,0)
|
244 |
+
self.relu5 = nn.ReLU(inplace=True)
|
245 |
+
# 112 x 112
|
246 |
+
|
247 |
+
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
|
248 |
+
# 56 x 56
|
249 |
+
|
250 |
+
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
|
251 |
+
self.conv6 = nn.Conv2d(128,256,3,1,0)
|
252 |
+
self.relu6 = nn.ReLU(inplace=True)
|
253 |
+
# 56 x 56
|
254 |
+
|
255 |
+
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
|
256 |
+
self.conv7 = nn.Conv2d(256,256,3,1,0)
|
257 |
+
self.relu7 = nn.ReLU(inplace=True)
|
258 |
+
# 56 x 56
|
259 |
+
|
260 |
+
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
|
261 |
+
self.conv8 = nn.Conv2d(256,256,3,1,0)
|
262 |
+
self.relu8 = nn.ReLU(inplace=True)
|
263 |
+
# 56 x 56
|
264 |
+
|
265 |
+
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
|
266 |
+
self.conv9 = nn.Conv2d(256,256,3,1,0)
|
267 |
+
self.relu9 = nn.ReLU(inplace=True)
|
268 |
+
# 56 x 56
|
269 |
+
|
270 |
+
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
|
271 |
+
# 28 x 28
|
272 |
+
|
273 |
+
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
|
274 |
+
self.conv10 = nn.Conv2d(256,512,3,1,0)
|
275 |
+
self.relu10 = nn.ReLU(inplace=True)
|
276 |
+
# 28 x 28
|
277 |
+
|
278 |
+
def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None):
|
279 |
+
output = {}
|
280 |
+
out = self.conv1(x)
|
281 |
+
out = self.reflecPad1(out)
|
282 |
+
out = self.conv2(out)
|
283 |
+
output['r11'] = self.relu2(out)
|
284 |
+
out = self.reflecPad7(output['r11'])
|
285 |
+
|
286 |
+
out = self.conv3(out)
|
287 |
+
output['r12'] = self.relu3(out)
|
288 |
+
|
289 |
+
output['p1'] = self.maxPool(output['r12'])
|
290 |
+
out = self.reflecPad4(output['p1'])
|
291 |
+
out = self.conv4(out)
|
292 |
+
output['r21'] = self.relu4(out)
|
293 |
+
out = self.reflecPad7(output['r21'])
|
294 |
+
|
295 |
+
out = self.conv5(out)
|
296 |
+
output['r22'] = self.relu5(out)
|
297 |
+
|
298 |
+
output['p2'] = self.maxPool2(output['r22'])
|
299 |
+
out = self.reflecPad6(output['p2'])
|
300 |
+
out = self.conv6(out)
|
301 |
+
output['r31'] = self.relu6(out)
|
302 |
+
if(matrix31 is not None):
|
303 |
+
feature3,transmatrix3 = matrix31(output['r31'],sF['r31'])
|
304 |
+
out = self.reflecPad7(feature3)
|
305 |
+
else:
|
306 |
+
out = self.reflecPad7(output['r31'])
|
307 |
+
out = self.conv7(out)
|
308 |
+
output['r32'] = self.relu7(out)
|
309 |
+
|
310 |
+
out = self.reflecPad8(output['r32'])
|
311 |
+
out = self.conv8(out)
|
312 |
+
output['r33'] = self.relu8(out)
|
313 |
+
|
314 |
+
out = self.reflecPad9(output['r33'])
|
315 |
+
out = self.conv9(out)
|
316 |
+
output['r34'] = self.relu9(out)
|
317 |
+
|
318 |
+
output['p3'] = self.maxPool3(output['r34'])
|
319 |
+
out = self.reflecPad10(output['p3'])
|
320 |
+
out = self.conv10(out)
|
321 |
+
output['r41'] = self.relu10(out)
|
322 |
+
|
323 |
+
return output
|
324 |
+
|
325 |
+
class decoder4(nn.Module):
|
326 |
+
def __init__(self):
|
327 |
+
super(decoder4,self).__init__()
|
328 |
+
# decoder
|
329 |
+
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
|
330 |
+
self.conv11 = nn.Conv2d(512,256,3,1,0)
|
331 |
+
self.relu11 = nn.ReLU(inplace=True)
|
332 |
+
# 28 x 28
|
333 |
+
|
334 |
+
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
|
335 |
+
# 56 x 56
|
336 |
+
|
337 |
+
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
|
338 |
+
self.conv12 = nn.Conv2d(256,256,3,1,0)
|
339 |
+
self.relu12 = nn.ReLU(inplace=True)
|
340 |
+
# 56 x 56
|
341 |
+
|
342 |
+
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
|
343 |
+
self.conv13 = nn.Conv2d(256,256,3,1,0)
|
344 |
+
self.relu13 = nn.ReLU(inplace=True)
|
345 |
+
# 56 x 56
|
346 |
+
|
347 |
+
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
|
348 |
+
self.conv14 = nn.Conv2d(256,256,3,1,0)
|
349 |
+
self.relu14 = nn.ReLU(inplace=True)
|
350 |
+
# 56 x 56
|
351 |
+
|
352 |
+
self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
|
353 |
+
self.conv15 = nn.Conv2d(256,128,3,1,0)
|
354 |
+
self.relu15 = nn.ReLU(inplace=True)
|
355 |
+
# 56 x 56
|
356 |
+
|
357 |
+
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
|
358 |
+
# 112 x 112
|
359 |
+
|
360 |
+
self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
|
361 |
+
self.conv16 = nn.Conv2d(128,128,3,1,0)
|
362 |
+
self.relu16 = nn.ReLU(inplace=True)
|
363 |
+
# 112 x 112
|
364 |
+
|
365 |
+
self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
|
366 |
+
self.conv17 = nn.Conv2d(128,64,3,1,0)
|
367 |
+
self.relu17 = nn.ReLU(inplace=True)
|
368 |
+
# 112 x 112
|
369 |
+
|
370 |
+
self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
|
371 |
+
# 224 x 224
|
372 |
+
|
373 |
+
self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
|
374 |
+
self.conv18 = nn.Conv2d(64,64,3,1,0)
|
375 |
+
self.relu18 = nn.ReLU(inplace=True)
|
376 |
+
# 224 x 224
|
377 |
+
|
378 |
+
self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
|
379 |
+
self.conv19 = nn.Conv2d(64,3,3,1,0)
|
380 |
+
|
381 |
+
def forward(self,x):
|
382 |
+
# decoder
|
383 |
+
out = self.reflecPad11(x)
|
384 |
+
out = self.conv11(out)
|
385 |
+
out = self.relu11(out)
|
386 |
+
out = self.unpool(out)
|
387 |
+
out = self.reflecPad12(out)
|
388 |
+
out = self.conv12(out)
|
389 |
+
|
390 |
+
out = self.relu12(out)
|
391 |
+
out = self.reflecPad13(out)
|
392 |
+
out = self.conv13(out)
|
393 |
+
out = self.relu13(out)
|
394 |
+
out = self.reflecPad14(out)
|
395 |
+
out = self.conv14(out)
|
396 |
+
out = self.relu14(out)
|
397 |
+
out = self.reflecPad15(out)
|
398 |
+
out = self.conv15(out)
|
399 |
+
out = self.relu15(out)
|
400 |
+
out = self.unpool2(out)
|
401 |
+
out = self.reflecPad16(out)
|
402 |
+
out = self.conv16(out)
|
403 |
+
out = self.relu16(out)
|
404 |
+
out = self.reflecPad17(out)
|
405 |
+
out = self.conv17(out)
|
406 |
+
out = self.relu17(out)
|
407 |
+
out = self.unpool3(out)
|
408 |
+
out = self.reflecPad18(out)
|
409 |
+
out = self.conv18(out)
|
410 |
+
out = self.relu18(out)
|
411 |
+
out = self.reflecPad19(out)
|
412 |
+
out = self.conv19(out)
|
413 |
+
return out
|
414 |
+
|
415 |
+
class encoder5(nn.Module):
|
416 |
+
def __init__(self):
|
417 |
+
super(encoder5,self).__init__()
|
418 |
+
# vgg
|
419 |
+
# 224 x 224
|
420 |
+
self.conv1 = nn.Conv2d(3,3,1,1,0)
|
421 |
+
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
|
422 |
+
# 226 x 226
|
423 |
+
|
424 |
+
self.conv2 = nn.Conv2d(3,64,3,1,0)
|
425 |
+
self.relu2 = nn.ReLU(inplace=True)
|
426 |
+
# 224 x 224
|
427 |
+
|
428 |
+
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
|
429 |
+
self.conv3 = nn.Conv2d(64,64,3,1,0)
|
430 |
+
self.relu3 = nn.ReLU(inplace=True)
|
431 |
+
# 224 x 224
|
432 |
+
|
433 |
+
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
|
434 |
+
# 112 x 112
|
435 |
+
|
436 |
+
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
|
437 |
+
self.conv4 = nn.Conv2d(64,128,3,1,0)
|
438 |
+
self.relu4 = nn.ReLU(inplace=True)
|
439 |
+
# 112 x 112
|
440 |
+
|
441 |
+
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
|
442 |
+
self.conv5 = nn.Conv2d(128,128,3,1,0)
|
443 |
+
self.relu5 = nn.ReLU(inplace=True)
|
444 |
+
# 112 x 112
|
445 |
+
|
446 |
+
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
|
447 |
+
# 56 x 56
|
448 |
+
|
449 |
+
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
|
450 |
+
self.conv6 = nn.Conv2d(128,256,3,1,0)
|
451 |
+
self.relu6 = nn.ReLU(inplace=True)
|
452 |
+
# 56 x 56
|
453 |
+
|
454 |
+
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
|
455 |
+
self.conv7 = nn.Conv2d(256,256,3,1,0)
|
456 |
+
self.relu7 = nn.ReLU(inplace=True)
|
457 |
+
# 56 x 56
|
458 |
+
|
459 |
+
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
|
460 |
+
self.conv8 = nn.Conv2d(256,256,3,1,0)
|
461 |
+
self.relu8 = nn.ReLU(inplace=True)
|
462 |
+
# 56 x 56
|
463 |
+
|
464 |
+
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
|
465 |
+
self.conv9 = nn.Conv2d(256,256,3,1,0)
|
466 |
+
self.relu9 = nn.ReLU(inplace=True)
|
467 |
+
# 56 x 56
|
468 |
+
|
469 |
+
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
|
470 |
+
# 28 x 28
|
471 |
+
|
472 |
+
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
|
473 |
+
self.conv10 = nn.Conv2d(256,512,3,1,0)
|
474 |
+
self.relu10 = nn.ReLU(inplace=True)
|
475 |
+
|
476 |
+
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
|
477 |
+
self.conv11 = nn.Conv2d(512,512,3,1,0)
|
478 |
+
self.relu11 = nn.ReLU(inplace=True)
|
479 |
+
|
480 |
+
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
|
481 |
+
self.conv12 = nn.Conv2d(512,512,3,1,0)
|
482 |
+
self.relu12 = nn.ReLU(inplace=True)
|
483 |
+
|
484 |
+
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
|
485 |
+
self.conv13 = nn.Conv2d(512,512,3,1,0)
|
486 |
+
self.relu13 = nn.ReLU(inplace=True)
|
487 |
+
|
488 |
+
self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
|
489 |
+
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
|
490 |
+
self.conv14 = nn.Conv2d(512,512,3,1,0)
|
491 |
+
self.relu14 = nn.ReLU(inplace=True)
|
492 |
+
|
493 |
+
def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None):
|
494 |
+
output = {}
|
495 |
+
out = self.conv1(x)
|
496 |
+
out = self.reflecPad1(out)
|
497 |
+
out = self.conv2(out)
|
498 |
+
output['r11'] = self.relu2(out)
|
499 |
+
out = self.reflecPad7(output['r11'])
|
500 |
+
|
501 |
+
#out = self.reflecPad3(output['r11'])
|
502 |
+
out = self.conv3(out)
|
503 |
+
output['r12'] = self.relu3(out)
|
504 |
+
|
505 |
+
output['p1'] = self.maxPool(output['r12'])
|
506 |
+
out = self.reflecPad4(output['p1'])
|
507 |
+
out = self.conv4(out)
|
508 |
+
output['r21'] = self.relu4(out)
|
509 |
+
out = self.reflecPad7(output['r21'])
|
510 |
+
|
511 |
+
#out = self.reflecPad5(output['r21'])
|
512 |
+
out = self.conv5(out)
|
513 |
+
output['r22'] = self.relu5(out)
|
514 |
+
|
515 |
+
output['p2'] = self.maxPool2(output['r22'])
|
516 |
+
out = self.reflecPad6(output['p2'])
|
517 |
+
out = self.conv6(out)
|
518 |
+
output['r31'] = self.relu6(out)
|
519 |
+
if(styleV256 is not None):
|
520 |
+
feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256)
|
521 |
+
out = self.reflecPad7(feature)
|
522 |
+
else:
|
523 |
+
out = self.reflecPad7(output['r31'])
|
524 |
+
out = self.conv7(out)
|
525 |
+
output['r32'] = self.relu7(out)
|
526 |
+
|
527 |
+
out = self.reflecPad8(output['r32'])
|
528 |
+
out = self.conv8(out)
|
529 |
+
output['r33'] = self.relu8(out)
|
530 |
+
|
531 |
+
out = self.reflecPad9(output['r33'])
|
532 |
+
out = self.conv9(out)
|
533 |
+
output['r34'] = self.relu9(out)
|
534 |
+
|
535 |
+
output['p3'] = self.maxPool3(output['r34'])
|
536 |
+
out = self.reflecPad10(output['p3'])
|
537 |
+
out = self.conv10(out)
|
538 |
+
output['r41'] = self.relu10(out)
|
539 |
+
|
540 |
+
out = self.reflecPad11(out)
|
541 |
+
out = self.conv11(out)
|
542 |
+
out = self.relu11(out)
|
543 |
+
out = self.reflecPad12(out)
|
544 |
+
out = self.conv12(out)
|
545 |
+
out = self.relu12(out)
|
546 |
+
out = self.reflecPad13(out)
|
547 |
+
out = self.conv13(out)
|
548 |
+
out = self.relu13(out)
|
549 |
+
out = self.maxPool4(out)
|
550 |
+
out = self.reflecPad14(out)
|
551 |
+
out = self.conv14(out)
|
552 |
+
out = self.relu14(out)
|
553 |
+
output['r51'] = out
|
554 |
+
return output
|
555 |
+
|
556 |
+
class styleLoss(nn.Module):
|
557 |
+
def forward(self, input, target):
|
558 |
+
ib,ic,ih,iw = input.size()
|
559 |
+
iF = input.view(ib,ic,-1)
|
560 |
+
iMean = torch.mean(iF,dim=2)
|
561 |
+
iCov = GramMatrix()(input)
|
562 |
+
|
563 |
+
tb,tc,th,tw = target.size()
|
564 |
+
tF = target.view(tb,tc,-1)
|
565 |
+
tMean = torch.mean(tF,dim=2)
|
566 |
+
tCov = GramMatrix()(target)
|
567 |
+
|
568 |
+
loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov)
|
569 |
+
return loss/tb
|
570 |
+
|
571 |
+
class GramMatrix(nn.Module):
|
572 |
+
def forward(self, input):
|
573 |
+
b, c, h, w = input.size()
|
574 |
+
f = input.view(b,c,h*w) # bxcx(hxw)
|
575 |
+
# torch.bmm(batch1, batch2, out=None) #
|
576 |
+
# batch1: bxmxp, batch2: bxpxn -> bxmxn #
|
577 |
+
G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
|
578 |
+
return G.div_(c*h*w)
|
579 |
+
|
580 |
+
class LossCriterion(nn.Module):
|
581 |
+
def __init__(self, style_layers, content_layers, style_weight, content_weight,
|
582 |
+
model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'):
|
583 |
+
super(LossCriterion,self).__init__()
|
584 |
+
|
585 |
+
self.style_layers = style_layers
|
586 |
+
self.content_layers = content_layers
|
587 |
+
self.style_weight = style_weight
|
588 |
+
self.content_weight = content_weight
|
589 |
+
|
590 |
+
self.styleLosses = [styleLoss()] * len(style_layers)
|
591 |
+
self.contentLosses = [nn.MSELoss()] * len(content_layers)
|
592 |
+
|
593 |
+
self.vgg5 = encoder5()
|
594 |
+
self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth')))
|
595 |
+
|
596 |
+
for param in self.vgg5.parameters():
|
597 |
+
param.requires_grad = True
|
598 |
+
|
599 |
+
def forward(self, transfer, image, content=True, style=True):
|
600 |
+
cF = self.vgg5(image)
|
601 |
+
sF = self.vgg5(image)
|
602 |
+
tF = self.vgg5(transfer)
|
603 |
+
|
604 |
+
losses = {}
|
605 |
+
|
606 |
+
# content loss
|
607 |
+
if content:
|
608 |
+
totalContentLoss = 0
|
609 |
+
for i,layer in enumerate(self.content_layers):
|
610 |
+
cf_i = cF[layer]
|
611 |
+
cf_i = cf_i.detach()
|
612 |
+
tf_i = tF[layer]
|
613 |
+
loss_i = self.contentLosses[i]
|
614 |
+
totalContentLoss += loss_i(tf_i,cf_i)
|
615 |
+
totalContentLoss = totalContentLoss * self.content_weight
|
616 |
+
losses['content'] = totalContentLoss
|
617 |
+
|
618 |
+
# style loss
|
619 |
+
if style:
|
620 |
+
totalStyleLoss = 0
|
621 |
+
for i,layer in enumerate(self.style_layers):
|
622 |
+
sf_i = sF[layer]
|
623 |
+
sf_i = sf_i.detach()
|
624 |
+
tf_i = tF[layer]
|
625 |
+
loss_i = self.styleLosses[i]
|
626 |
+
totalStyleLoss += loss_i(tf_i,sf_i)
|
627 |
+
totalStyleLoss = totalStyleLoss * self.style_weight
|
628 |
+
losses['style'] = totalStyleLoss
|
629 |
+
|
630 |
+
return losses
|
631 |
+
|
632 |
+
class styleLossMask(nn.Module):
|
633 |
+
def forward(self, input, target, mask):
|
634 |
+
ib,ic,ih,iw = input.size()
|
635 |
+
iF = input.view(ib,ic,-1)
|
636 |
+
tb,tc,th,tw = target.size()
|
637 |
+
tF = target.view(tb,tc,-1)
|
638 |
+
|
639 |
+
loss = 0
|
640 |
+
mb, mc, mh, mw = mask.shape
|
641 |
+
for i in range(mb):
|
642 |
+
# resize mask to have the same size of the feature
|
643 |
+
maski = F.interpolate(mask[i:i+1], size = (ih, iw), mode = 'nearest')
|
644 |
+
mask_flat = maski.view(mc, -1)
|
645 |
+
for j in range(mc):
|
646 |
+
# get features for each part
|
647 |
+
idx = torch.nonzero(mask_flat[j]).squeeze()
|
648 |
+
if len(idx.shape) == 0 or idx.shape[0] == 0:
|
649 |
+
continue
|
650 |
+
ipart = torch.index_select(iF, 2, idx)
|
651 |
+
tpart = torch.index_select(tF, 2, idx)
|
652 |
+
|
653 |
+
iMean = torch.mean(ipart,dim=2)
|
654 |
+
iGram = torch.bmm(ipart, ipart.transpose(1,2)).div_(ic*ih*iw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
|
655 |
+
|
656 |
+
tMean = torch.mean(tpart,dim=2)
|
657 |
+
tGram = torch.bmm(tpart, tpart.transpose(1,2)).div_(tc*th*tw) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
|
658 |
+
|
659 |
+
loss += nn.MSELoss()(iMean,tMean) + nn.MSELoss()(iGram,tGram)
|
660 |
+
return loss/tb
|
661 |
+
|
662 |
+
class LossCriterionMask(nn.Module):
|
663 |
+
def __init__(self, style_layers, content_layers, style_weight, content_weight,
|
664 |
+
model_path = '/home/xtli/Documents/GITHUB/LinearStyleTransfer/models/'):
|
665 |
+
super(LossCriterionMask,self).__init__()
|
666 |
+
|
667 |
+
self.style_layers = style_layers
|
668 |
+
self.content_layers = content_layers
|
669 |
+
self.style_weight = style_weight
|
670 |
+
self.content_weight = content_weight
|
671 |
+
|
672 |
+
self.styleLosses = [styleLossMask()] * len(style_layers)
|
673 |
+
self.contentLosses = [nn.MSELoss()] * len(content_layers)
|
674 |
+
|
675 |
+
self.vgg5 = encoder5()
|
676 |
+
self.vgg5.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth')))
|
677 |
+
|
678 |
+
for param in self.vgg5.parameters():
|
679 |
+
param.requires_grad = True
|
680 |
+
|
681 |
+
def forward(self, transfer, image, mask, content=True, style=True):
|
682 |
+
# mask: B, N, H, W
|
683 |
+
cF = self.vgg5(image)
|
684 |
+
sF = self.vgg5(image)
|
685 |
+
tF = self.vgg5(transfer)
|
686 |
+
|
687 |
+
losses = {}
|
688 |
+
|
689 |
+
# content loss
|
690 |
+
if content:
|
691 |
+
totalContentLoss = 0
|
692 |
+
for i,layer in enumerate(self.content_layers):
|
693 |
+
cf_i = cF[layer]
|
694 |
+
cf_i = cf_i.detach()
|
695 |
+
tf_i = tF[layer]
|
696 |
+
loss_i = self.contentLosses[i]
|
697 |
+
totalContentLoss += loss_i(tf_i,cf_i)
|
698 |
+
totalContentLoss = totalContentLoss * self.content_weight
|
699 |
+
losses['content'] = totalContentLoss
|
700 |
+
|
701 |
+
# style loss
|
702 |
+
if style:
|
703 |
+
totalStyleLoss = 0
|
704 |
+
for i,layer in enumerate(self.style_layers):
|
705 |
+
sf_i = sF[layer]
|
706 |
+
sf_i = sf_i.detach()
|
707 |
+
tf_i = tF[layer]
|
708 |
+
loss_i = self.styleLosses[i]
|
709 |
+
totalStyleLoss += loss_i(tf_i,sf_i, mask)
|
710 |
+
totalStyleLoss = totalStyleLoss * self.style_weight
|
711 |
+
losses['style'] = totalStyleLoss
|
712 |
+
|
713 |
+
return losses
|
714 |
+
|
715 |
+
class VQEmbedding(nn.Module):
|
716 |
+
def __init__(self, K, D):
|
717 |
+
super().__init__()
|
718 |
+
self.embedding = nn.Embedding(K, D)
|
719 |
+
self.embedding.weight.data.uniform_(-1./K, 1./K)
|
720 |
+
|
721 |
+
def forward(self, z_e_x):
|
722 |
+
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
|
723 |
+
latents = vq(z_e_x_, self.embedding.weight)
|
724 |
+
return latents
|
725 |
+
|
726 |
+
def straight_through(self, z_e_x, return_index=False):
|
727 |
+
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
|
728 |
+
z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
|
729 |
+
z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()
|
730 |
+
|
731 |
+
z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
|
732 |
+
dim=0, index=indices)
|
733 |
+
z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
|
734 |
+
z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()
|
735 |
+
|
736 |
+
if return_index:
|
737 |
+
return z_q_x, z_q_x_bar, indices
|
738 |
+
else:
|
739 |
+
return z_q_x, z_q_x_bar
|
libs/custom_transform.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torchvision import transforms
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageFilter
|
8 |
+
import random
|
9 |
+
|
10 |
+
class BaseTransform(object):
|
11 |
+
"""
|
12 |
+
Resize and center crop.
|
13 |
+
"""
|
14 |
+
def __init__(self, res):
|
15 |
+
self.res = res
|
16 |
+
|
17 |
+
def __call__(self, index, image):
|
18 |
+
image = TF.resize(image, self.res, Image.BILINEAR)
|
19 |
+
w, h = image.size
|
20 |
+
left = int(round((w - self.res) / 2.))
|
21 |
+
top = int(round((h - self.res) / 2.))
|
22 |
+
|
23 |
+
return TF.crop(image, top, left, self.res, self.res)
|
24 |
+
|
25 |
+
|
26 |
+
class ComposeTransform(object):
|
27 |
+
def __init__(self, tlist):
|
28 |
+
self.tlist = tlist
|
29 |
+
|
30 |
+
def __call__(self, index, image):
|
31 |
+
for trans in self.tlist:
|
32 |
+
image = trans(index, image)
|
33 |
+
|
34 |
+
return image
|
35 |
+
|
36 |
+
class RandomResize(object):
|
37 |
+
def __init__(self, rmin, rmax, N):
|
38 |
+
self.reslist = [random.randint(rmin, rmax) for _ in range(N)]
|
39 |
+
|
40 |
+
def __call__(self, index, image):
|
41 |
+
return TF.resize(image, self.reslist[index], Image.BILINEAR)
|
42 |
+
|
43 |
+
class RandomCrop(object):
|
44 |
+
def __init__(self, res, N):
|
45 |
+
self.res = res
|
46 |
+
self.cons = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
|
47 |
+
|
48 |
+
def __call__(self, index, image):
|
49 |
+
ws, hs = self.cons[index]
|
50 |
+
w, h = image.size
|
51 |
+
left = int(round((w-self.res)*ws))
|
52 |
+
top = int(round((h-self.res)*hs))
|
53 |
+
|
54 |
+
return TF.crop(image, top, left, self.res, self.res)
|
55 |
+
|
56 |
+
class RandomHorizontalFlip(object):
|
57 |
+
def __init__(self, N, p=0.5):
|
58 |
+
self.p_ref = p
|
59 |
+
self.plist = np.random.random_sample(N)
|
60 |
+
|
61 |
+
def __call__(self, index, image):
|
62 |
+
if self.plist[index.cpu()] < self.p_ref:
|
63 |
+
return TF.hflip(image)
|
64 |
+
else:
|
65 |
+
return image
|
66 |
+
|
67 |
+
|
68 |
+
class TensorTransform(object):
|
69 |
+
def __init__(self):
|
70 |
+
self.to_tensor = transforms.ToTensor()
|
71 |
+
#self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
72 |
+
|
73 |
+
def __call__(self, image):
|
74 |
+
image = self.to_tensor(image)
|
75 |
+
#image = self.normalize(image)
|
76 |
+
|
77 |
+
return image
|
78 |
+
|
79 |
+
|
80 |
+
class RandomGaussianBlur(object):
|
81 |
+
def __init__(self, sigma, p, N):
|
82 |
+
self.min_x = sigma[0]
|
83 |
+
self.max_x = sigma[1]
|
84 |
+
self.del_p = 1 - p
|
85 |
+
self.p_ref = p
|
86 |
+
self.plist = np.random.random_sample(N)
|
87 |
+
|
88 |
+
def __call__(self, index, image):
|
89 |
+
if self.plist[index] < self.p_ref:
|
90 |
+
x = self.plist[index] - self.p_ref
|
91 |
+
m = (self.max_x - self.min_x) / self.del_p
|
92 |
+
b = self.min_x
|
93 |
+
s = m * x + b
|
94 |
+
|
95 |
+
return image.filter(ImageFilter.GaussianBlur(radius=s))
|
96 |
+
else:
|
97 |
+
return image
|
98 |
+
|
99 |
+
|
100 |
+
class RandomGrayScale(object):
|
101 |
+
def __init__(self, p, N):
|
102 |
+
self.grayscale = transforms.RandomGrayscale(p=1.) # Deterministic (We still want flexible out_dim).
|
103 |
+
self.p_ref = p
|
104 |
+
self.plist = np.random.random_sample(N)
|
105 |
+
|
106 |
+
def __call__(self, index, image):
|
107 |
+
if self.plist[index] < self.p_ref:
|
108 |
+
return self.grayscale(image)
|
109 |
+
else:
|
110 |
+
return image
|
111 |
+
|
112 |
+
|
113 |
+
class RandomColorBrightness(object):
|
114 |
+
def __init__(self, x, p, N):
|
115 |
+
self.min_x = max(0, 1 - x)
|
116 |
+
self.max_x = 1 + x
|
117 |
+
self.p_ref = p
|
118 |
+
self.plist = np.random.random_sample(N)
|
119 |
+
self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)]
|
120 |
+
|
121 |
+
def __call__(self, index, image):
|
122 |
+
if self.plist[index] < self.p_ref:
|
123 |
+
return TF.adjust_brightness(image, self.rlist[index])
|
124 |
+
else:
|
125 |
+
return image
|
126 |
+
|
127 |
+
|
128 |
+
class RandomColorContrast(object):
|
129 |
+
def __init__(self, x, p, N):
|
130 |
+
self.min_x = max(0, 1 - x)
|
131 |
+
self.max_x = 1 + x
|
132 |
+
self.p_ref = p
|
133 |
+
self.plist = np.random.random_sample(N)
|
134 |
+
self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)]
|
135 |
+
|
136 |
+
def __call__(self, index, image):
|
137 |
+
if self.plist[index] < self.p_ref:
|
138 |
+
return TF.adjust_contrast(image, self.rlist[index])
|
139 |
+
else:
|
140 |
+
return image
|
141 |
+
|
142 |
+
|
143 |
+
class RandomColorSaturation(object):
|
144 |
+
def __init__(self, x, p, N):
|
145 |
+
self.min_x = max(0, 1 - x)
|
146 |
+
self.max_x = 1 + x
|
147 |
+
self.p_ref = p
|
148 |
+
self.plist = np.random.random_sample(N)
|
149 |
+
self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)]
|
150 |
+
|
151 |
+
def __call__(self, index, image):
|
152 |
+
if self.plist[index] < self.p_ref:
|
153 |
+
return TF.adjust_saturation(image, self.rlist[index])
|
154 |
+
else:
|
155 |
+
return image
|
156 |
+
|
157 |
+
|
158 |
+
class RandomColorHue(object):
|
159 |
+
def __init__(self, x, p, N):
|
160 |
+
self.min_x = -x
|
161 |
+
self.max_x = x
|
162 |
+
self.p_ref = p
|
163 |
+
self.plist = np.random.random_sample(N)
|
164 |
+
self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)]
|
165 |
+
|
166 |
+
def __call__(self, index, image):
|
167 |
+
if self.plist[index] < self.p_ref:
|
168 |
+
return TF.adjust_hue(image, self.rlist[index])
|
169 |
+
else:
|
170 |
+
return image
|
171 |
+
|
172 |
+
|
173 |
+
class RandomVerticalFlip(object):
|
174 |
+
def __init__(self, N, p=0.5):
|
175 |
+
self.p_ref = p
|
176 |
+
self.plist = np.random.random_sample(N)
|
177 |
+
|
178 |
+
def __call__(self, indice, image):
|
179 |
+
I = np.nonzero(self.plist[indice] < self.p_ref)[0]
|
180 |
+
|
181 |
+
if len(image.size()) == 3:
|
182 |
+
image_t = image[I].flip([1])
|
183 |
+
else:
|
184 |
+
image_t = image[I].flip([2])
|
185 |
+
|
186 |
+
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
class RandomHorizontalTensorFlip(object):
|
191 |
+
def __init__(self, N, p=0.5):
|
192 |
+
self.p_ref = p
|
193 |
+
self.plist = np.random.random_sample(N)
|
194 |
+
|
195 |
+
def __call__(self, indice, image, is_label=False):
|
196 |
+
I = np.nonzero(self.plist[indice] < self.p_ref)[0]
|
197 |
+
|
198 |
+
if len(image.size()) == 3:
|
199 |
+
image_t = image[I].flip([2])
|
200 |
+
else:
|
201 |
+
image_t = image[I].flip([3])
|
202 |
+
|
203 |
+
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
class RandomResizedCrop(object):
|
208 |
+
def __init__(self, N, res, scale=(0.5, 1.0)):
|
209 |
+
self.res = res
|
210 |
+
self.scale = scale
|
211 |
+
self.rscale = [np.random.uniform(*scale) for _ in range(N)]
|
212 |
+
self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
|
213 |
+
|
214 |
+
def random_crop(self, idx, img):
|
215 |
+
ws, hs = self.rcrop[idx]
|
216 |
+
res1 = int(img.size(-1))
|
217 |
+
res2 = int(self.rscale[idx]*res1)
|
218 |
+
i1 = int(round((res1-res2)*ws))
|
219 |
+
j1 = int(round((res1-res2)*hs))
|
220 |
+
|
221 |
+
return img[:, :, i1:i1+res2, j1:j1+res2]
|
222 |
+
|
223 |
+
|
224 |
+
def __call__(self, indice, image):
|
225 |
+
new_image = []
|
226 |
+
res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2?
|
227 |
+
|
228 |
+
for i, idx in enumerate(indice):
|
229 |
+
img = image[[i]]
|
230 |
+
img = self.random_crop(idx, img)
|
231 |
+
img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False)
|
232 |
+
|
233 |
+
new_image.append(img)
|
234 |
+
|
235 |
+
new_image = torch.cat(new_image)
|
236 |
+
|
237 |
+
return new_image
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
|
libs/data_coco_stuff.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import os.path as osp
|
5 |
+
import numpy as np
|
6 |
+
from torch.utils import data
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import torchvision.transforms.functional as TF
|
9 |
+
import random
|
10 |
+
|
11 |
+
class RandomResizedCrop(object):
|
12 |
+
def __init__(self, N, res, scale=(0.5, 1.0)):
|
13 |
+
self.res = res
|
14 |
+
self.scale = scale
|
15 |
+
self.rscale = [np.random.uniform(*scale) for _ in range(N)]
|
16 |
+
self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
|
17 |
+
|
18 |
+
def random_crop(self, idx, img):
|
19 |
+
ws, hs = self.rcrop[idx]
|
20 |
+
res1 = int(img.size(-1))
|
21 |
+
res2 = int(self.rscale[idx]*res1)
|
22 |
+
i1 = int(round((res1-res2)*ws))
|
23 |
+
j1 = int(round((res1-res2)*hs))
|
24 |
+
|
25 |
+
return img[:, :, i1:i1+res2, j1:j1+res2]
|
26 |
+
|
27 |
+
|
28 |
+
def __call__(self, indice, image):
|
29 |
+
new_image = []
|
30 |
+
res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2?
|
31 |
+
|
32 |
+
for i, idx in enumerate(indice):
|
33 |
+
img = image[[i]]
|
34 |
+
img = self.random_crop(idx, img)
|
35 |
+
img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False)
|
36 |
+
|
37 |
+
new_image.append(img)
|
38 |
+
|
39 |
+
new_image = torch.cat(new_image)
|
40 |
+
|
41 |
+
return new_image
|
42 |
+
|
43 |
+
class RandomVerticalFlip(object):
|
44 |
+
def __init__(self, N, p=0.5):
|
45 |
+
self.p_ref = p
|
46 |
+
self.plist = np.random.random_sample(N)
|
47 |
+
|
48 |
+
def __call__(self, indice, image):
|
49 |
+
I = np.nonzero(self.plist[indice] < self.p_ref)[0]
|
50 |
+
|
51 |
+
if len(image.size()) == 3:
|
52 |
+
image_t = image[I].flip([1])
|
53 |
+
else:
|
54 |
+
image_t = image[I].flip([2])
|
55 |
+
|
56 |
+
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
|
57 |
+
|
58 |
+
class RandomHorizontalTensorFlip(object):
|
59 |
+
def __init__(self, N, p=0.5):
|
60 |
+
self.p_ref = p
|
61 |
+
self.plist = np.random.random_sample(N)
|
62 |
+
|
63 |
+
def __call__(self, indice, image, is_label=False):
|
64 |
+
I = np.nonzero(self.plist[indice] < self.p_ref)[0]
|
65 |
+
|
66 |
+
if len(image.size()) == 3:
|
67 |
+
image_t = image[I].flip([2])
|
68 |
+
else:
|
69 |
+
image_t = image[I].flip([3])
|
70 |
+
|
71 |
+
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
|
72 |
+
|
73 |
+
class _Coco164kCuratedFew(data.Dataset):
|
74 |
+
"""Base class
|
75 |
+
This contains fields and methods common to all COCO 164k curated few datasets:
|
76 |
+
|
77 |
+
(curated) Coco164kFew_Stuff
|
78 |
+
(curated) Coco164kFew_Stuff_People
|
79 |
+
(curated) Coco164kFew_Stuff_Animals
|
80 |
+
(curated) Coco164kFew_Stuff_People_Animals
|
81 |
+
|
82 |
+
"""
|
83 |
+
def __init__(self, root, img_size, crop_size, split = "train2017"):
|
84 |
+
super(_Coco164kCuratedFew, self).__init__()
|
85 |
+
|
86 |
+
# work out name
|
87 |
+
self.split = split
|
88 |
+
self.root = root
|
89 |
+
self.include_things_labels = False # people
|
90 |
+
self.incl_animal_things = False # animals
|
91 |
+
|
92 |
+
version = 6
|
93 |
+
|
94 |
+
name = "Coco164kFew_Stuff"
|
95 |
+
if self.include_things_labels and self.incl_animal_things:
|
96 |
+
name += "_People_Animals"
|
97 |
+
elif self.include_things_labels:
|
98 |
+
name += "_People"
|
99 |
+
elif self.incl_animal_things:
|
100 |
+
name += "_Animals"
|
101 |
+
|
102 |
+
self.name = (name + "_%d" % version)
|
103 |
+
|
104 |
+
print("Specific type of _Coco164kCuratedFew dataset: %s" % self.name)
|
105 |
+
|
106 |
+
self._set_files()
|
107 |
+
|
108 |
+
|
109 |
+
self.transform = transforms.Compose([
|
110 |
+
transforms.RandomChoice([
|
111 |
+
transforms.ColorJitter(brightness=0.05),
|
112 |
+
transforms.ColorJitter(contrast=0.05),
|
113 |
+
transforms.ColorJitter(saturation=0.01),
|
114 |
+
transforms.ColorJitter(hue=0.01)]),
|
115 |
+
transforms.RandomHorizontalFlip(),
|
116 |
+
transforms.RandomVerticalFlip(),
|
117 |
+
transforms.Resize(int(img_size)),
|
118 |
+
transforms.RandomCrop(crop_size)])
|
119 |
+
|
120 |
+
N = len(self.files)
|
121 |
+
self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
|
122 |
+
self.random_vertical_flip = RandomVerticalFlip(N=N)
|
123 |
+
self.random_resized_crop = RandomResizedCrop(N=N, res=self.res1, scale=self.scale)
|
124 |
+
|
125 |
+
|
126 |
+
def _set_files(self):
|
127 |
+
# Create data list by parsing the "images" folder
|
128 |
+
if self.split in ["train2017", "val2017"]:
|
129 |
+
file_list = osp.join(self.root, "curated", self.split, self.name + ".txt")
|
130 |
+
file_list = tuple(open(file_list, "r"))
|
131 |
+
file_list = [id_.rstrip() for id_ in file_list]
|
132 |
+
|
133 |
+
self.files = file_list
|
134 |
+
print("In total {} images.".format(len(self.files)))
|
135 |
+
else:
|
136 |
+
raise ValueError("Invalid split name: {}".format(self.split))
|
137 |
+
|
138 |
+
def __getitem__(self, index):
|
139 |
+
# same as _Coco164k
|
140 |
+
# Set paths
|
141 |
+
image_id = self.files[index]
|
142 |
+
image_path = osp.join(self.root, "images", self.split, image_id + ".jpg")
|
143 |
+
label_path = osp.join(self.root, "annotations", self.split,
|
144 |
+
image_id + ".png")
|
145 |
+
# Load an image
|
146 |
+
#image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.uint8)
|
147 |
+
ori_img = Image.open(image_path)
|
148 |
+
ori_img = self.transform(ori_img)
|
149 |
+
ori_img = np.array(ori_img)
|
150 |
+
if ori_img.ndim < 3:
|
151 |
+
ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2)
|
152 |
+
ori_img = ori_img[:, :, :3]
|
153 |
+
ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1)
|
154 |
+
ori_img = ori_img / 255.0
|
155 |
+
|
156 |
+
#label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE).astype(np.int32)
|
157 |
+
|
158 |
+
#label[label == 255] = -1 # to be consistent with 10k
|
159 |
+
|
160 |
+
rets = []
|
161 |
+
rets.append(ori_img)
|
162 |
+
#rets.append(label)
|
163 |
+
return rets
|
164 |
+
|
165 |
+
def __len__(self):
|
166 |
+
return len(self.files)
|
libs/data_coco_stuff_geo_pho.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import os.path as osp
|
5 |
+
import numpy as np
|
6 |
+
from torch.utils import data
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import torchvision.transforms.functional as TF
|
9 |
+
import torchvision.transforms.functional as TF
|
10 |
+
from .custom_transform import *
|
11 |
+
|
12 |
+
class _Coco164kCuratedFew(data.Dataset):
|
13 |
+
"""Base class
|
14 |
+
This contains fields and methods common to all COCO 164k curated few datasets:
|
15 |
+
|
16 |
+
(curated) Coco164kFew_Stuff
|
17 |
+
(curated) Coco164kFew_Stuff_People
|
18 |
+
(curated) Coco164kFew_Stuff_Animals
|
19 |
+
(curated) Coco164kFew_Stuff_People_Animals
|
20 |
+
|
21 |
+
"""
|
22 |
+
def __init__(self, root, img_size, crop_size, split = "train2017"):
|
23 |
+
super(_Coco164kCuratedFew, self).__init__()
|
24 |
+
|
25 |
+
# work out name
|
26 |
+
self.split = split
|
27 |
+
self.root = root
|
28 |
+
self.include_things_labels = False # people
|
29 |
+
self.incl_animal_things = False # animals
|
30 |
+
|
31 |
+
version = 6
|
32 |
+
|
33 |
+
name = "Coco164kFew_Stuff"
|
34 |
+
if self.include_things_labels and self.incl_animal_things:
|
35 |
+
name += "_People_Animals"
|
36 |
+
elif self.include_things_labels:
|
37 |
+
name += "_People"
|
38 |
+
elif self.incl_animal_things:
|
39 |
+
name += "_Animals"
|
40 |
+
|
41 |
+
self.name = (name + "_%d" % version)
|
42 |
+
|
43 |
+
print("Specific type of _Coco164kCuratedFew dataset: %s" % self.name)
|
44 |
+
|
45 |
+
self._set_files()
|
46 |
+
|
47 |
+
self.transform = transforms.Compose([
|
48 |
+
transforms.Resize(int(img_size)),
|
49 |
+
transforms.RandomCrop(crop_size)])
|
50 |
+
|
51 |
+
N = len(self.files)
|
52 |
+
# eqv transform
|
53 |
+
self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
|
54 |
+
self.random_vertical_flip = RandomVerticalFlip(N=N)
|
55 |
+
self.random_resized_crop = RandomResizedCrop(N=N, res=288)
|
56 |
+
|
57 |
+
# photometric transform
|
58 |
+
self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)]
|
59 |
+
self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
|
60 |
+
self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
|
61 |
+
self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
|
62 |
+
self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)]
|
63 |
+
self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)]
|
64 |
+
|
65 |
+
self.eqv_list = ['random_crop', 'h_flip']
|
66 |
+
self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur']
|
67 |
+
|
68 |
+
self.transform_tensor = TensorTransform()
|
69 |
+
|
70 |
+
|
71 |
+
def _set_files(self):
|
72 |
+
# Create data list by parsing the "images" folder
|
73 |
+
if self.split in ["train2017", "val2017"]:
|
74 |
+
file_list = osp.join(self.root, "curated", self.split, self.name + ".txt")
|
75 |
+
file_list = tuple(open(file_list, "r"))
|
76 |
+
file_list = [id_.rstrip() for id_ in file_list]
|
77 |
+
|
78 |
+
self.files = file_list
|
79 |
+
print("In total {} images.".format(len(self.files)))
|
80 |
+
else:
|
81 |
+
raise ValueError("Invalid split name: {}".format(self.split))
|
82 |
+
|
83 |
+
def transform_eqv(self, indice, image):
|
84 |
+
if 'random_crop' in self.eqv_list:
|
85 |
+
image = self.random_resized_crop(indice, image)
|
86 |
+
if 'h_flip' in self.eqv_list:
|
87 |
+
image = self.random_horizontal_flip(indice, image)
|
88 |
+
if 'v_flip' in self.eqv_list:
|
89 |
+
image = self.random_vertical_flip(indice, image)
|
90 |
+
|
91 |
+
return image
|
92 |
+
|
93 |
+
def transform_inv(self, index, image, ver):
|
94 |
+
"""
|
95 |
+
Hyperparameters same as MoCo v2.
|
96 |
+
(https://github.com/facebookresearch/moco/blob/master/main_moco.py)
|
97 |
+
"""
|
98 |
+
if 'brightness' in self.inv_list:
|
99 |
+
image = self.random_color_brightness[ver](index, image)
|
100 |
+
if 'contrast' in self.inv_list:
|
101 |
+
image = self.random_color_contrast[ver](index, image)
|
102 |
+
if 'saturation' in self.inv_list:
|
103 |
+
image = self.random_color_saturation[ver](index, image)
|
104 |
+
if 'hue' in self.inv_list:
|
105 |
+
image = self.random_color_hue[ver](index, image)
|
106 |
+
if 'gray' in self.inv_list:
|
107 |
+
image = self.random_gray_scale[ver](index, image)
|
108 |
+
if 'blur' in self.inv_list:
|
109 |
+
image = self.random_gaussian_blur[ver](index, image)
|
110 |
+
|
111 |
+
return image
|
112 |
+
|
113 |
+
def transform_image(self, index, image):
|
114 |
+
image1 = self.transform_inv(index, image, 0)
|
115 |
+
image1 = self.transform_tensor(image)
|
116 |
+
|
117 |
+
image2 = self.transform_inv(index, image, 1)
|
118 |
+
#image2 = TF.resize(image2, self.crop_size, Image.BILINEAR)
|
119 |
+
image2 = self.transform_tensor(image2)
|
120 |
+
return image1, image2
|
121 |
+
|
122 |
+
def __getitem__(self, index):
|
123 |
+
# same as _Coco164k
|
124 |
+
# Set paths
|
125 |
+
image_id = self.files[index]
|
126 |
+
image_path = osp.join(self.root, "images", self.split, image_id + ".jpg")
|
127 |
+
# Load an image
|
128 |
+
ori_img = Image.open(image_path)
|
129 |
+
ori_img = self.transform(ori_img)
|
130 |
+
|
131 |
+
image1, image2 = self.transform_image(index, ori_img)
|
132 |
+
if image1.shape[0] < 3:
|
133 |
+
image1 = image1.repeat(3, 1, 1)
|
134 |
+
if image2.shape[0] < 3:
|
135 |
+
image2 = image2.repeat(3, 1, 1)
|
136 |
+
|
137 |
+
rets = []
|
138 |
+
rets.append(image1)
|
139 |
+
rets.append(image2)
|
140 |
+
rets.append(index)
|
141 |
+
|
142 |
+
return rets
|
143 |
+
|
144 |
+
def __len__(self):
|
145 |
+
return len(self.files)
|
libs/data_geo.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SLIC dataset
|
2 |
+
- Returns an image together with its SLIC segmentation map.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.utils.data as data
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from glob import glob
|
10 |
+
from PIL import Image
|
11 |
+
from skimage.segmentation import slic
|
12 |
+
from skimage.color import rgb2lab
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from .utils import label2one_hot_torch
|
16 |
+
|
17 |
+
class RandomResizedCrop(object):
|
18 |
+
def __init__(self, N, res, scale=(0.5, 1.0)):
|
19 |
+
self.res = res
|
20 |
+
self.scale = scale
|
21 |
+
self.rscale = [np.random.uniform(*scale) for _ in range(N)]
|
22 |
+
self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
|
23 |
+
|
24 |
+
def random_crop(self, idx, img):
|
25 |
+
ws, hs = self.rcrop[idx]
|
26 |
+
res1 = int(img.size(-1))
|
27 |
+
res2 = int(self.rscale[idx]*res1)
|
28 |
+
i1 = int(round((res1-res2)*ws))
|
29 |
+
j1 = int(round((res1-res2)*hs))
|
30 |
+
|
31 |
+
return img[:, :, i1:i1+res2, j1:j1+res2]
|
32 |
+
|
33 |
+
|
34 |
+
def __call__(self, indice, image):
|
35 |
+
new_image = []
|
36 |
+
res_tar = self.res // 8 if image.size(1) > 5 else self.res # View 1 or View 2?
|
37 |
+
|
38 |
+
for i, idx in enumerate(indice):
|
39 |
+
img = image[[i]]
|
40 |
+
img = self.random_crop(idx, img)
|
41 |
+
img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False)
|
42 |
+
|
43 |
+
new_image.append(img)
|
44 |
+
|
45 |
+
new_image = torch.cat(new_image)
|
46 |
+
|
47 |
+
return new_image
|
48 |
+
|
49 |
+
class RandomVerticalFlip(object):
|
50 |
+
def __init__(self, N, p=0.5):
|
51 |
+
self.p_ref = p
|
52 |
+
self.plist = np.random.random_sample(N)
|
53 |
+
|
54 |
+
def __call__(self, indice, image):
|
55 |
+
I = np.nonzero(self.plist[indice] < self.p_ref)[0]
|
56 |
+
|
57 |
+
if len(image.size()) == 3:
|
58 |
+
image_t = image[I].flip([1])
|
59 |
+
else:
|
60 |
+
image_t = image[I].flip([2])
|
61 |
+
|
62 |
+
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
|
63 |
+
|
64 |
+
class RandomHorizontalTensorFlip(object):
|
65 |
+
def __init__(self, N, p=0.5):
|
66 |
+
self.p_ref = p
|
67 |
+
self.plist = np.random.random_sample(N)
|
68 |
+
|
69 |
+
def __call__(self, indice, image, is_label=False):
|
70 |
+
I = np.nonzero(self.plist[indice.cpu()] < self.p_ref)[0]
|
71 |
+
|
72 |
+
if len(image.size()) == 3:
|
73 |
+
image_t = image[I].flip([2])
|
74 |
+
else:
|
75 |
+
image_t = image[I].flip([3])
|
76 |
+
|
77 |
+
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
|
78 |
+
|
79 |
+
class Dataset(data.Dataset):
|
80 |
+
def __init__(self, data_dir, img_size=256, crop_size=128, test=False,
|
81 |
+
sp_num=256, slic = True, lab = False):
|
82 |
+
super(Dataset, self).__init__()
|
83 |
+
#self.data_list = glob(os.path.join(data_dir, "*.jpg"))
|
84 |
+
ext = ["*.jpg"]
|
85 |
+
dl = []
|
86 |
+
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
|
87 |
+
self.data_list = dl
|
88 |
+
self.sp_num = sp_num
|
89 |
+
self.slic = slic
|
90 |
+
self.lab = lab
|
91 |
+
if test:
|
92 |
+
self.transform = transforms.Compose([
|
93 |
+
transforms.Resize(img_size),
|
94 |
+
transforms.CenterCrop(crop_size)])
|
95 |
+
else:
|
96 |
+
self.transform = transforms.Compose([
|
97 |
+
transforms.RandomChoice([
|
98 |
+
transforms.ColorJitter(brightness=0.05),
|
99 |
+
transforms.ColorJitter(contrast=0.05),
|
100 |
+
transforms.ColorJitter(saturation=0.01),
|
101 |
+
transforms.ColorJitter(hue=0.01)]),
|
102 |
+
transforms.RandomHorizontalFlip(),
|
103 |
+
transforms.RandomVerticalFlip(),
|
104 |
+
transforms.Resize(int(img_size)),
|
105 |
+
transforms.RandomCrop(crop_size)])
|
106 |
+
|
107 |
+
N = len(self.data_list)
|
108 |
+
self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
|
109 |
+
self.random_vertical_flip = RandomVerticalFlip(N=N)
|
110 |
+
self.random_resized_crop = RandomResizedCrop(N=N, res=224)
|
111 |
+
self.eqv_list = ['random_crop', 'h_flip']
|
112 |
+
|
113 |
+
def transform_eqv(self, indice, image):
|
114 |
+
if 'random_crop' in self.eqv_list:
|
115 |
+
image = self.random_resized_crop(indice, image)
|
116 |
+
if 'h_flip' in self.eqv_list:
|
117 |
+
image = self.random_horizontal_flip(indice, image)
|
118 |
+
if 'v_flip' in self.eqv_list:
|
119 |
+
image = self.random_vertical_flip(indice, image)
|
120 |
+
|
121 |
+
return image
|
122 |
+
|
123 |
+
def __getitem__(self, index):
|
124 |
+
data_path = self.data_list[index]
|
125 |
+
ori_img = Image.open(data_path)
|
126 |
+
ori_img = self.transform(ori_img)
|
127 |
+
ori_img = np.array(ori_img)
|
128 |
+
|
129 |
+
# compute slic
|
130 |
+
if self.slic:
|
131 |
+
slic_i = slic(ori_img, n_segments=self.sp_num, compactness=10, start_label=0, min_size_factor=0.3)
|
132 |
+
slic_i = torch.from_numpy(slic_i)
|
133 |
+
slic_i[slic_i >= self.sp_num] = self.sp_num - 1
|
134 |
+
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = self.sp_num).squeeze()
|
135 |
+
|
136 |
+
if ori_img.ndim < 3:
|
137 |
+
ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2)
|
138 |
+
ori_img = ori_img[:, :, :3]
|
139 |
+
|
140 |
+
rets = []
|
141 |
+
if self.lab:
|
142 |
+
lab_img = rgb2lab(ori_img)
|
143 |
+
rets.append(torch.from_numpy(lab_img).float().permute(2, 0, 1))
|
144 |
+
|
145 |
+
ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1)
|
146 |
+
rets.append(ori_img/255.0)
|
147 |
+
|
148 |
+
if self.slic:
|
149 |
+
rets.append(oh)
|
150 |
+
|
151 |
+
rets.append(index)
|
152 |
+
|
153 |
+
return rets
|
154 |
+
|
155 |
+
def __len__(self):
|
156 |
+
return len(self.data_list)
|
157 |
+
|
158 |
+
if __name__ == '__main__':
|
159 |
+
import torchvision.utils as vutils
|
160 |
+
dataset = Dataset('/home/xtli/DATA/texture_data/',
|
161 |
+
sampled_num=3000)
|
162 |
+
loader_ = torch.utils.data.DataLoader(dataset = dataset,
|
163 |
+
batch_size = 1,
|
164 |
+
shuffle = True,
|
165 |
+
num_workers = 1,
|
166 |
+
drop_last = True)
|
167 |
+
loader = iter(loader_)
|
168 |
+
img, points, pixs = loader.next()
|
169 |
+
|
170 |
+
crop_size = 128
|
171 |
+
canvas = torch.zeros((1, 3, crop_size, crop_size))
|
172 |
+
for i in range(points.shape[-2]):
|
173 |
+
p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
|
174 |
+
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
|
175 |
+
vutils.save_image(canvas, 'canvas.png')
|
176 |
+
vutils.save_image(img, 'img.png')
|
libs/data_geo_pho.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SLIC dataset
|
2 |
+
- Returns an image together with its SLIC segmentation map.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.utils.data as data
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from glob import glob
|
10 |
+
from PIL import Image
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torchvision.transforms.functional as TF
|
13 |
+
|
14 |
+
from .custom_transform import *
|
15 |
+
|
16 |
+
class Dataset(data.Dataset):
|
17 |
+
def __init__(self, data_dir, img_size=256, crop_size=128, test=False,
|
18 |
+
sp_num=256, slic = True, lab = False):
|
19 |
+
super(Dataset, self).__init__()
|
20 |
+
#self.data_list = glob(os.path.join(data_dir, "*.jpg"))
|
21 |
+
ext = ["*.jpg"]
|
22 |
+
dl = []
|
23 |
+
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
|
24 |
+
self.data_list = dl
|
25 |
+
self.sp_num = sp_num
|
26 |
+
self.slic = slic
|
27 |
+
self.lab = lab
|
28 |
+
if test:
|
29 |
+
self.transform = transforms.Compose([
|
30 |
+
transforms.Resize(img_size),
|
31 |
+
transforms.CenterCrop(crop_size)])
|
32 |
+
else:
|
33 |
+
self.transform = transforms.Compose([
|
34 |
+
transforms.Resize(int(img_size)),
|
35 |
+
transforms.RandomCrop(crop_size)])
|
36 |
+
|
37 |
+
N = len(self.data_list)
|
38 |
+
# eqv transform
|
39 |
+
self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
|
40 |
+
self.random_vertical_flip = RandomVerticalFlip(N=N)
|
41 |
+
self.random_resized_crop = RandomResizedCrop(N=N, res=256)
|
42 |
+
|
43 |
+
# photometric transform
|
44 |
+
self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)]
|
45 |
+
self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
|
46 |
+
self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
|
47 |
+
self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)
|
48 |
+
self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)]
|
49 |
+
self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)]
|
50 |
+
|
51 |
+
self.eqv_list = ['random_crop', 'h_flip']
|
52 |
+
self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur']
|
53 |
+
|
54 |
+
self.transform_tensor = TensorTransform()
|
55 |
+
|
56 |
+
def transform_eqv(self, indice, image):
|
57 |
+
if 'random_crop' in self.eqv_list:
|
58 |
+
image = self.random_resized_crop(indice, image)
|
59 |
+
if 'h_flip' in self.eqv_list:
|
60 |
+
image = self.random_horizontal_flip(indice, image)
|
61 |
+
if 'v_flip' in self.eqv_list:
|
62 |
+
image = self.random_vertical_flip(indice, image)
|
63 |
+
|
64 |
+
return image
|
65 |
+
|
66 |
+
def transform_inv(self, index, image, ver):
|
67 |
+
"""
|
68 |
+
Hyperparameters same as MoCo v2.
|
69 |
+
(https://github.com/facebookresearch/moco/blob/master/main_moco.py)
|
70 |
+
"""
|
71 |
+
if 'brightness' in self.inv_list:
|
72 |
+
image = self.random_color_brightness[ver](index, image)
|
73 |
+
if 'contrast' in self.inv_list:
|
74 |
+
image = self.random_color_contrast[ver](index, image)
|
75 |
+
if 'saturation' in self.inv_list:
|
76 |
+
image = self.random_color_saturation[ver](index, image)
|
77 |
+
if 'hue' in self.inv_list:
|
78 |
+
image = self.random_color_hue[ver](index, image)
|
79 |
+
if 'gray' in self.inv_list:
|
80 |
+
image = self.random_gray_scale[ver](index, image)
|
81 |
+
if 'blur' in self.inv_list:
|
82 |
+
image = self.random_gaussian_blur[ver](index, image)
|
83 |
+
|
84 |
+
return image
|
85 |
+
|
86 |
+
def transform_image(self, index, image):
|
87 |
+
image1 = self.transform_inv(index, image, 0)
|
88 |
+
image1 = self.transform_tensor(image)
|
89 |
+
|
90 |
+
image2 = self.transform_inv(index, image, 1)
|
91 |
+
#image2 = TF.resize(image2, self.crop_size, Image.BILINEAR)
|
92 |
+
image2 = self.transform_tensor(image2)
|
93 |
+
return image1, image2
|
94 |
+
|
95 |
+
def __getitem__(self, index):
|
96 |
+
data_path = self.data_list[index]
|
97 |
+
ori_img = Image.open(data_path)
|
98 |
+
ori_img = self.transform(ori_img)
|
99 |
+
|
100 |
+
image1, image2 = self.transform_image(index, ori_img)
|
101 |
+
|
102 |
+
rets = []
|
103 |
+
rets.append(image1)
|
104 |
+
rets.append(image2)
|
105 |
+
rets.append(index)
|
106 |
+
|
107 |
+
return rets
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self.data_list)
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
import torchvision.utils as vutils
|
114 |
+
dataset = Dataset('/home/xtli/DATA/texture_data/',
|
115 |
+
sampled_num=3000)
|
116 |
+
loader_ = torch.utils.data.DataLoader(dataset = dataset,
|
117 |
+
batch_size = 1,
|
118 |
+
shuffle = True,
|
119 |
+
num_workers = 1,
|
120 |
+
drop_last = True)
|
121 |
+
loader = iter(loader_)
|
122 |
+
img, points, pixs = loader.next()
|
123 |
+
|
124 |
+
crop_size = 128
|
125 |
+
canvas = torch.zeros((1, 3, crop_size, crop_size))
|
126 |
+
for i in range(points.shape[-2]):
|
127 |
+
p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
|
128 |
+
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
|
129 |
+
vutils.save_image(canvas, 'canvas.png')
|
130 |
+
vutils.save_image(img, 'img.png')
|
libs/data_slic.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SLIC dataset
|
2 |
+
- Returns an image together with its SLIC segmentation map.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.utils.data as data
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from glob import glob
|
10 |
+
from PIL import Image
|
11 |
+
from skimage.segmentation import slic
|
12 |
+
from skimage.color import rgb2lab
|
13 |
+
|
14 |
+
from .utils import label2one_hot_torch
|
15 |
+
|
16 |
+
class RandomResizedCrop(object):
|
17 |
+
def __init__(self, N, res, scale=(0.5, 1.0)):
|
18 |
+
self.res = res
|
19 |
+
self.scale = scale
|
20 |
+
self.rscale = [np.random.uniform(*scale) for _ in range(N)]
|
21 |
+
self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)]
|
22 |
+
|
23 |
+
def random_crop(self, idx, img):
|
24 |
+
ws, hs = self.rcrop[idx]
|
25 |
+
res1 = int(img.size(-1))
|
26 |
+
res2 = int(self.rscale[idx]*res1)
|
27 |
+
i1 = int(round((res1-res2)*ws))
|
28 |
+
j1 = int(round((res1-res2)*hs))
|
29 |
+
|
30 |
+
return img[:, :, i1:i1+res2, j1:j1+res2]
|
31 |
+
|
32 |
+
|
33 |
+
def __call__(self, indice, image):
|
34 |
+
new_image = []
|
35 |
+
res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2?
|
36 |
+
|
37 |
+
for i, idx in enumerate(indice):
|
38 |
+
img = image[[i]]
|
39 |
+
img = self.random_crop(idx, img)
|
40 |
+
img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False)
|
41 |
+
|
42 |
+
new_image.append(img)
|
43 |
+
|
44 |
+
new_image = torch.cat(new_image)
|
45 |
+
|
46 |
+
return new_image
|
47 |
+
|
48 |
+
class RandomVerticalFlip(object):
|
49 |
+
def __init__(self, N, p=0.5):
|
50 |
+
self.p_ref = p
|
51 |
+
self.plist = np.random.random_sample(N)
|
52 |
+
|
53 |
+
def __call__(self, indice, image):
|
54 |
+
I = np.nonzero(self.plist[indice] < self.p_ref)[0]
|
55 |
+
|
56 |
+
if len(image.size()) == 3:
|
57 |
+
image_t = image[I].flip([1])
|
58 |
+
else:
|
59 |
+
image_t = image[I].flip([2])
|
60 |
+
|
61 |
+
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
|
62 |
+
|
63 |
+
class RandomHorizontalTensorFlip(object):
|
64 |
+
def __init__(self, N, p=0.5):
|
65 |
+
self.p_ref = p
|
66 |
+
self.plist = np.random.random_sample(N)
|
67 |
+
|
68 |
+
def __call__(self, indice, image, is_label=False):
|
69 |
+
I = np.nonzero(self.plist[indice] < self.p_ref)[0]
|
70 |
+
|
71 |
+
if len(image.size()) == 3:
|
72 |
+
image_t = image[I].flip([2])
|
73 |
+
else:
|
74 |
+
image_t = image[I].flip([3])
|
75 |
+
|
76 |
+
return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))])
|
77 |
+
|
78 |
+
class Dataset(data.Dataset):
|
79 |
+
def __init__(self, data_dir, img_size=256, crop_size=128, test=False,
|
80 |
+
sp_num=256, slic = True, lab = False):
|
81 |
+
super(Dataset, self).__init__()
|
82 |
+
#self.data_list = glob(os.path.join(data_dir, "*.jpg"))
|
83 |
+
ext = ["*.jpg"]
|
84 |
+
dl = []
|
85 |
+
[dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]
|
86 |
+
self.data_list = dl
|
87 |
+
self.sp_num = sp_num
|
88 |
+
self.slic = slic
|
89 |
+
self.lab = lab
|
90 |
+
if test:
|
91 |
+
self.transform = transforms.Compose([
|
92 |
+
transforms.Resize(img_size),
|
93 |
+
transforms.CenterCrop(crop_size)])
|
94 |
+
else:
|
95 |
+
self.transform = transforms.Compose([
|
96 |
+
transforms.RandomChoice([
|
97 |
+
transforms.ColorJitter(brightness=0.05),
|
98 |
+
transforms.ColorJitter(contrast=0.05),
|
99 |
+
transforms.ColorJitter(saturation=0.01),
|
100 |
+
transforms.ColorJitter(hue=0.01)]),
|
101 |
+
transforms.RandomHorizontalFlip(),
|
102 |
+
transforms.RandomVerticalFlip(),
|
103 |
+
transforms.Resize(int(img_size)),
|
104 |
+
transforms.RandomCrop(crop_size)])
|
105 |
+
|
106 |
+
N = len(self.data_list)
|
107 |
+
self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N)
|
108 |
+
self.random_vertical_flip = RandomVerticalFlip(N=N)
|
109 |
+
self.random_resized_crop = RandomResizedCrop(N=N, res=img_size)
|
110 |
+
self.eqv_list = ['random_crop', 'h_flip']
|
111 |
+
|
112 |
+
def transform_eqv(self, indice, image):
|
113 |
+
if 'random_crop' in self.eqv_list:
|
114 |
+
image = self.random_resized_crop(indice, image)
|
115 |
+
if 'h_flip' in self.eqv_list:
|
116 |
+
image = self.random_horizontal_flip(indice, image)
|
117 |
+
if 'v_flip' in self.eqv_list:
|
118 |
+
image = self.random_vertical_flip(indice, image)
|
119 |
+
|
120 |
+
return image
|
121 |
+
|
122 |
+
def __getitem__(self, index):
|
123 |
+
data_path = self.data_list[index]
|
124 |
+
ori_img = Image.open(data_path)
|
125 |
+
ori_img = self.transform(ori_img)
|
126 |
+
ori_img = np.array(ori_img)
|
127 |
+
|
128 |
+
# compute slic
|
129 |
+
if self.slic:
|
130 |
+
slic_i = slic(ori_img, n_segments=self.sp_num, compactness=10, start_label=0, min_size_factor=0.3)
|
131 |
+
slic_i = torch.from_numpy(slic_i)
|
132 |
+
slic_i[slic_i >= self.sp_num] = self.sp_num - 1
|
133 |
+
oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = self.sp_num).squeeze()
|
134 |
+
|
135 |
+
if ori_img.ndim < 3:
|
136 |
+
ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2)
|
137 |
+
ori_img = ori_img[:, :, :3]
|
138 |
+
|
139 |
+
rets = []
|
140 |
+
if self.lab:
|
141 |
+
lab_img = rgb2lab(ori_img)
|
142 |
+
rets.append(torch.from_numpy(lab_img).float().permute(2, 0, 1))
|
143 |
+
|
144 |
+
ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1)
|
145 |
+
rets.append(ori_img/255.0)
|
146 |
+
|
147 |
+
if self.slic:
|
148 |
+
rets.append(oh)
|
149 |
+
|
150 |
+
rets.append(index)
|
151 |
+
|
152 |
+
return rets
|
153 |
+
|
154 |
+
def __len__(self):
|
155 |
+
return len(self.data_list)
|
156 |
+
|
157 |
+
if __name__ == '__main__':
|
158 |
+
import torchvision.utils as vutils
|
159 |
+
dataset = Dataset('/home/xtli/DATA/texture_data/',
|
160 |
+
sampled_num=3000)
|
161 |
+
loader_ = torch.utils.data.DataLoader(dataset = dataset,
|
162 |
+
batch_size = 1,
|
163 |
+
shuffle = True,
|
164 |
+
num_workers = 1,
|
165 |
+
drop_last = True)
|
166 |
+
loader = iter(loader_)
|
167 |
+
img, points, pixs = loader.next()
|
168 |
+
|
169 |
+
crop_size = 128
|
170 |
+
canvas = torch.zeros((1, 3, crop_size, crop_size))
|
171 |
+
for i in range(points.shape[-2]):
|
172 |
+
p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
|
173 |
+
canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
|
174 |
+
vutils.save_image(canvas, 'canvas.png')
|
175 |
+
vutils.save_image(img, 'img.png')
|
libs/discriminator.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
def weights_init(m):
|
5 |
+
classname = m.__class__.__name__
|
6 |
+
if classname.find('Conv') != -1:
|
7 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
8 |
+
elif classname.find('BatchNorm') != -1:
|
9 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
10 |
+
nn.init.constant_(m.bias.data, 0)
|
11 |
+
|
12 |
+
|
13 |
+
class NLayerDiscriminator(nn.Module):
|
14 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
15 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
16 |
+
"""
|
17 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
18 |
+
"""Construct a PatchGAN discriminator
|
19 |
+
Parameters:
|
20 |
+
input_nc (int) -- the number of channels in input images
|
21 |
+
ndf (int) -- the number of filters in the last conv layer
|
22 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
23 |
+
norm_layer -- normalization layer
|
24 |
+
"""
|
25 |
+
super(NLayerDiscriminator, self).__init__()
|
26 |
+
norm_layer = nn.BatchNorm2d
|
27 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
28 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
29 |
+
else:
|
30 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
31 |
+
|
32 |
+
kw = 4
|
33 |
+
padw = 1
|
34 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
35 |
+
nf_mult = 1
|
36 |
+
nf_mult_prev = 1
|
37 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
38 |
+
nf_mult_prev = nf_mult
|
39 |
+
nf_mult = min(2 ** n, 8)
|
40 |
+
sequence += [
|
41 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
42 |
+
norm_layer(ndf * nf_mult),
|
43 |
+
nn.LeakyReLU(0.2, True)
|
44 |
+
]
|
45 |
+
|
46 |
+
nf_mult_prev = nf_mult
|
47 |
+
nf_mult = min(2 ** n_layers, 8)
|
48 |
+
sequence += [
|
49 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
50 |
+
norm_layer(ndf * nf_mult),
|
51 |
+
nn.LeakyReLU(0.2, True)
|
52 |
+
]
|
53 |
+
|
54 |
+
sequence += [
|
55 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
56 |
+
self.main = nn.Sequential(*sequence)
|
57 |
+
|
58 |
+
def forward(self, input):
|
59 |
+
"""Standard forward."""
|
60 |
+
return self.main(input)
|
libs/flow_transforms.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import numbers
|
6 |
+
import types
|
7 |
+
import scipy.ndimage as ndimage
|
8 |
+
import cv2
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from PIL import Image
|
11 |
+
# import torchvision.transforms.functional as FF
|
12 |
+
|
13 |
+
'''
|
14 |
+
Data argumentation file
|
15 |
+
modifed from
|
16 |
+
https://github.com/ClementPinard/FlowNetPytorch
|
17 |
+
|
18 |
+
|
19 |
+
'''
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
'''Set of tranform random routines that takes both input and target as arguments,
|
24 |
+
in order to have random but coherent transformations.
|
25 |
+
inputs are PIL Image pairs and targets are ndarrays'''
|
26 |
+
|
27 |
+
_pil_interpolation_to_str = {
|
28 |
+
Image.NEAREST: 'PIL.Image.NEAREST',
|
29 |
+
Image.BILINEAR: 'PIL.Image.BILINEAR',
|
30 |
+
Image.BICUBIC: 'PIL.Image.BICUBIC',
|
31 |
+
Image.LANCZOS: 'PIL.Image.LANCZOS',
|
32 |
+
Image.HAMMING: 'PIL.Image.HAMMING',
|
33 |
+
Image.BOX: 'PIL.Image.BOX',
|
34 |
+
}
|
35 |
+
|
36 |
+
class Compose(object):
|
37 |
+
""" Composes several co_transforms together.
|
38 |
+
For example:
|
39 |
+
>>> co_transforms.Compose([
|
40 |
+
>>> co_transforms.CenterCrop(10),
|
41 |
+
>>> co_transforms.ToTensor(),
|
42 |
+
>>> ])
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, co_transforms):
|
46 |
+
self.co_transforms = co_transforms
|
47 |
+
|
48 |
+
def __call__(self, input, target):
|
49 |
+
for t in self.co_transforms:
|
50 |
+
input,target = t(input,target)
|
51 |
+
return input,target
|
52 |
+
|
53 |
+
|
54 |
+
class ArrayToTensor(object):
|
55 |
+
"""Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
|
56 |
+
|
57 |
+
def __call__(self, array):
|
58 |
+
assert(isinstance(array, np.ndarray))
|
59 |
+
|
60 |
+
array = np.transpose(array, (2, 0, 1))
|
61 |
+
# handle numpy array
|
62 |
+
tensor = torch.from_numpy(array)
|
63 |
+
# put it from HWC to CHW format
|
64 |
+
|
65 |
+
return tensor.float()
|
66 |
+
|
67 |
+
|
68 |
+
class ArrayToPILImage(object):
|
69 |
+
"""Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
|
70 |
+
|
71 |
+
def __call__(self, array):
|
72 |
+
assert(isinstance(array, np.ndarray))
|
73 |
+
|
74 |
+
img = Image.fromarray(array.astype(np.uint8))
|
75 |
+
|
76 |
+
return img
|
77 |
+
|
78 |
+
class PILImageToTensor(object):
|
79 |
+
"""Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
|
80 |
+
|
81 |
+
def __call__(self, img):
|
82 |
+
assert(isinstance(img, Image.Image))
|
83 |
+
|
84 |
+
array = np.asarray(img)
|
85 |
+
array = np.transpose(array, (2, 0, 1))
|
86 |
+
tensor = torch.from_numpy(array)
|
87 |
+
|
88 |
+
return tensor.float()
|
89 |
+
|
90 |
+
|
91 |
+
class Lambda(object):
|
92 |
+
"""Applies a lambda as a transform"""
|
93 |
+
|
94 |
+
def __init__(self, lambd):
|
95 |
+
assert isinstance(lambd, types.LambdaType)
|
96 |
+
self.lambd = lambd
|
97 |
+
|
98 |
+
def __call__(self, input,target):
|
99 |
+
return self.lambd(input,target)
|
100 |
+
|
101 |
+
|
102 |
+
class CenterCrop(object):
|
103 |
+
"""Crops the given inputs and target arrays at the center to have a region of
|
104 |
+
the given size. size can be a tuple (target_height, target_width)
|
105 |
+
or an integer, in which case the target will be of a square shape (size, size)
|
106 |
+
Careful, img1 and img2 may not be the same size
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, size):
|
110 |
+
if isinstance(size, numbers.Number):
|
111 |
+
self.size = (int(size), int(size))
|
112 |
+
else:
|
113 |
+
self.size = size
|
114 |
+
|
115 |
+
def __call__(self, inputs, target):
|
116 |
+
h1, w1, _ = inputs[0].shape
|
117 |
+
# h2, w2, _ = inputs[1].shape
|
118 |
+
th, tw = self.size
|
119 |
+
x1 = int(round((w1 - tw) / 2.))
|
120 |
+
y1 = int(round((h1 - th) / 2.))
|
121 |
+
# x2 = int(round((w2 - tw) / 2.))
|
122 |
+
# y2 = int(round((h2 - th) / 2.))
|
123 |
+
for i in range(len(inputs)):
|
124 |
+
inputs[i] = inputs[i][y1: y1 + th, x1: x1 + tw]
|
125 |
+
# inputs[0] = inputs[0][y1: y1 + th, x1: x1 + tw]
|
126 |
+
# inputs[1] = inputs[1][y2: y2 + th, x2: x2 + tw]
|
127 |
+
target = target[y1: y1 + th, x1: x1 + tw]
|
128 |
+
return inputs,target
|
129 |
+
|
130 |
+
class myRandomResized(object):
|
131 |
+
"""
|
132 |
+
based on RandomResizedCrop in
|
133 |
+
https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html#RandomResizedCrop
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __init__(self, expect_min_size, scale=(0.8, 1.5), interpolation=cv2.INTER_NEAREST):
|
137 |
+
# assert (min(input_size) * min(scale) > max(expect_size))
|
138 |
+
# one consider one decimal !!
|
139 |
+
assert (isinstance(scale,tuple) and len(scale)==2)
|
140 |
+
self.interpolation = interpolation
|
141 |
+
self.scale = [ x*0.1 for x in range(int(scale[0]*10),int(scale[1])*10 )]
|
142 |
+
self.min_size = expect_min_size
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def get_params(img, scale, min_size):
|
146 |
+
"""Get parameters for ``crop`` for a random sized crop.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
img (PIL Image): Image to be cropped.
|
150 |
+
scale (tuple): range of size of the origin size cropped
|
151 |
+
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
155 |
+
sized crop.
|
156 |
+
"""
|
157 |
+
# area = img.size[0] * img.size[1]
|
158 |
+
h, w, _ = img.shape
|
159 |
+
for attempt in range(10):
|
160 |
+
rand_scale_ = random.choice(scale)
|
161 |
+
|
162 |
+
if random.random() < 0.5:
|
163 |
+
rand_scale = rand_scale_
|
164 |
+
else:
|
165 |
+
rand_scale = -1.
|
166 |
+
|
167 |
+
if min_size[0] <= rand_scale * h and min_size[1] <= rand_scale * w\
|
168 |
+
and rand_scale * h % 16 == 0 and rand_scale * w %16 ==0 :
|
169 |
+
# the 16*n condition is for network architecture
|
170 |
+
return (int(rand_scale * h),int(rand_scale * w ))
|
171 |
+
|
172 |
+
# Fallback
|
173 |
+
return (h, w)
|
174 |
+
|
175 |
+
def __call__(self, inputs, tgt):
|
176 |
+
"""
|
177 |
+
Args:
|
178 |
+
img (PIL Image): Image to be cropped and resized.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
PIL Image: Randomly cropped and resized image.
|
182 |
+
"""
|
183 |
+
h,w = self.get_params(inputs[0], self.scale, self.min_size)
|
184 |
+
for i in range(len(inputs)):
|
185 |
+
inputs[i] = cv2.resize(inputs[i], (w,h), self.interpolation)
|
186 |
+
|
187 |
+
tgt = cv2.resize(tgt, (w,h), self.interpolation) #for input as h*w*1 the output is h*w
|
188 |
+
return inputs, np.expand_dims(tgt,-1)
|
189 |
+
|
190 |
+
def __repr__(self):
|
191 |
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
192 |
+
format_string = self.__class__.__name__ + '(min_size={0}'.format(self.min_size)
|
193 |
+
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
|
194 |
+
format_string += ', interpolation={0})'.format(interpolate_str)
|
195 |
+
return format_string
|
196 |
+
|
197 |
+
|
198 |
+
class Scale(object):
|
199 |
+
""" Rescales the inputs and target arrays to the given 'size'.
|
200 |
+
'size' will be the size of the smaller edge.
|
201 |
+
For example, if height > width, then image will be
|
202 |
+
rescaled to (size * height / width, size)
|
203 |
+
size: size of the smaller edge
|
204 |
+
interpolation order: Default: 2 (bilinear)
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(self, size, order=2):
|
208 |
+
self.size = size
|
209 |
+
self.order = order
|
210 |
+
|
211 |
+
def __call__(self, inputs, target):
|
212 |
+
h, w, _ = inputs[0].shape
|
213 |
+
if (w <= h and w == self.size) or (h <= w and h == self.size):
|
214 |
+
return inputs,target
|
215 |
+
if w < h:
|
216 |
+
ratio = self.size/w
|
217 |
+
else:
|
218 |
+
ratio = self.size/h
|
219 |
+
|
220 |
+
for i in range(len(inputs)):
|
221 |
+
inputs[i] = ndimage.interpolation.zoom(inputs[i], ratio, order=self.order)[:, :, :3]
|
222 |
+
|
223 |
+
target = ndimage.interpolation.zoom(target, ratio, order=self.order)[:, :, :1]
|
224 |
+
#target *= ratio
|
225 |
+
return inputs, target
|
226 |
+
|
227 |
+
|
228 |
+
class RandomCrop(object):
|
229 |
+
"""Crops the given PIL.Image at a random location to have a region of
|
230 |
+
the given size. size can be a tuple (target_height, target_width)
|
231 |
+
or an integer, in which case the target will be of a square shape (size, size)
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(self, size):
|
235 |
+
if isinstance(size, numbers.Number):
|
236 |
+
self.size = (int(size), int(size))
|
237 |
+
else:
|
238 |
+
self.size = size
|
239 |
+
|
240 |
+
def __call__(self, inputs,target):
|
241 |
+
h, w, _ = inputs[0].shape
|
242 |
+
th, tw = self.size
|
243 |
+
if w == tw and h == th:
|
244 |
+
return inputs,target
|
245 |
+
|
246 |
+
x1 = random.randint(0, w - tw)
|
247 |
+
y1 = random.randint(0, h - th)
|
248 |
+
for i in range(len(inputs)):
|
249 |
+
inputs[i] = inputs[i][y1: y1 + th,x1: x1 + tw]
|
250 |
+
# inputs[1] = inputs[1][y1: y1 + th,x1: x1 + tw]
|
251 |
+
# inputs[2] = inputs[2][y1: y1 + th, x1: x1 + tw]
|
252 |
+
|
253 |
+
return inputs, target[y1: y1 + th,x1: x1 + tw]
|
254 |
+
|
255 |
+
class MyScale(object):
|
256 |
+
def __init__(self, size, order=2):
|
257 |
+
self.size = size
|
258 |
+
self.order = order
|
259 |
+
|
260 |
+
def __call__(self, inputs, target):
|
261 |
+
h, w, _ = inputs[0].shape
|
262 |
+
if (w <= h and w == self.size) or (h <= w and h == self.size):
|
263 |
+
return inputs,target
|
264 |
+
if w < h:
|
265 |
+
for i in range(len(inputs)):
|
266 |
+
inputs[i] = cv2.resize(inputs[i], (self.size, int(h * self.size / w)))
|
267 |
+
target = cv2.resize(target.squeeze(), (self.size, int(h * self.size / w)), cv2.INTER_NEAREST)
|
268 |
+
else:
|
269 |
+
for i in range(len(inputs)):
|
270 |
+
inputs[i] = cv2.resize(inputs[i], (int(w * self.size / h), self.size))
|
271 |
+
target = cv2.resize(target.squeeze(), (int(w * self.size / h), self.size), cv2.INTER_NEAREST)
|
272 |
+
target = np.expand_dims(target, axis=2)
|
273 |
+
return inputs, target
|
274 |
+
|
275 |
+
class RandomHorizontalFlip(object):
|
276 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
277 |
+
"""
|
278 |
+
|
279 |
+
def __call__(self, inputs, target):
|
280 |
+
if random.random() < 0.5:
|
281 |
+
for i in range(len(inputs)):
|
282 |
+
inputs[i] = np.copy(np.fliplr(inputs[i]))
|
283 |
+
# inputs[1] = np.copy(np.fliplr(inputs[1]))
|
284 |
+
# inputs[2] = np.copy(np.fliplr(inputs[2]))
|
285 |
+
|
286 |
+
target = np.copy(np.fliplr(target))
|
287 |
+
# target[:,:,0] *= -1
|
288 |
+
return inputs,target
|
289 |
+
|
290 |
+
|
291 |
+
class RandomVerticalFlip(object):
|
292 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __call__(self, inputs, target):
|
296 |
+
if random.random() < 0.5:
|
297 |
+
for i in range(len(inputs)):
|
298 |
+
inputs[i] = np.copy(np.flipud(inputs[i]))
|
299 |
+
# inputs[1] = np.copy(np.flipud(inputs[1]))
|
300 |
+
# inputs[2] = np.copy(np.flipud(inputs[2]))
|
301 |
+
|
302 |
+
target = np.copy(np.flipud(target))
|
303 |
+
# target[:,:,1] *= -1 #for disp there is no y dim
|
304 |
+
return inputs,target
|
305 |
+
|
306 |
+
|
307 |
+
class RandomRotate(object):
|
308 |
+
"""Random rotation of the image from -angle to angle (in degrees)
|
309 |
+
This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation
|
310 |
+
angle: max angle of the rotation
|
311 |
+
interpolation order: Default: 2 (bilinear)
|
312 |
+
reshape: Default: false. If set to true, image size will be set to keep every pixel in the image.
|
313 |
+
diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off.
|
314 |
+
"""
|
315 |
+
|
316 |
+
def __init__(self, angle, diff_angle=0, order=2, reshape=False):
|
317 |
+
self.angle = angle
|
318 |
+
self.reshape = reshape
|
319 |
+
self.order = order
|
320 |
+
self.diff_angle = diff_angle
|
321 |
+
|
322 |
+
def __call__(self, inputs,target):
|
323 |
+
applied_angle = random.uniform(-self.angle,self.angle)
|
324 |
+
diff = random.uniform(-self.diff_angle,self.diff_angle)
|
325 |
+
angle1 = applied_angle - diff/2
|
326 |
+
angle2 = applied_angle + diff/2
|
327 |
+
angle1_rad = angle1*np.pi/180
|
328 |
+
|
329 |
+
h, w, _ = target.shape
|
330 |
+
|
331 |
+
def rotate_flow(i,j,k):
|
332 |
+
return -k*(j-w/2)*(diff*np.pi/180) + (1-k)*(i-h/2)*(diff*np.pi/180)
|
333 |
+
|
334 |
+
rotate_flow_map = np.fromfunction(rotate_flow, target.shape)
|
335 |
+
target += rotate_flow_map
|
336 |
+
|
337 |
+
inputs[0] = ndimage.interpolation.rotate(inputs[0], angle1, reshape=self.reshape, order=self.order)
|
338 |
+
inputs[1] = ndimage.interpolation.rotate(inputs[1], angle2, reshape=self.reshape, order=self.order)
|
339 |
+
target = ndimage.interpolation.rotate(target, angle1, reshape=self.reshape, order=self.order)
|
340 |
+
# flow vectors must be rotated too! careful about Y flow which is upside down
|
341 |
+
target_ = np.copy(target)
|
342 |
+
target[:,:,0] = np.cos(angle1_rad)*target_[:,:,0] + np.sin(angle1_rad)*target_[:,:,1]
|
343 |
+
target[:,:,1] = -np.sin(angle1_rad)*target_[:,:,0] + np.cos(angle1_rad)*target_[:,:,1]
|
344 |
+
return inputs,target
|
345 |
+
|
346 |
+
|
347 |
+
class RandomTranslate(object):
|
348 |
+
def __init__(self, translation):
|
349 |
+
if isinstance(translation, numbers.Number):
|
350 |
+
self.translation = (int(translation), int(translation))
|
351 |
+
else:
|
352 |
+
self.translation = translation
|
353 |
+
|
354 |
+
def __call__(self, inputs,target):
|
355 |
+
h, w, _ = inputs[0].shape
|
356 |
+
th, tw = self.translation
|
357 |
+
tw = random.randint(-tw, tw)
|
358 |
+
th = random.randint(-th, th)
|
359 |
+
if tw == 0 and th == 0:
|
360 |
+
return inputs, target
|
361 |
+
# compute x1,x2,y1,y2 for img1 and target, and x3,x4,y3,y4 for img2
|
362 |
+
x1,x2,x3,x4 = max(0,tw), min(w+tw,w), max(0,-tw), min(w-tw,w)
|
363 |
+
y1,y2,y3,y4 = max(0,th), min(h+th,h), max(0,-th), min(h-th,h)
|
364 |
+
|
365 |
+
inputs[0] = inputs[0][y1:y2,x1:x2]
|
366 |
+
inputs[1] = inputs[1][y3:y4,x3:x4]
|
367 |
+
target = target[y1:y2,x1:x2]
|
368 |
+
target[:,:,0] += tw
|
369 |
+
target[:,:,1] += th
|
370 |
+
|
371 |
+
return inputs, target
|
372 |
+
|
373 |
+
|
374 |
+
class RandomColorWarp(object):
|
375 |
+
def __init__(self, mean_range=0, std_range=0):
|
376 |
+
self.mean_range = mean_range
|
377 |
+
self.std_range = std_range
|
378 |
+
|
379 |
+
def __call__(self, inputs, target):
|
380 |
+
random_std = np.random.uniform(-self.std_range, self.std_range, 3)
|
381 |
+
random_mean = np.random.uniform(-self.mean_range, self.mean_range, 3)
|
382 |
+
random_order = np.random.permutation(3)
|
383 |
+
|
384 |
+
inputs[0] *= (1 + random_std)
|
385 |
+
inputs[0] += random_mean
|
386 |
+
|
387 |
+
inputs[1] *= (1 + random_std)
|
388 |
+
inputs[1] += random_mean
|
389 |
+
|
390 |
+
inputs[0] = inputs[0][:,:,random_order]
|
391 |
+
inputs[1] = inputs[1][:,:,random_order]
|
392 |
+
|
393 |
+
return inputs, target
|
libs/losses.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from libs.blocks import encoder5
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import init
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from .normalization import get_nonspade_norm_layer
|
8 |
+
from .blocks import encoder5
|
9 |
+
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
class BaseNetwork(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super(BaseNetwork, self).__init__()
|
16 |
+
|
17 |
+
def print_network(self):
|
18 |
+
if isinstance(self, list):
|
19 |
+
self = self[0]
|
20 |
+
num_params = 0
|
21 |
+
for param in self.parameters():
|
22 |
+
num_params += param.numel()
|
23 |
+
print('Network [%s] was created. Total number of parameters: %.1f million. '
|
24 |
+
'To see the architecture, do print(network).'
|
25 |
+
% (type(self).__name__, num_params / 1000000))
|
26 |
+
|
27 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
28 |
+
def init_func(m):
|
29 |
+
classname = m.__class__.__name__
|
30 |
+
if classname.find('BatchNorm2d') != -1:
|
31 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
32 |
+
init.normal_(m.weight.data, 1.0, gain)
|
33 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
34 |
+
init.constant_(m.bias.data, 0.0)
|
35 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
36 |
+
if init_type == 'normal':
|
37 |
+
init.normal_(m.weight.data, 0.0, gain)
|
38 |
+
elif init_type == 'xavier':
|
39 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
40 |
+
elif init_type == 'xavier_uniform':
|
41 |
+
init.xavier_uniform_(m.weight.data, gain=1.0)
|
42 |
+
elif init_type == 'kaiming':
|
43 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
44 |
+
elif init_type == 'orthogonal':
|
45 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
46 |
+
elif init_type == 'none': # uses pytorch's default init method
|
47 |
+
m.reset_parameters()
|
48 |
+
else:
|
49 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
50 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
51 |
+
init.constant_(m.bias.data, 0.0)
|
52 |
+
|
53 |
+
self.apply(init_func)
|
54 |
+
|
55 |
+
# propagate to children
|
56 |
+
for m in self.children():
|
57 |
+
if hasattr(m, 'init_weights'):
|
58 |
+
m.init_weights(init_type, gain)
|
59 |
+
|
60 |
+
class NLayerDiscriminator(BaseNetwork):
|
61 |
+
def __init__(self):
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
kw = 4
|
65 |
+
padw = int(np.ceil((kw - 1.0) / 2))
|
66 |
+
nf = 64
|
67 |
+
n_layers_D = 4
|
68 |
+
input_nc = 3
|
69 |
+
|
70 |
+
norm_layer = get_nonspade_norm_layer('spectralinstance')
|
71 |
+
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
|
72 |
+
nn.LeakyReLU(0.2, False)]]
|
73 |
+
|
74 |
+
for n in range(1, n_layers_D):
|
75 |
+
nf_prev = nf
|
76 |
+
nf = min(nf * 2, 512)
|
77 |
+
stride = 1 if n == n_layers_D - 1 else 2
|
78 |
+
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
|
79 |
+
stride=stride, padding=padw)),
|
80 |
+
nn.LeakyReLU(0.2, False)
|
81 |
+
]]
|
82 |
+
|
83 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
84 |
+
|
85 |
+
# We divide the layers into groups to extract intermediate layer outputs
|
86 |
+
for n in range(len(sequence)):
|
87 |
+
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
|
88 |
+
|
89 |
+
def forward(self, input, get_intermediate_features = True):
|
90 |
+
results = [input]
|
91 |
+
for submodel in self.children():
|
92 |
+
intermediate_output = submodel(results[-1])
|
93 |
+
results.append(intermediate_output)
|
94 |
+
|
95 |
+
if get_intermediate_features:
|
96 |
+
return results[1:]
|
97 |
+
else:
|
98 |
+
return results[-1]
|
99 |
+
|
100 |
+
class VGG19(torch.nn.Module):
|
101 |
+
def __init__(self, requires_grad=False):
|
102 |
+
super().__init__()
|
103 |
+
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
|
104 |
+
self.slice1 = torch.nn.Sequential()
|
105 |
+
self.slice2 = torch.nn.Sequential()
|
106 |
+
self.slice3 = torch.nn.Sequential()
|
107 |
+
self.slice4 = torch.nn.Sequential()
|
108 |
+
self.slice5 = torch.nn.Sequential()
|
109 |
+
for x in range(2):
|
110 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(2, 7):
|
112 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(7, 12):
|
114 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(12, 21):
|
116 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
for x in range(21, 30):
|
118 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
119 |
+
import pdb; pdb.set_trace()
|
120 |
+
if not requires_grad:
|
121 |
+
for param in self.parameters():
|
122 |
+
param.requires_grad = False
|
123 |
+
|
124 |
+
def forward(self, X):
|
125 |
+
h_relu1 = self.slice1(X)
|
126 |
+
h_relu2 = self.slice2(h_relu1)
|
127 |
+
h_relu3 = self.slice3(h_relu2)
|
128 |
+
h_relu4 = self.slice4(h_relu3)
|
129 |
+
h_relu5 = self.slice5(h_relu4)
|
130 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
131 |
+
return out
|
132 |
+
|
133 |
+
class encoder5(nn.Module):
|
134 |
+
def __init__(self):
|
135 |
+
super(encoder5,self).__init__()
|
136 |
+
# vgg
|
137 |
+
# 224 x 224
|
138 |
+
self.conv1 = nn.Conv2d(3,3,1,1,0)
|
139 |
+
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
|
140 |
+
# 226 x 226
|
141 |
+
|
142 |
+
self.conv2 = nn.Conv2d(3,64,3,1,0)
|
143 |
+
self.relu2 = nn.ReLU(inplace=True)
|
144 |
+
# 224 x 224
|
145 |
+
|
146 |
+
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
|
147 |
+
self.conv3 = nn.Conv2d(64,64,3,1,0)
|
148 |
+
self.relu3 = nn.ReLU(inplace=True)
|
149 |
+
# 224 x 224
|
150 |
+
|
151 |
+
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
|
152 |
+
# 112 x 112
|
153 |
+
|
154 |
+
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
|
155 |
+
self.conv4 = nn.Conv2d(64,128,3,1,0)
|
156 |
+
self.relu4 = nn.ReLU(inplace=True)
|
157 |
+
# 112 x 112
|
158 |
+
|
159 |
+
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
|
160 |
+
self.conv5 = nn.Conv2d(128,128,3,1,0)
|
161 |
+
self.relu5 = nn.ReLU(inplace=True)
|
162 |
+
# 112 x 112
|
163 |
+
|
164 |
+
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
|
165 |
+
# 56 x 56
|
166 |
+
|
167 |
+
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
|
168 |
+
self.conv6 = nn.Conv2d(128,256,3,1,0)
|
169 |
+
self.relu6 = nn.ReLU(inplace=True)
|
170 |
+
# 56 x 56
|
171 |
+
|
172 |
+
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
|
173 |
+
self.conv7 = nn.Conv2d(256,256,3,1,0)
|
174 |
+
self.relu7 = nn.ReLU(inplace=True)
|
175 |
+
# 56 x 56
|
176 |
+
|
177 |
+
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
|
178 |
+
self.conv8 = nn.Conv2d(256,256,3,1,0)
|
179 |
+
self.relu8 = nn.ReLU(inplace=True)
|
180 |
+
# 56 x 56
|
181 |
+
|
182 |
+
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
|
183 |
+
self.conv9 = nn.Conv2d(256,256,3,1,0)
|
184 |
+
self.relu9 = nn.ReLU(inplace=True)
|
185 |
+
# 56 x 56
|
186 |
+
|
187 |
+
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
|
188 |
+
# 28 x 28
|
189 |
+
|
190 |
+
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
|
191 |
+
self.conv10 = nn.Conv2d(256,512,3,1,0)
|
192 |
+
self.relu10 = nn.ReLU(inplace=True)
|
193 |
+
|
194 |
+
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
|
195 |
+
self.conv11 = nn.Conv2d(512,512,3,1,0)
|
196 |
+
self.relu11 = nn.ReLU(inplace=True)
|
197 |
+
|
198 |
+
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
|
199 |
+
self.conv12 = nn.Conv2d(512,512,3,1,0)
|
200 |
+
self.relu12 = nn.ReLU(inplace=True)
|
201 |
+
|
202 |
+
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
|
203 |
+
self.conv13 = nn.Conv2d(512,512,3,1,0)
|
204 |
+
self.relu13 = nn.ReLU(inplace=True)
|
205 |
+
|
206 |
+
self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
|
207 |
+
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
|
208 |
+
self.conv14 = nn.Conv2d(512,512,3,1,0)
|
209 |
+
self.relu14 = nn.ReLU(inplace=True)
|
210 |
+
|
211 |
+
def forward(self,x):
|
212 |
+
output = []
|
213 |
+
out = self.conv1(x)
|
214 |
+
out = self.reflecPad1(out)
|
215 |
+
out = self.conv2(out)
|
216 |
+
out = self.relu2(out)
|
217 |
+
output.append(out)
|
218 |
+
|
219 |
+
out = self.reflecPad3(out)
|
220 |
+
out = self.conv3(out)
|
221 |
+
out = self.relu3(out)
|
222 |
+
out = self.maxPool(out)
|
223 |
+
out = self.reflecPad4(out)
|
224 |
+
out = self.conv4(out)
|
225 |
+
out = self.relu4(out)
|
226 |
+
output.append(out)
|
227 |
+
|
228 |
+
out = self.reflecPad5(out)
|
229 |
+
out = self.conv5(out)
|
230 |
+
out = self.relu5(out)
|
231 |
+
out = self.maxPool2(out)
|
232 |
+
out = self.reflecPad6(out)
|
233 |
+
out = self.conv6(out)
|
234 |
+
out = self.relu6(out)
|
235 |
+
output.append(out)
|
236 |
+
|
237 |
+
out = self.reflecPad7(out)
|
238 |
+
out = self.conv7(out)
|
239 |
+
out = self.relu7(out)
|
240 |
+
out = self.reflecPad8(out)
|
241 |
+
out = self.conv8(out)
|
242 |
+
out = self.relu8(out)
|
243 |
+
out = self.reflecPad9(out)
|
244 |
+
out = self.conv9(out)
|
245 |
+
out = self.relu9(out)
|
246 |
+
out = self.maxPool3(out)
|
247 |
+
out = self.reflecPad10(out)
|
248 |
+
out = self.conv10(out)
|
249 |
+
out = self.relu10(out)
|
250 |
+
output.append(out)
|
251 |
+
|
252 |
+
out = self.reflecPad11(out)
|
253 |
+
out = self.conv11(out)
|
254 |
+
out = self.relu11(out)
|
255 |
+
out = self.reflecPad12(out)
|
256 |
+
out = self.conv12(out)
|
257 |
+
out = self.relu12(out)
|
258 |
+
out = self.reflecPad13(out)
|
259 |
+
out = self.conv13(out)
|
260 |
+
out = self.relu13(out)
|
261 |
+
out = self.maxPool4(out)
|
262 |
+
out = self.reflecPad14(out)
|
263 |
+
out = self.conv14(out)
|
264 |
+
out = self.relu14(out)
|
265 |
+
|
266 |
+
output.append(out)
|
267 |
+
return output
|
268 |
+
|
269 |
+
class VGGLoss(nn.Module):
|
270 |
+
def __init__(self, model_path):
|
271 |
+
super(VGGLoss, self).__init__()
|
272 |
+
self.vgg = encoder5().cuda()
|
273 |
+
self.vgg.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth')))
|
274 |
+
self.criterion = nn.MSELoss()
|
275 |
+
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
|
276 |
+
|
277 |
+
def forward(self, x, y):
|
278 |
+
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
279 |
+
loss = 0
|
280 |
+
for i in range(4):
|
281 |
+
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
282 |
+
return loss
|
283 |
+
|
284 |
+
class GANLoss(nn.Module):
|
285 |
+
def __init__(self, gan_mode = 'hinge', target_real_label=1.0, target_fake_label=0.0,
|
286 |
+
tensor=torch.cuda.FloatTensor):
|
287 |
+
super(GANLoss, self).__init__()
|
288 |
+
self.real_label = target_real_label
|
289 |
+
self.fake_label = target_fake_label
|
290 |
+
self.real_label_tensor = None
|
291 |
+
self.fake_label_tensor = None
|
292 |
+
self.zero_tensor = None
|
293 |
+
self.Tensor = tensor
|
294 |
+
self.gan_mode = gan_mode
|
295 |
+
if gan_mode == 'ls':
|
296 |
+
pass
|
297 |
+
elif gan_mode == 'original':
|
298 |
+
pass
|
299 |
+
elif gan_mode == 'w':
|
300 |
+
pass
|
301 |
+
elif gan_mode == 'hinge':
|
302 |
+
pass
|
303 |
+
else:
|
304 |
+
raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
|
305 |
+
|
306 |
+
def get_target_tensor(self, input, target_is_real):
|
307 |
+
if target_is_real:
|
308 |
+
if self.real_label_tensor is None:
|
309 |
+
self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
|
310 |
+
self.real_label_tensor.requires_grad_(False)
|
311 |
+
return self.real_label_tensor.expand_as(input)
|
312 |
+
else:
|
313 |
+
if self.fake_label_tensor is None:
|
314 |
+
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
|
315 |
+
self.fake_label_tensor.requires_grad_(False)
|
316 |
+
return self.fake_label_tensor.expand_as(input)
|
317 |
+
|
318 |
+
def get_zero_tensor(self, input):
|
319 |
+
if self.zero_tensor is None:
|
320 |
+
self.zero_tensor = self.Tensor(1).fill_(0)
|
321 |
+
self.zero_tensor.requires_grad_(False)
|
322 |
+
return self.zero_tensor.expand_as(input)
|
323 |
+
|
324 |
+
def loss(self, input, target_is_real, for_discriminator=True):
|
325 |
+
if self.gan_mode == 'original': # cross entropy loss
|
326 |
+
target_tensor = self.get_target_tensor(input, target_is_real)
|
327 |
+
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
|
328 |
+
return loss
|
329 |
+
elif self.gan_mode == 'ls':
|
330 |
+
target_tensor = self.get_target_tensor(input, target_is_real)
|
331 |
+
return F.mse_loss(input, target_tensor)
|
332 |
+
elif self.gan_mode == 'hinge':
|
333 |
+
if for_discriminator:
|
334 |
+
if target_is_real:
|
335 |
+
minval = torch.min(input - 1, self.get_zero_tensor(input))
|
336 |
+
loss = -torch.mean(minval)
|
337 |
+
else:
|
338 |
+
minval = torch.min(-input - 1, self.get_zero_tensor(input))
|
339 |
+
loss = -torch.mean(minval)
|
340 |
+
else:
|
341 |
+
assert target_is_real, "The generator's hinge loss must be aiming for real"
|
342 |
+
loss = -torch.mean(input)
|
343 |
+
return loss
|
344 |
+
else:
|
345 |
+
# wgan
|
346 |
+
if target_is_real:
|
347 |
+
return -input.mean()
|
348 |
+
else:
|
349 |
+
return input.mean()
|
350 |
+
|
351 |
+
def __call__(self, input, target_is_real, for_discriminator=True):
|
352 |
+
# computing loss is a bit complicated because |input| may not be
|
353 |
+
# a tensor, but list of tensors in case of multiscale discriminator
|
354 |
+
if isinstance(input, list):
|
355 |
+
loss = 0
|
356 |
+
for pred_i in input:
|
357 |
+
if isinstance(pred_i, list):
|
358 |
+
pred_i = pred_i[-1]
|
359 |
+
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
|
360 |
+
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
|
361 |
+
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
|
362 |
+
loss += new_loss
|
363 |
+
return loss / len(input)
|
364 |
+
else:
|
365 |
+
return self.loss(input, target_is_real, for_discriminator)
|
366 |
+
|
367 |
+
class SPADE_LOSS(nn.Module):
|
368 |
+
def __init__(self, model_path, lambda_feat = 1):
|
369 |
+
super(SPADE_LOSS, self).__init__()
|
370 |
+
self.criterionVGG = VGGLoss(model_path)
|
371 |
+
self.criterionGAN = GANLoss('hinge')
|
372 |
+
self.criterionL1 = nn.L1Loss()
|
373 |
+
self.discriminator = NLayerDiscriminator()
|
374 |
+
self.lambda_feat = lambda_feat
|
375 |
+
|
376 |
+
def forward(self, x, y, for_discriminator = False):
|
377 |
+
pred_real = self.discriminator(y)
|
378 |
+
if not for_discriminator:
|
379 |
+
pred_fake = self.discriminator(x)
|
380 |
+
VGGLoss = self.criterionVGG(x, y)
|
381 |
+
GANLoss = self.criterionGAN(pred_fake, True, for_discriminator = False)
|
382 |
+
|
383 |
+
# feature matching loss
|
384 |
+
# last output is the final prediction, so we exclude it
|
385 |
+
num_intermediate_outputs = len(pred_fake) - 1
|
386 |
+
GAN_Feat_loss = 0
|
387 |
+
for j in range(num_intermediate_outputs): # for each layer output
|
388 |
+
unweighted_loss = self.criterionL1(pred_fake[j], pred_real[j].detach())
|
389 |
+
GAN_Feat_loss += unweighted_loss * self.lambda_feat
|
390 |
+
L1Loss = self.criterionL1(x, y)
|
391 |
+
return VGGLoss, GANLoss, GAN_Feat_loss, L1Loss
|
392 |
+
else:
|
393 |
+
pred_fake = self.discriminator(x.detach())
|
394 |
+
GANLoss = self.criterionGAN(pred_fake, False, for_discriminator = True)
|
395 |
+
GANLoss += self.criterionGAN(pred_real, True, for_discriminator = True)
|
396 |
+
return GANLoss
|
397 |
+
|
398 |
+
class ContrastiveLoss(nn.Module):
|
399 |
+
"""
|
400 |
+
Contrastive loss
|
401 |
+
Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
|
402 |
+
"""
|
403 |
+
|
404 |
+
def __init__(self, margin):
|
405 |
+
super(ContrastiveLoss, self).__init__()
|
406 |
+
self.margin = margin
|
407 |
+
self.eps = 1e-9
|
408 |
+
|
409 |
+
def forward(self, out1, out2, target, size_average=True, norm = True):
|
410 |
+
if norm:
|
411 |
+
output1 = out1 / out1.pow(2).sum(1, keepdim=True).sqrt()
|
412 |
+
output2 = out1 / out2.pow(2).sum(1, keepdim=True).sqrt()
|
413 |
+
distances = (output2 - output1).pow(2).sum(1) # squared distances
|
414 |
+
losses = 0.5 * (target.float() * distances +
|
415 |
+
(1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
|
416 |
+
return losses.mean() if size_average else losses.sum()
|