aoxo
/

Image-to-Image
English
art
aoxo commited on
Commit
e62b5fc
·
verified ·
1 Parent(s): 4bc5e63

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +120 -5
README.md CHANGED
@@ -120,17 +120,132 @@ Images and their corresponding style semantic maps were resized to fit the input
120
  - Number of attention layers: 8
121
  - Number of transformer encoder layers (feed-forward): 8
122
  - Number of transformer decoder layers (feed-forward): 8
123
- - Activation function: ReLU
124
  - Patch Size: 8
125
  - Swin Window Size: 7
126
  - Swin Shift Size: 2
127
- -
128
 
129
- #### Speeds, Sizes, Times [optional]
130
 
131
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
 
 
 
 
132
 
133
- [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  ## Evaluation
136
 
 
120
  - Number of attention layers: 8
121
  - Number of transformer encoder layers (feed-forward): 8
122
  - Number of transformer decoder layers (feed-forward): 8
123
+ - Activation function(s): ReLU, GeLU
124
  - Patch Size: 8
125
  - Swin Window Size: 7
126
  - Swin Shift Size: 2
127
+ - Style Transfer Module: AdaIN
128
 
129
+ #### Speeds, Sizes, Times
130
 
131
+ **Model size:** There are currently four versions of the model:
132
+ - v1_1: 224M params
133
+ - v1_2: 200M params
134
+ - v1_3: 93M params
135
+ - v2_1: 2.9M params
136
 
137
+ **Architecture:** The latest model, v2_1, introduces Location-based Multi-head Attention (LbMhA) to improve feature extraction at lower parameters. The three other predecessors attained a similar level of accuracy without the LbMhA layers. The general architecture is as follows:
138
+
139
+ ```python
140
+ 223543305
141
+ DataParallel(
142
+ (module): ViTImage2Image(
143
+ (patch_embed): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
144
+ (encoder_layers): ModuleList(
145
+ (0-7): 8 x TransformerEncoderBlock(
146
+ (attn): LocationBasedMultiheadAttention(
147
+ (q_proj): Linear(in_features=768, out_features=768, bias=True)
148
+ (k_proj): Linear(in_features=768, out_features=768, bias=True)
149
+ (v_proj): Linear(in_features=768, out_features=768, bias=True)
150
+ (out_proj): Linear(in_features=768, out_features=768, bias=True)
151
+ (dropout): Dropout(p=0.1, inplace=False)
152
+ )
153
+ (ff): Sequential(
154
+ (0): Linear(in_features=768, out_features=3072, bias=True)
155
+ (1): ReLU()
156
+ (2): Linear(in_features=3072, out_features=768, bias=True)
157
+ )
158
+ (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
159
+ (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
160
+ (adain): AdaIN(
161
+ (norm): InstanceNorm1d(768, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
162
+ (fc): Linear(in_features=768, out_features=1536, bias=True)
163
+ )
164
+ (dropout): Dropout(p=0.1, inplace=False)
165
+ )
166
+ )
167
+ (decoder_layers): ModuleList(
168
+ (0-7): 8 x TransformerDecoderBlock(
169
+ (attn1): LocationBasedMultiheadAttention(
170
+ (q_proj): Linear(in_features=768, out_features=768, bias=True)
171
+ (k_proj): Linear(in_features=768, out_features=768, bias=True)
172
+ (v_proj): Linear(in_features=768, out_features=768, bias=True)
173
+ (out_proj): Linear(in_features=768, out_features=768, bias=True)
174
+ (dropout): Dropout(p=0.1, inplace=False)
175
+ )
176
+ (attn2): LocationBasedMultiheadAttention(
177
+ (q_proj): Linear(in_features=768, out_features=768, bias=True)
178
+ (k_proj): Linear(in_features=768, out_features=768, bias=True)
179
+ (v_proj): Linear(in_features=768, out_features=768, bias=True)
180
+ (out_proj): Linear(in_features=768, out_features=768, bias=True)
181
+ (dropout): Dropout(p=0.1, inplace=False)
182
+ )
183
+ (ff): Sequential(
184
+ (0): Linear(in_features=768, out_features=3072, bias=True)
185
+ (1): ReLU()
186
+ (2): Linear(in_features=3072, out_features=768, bias=True)
187
+ )
188
+ (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
189
+ (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
190
+ (norm3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
191
+ (norm4): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
192
+ (adain1): AdaIN(
193
+ (norm): InstanceNorm1d(768, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
194
+ (fc): Linear(in_features=768, out_features=1536, bias=True)
195
+ )
196
+ (adain2): AdaIN(
197
+ (norm): InstanceNorm1d(768, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
198
+ (fc): Linear(in_features=768, out_features=1536, bias=True)
199
+ )
200
+ (dropout): Dropout(p=0.1, inplace=False)
201
+ )
202
+ )
203
+ (swin_layers): ModuleList(
204
+ (0-7): 8 x SwinTransformerBlock(
205
+ (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
206
+ (attn): MultiheadAttention(
207
+ (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
208
+ )
209
+ (mlp): Sequential(
210
+ (0): Linear(in_features=768, out_features=3072, bias=True)
211
+ (1): GELU(approximate='none')
212
+ (2): Linear(in_features=3072, out_features=768, bias=True)
213
+ )
214
+ (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
215
+ )
216
+ )
217
+ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
218
+ (mlp_head): Sequential(
219
+ (0): Linear(in_features=768, out_features=3072, bias=True)
220
+ (1): GELU(approximate='none')
221
+ (2): Linear(in_features=3072, out_features=768, bias=True)
222
+ )
223
+ (refinement): RefinementBlock(
224
+ (conv): Conv2d(768, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
225
+ (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
226
+ (relu): ReLU(inplace=True)
227
+ )
228
+ (style_encoder): Sequential(
229
+ (0): Conv2d(3, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
230
+ (1): ReLU()
231
+ (2): AdaptiveAvgPool2d(output_size=1)
232
+ (3): Flatten(start_dim=1, end_dim=-1)
233
+ (4): Linear(in_features=768, out_features=768, bias=True)
234
+ )
235
+ )
236
+ )
237
+ ```
238
+
239
+ **Training hardware:** Each of the models were trained on 2 x T4 GPUs (multi-GPU training). For this reason, linear attention modules were implemented as ring (distributed) attention during training.
240
+ **Total Training Compute Throughput:** 4.13 TFLOPS
241
+ **Total Logged Training Time:** ~210 hours (total time split across four models including overhead)
242
+ **Start Time:** 09-13-2024
243
+ **End Time:** 09-21-2024
244
+ **Checkpoint Size:**
245
+ - v1_1: 855 MB
246
+ - v1_2: 764 MB
247
+ - v1_3: 355 MB
248
+ - v2_2: 11 MB
249
 
250
  ## Evaluation
251