Qihang Yu commited on
Commit
40177ed
·
1 Parent(s): 158e9fd

Add kMaX-DeepLab

Browse files
kmax_deeplab/modeling/pixel_decoder/kmax_pixel_decoder.py CHANGED
@@ -1,6 +1,7 @@
1
  # Reference: https://github.com/google-research/deeplab2/blob/main/model/pixel_decoder/kmax.py
2
  # Modified by Qihang Yu
3
 
 
4
  from typing import Dict, List
5
 
6
  import torch
@@ -28,13 +29,17 @@ def get_activation(name):
28
  elif name == 'gelu':
29
  return nn.GELU()
30
 
 
 
 
 
31
 
32
  def get_norm(name, channels):
33
  if name is None or name.lower() == 'none':
34
  return nn.Identity()
35
 
36
  if name.lower() == 'syncbn':
37
- return nn.SyncBatchNorm(channels, eps=1e-3, momentum=0.01)
38
 
39
 
40
  class ConvBN(nn.Module):
 
1
  # Reference: https://github.com/google-research/deeplab2/blob/main/model/pixel_decoder/kmax.py
2
  # Modified by Qihang Yu
3
 
4
+ from turtle import forward
5
  from typing import Dict, List
6
 
7
  import torch
 
29
  elif name == 'gelu':
30
  return nn.GELU()
31
 
32
+ class SyncBNCPU(nn.SyncBatchNorm):
33
+ def forward(self, input):
34
+ self.eval()
35
+ return super().forward(input)
36
 
37
  def get_norm(name, channels):
38
  if name is None or name.lower() == 'none':
39
  return nn.Identity()
40
 
41
  if name.lower() == 'syncbn':
42
+ return SyncBNCPU(channels, eps=1e-3, momentum=0.01)
43
 
44
 
45
  class ConvBN(nn.Module):