Stable-X commited on
Commit
53a077e
·
1 Parent(s): a18753a

Fix environment dependency

Browse files
app.py CHANGED
@@ -375,10 +375,9 @@ def run_demo_server(pipe):
375
  def main():
376
 
377
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
378
-
379
  marigold_pipe = Marigold()
380
  geowizard_pipe = Geowizard()
381
- dsine_pipe = DSINE()
382
  our_pipe = StableNormal()
383
 
384
  run_demo_server([dsine_pipe, marigold_pipe, geowizard_pipe, our_pipe])
 
375
  def main():
376
 
377
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
378
+ dsine_pipe = DSINE()
379
  marigold_pipe = Marigold()
380
  geowizard_pipe = Geowizard()
 
381
  our_pipe = StableNormal()
382
 
383
  run_demo_server([dsine_pipe, marigold_pipe, geowizard_pipe, our_pipe])
geffnet/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .gen_efficientnet import *
2
+ from .mobilenetv3 import *
3
+ from .model_factory import create_model
4
+ from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
5
+ from .activations import *
geffnet/activations/__init__.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from geffnet import config
2
+ from geffnet.activations.activations_me import *
3
+ from geffnet.activations.activations_jit import *
4
+ from geffnet.activations.activations import *
5
+ import torch
6
+
7
+ _has_silu = 'silu' in dir(torch.nn.functional)
8
+
9
+ _ACT_FN_DEFAULT = dict(
10
+ silu=F.silu if _has_silu else swish,
11
+ swish=F.silu if _has_silu else swish,
12
+ mish=mish,
13
+ relu=F.relu,
14
+ relu6=F.relu6,
15
+ sigmoid=sigmoid,
16
+ tanh=tanh,
17
+ hard_sigmoid=hard_sigmoid,
18
+ hard_swish=hard_swish,
19
+ )
20
+
21
+ _ACT_FN_JIT = dict(
22
+ silu=F.silu if _has_silu else swish_jit,
23
+ swish=F.silu if _has_silu else swish_jit,
24
+ mish=mish_jit,
25
+ )
26
+
27
+ _ACT_FN_ME = dict(
28
+ silu=F.silu if _has_silu else swish_me,
29
+ swish=F.silu if _has_silu else swish_me,
30
+ mish=mish_me,
31
+ hard_swish=hard_swish_me,
32
+ hard_sigmoid_jit=hard_sigmoid_me,
33
+ )
34
+
35
+ _ACT_LAYER_DEFAULT = dict(
36
+ silu=nn.SiLU if _has_silu else Swish,
37
+ swish=nn.SiLU if _has_silu else Swish,
38
+ mish=Mish,
39
+ relu=nn.ReLU,
40
+ relu6=nn.ReLU6,
41
+ sigmoid=Sigmoid,
42
+ tanh=Tanh,
43
+ hard_sigmoid=HardSigmoid,
44
+ hard_swish=HardSwish,
45
+ )
46
+
47
+ _ACT_LAYER_JIT = dict(
48
+ silu=nn.SiLU if _has_silu else SwishJit,
49
+ swish=nn.SiLU if _has_silu else SwishJit,
50
+ mish=MishJit,
51
+ )
52
+
53
+ _ACT_LAYER_ME = dict(
54
+ silu=nn.SiLU if _has_silu else SwishMe,
55
+ swish=nn.SiLU if _has_silu else SwishMe,
56
+ mish=MishMe,
57
+ hard_swish=HardSwishMe,
58
+ hard_sigmoid=HardSigmoidMe
59
+ )
60
+
61
+ _OVERRIDE_FN = dict()
62
+ _OVERRIDE_LAYER = dict()
63
+
64
+
65
+ def add_override_act_fn(name, fn):
66
+ global _OVERRIDE_FN
67
+ _OVERRIDE_FN[name] = fn
68
+
69
+
70
+ def update_override_act_fn(overrides):
71
+ assert isinstance(overrides, dict)
72
+ global _OVERRIDE_FN
73
+ _OVERRIDE_FN.update(overrides)
74
+
75
+
76
+ def clear_override_act_fn():
77
+ global _OVERRIDE_FN
78
+ _OVERRIDE_FN = dict()
79
+
80
+
81
+ def add_override_act_layer(name, fn):
82
+ _OVERRIDE_LAYER[name] = fn
83
+
84
+
85
+ def update_override_act_layer(overrides):
86
+ assert isinstance(overrides, dict)
87
+ global _OVERRIDE_LAYER
88
+ _OVERRIDE_LAYER.update(overrides)
89
+
90
+
91
+ def clear_override_act_layer():
92
+ global _OVERRIDE_LAYER
93
+ _OVERRIDE_LAYER = dict()
94
+
95
+
96
+ def get_act_fn(name='relu'):
97
+ """ Activation Function Factory
98
+ Fetching activation fns by name with this function allows export or torch script friendly
99
+ functions to be returned dynamically based on current config.
100
+ """
101
+ if name in _OVERRIDE_FN:
102
+ return _OVERRIDE_FN[name]
103
+ use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
104
+ if use_me and name in _ACT_FN_ME:
105
+ # If not exporting or scripting the model, first look for a memory optimized version
106
+ # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin
107
+ return _ACT_FN_ME[name]
108
+ if config.is_exportable() and name in ('silu', 'swish'):
109
+ # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
110
+ return swish
111
+ use_jit = not (config.is_exportable() or config.is_no_jit())
112
+ # NOTE: export tracing should work with jit scripted components, but I keep running into issues
113
+ if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
114
+ return _ACT_FN_JIT[name]
115
+ return _ACT_FN_DEFAULT[name]
116
+
117
+
118
+ def get_act_layer(name='relu'):
119
+ """ Activation Layer Factory
120
+ Fetching activation layers by name with this function allows export or torch script friendly
121
+ functions to be returned dynamically based on current config.
122
+ """
123
+ if name in _OVERRIDE_LAYER:
124
+ return _OVERRIDE_LAYER[name]
125
+ use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
126
+ if use_me and name in _ACT_LAYER_ME:
127
+ return _ACT_LAYER_ME[name]
128
+ if config.is_exportable() and name in ('silu', 'swish'):
129
+ # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
130
+ return Swish
131
+ use_jit = not (config.is_exportable() or config.is_no_jit())
132
+ # NOTE: export tracing should work with jit scripted components, but I keep running into issues
133
+ if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
134
+ return _ACT_LAYER_JIT[name]
135
+ return _ACT_LAYER_DEFAULT[name]
136
+
137
+
geffnet/activations/activations.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Activations
2
+
3
+ A collection of activations fn and modules with a common interface so that they can
4
+ easily be swapped. All have an `inplace` arg even if not used.
5
+
6
+ Copyright 2020 Ross Wightman
7
+ """
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ def swish(x, inplace: bool = False):
13
+ """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
14
+ and also as Swish (https://arxiv.org/abs/1710.05941).
15
+
16
+ TODO Rename to SiLU with addition to PyTorch
17
+ """
18
+ return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
19
+
20
+
21
+ class Swish(nn.Module):
22
+ def __init__(self, inplace: bool = False):
23
+ super(Swish, self).__init__()
24
+ self.inplace = inplace
25
+
26
+ def forward(self, x):
27
+ return swish(x, self.inplace)
28
+
29
+
30
+ def mish(x, inplace: bool = False):
31
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
32
+ """
33
+ return x.mul(F.softplus(x).tanh())
34
+
35
+
36
+ class Mish(nn.Module):
37
+ def __init__(self, inplace: bool = False):
38
+ super(Mish, self).__init__()
39
+ self.inplace = inplace
40
+
41
+ def forward(self, x):
42
+ return mish(x, self.inplace)
43
+
44
+
45
+ def sigmoid(x, inplace: bool = False):
46
+ return x.sigmoid_() if inplace else x.sigmoid()
47
+
48
+
49
+ # PyTorch has this, but not with a consistent inplace argmument interface
50
+ class Sigmoid(nn.Module):
51
+ def __init__(self, inplace: bool = False):
52
+ super(Sigmoid, self).__init__()
53
+ self.inplace = inplace
54
+
55
+ def forward(self, x):
56
+ return x.sigmoid_() if self.inplace else x.sigmoid()
57
+
58
+
59
+ def tanh(x, inplace: bool = False):
60
+ return x.tanh_() if inplace else x.tanh()
61
+
62
+
63
+ # PyTorch has this, but not with a consistent inplace argmument interface
64
+ class Tanh(nn.Module):
65
+ def __init__(self, inplace: bool = False):
66
+ super(Tanh, self).__init__()
67
+ self.inplace = inplace
68
+
69
+ def forward(self, x):
70
+ return x.tanh_() if self.inplace else x.tanh()
71
+
72
+
73
+ def hard_swish(x, inplace: bool = False):
74
+ inner = F.relu6(x + 3.).div_(6.)
75
+ return x.mul_(inner) if inplace else x.mul(inner)
76
+
77
+
78
+ class HardSwish(nn.Module):
79
+ def __init__(self, inplace: bool = False):
80
+ super(HardSwish, self).__init__()
81
+ self.inplace = inplace
82
+
83
+ def forward(self, x):
84
+ return hard_swish(x, self.inplace)
85
+
86
+
87
+ def hard_sigmoid(x, inplace: bool = False):
88
+ if inplace:
89
+ return x.add_(3.).clamp_(0., 6.).div_(6.)
90
+ else:
91
+ return F.relu6(x + 3.) / 6.
92
+
93
+
94
+ class HardSigmoid(nn.Module):
95
+ def __init__(self, inplace: bool = False):
96
+ super(HardSigmoid, self).__init__()
97
+ self.inplace = inplace
98
+
99
+ def forward(self, x):
100
+ return hard_sigmoid(x, self.inplace)
101
+
102
+
geffnet/activations/activations_jit.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Activations (jit)
2
+
3
+ A collection of jit-scripted activations fn and modules with a common interface so that they can
4
+ easily be swapped. All have an `inplace` arg even if not used.
5
+
6
+ All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
7
+ currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
8
+ versions if they contain in-place ops.
9
+
10
+ Copyright 2020 Ross Wightman
11
+ """
12
+
13
+ import torch
14
+ from torch import nn as nn
15
+ from torch.nn import functional as F
16
+
17
+ __all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit',
18
+ 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit']
19
+
20
+
21
+ @torch.jit.script
22
+ def swish_jit(x, inplace: bool = False):
23
+ """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
24
+ and also as Swish (https://arxiv.org/abs/1710.05941).
25
+
26
+ TODO Rename to SiLU with addition to PyTorch
27
+ """
28
+ return x.mul(x.sigmoid())
29
+
30
+
31
+ @torch.jit.script
32
+ def mish_jit(x, _inplace: bool = False):
33
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
34
+ """
35
+ return x.mul(F.softplus(x).tanh())
36
+
37
+
38
+ class SwishJit(nn.Module):
39
+ def __init__(self, inplace: bool = False):
40
+ super(SwishJit, self).__init__()
41
+
42
+ def forward(self, x):
43
+ return swish_jit(x)
44
+
45
+
46
+ class MishJit(nn.Module):
47
+ def __init__(self, inplace: bool = False):
48
+ super(MishJit, self).__init__()
49
+
50
+ def forward(self, x):
51
+ return mish_jit(x)
52
+
53
+
54
+ @torch.jit.script
55
+ def hard_sigmoid_jit(x, inplace: bool = False):
56
+ # return F.relu6(x + 3.) / 6.
57
+ return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
58
+
59
+
60
+ class HardSigmoidJit(nn.Module):
61
+ def __init__(self, inplace: bool = False):
62
+ super(HardSigmoidJit, self).__init__()
63
+
64
+ def forward(self, x):
65
+ return hard_sigmoid_jit(x)
66
+
67
+
68
+ @torch.jit.script
69
+ def hard_swish_jit(x, inplace: bool = False):
70
+ # return x * (F.relu6(x + 3.) / 6)
71
+ return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
72
+
73
+
74
+ class HardSwishJit(nn.Module):
75
+ def __init__(self, inplace: bool = False):
76
+ super(HardSwishJit, self).__init__()
77
+
78
+ def forward(self, x):
79
+ return hard_swish_jit(x)
geffnet/activations/activations_me.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Activations (memory-efficient w/ custom autograd)
2
+
3
+ A collection of activations fn and modules with a common interface so that they can
4
+ easily be swapped. All have an `inplace` arg even if not used.
5
+
6
+ These activations are not compatible with jit scripting or ONNX export of the model, please use either
7
+ the JIT or basic versions of the activations.
8
+
9
+ Copyright 2020 Ross Wightman
10
+ """
11
+
12
+ import torch
13
+ from torch import nn as nn
14
+ from torch.nn import functional as F
15
+
16
+
17
+ __all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe',
18
+ 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe']
19
+
20
+
21
+ @torch.jit.script
22
+ def swish_jit_fwd(x):
23
+ return x.mul(torch.sigmoid(x))
24
+
25
+
26
+ @torch.jit.script
27
+ def swish_jit_bwd(x, grad_output):
28
+ x_sigmoid = torch.sigmoid(x)
29
+ return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
30
+
31
+
32
+ class SwishJitAutoFn(torch.autograd.Function):
33
+ """ torch.jit.script optimised Swish w/ memory-efficient checkpoint
34
+ Inspired by conversation btw Jeremy Howard & Adam Pazske
35
+ https://twitter.com/jeremyphoward/status/1188251041835315200
36
+
37
+ Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
38
+ and also as Swish (https://arxiv.org/abs/1710.05941).
39
+
40
+ TODO Rename to SiLU with addition to PyTorch
41
+ """
42
+
43
+ @staticmethod
44
+ def forward(ctx, x):
45
+ ctx.save_for_backward(x)
46
+ return swish_jit_fwd(x)
47
+
48
+ @staticmethod
49
+ def backward(ctx, grad_output):
50
+ x = ctx.saved_tensors[0]
51
+ return swish_jit_bwd(x, grad_output)
52
+
53
+
54
+ def swish_me(x, inplace=False):
55
+ return SwishJitAutoFn.apply(x)
56
+
57
+
58
+ class SwishMe(nn.Module):
59
+ def __init__(self, inplace: bool = False):
60
+ super(SwishMe, self).__init__()
61
+
62
+ def forward(self, x):
63
+ return SwishJitAutoFn.apply(x)
64
+
65
+
66
+ @torch.jit.script
67
+ def mish_jit_fwd(x):
68
+ return x.mul(torch.tanh(F.softplus(x)))
69
+
70
+
71
+ @torch.jit.script
72
+ def mish_jit_bwd(x, grad_output):
73
+ x_sigmoid = torch.sigmoid(x)
74
+ x_tanh_sp = F.softplus(x).tanh()
75
+ return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
76
+
77
+
78
+ class MishJitAutoFn(torch.autograd.Function):
79
+ """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
80
+ A memory efficient, jit scripted variant of Mish
81
+ """
82
+ @staticmethod
83
+ def forward(ctx, x):
84
+ ctx.save_for_backward(x)
85
+ return mish_jit_fwd(x)
86
+
87
+ @staticmethod
88
+ def backward(ctx, grad_output):
89
+ x = ctx.saved_tensors[0]
90
+ return mish_jit_bwd(x, grad_output)
91
+
92
+
93
+ def mish_me(x, inplace=False):
94
+ return MishJitAutoFn.apply(x)
95
+
96
+
97
+ class MishMe(nn.Module):
98
+ def __init__(self, inplace: bool = False):
99
+ super(MishMe, self).__init__()
100
+
101
+ def forward(self, x):
102
+ return MishJitAutoFn.apply(x)
103
+
104
+
105
+ def hard_sigmoid_jit_fwd(x, inplace: bool = False):
106
+ return (x + 3).clamp(min=0, max=6).div(6.)
107
+
108
+
109
+ def hard_sigmoid_jit_bwd(x, grad_output):
110
+ m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
111
+ return grad_output * m
112
+
113
+
114
+ class HardSigmoidJitAutoFn(torch.autograd.Function):
115
+ @staticmethod
116
+ def forward(ctx, x):
117
+ ctx.save_for_backward(x)
118
+ return hard_sigmoid_jit_fwd(x)
119
+
120
+ @staticmethod
121
+ def backward(ctx, grad_output):
122
+ x = ctx.saved_tensors[0]
123
+ return hard_sigmoid_jit_bwd(x, grad_output)
124
+
125
+
126
+ def hard_sigmoid_me(x, inplace: bool = False):
127
+ return HardSigmoidJitAutoFn.apply(x)
128
+
129
+
130
+ class HardSigmoidMe(nn.Module):
131
+ def __init__(self, inplace: bool = False):
132
+ super(HardSigmoidMe, self).__init__()
133
+
134
+ def forward(self, x):
135
+ return HardSigmoidJitAutoFn.apply(x)
136
+
137
+
138
+ def hard_swish_jit_fwd(x):
139
+ return x * (x + 3).clamp(min=0, max=6).div(6.)
140
+
141
+
142
+ def hard_swish_jit_bwd(x, grad_output):
143
+ m = torch.ones_like(x) * (x >= 3.)
144
+ m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
145
+ return grad_output * m
146
+
147
+
148
+ class HardSwishJitAutoFn(torch.autograd.Function):
149
+ """A memory efficient, jit-scripted HardSwish activation"""
150
+ @staticmethod
151
+ def forward(ctx, x):
152
+ ctx.save_for_backward(x)
153
+ return hard_swish_jit_fwd(x)
154
+
155
+ @staticmethod
156
+ def backward(ctx, grad_output):
157
+ x = ctx.saved_tensors[0]
158
+ return hard_swish_jit_bwd(x, grad_output)
159
+
160
+
161
+ def hard_swish_me(x, inplace=False):
162
+ return HardSwishJitAutoFn.apply(x)
163
+
164
+
165
+ class HardSwishMe(nn.Module):
166
+ def __init__(self, inplace: bool = False):
167
+ super(HardSwishMe, self).__init__()
168
+
169
+ def forward(self, x):
170
+ return HardSwishJitAutoFn.apply(x)
geffnet/config.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Global layer config state
2
+ """
3
+ from typing import Any, Optional
4
+
5
+ __all__ = [
6
+ 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs',
7
+ 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
8
+ ]
9
+
10
+ # Set to True if prefer to have layers with no jit optimization (includes activations)
11
+ _NO_JIT = False
12
+
13
+ # Set to True if prefer to have activation layers with no jit optimization
14
+ # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
15
+ # the jit flags so far are activations. This will change as more layers are updated and/or added.
16
+ _NO_ACTIVATION_JIT = False
17
+
18
+ # Set to True if exporting a model with Same padding via ONNX
19
+ _EXPORTABLE = False
20
+
21
+ # Set to True if wanting to use torch.jit.script on a model
22
+ _SCRIPTABLE = False
23
+
24
+
25
+ def is_no_jit():
26
+ return _NO_JIT
27
+
28
+
29
+ class set_no_jit:
30
+ def __init__(self, mode: bool) -> None:
31
+ global _NO_JIT
32
+ self.prev = _NO_JIT
33
+ _NO_JIT = mode
34
+
35
+ def __enter__(self) -> None:
36
+ pass
37
+
38
+ def __exit__(self, *args: Any) -> bool:
39
+ global _NO_JIT
40
+ _NO_JIT = self.prev
41
+ return False
42
+
43
+
44
+ def is_exportable():
45
+ return _EXPORTABLE
46
+
47
+
48
+ class set_exportable:
49
+ def __init__(self, mode: bool) -> None:
50
+ global _EXPORTABLE
51
+ self.prev = _EXPORTABLE
52
+ _EXPORTABLE = mode
53
+
54
+ def __enter__(self) -> None:
55
+ pass
56
+
57
+ def __exit__(self, *args: Any) -> bool:
58
+ global _EXPORTABLE
59
+ _EXPORTABLE = self.prev
60
+ return False
61
+
62
+
63
+ def is_scriptable():
64
+ return _SCRIPTABLE
65
+
66
+
67
+ class set_scriptable:
68
+ def __init__(self, mode: bool) -> None:
69
+ global _SCRIPTABLE
70
+ self.prev = _SCRIPTABLE
71
+ _SCRIPTABLE = mode
72
+
73
+ def __enter__(self) -> None:
74
+ pass
75
+
76
+ def __exit__(self, *args: Any) -> bool:
77
+ global _SCRIPTABLE
78
+ _SCRIPTABLE = self.prev
79
+ return False
80
+
81
+
82
+ class set_layer_config:
83
+ """ Layer config context manager that allows setting all layer config flags at once.
84
+ If a flag arg is None, it will not change the current value.
85
+ """
86
+ def __init__(
87
+ self,
88
+ scriptable: Optional[bool] = None,
89
+ exportable: Optional[bool] = None,
90
+ no_jit: Optional[bool] = None,
91
+ no_activation_jit: Optional[bool] = None):
92
+ global _SCRIPTABLE
93
+ global _EXPORTABLE
94
+ global _NO_JIT
95
+ global _NO_ACTIVATION_JIT
96
+ self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
97
+ if scriptable is not None:
98
+ _SCRIPTABLE = scriptable
99
+ if exportable is not None:
100
+ _EXPORTABLE = exportable
101
+ if no_jit is not None:
102
+ _NO_JIT = no_jit
103
+ if no_activation_jit is not None:
104
+ _NO_ACTIVATION_JIT = no_activation_jit
105
+
106
+ def __enter__(self) -> None:
107
+ pass
108
+
109
+ def __exit__(self, *args: Any) -> bool:
110
+ global _SCRIPTABLE
111
+ global _EXPORTABLE
112
+ global _NO_JIT
113
+ global _NO_ACTIVATION_JIT
114
+ _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
115
+ return False
116
+
117
+
118
+ def layer_config_kwargs(kwargs):
119
+ """ Consume config kwargs and return contextmgr obj """
120
+ return set_layer_config(
121
+ scriptable=kwargs.pop('scriptable', None),
122
+ exportable=kwargs.pop('exportable', None),
123
+ no_jit=kwargs.pop('no_jit', None))
geffnet/conv2d_layers.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Conv2D w/ SAME padding, CondConv, MixedConv
2
+
3
+ A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and
4
+ MobileNetV3 models that maintain weight compatibility with original Tensorflow models.
5
+
6
+ Copyright 2020 Ross Wightman
7
+ """
8
+ import collections.abc
9
+ import math
10
+ from functools import partial
11
+ from itertools import repeat
12
+ from typing import Tuple, Optional
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from .config import *
20
+
21
+
22
+ # From PyTorch internals
23
+ def _ntuple(n):
24
+ def parse(x):
25
+ if isinstance(x, collections.abc.Iterable):
26
+ return x
27
+ return tuple(repeat(x, n))
28
+ return parse
29
+
30
+
31
+ _single = _ntuple(1)
32
+ _pair = _ntuple(2)
33
+ _triple = _ntuple(3)
34
+ _quadruple = _ntuple(4)
35
+
36
+
37
+ def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
38
+ return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
39
+
40
+
41
+ def _get_padding(kernel_size, stride=1, dilation=1, **_):
42
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
43
+ return padding
44
+
45
+
46
+ def _calc_same_pad(i: int, k: int, s: int, d: int):
47
+ return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0)
48
+
49
+
50
+ def _same_pad_arg(input_size, kernel_size, stride, dilation):
51
+ ih, iw = input_size
52
+ kh, kw = kernel_size
53
+ pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
54
+ pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
55
+ return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
56
+
57
+
58
+ def _split_channels(num_chan, num_groups):
59
+ split = [num_chan // num_groups for _ in range(num_groups)]
60
+ split[0] += num_chan - sum(split)
61
+ return split
62
+
63
+
64
+ def conv2d_same(
65
+ x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
66
+ padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
67
+ ih, iw = x.size()[-2:]
68
+ kh, kw = weight.size()[-2:]
69
+ pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
70
+ pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
71
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
72
+ return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
73
+
74
+
75
+ class Conv2dSame(nn.Conv2d):
76
+ """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
77
+ """
78
+
79
+ # pylint: disable=unused-argument
80
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
81
+ padding=0, dilation=1, groups=1, bias=True):
82
+ super(Conv2dSame, self).__init__(
83
+ in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
84
+
85
+ def forward(self, x):
86
+ return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
87
+
88
+
89
+ class Conv2dSameExport(nn.Conv2d):
90
+ """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
91
+
92
+ NOTE: This does not currently work with torch.jit.script
93
+ """
94
+
95
+ # pylint: disable=unused-argument
96
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
97
+ padding=0, dilation=1, groups=1, bias=True):
98
+ super(Conv2dSameExport, self).__init__(
99
+ in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
100
+ self.pad = None
101
+ self.pad_input_size = (0, 0)
102
+
103
+ def forward(self, x):
104
+ input_size = x.size()[-2:]
105
+ if self.pad is None:
106
+ pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
107
+ self.pad = nn.ZeroPad2d(pad_arg)
108
+ self.pad_input_size = input_size
109
+
110
+ if self.pad is not None:
111
+ x = self.pad(x)
112
+ return F.conv2d(
113
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
114
+
115
+
116
+ def get_padding_value(padding, kernel_size, **kwargs):
117
+ dynamic = False
118
+ if isinstance(padding, str):
119
+ # for any string padding, the padding will be calculated for you, one of three ways
120
+ padding = padding.lower()
121
+ if padding == 'same':
122
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
123
+ if _is_static_pad(kernel_size, **kwargs):
124
+ # static case, no extra overhead
125
+ padding = _get_padding(kernel_size, **kwargs)
126
+ else:
127
+ # dynamic padding
128
+ padding = 0
129
+ dynamic = True
130
+ elif padding == 'valid':
131
+ # 'VALID' padding, same as padding=0
132
+ padding = 0
133
+ else:
134
+ # Default to PyTorch style 'same'-ish symmetric padding
135
+ padding = _get_padding(kernel_size, **kwargs)
136
+ return padding, dynamic
137
+
138
+
139
+ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
140
+ padding = kwargs.pop('padding', '')
141
+ kwargs.setdefault('bias', False)
142
+ padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
143
+ if is_dynamic:
144
+ if is_exportable():
145
+ assert not is_scriptable()
146
+ return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
147
+ else:
148
+ return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
149
+ else:
150
+ return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
151
+
152
+
153
+ class MixedConv2d(nn.ModuleDict):
154
+ """ Mixed Grouped Convolution
155
+ Based on MDConv and GroupedConv in MixNet impl:
156
+ https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
157
+ """
158
+
159
+ def __init__(self, in_channels, out_channels, kernel_size=3,
160
+ stride=1, padding='', dilation=1, depthwise=False, **kwargs):
161
+ super(MixedConv2d, self).__init__()
162
+
163
+ kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
164
+ num_groups = len(kernel_size)
165
+ in_splits = _split_channels(in_channels, num_groups)
166
+ out_splits = _split_channels(out_channels, num_groups)
167
+ self.in_channels = sum(in_splits)
168
+ self.out_channels = sum(out_splits)
169
+ for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
170
+ conv_groups = out_ch if depthwise else 1
171
+ self.add_module(
172
+ str(idx),
173
+ create_conv2d_pad(
174
+ in_ch, out_ch, k, stride=stride,
175
+ padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
176
+ )
177
+ self.splits = in_splits
178
+
179
+ def forward(self, x):
180
+ x_split = torch.split(x, self.splits, 1)
181
+ x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
182
+ x = torch.cat(x_out, 1)
183
+ return x
184
+
185
+
186
+ def get_condconv_initializer(initializer, num_experts, expert_shape):
187
+ def condconv_initializer(weight):
188
+ """CondConv initializer function."""
189
+ num_params = np.prod(expert_shape)
190
+ if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
191
+ weight.shape[1] != num_params):
192
+ raise (ValueError(
193
+ 'CondConv variables must have shape [num_experts, num_params]'))
194
+ for i in range(num_experts):
195
+ initializer(weight[i].view(expert_shape))
196
+ return condconv_initializer
197
+
198
+
199
+ class CondConv2d(nn.Module):
200
+ """ Conditional Convolution
201
+ Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
202
+
203
+ Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
204
+ https://github.com/pytorch/pytorch/issues/17983
205
+ """
206
+ __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
207
+
208
+ def __init__(self, in_channels, out_channels, kernel_size=3,
209
+ stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
210
+ super(CondConv2d, self).__init__()
211
+
212
+ self.in_channels = in_channels
213
+ self.out_channels = out_channels
214
+ self.kernel_size = _pair(kernel_size)
215
+ self.stride = _pair(stride)
216
+ padding_val, is_padding_dynamic = get_padding_value(
217
+ padding, kernel_size, stride=stride, dilation=dilation)
218
+ self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
219
+ self.padding = _pair(padding_val)
220
+ self.dilation = _pair(dilation)
221
+ self.groups = groups
222
+ self.num_experts = num_experts
223
+
224
+ self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
225
+ weight_num_param = 1
226
+ for wd in self.weight_shape:
227
+ weight_num_param *= wd
228
+ self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
229
+
230
+ if bias:
231
+ self.bias_shape = (self.out_channels,)
232
+ self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
233
+ else:
234
+ self.register_parameter('bias', None)
235
+
236
+ self.reset_parameters()
237
+
238
+ def reset_parameters(self):
239
+ init_weight = get_condconv_initializer(
240
+ partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
241
+ init_weight(self.weight)
242
+ if self.bias is not None:
243
+ fan_in = np.prod(self.weight_shape[1:])
244
+ bound = 1 / math.sqrt(fan_in)
245
+ init_bias = get_condconv_initializer(
246
+ partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
247
+ init_bias(self.bias)
248
+
249
+ def forward(self, x, routing_weights):
250
+ B, C, H, W = x.shape
251
+ weight = torch.matmul(routing_weights, self.weight)
252
+ new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
253
+ weight = weight.view(new_weight_shape)
254
+ bias = None
255
+ if self.bias is not None:
256
+ bias = torch.matmul(routing_weights, self.bias)
257
+ bias = bias.view(B * self.out_channels)
258
+ # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
259
+ x = x.view(1, B * C, H, W)
260
+ if self.dynamic_padding:
261
+ out = conv2d_same(
262
+ x, weight, bias, stride=self.stride, padding=self.padding,
263
+ dilation=self.dilation, groups=self.groups * B)
264
+ else:
265
+ out = F.conv2d(
266
+ x, weight, bias, stride=self.stride, padding=self.padding,
267
+ dilation=self.dilation, groups=self.groups * B)
268
+ out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
269
+
270
+ # Literal port (from TF definition)
271
+ # x = torch.split(x, 1, 0)
272
+ # weight = torch.split(weight, 1, 0)
273
+ # if self.bias is not None:
274
+ # bias = torch.matmul(routing_weights, self.bias)
275
+ # bias = torch.split(bias, 1, 0)
276
+ # else:
277
+ # bias = [None] * B
278
+ # out = []
279
+ # for xi, wi, bi in zip(x, weight, bias):
280
+ # wi = wi.view(*self.weight_shape)
281
+ # if bi is not None:
282
+ # bi = bi.view(*self.bias_shape)
283
+ # out.append(self.conv_fn(
284
+ # xi, wi, bi, stride=self.stride, padding=self.padding,
285
+ # dilation=self.dilation, groups=self.groups))
286
+ # out = torch.cat(out, 0)
287
+ return out
288
+
289
+
290
+ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
291
+ assert 'groups' not in kwargs # only use 'depthwise' bool arg
292
+ if isinstance(kernel_size, list):
293
+ assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
294
+ # We're going to use only lists for defining the MixedConv2d kernel groups,
295
+ # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
296
+ m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
297
+ else:
298
+ depthwise = kwargs.pop('depthwise', False)
299
+ groups = out_chs if depthwise else 1
300
+ if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
301
+ m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
302
+ else:
303
+ m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
304
+ return m
geffnet/efficientnet_builder.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ EfficientNet / MobileNetV3 Blocks and Builder
2
+
3
+ Copyright 2020 Ross Wightman
4
+ """
5
+ import re
6
+ from copy import deepcopy
7
+
8
+ from .conv2d_layers import *
9
+ from geffnet.activations import *
10
+
11
+ __all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible',
12
+ 'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv',
13
+ 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def',
14
+ 'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'
15
+ ]
16
+
17
+ # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
18
+ # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
19
+ # NOTE: momentum varies btw .99 and .9997 depending on source
20
+ # .99 in official TF TPU impl
21
+ # .9997 (/w .999 in search space) for paper
22
+ #
23
+ # PyTorch defaults are momentum = .1, eps = 1e-5
24
+ #
25
+ BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
26
+ BN_EPS_TF_DEFAULT = 1e-3
27
+ _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
28
+
29
+
30
+ def get_bn_args_tf():
31
+ return _BN_ARGS_TF.copy()
32
+
33
+
34
+ def resolve_bn_args(kwargs):
35
+ bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
36
+ bn_momentum = kwargs.pop('bn_momentum', None)
37
+ if bn_momentum is not None:
38
+ bn_args['momentum'] = bn_momentum
39
+ bn_eps = kwargs.pop('bn_eps', None)
40
+ if bn_eps is not None:
41
+ bn_args['eps'] = bn_eps
42
+ return bn_args
43
+
44
+
45
+ _SE_ARGS_DEFAULT = dict(
46
+ gate_fn=sigmoid,
47
+ act_layer=None, # None == use containing block's activation layer
48
+ reduce_mid=False,
49
+ divisor=1)
50
+
51
+
52
+ def resolve_se_args(kwargs, in_chs, act_layer=None):
53
+ se_kwargs = kwargs.copy() if kwargs is not None else {}
54
+ # fill in args that aren't specified with the defaults
55
+ for k, v in _SE_ARGS_DEFAULT.items():
56
+ se_kwargs.setdefault(k, v)
57
+ # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
58
+ if not se_kwargs.pop('reduce_mid'):
59
+ se_kwargs['reduced_base_chs'] = in_chs
60
+ # act_layer override, if it remains None, the containing block's act_layer will be used
61
+ if se_kwargs['act_layer'] is None:
62
+ assert act_layer is not None
63
+ se_kwargs['act_layer'] = act_layer
64
+ return se_kwargs
65
+
66
+
67
+ def resolve_act_layer(kwargs, default='relu'):
68
+ act_layer = kwargs.pop('act_layer', default)
69
+ if isinstance(act_layer, str):
70
+ act_layer = get_act_layer(act_layer)
71
+ return act_layer
72
+
73
+
74
+ def make_divisible(v: int, divisor: int = 8, min_value: int = None):
75
+ min_value = min_value or divisor
76
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
77
+ if new_v < 0.9 * v: # ensure round down does not go down by more than 10%.
78
+ new_v += divisor
79
+ return new_v
80
+
81
+
82
+ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
83
+ """Round number of filters based on depth multiplier."""
84
+ if not multiplier:
85
+ return channels
86
+ channels *= multiplier
87
+ return make_divisible(channels, divisor, channel_min)
88
+
89
+
90
+ def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
91
+ """Apply drop connect."""
92
+ if not training:
93
+ return inputs
94
+
95
+ keep_prob = 1 - drop_connect_rate
96
+ random_tensor = keep_prob + torch.rand(
97
+ (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
98
+ random_tensor.floor_() # binarize
99
+ output = inputs.div(keep_prob) * random_tensor
100
+ return output
101
+
102
+
103
+ class SqueezeExcite(nn.Module):
104
+
105
+ def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1):
106
+ super(SqueezeExcite, self).__init__()
107
+ reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
108
+ self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
109
+ self.act1 = act_layer(inplace=True)
110
+ self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
111
+ self.gate_fn = gate_fn
112
+
113
+ def forward(self, x):
114
+ x_se = x.mean((2, 3), keepdim=True)
115
+ x_se = self.conv_reduce(x_se)
116
+ x_se = self.act1(x_se)
117
+ x_se = self.conv_expand(x_se)
118
+ x = x * self.gate_fn(x_se)
119
+ return x
120
+
121
+
122
+ class ConvBnAct(nn.Module):
123
+ def __init__(self, in_chs, out_chs, kernel_size,
124
+ stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
125
+ super(ConvBnAct, self).__init__()
126
+ assert stride in [1, 2]
127
+ norm_kwargs = norm_kwargs or {}
128
+ self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
129
+ self.bn1 = norm_layer(out_chs, **norm_kwargs)
130
+ self.act1 = act_layer(inplace=True)
131
+
132
+ def forward(self, x):
133
+ x = self.conv(x)
134
+ x = self.bn1(x)
135
+ x = self.act1(x)
136
+ return x
137
+
138
+
139
+ class DepthwiseSeparableConv(nn.Module):
140
+ """ DepthwiseSeparable block
141
+ Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
142
+ factor of 1.0. This is an alternative to having a IR with optional first pw conv.
143
+ """
144
+ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
145
+ stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
146
+ pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
147
+ norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
148
+ super(DepthwiseSeparableConv, self).__init__()
149
+ assert stride in [1, 2]
150
+ norm_kwargs = norm_kwargs or {}
151
+ self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
152
+ self.drop_connect_rate = drop_connect_rate
153
+
154
+ self.conv_dw = select_conv2d(
155
+ in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
156
+ self.bn1 = norm_layer(in_chs, **norm_kwargs)
157
+ self.act1 = act_layer(inplace=True)
158
+
159
+ # Squeeze-and-excitation
160
+ if se_ratio is not None and se_ratio > 0.:
161
+ se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
162
+ self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
163
+ else:
164
+ self.se = nn.Identity()
165
+
166
+ self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
167
+ self.bn2 = norm_layer(out_chs, **norm_kwargs)
168
+ self.act2 = act_layer(inplace=True) if pw_act else nn.Identity()
169
+
170
+ def forward(self, x):
171
+ residual = x
172
+
173
+ x = self.conv_dw(x)
174
+ x = self.bn1(x)
175
+ x = self.act1(x)
176
+
177
+ x = self.se(x)
178
+
179
+ x = self.conv_pw(x)
180
+ x = self.bn2(x)
181
+ x = self.act2(x)
182
+
183
+ if self.has_residual:
184
+ if self.drop_connect_rate > 0.:
185
+ x = drop_connect(x, self.training, self.drop_connect_rate)
186
+ x += residual
187
+ return x
188
+
189
+
190
+ class InvertedResidual(nn.Module):
191
+ """ Inverted residual block w/ optional SE"""
192
+
193
+ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
194
+ stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
195
+ exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
196
+ se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
197
+ conv_kwargs=None, drop_connect_rate=0.):
198
+ super(InvertedResidual, self).__init__()
199
+ norm_kwargs = norm_kwargs or {}
200
+ conv_kwargs = conv_kwargs or {}
201
+ mid_chs: int = make_divisible(in_chs * exp_ratio)
202
+ self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
203
+ self.drop_connect_rate = drop_connect_rate
204
+
205
+ # Point-wise expansion
206
+ self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
207
+ self.bn1 = norm_layer(mid_chs, **norm_kwargs)
208
+ self.act1 = act_layer(inplace=True)
209
+
210
+ # Depth-wise convolution
211
+ self.conv_dw = select_conv2d(
212
+ mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs)
213
+ self.bn2 = norm_layer(mid_chs, **norm_kwargs)
214
+ self.act2 = act_layer(inplace=True)
215
+
216
+ # Squeeze-and-excitation
217
+ if se_ratio is not None and se_ratio > 0.:
218
+ se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
219
+ self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
220
+ else:
221
+ self.se = nn.Identity() # for jit.script compat
222
+
223
+ # Point-wise linear projection
224
+ self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
225
+ self.bn3 = norm_layer(out_chs, **norm_kwargs)
226
+
227
+ def forward(self, x):
228
+ residual = x
229
+
230
+ # Point-wise expansion
231
+ x = self.conv_pw(x)
232
+ x = self.bn1(x)
233
+ x = self.act1(x)
234
+
235
+ # Depth-wise convolution
236
+ x = self.conv_dw(x)
237
+ x = self.bn2(x)
238
+ x = self.act2(x)
239
+
240
+ # Squeeze-and-excitation
241
+ x = self.se(x)
242
+
243
+ # Point-wise linear projection
244
+ x = self.conv_pwl(x)
245
+ x = self.bn3(x)
246
+
247
+ if self.has_residual:
248
+ if self.drop_connect_rate > 0.:
249
+ x = drop_connect(x, self.training, self.drop_connect_rate)
250
+ x += residual
251
+ return x
252
+
253
+
254
+ class CondConvResidual(InvertedResidual):
255
+ """ Inverted residual block w/ CondConv routing"""
256
+
257
+ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
258
+ stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
259
+ exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
260
+ se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
261
+ num_experts=0, drop_connect_rate=0.):
262
+
263
+ self.num_experts = num_experts
264
+ conv_kwargs = dict(num_experts=self.num_experts)
265
+
266
+ super(CondConvResidual, self).__init__(
267
+ in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type,
268
+ act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
269
+ pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
270
+ norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
271
+ drop_connect_rate=drop_connect_rate)
272
+
273
+ self.routing_fn = nn.Linear(in_chs, self.num_experts)
274
+
275
+ def forward(self, x):
276
+ residual = x
277
+
278
+ # CondConv routing
279
+ pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
280
+ routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
281
+
282
+ # Point-wise expansion
283
+ x = self.conv_pw(x, routing_weights)
284
+ x = self.bn1(x)
285
+ x = self.act1(x)
286
+
287
+ # Depth-wise convolution
288
+ x = self.conv_dw(x, routing_weights)
289
+ x = self.bn2(x)
290
+ x = self.act2(x)
291
+
292
+ # Squeeze-and-excitation
293
+ x = self.se(x)
294
+
295
+ # Point-wise linear projection
296
+ x = self.conv_pwl(x, routing_weights)
297
+ x = self.bn3(x)
298
+
299
+ if self.has_residual:
300
+ if self.drop_connect_rate > 0.:
301
+ x = drop_connect(x, self.training, self.drop_connect_rate)
302
+ x += residual
303
+ return x
304
+
305
+
306
+ class EdgeResidual(nn.Module):
307
+ """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride"""
308
+
309
+ def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
310
+ stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
311
+ se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
312
+ super(EdgeResidual, self).__init__()
313
+ norm_kwargs = norm_kwargs or {}
314
+ mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio)
315
+ self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
316
+ self.drop_connect_rate = drop_connect_rate
317
+
318
+ # Expansion convolution
319
+ self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
320
+ self.bn1 = norm_layer(mid_chs, **norm_kwargs)
321
+ self.act1 = act_layer(inplace=True)
322
+
323
+ # Squeeze-and-excitation
324
+ if se_ratio is not None and se_ratio > 0.:
325
+ se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
326
+ self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
327
+ else:
328
+ self.se = nn.Identity()
329
+
330
+ # Point-wise linear projection
331
+ self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
332
+ self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs)
333
+
334
+ def forward(self, x):
335
+ residual = x
336
+
337
+ # Expansion convolution
338
+ x = self.conv_exp(x)
339
+ x = self.bn1(x)
340
+ x = self.act1(x)
341
+
342
+ # Squeeze-and-excitation
343
+ x = self.se(x)
344
+
345
+ # Point-wise linear projection
346
+ x = self.conv_pwl(x)
347
+ x = self.bn2(x)
348
+
349
+ if self.has_residual:
350
+ if self.drop_connect_rate > 0.:
351
+ x = drop_connect(x, self.training, self.drop_connect_rate)
352
+ x += residual
353
+
354
+ return x
355
+
356
+
357
+ class EfficientNetBuilder:
358
+ """ Build Trunk Blocks for Efficient/Mobile Networks
359
+
360
+ This ended up being somewhat of a cross between
361
+ https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
362
+ and
363
+ https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
364
+
365
+ """
366
+
367
+ def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
368
+ pad_type='', act_layer=None, se_kwargs=None,
369
+ norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
370
+ self.channel_multiplier = channel_multiplier
371
+ self.channel_divisor = channel_divisor
372
+ self.channel_min = channel_min
373
+ self.pad_type = pad_type
374
+ self.act_layer = act_layer
375
+ self.se_kwargs = se_kwargs
376
+ self.norm_layer = norm_layer
377
+ self.norm_kwargs = norm_kwargs
378
+ self.drop_connect_rate = drop_connect_rate
379
+
380
+ # updated during build
381
+ self.in_chs = None
382
+ self.block_idx = 0
383
+ self.block_count = 0
384
+
385
+ def _round_channels(self, chs):
386
+ return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
387
+
388
+ def _make_block(self, ba):
389
+ bt = ba.pop('block_type')
390
+ ba['in_chs'] = self.in_chs
391
+ ba['out_chs'] = self._round_channels(ba['out_chs'])
392
+ if 'fake_in_chs' in ba and ba['fake_in_chs']:
393
+ # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU
394
+ ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
395
+ ba['norm_layer'] = self.norm_layer
396
+ ba['norm_kwargs'] = self.norm_kwargs
397
+ ba['pad_type'] = self.pad_type
398
+ # block act fn overrides the model default
399
+ ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
400
+ assert ba['act_layer'] is not None
401
+ if bt == 'ir':
402
+ ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
403
+ ba['se_kwargs'] = self.se_kwargs
404
+ if ba.get('num_experts', 0) > 0:
405
+ block = CondConvResidual(**ba)
406
+ else:
407
+ block = InvertedResidual(**ba)
408
+ elif bt == 'ds' or bt == 'dsa':
409
+ ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
410
+ ba['se_kwargs'] = self.se_kwargs
411
+ block = DepthwiseSeparableConv(**ba)
412
+ elif bt == 'er':
413
+ ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
414
+ ba['se_kwargs'] = self.se_kwargs
415
+ block = EdgeResidual(**ba)
416
+ elif bt == 'cn':
417
+ block = ConvBnAct(**ba)
418
+ else:
419
+ assert False, 'Uknkown block type (%s) while building model.' % bt
420
+ self.in_chs = ba['out_chs'] # update in_chs for arg of next block
421
+ return block
422
+
423
+ def _make_stack(self, stack_args):
424
+ blocks = []
425
+ # each stack (stage) contains a list of block arguments
426
+ for i, ba in enumerate(stack_args):
427
+ if i >= 1:
428
+ # only the first block in any stack can have a stride > 1
429
+ ba['stride'] = 1
430
+ block = self._make_block(ba)
431
+ blocks.append(block)
432
+ self.block_idx += 1 # incr global idx (across all stacks)
433
+ return nn.Sequential(*blocks)
434
+
435
+ def __call__(self, in_chs, block_args):
436
+ """ Build the blocks
437
+ Args:
438
+ in_chs: Number of input-channels passed to first block
439
+ block_args: A list of lists, outer list defines stages, inner
440
+ list contains strings defining block configuration(s)
441
+ Return:
442
+ List of block stacks (each stack wrapped in nn.Sequential)
443
+ """
444
+ self.in_chs = in_chs
445
+ self.block_count = sum([len(x) for x in block_args])
446
+ self.block_idx = 0
447
+ blocks = []
448
+ # outer list of block_args defines the stacks ('stages' by some conventions)
449
+ for stack_idx, stack in enumerate(block_args):
450
+ assert isinstance(stack, list)
451
+ stack = self._make_stack(stack)
452
+ blocks.append(stack)
453
+ return blocks
454
+
455
+
456
+ def _parse_ksize(ss):
457
+ if ss.isdigit():
458
+ return int(ss)
459
+ else:
460
+ return [int(k) for k in ss.split('.')]
461
+
462
+
463
+ def _decode_block_str(block_str):
464
+ """ Decode block definition string
465
+
466
+ Gets a list of block arg (dicts) through a string notation of arguments.
467
+ E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
468
+
469
+ All args can exist in any order with the exception of the leading string which
470
+ is assumed to indicate the block type.
471
+
472
+ leading string - block type (
473
+ ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
474
+ r - number of repeat blocks,
475
+ k - kernel size,
476
+ s - strides (1-9),
477
+ e - expansion ratio,
478
+ c - output channels,
479
+ se - squeeze/excitation ratio
480
+ n - activation fn ('re', 'r6', 'hs', or 'sw')
481
+ Args:
482
+ block_str: a string representation of block arguments.
483
+ Returns:
484
+ A list of block args (dicts)
485
+ Raises:
486
+ ValueError: if the string def not properly specified (TODO)
487
+ """
488
+ assert isinstance(block_str, str)
489
+ ops = block_str.split('_')
490
+ block_type = ops[0] # take the block type off the front
491
+ ops = ops[1:]
492
+ options = {}
493
+ noskip = False
494
+ for op in ops:
495
+ # string options being checked on individual basis, combine if they grow
496
+ if op == 'noskip':
497
+ noskip = True
498
+ elif op.startswith('n'):
499
+ # activation fn
500
+ key = op[0]
501
+ v = op[1:]
502
+ if v == 're':
503
+ value = get_act_layer('relu')
504
+ elif v == 'r6':
505
+ value = get_act_layer('relu6')
506
+ elif v == 'hs':
507
+ value = get_act_layer('hard_swish')
508
+ elif v == 'sw':
509
+ value = get_act_layer('swish')
510
+ else:
511
+ continue
512
+ options[key] = value
513
+ else:
514
+ # all numeric options
515
+ splits = re.split(r'(\d.*)', op)
516
+ if len(splits) >= 2:
517
+ key, value = splits[:2]
518
+ options[key] = value
519
+
520
+ # if act_layer is None, the model default (passed to model init) will be used
521
+ act_layer = options['n'] if 'n' in options else None
522
+ exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
523
+ pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
524
+ fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
525
+
526
+ num_repeat = int(options['r'])
527
+ # each type of block has different valid arguments, fill accordingly
528
+ if block_type == 'ir':
529
+ block_args = dict(
530
+ block_type=block_type,
531
+ dw_kernel_size=_parse_ksize(options['k']),
532
+ exp_kernel_size=exp_kernel_size,
533
+ pw_kernel_size=pw_kernel_size,
534
+ out_chs=int(options['c']),
535
+ exp_ratio=float(options['e']),
536
+ se_ratio=float(options['se']) if 'se' in options else None,
537
+ stride=int(options['s']),
538
+ act_layer=act_layer,
539
+ noskip=noskip,
540
+ )
541
+ if 'cc' in options:
542
+ block_args['num_experts'] = int(options['cc'])
543
+ elif block_type == 'ds' or block_type == 'dsa':
544
+ block_args = dict(
545
+ block_type=block_type,
546
+ dw_kernel_size=_parse_ksize(options['k']),
547
+ pw_kernel_size=pw_kernel_size,
548
+ out_chs=int(options['c']),
549
+ se_ratio=float(options['se']) if 'se' in options else None,
550
+ stride=int(options['s']),
551
+ act_layer=act_layer,
552
+ pw_act=block_type == 'dsa',
553
+ noskip=block_type == 'dsa' or noskip,
554
+ )
555
+ elif block_type == 'er':
556
+ block_args = dict(
557
+ block_type=block_type,
558
+ exp_kernel_size=_parse_ksize(options['k']),
559
+ pw_kernel_size=pw_kernel_size,
560
+ out_chs=int(options['c']),
561
+ exp_ratio=float(options['e']),
562
+ fake_in_chs=fake_in_chs,
563
+ se_ratio=float(options['se']) if 'se' in options else None,
564
+ stride=int(options['s']),
565
+ act_layer=act_layer,
566
+ noskip=noskip,
567
+ )
568
+ elif block_type == 'cn':
569
+ block_args = dict(
570
+ block_type=block_type,
571
+ kernel_size=int(options['k']),
572
+ out_chs=int(options['c']),
573
+ stride=int(options['s']),
574
+ act_layer=act_layer,
575
+ )
576
+ else:
577
+ assert False, 'Unknown block type (%s)' % block_type
578
+
579
+ return block_args, num_repeat
580
+
581
+
582
+ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
583
+ """ Per-stage depth scaling
584
+ Scales the block repeats in each stage. This depth scaling impl maintains
585
+ compatibility with the EfficientNet scaling method, while allowing sensible
586
+ scaling for other models that may have multiple block arg definitions in each stage.
587
+ """
588
+
589
+ # We scale the total repeat count for each stage, there may be multiple
590
+ # block arg defs per stage so we need to sum.
591
+ num_repeat = sum(repeats)
592
+ if depth_trunc == 'round':
593
+ # Truncating to int by rounding allows stages with few repeats to remain
594
+ # proportionally smaller for longer. This is a good choice when stage definitions
595
+ # include single repeat stages that we'd prefer to keep that way as long as possible
596
+ num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
597
+ else:
598
+ # The default for EfficientNet truncates repeats to int via 'ceil'.
599
+ # Any multiplier > 1.0 will result in an increased depth for every stage.
600
+ num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
601
+
602
+ # Proportionally distribute repeat count scaling to each block definition in the stage.
603
+ # Allocation is done in reverse as it results in the first block being less likely to be scaled.
604
+ # The first block makes less sense to repeat in most of the arch definitions.
605
+ repeats_scaled = []
606
+ for r in repeats[::-1]:
607
+ rs = max(1, round((r / num_repeat * num_repeat_scaled)))
608
+ repeats_scaled.append(rs)
609
+ num_repeat -= r
610
+ num_repeat_scaled -= rs
611
+ repeats_scaled = repeats_scaled[::-1]
612
+
613
+ # Apply the calculated scaling to each block arg in the stage
614
+ sa_scaled = []
615
+ for ba, rep in zip(stack_args, repeats_scaled):
616
+ sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
617
+ return sa_scaled
618
+
619
+
620
+ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
621
+ arch_args = []
622
+ for stack_idx, block_strings in enumerate(arch_def):
623
+ assert isinstance(block_strings, list)
624
+ stack_args = []
625
+ repeats = []
626
+ for block_str in block_strings:
627
+ assert isinstance(block_str, str)
628
+ ba, rep = _decode_block_str(block_str)
629
+ if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
630
+ ba['num_experts'] *= experts_multiplier
631
+ stack_args.append(ba)
632
+ repeats.append(rep)
633
+ if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
634
+ arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
635
+ else:
636
+ arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
637
+ return arch_args
638
+
639
+
640
+ def initialize_weight_goog(m, n='', fix_group_fanout=True):
641
+ # weight init as per Tensorflow Official impl
642
+ # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
643
+ if isinstance(m, CondConv2d):
644
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
645
+ if fix_group_fanout:
646
+ fan_out //= m.groups
647
+ init_weight_fn = get_condconv_initializer(
648
+ lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
649
+ init_weight_fn(m.weight)
650
+ if m.bias is not None:
651
+ m.bias.data.zero_()
652
+ elif isinstance(m, nn.Conv2d):
653
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
654
+ if fix_group_fanout:
655
+ fan_out //= m.groups
656
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
657
+ if m.bias is not None:
658
+ m.bias.data.zero_()
659
+ elif isinstance(m, nn.BatchNorm2d):
660
+ m.weight.data.fill_(1.0)
661
+ m.bias.data.zero_()
662
+ elif isinstance(m, nn.Linear):
663
+ fan_out = m.weight.size(0) # fan-out
664
+ fan_in = 0
665
+ if 'routing_fn' in n:
666
+ fan_in = m.weight.size(1)
667
+ init_range = 1.0 / math.sqrt(fan_in + fan_out)
668
+ m.weight.data.uniform_(-init_range, init_range)
669
+ m.bias.data.zero_()
670
+
671
+
672
+ def initialize_weight_default(m, n=''):
673
+ if isinstance(m, CondConv2d):
674
+ init_fn = get_condconv_initializer(partial(
675
+ nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
676
+ init_fn(m.weight)
677
+ elif isinstance(m, nn.Conv2d):
678
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
679
+ elif isinstance(m, nn.BatchNorm2d):
680
+ m.weight.data.fill_(1.0)
681
+ m.bias.data.zero_()
682
+ elif isinstance(m, nn.Linear):
683
+ nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
geffnet/gen_efficientnet.py ADDED
@@ -0,0 +1,1450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Generic Efficient Networks
2
+
3
+ A generic MobileNet class with building blocks to support a variety of models:
4
+
5
+ * EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports)
6
+ - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
7
+ - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
8
+ - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
9
+ - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252
10
+
11
+ * EfficientNet-Lite
12
+
13
+ * MixNet (Small, Medium, and Large)
14
+ - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595
15
+
16
+ * MNasNet B1, A1 (SE), Small
17
+ - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626
18
+
19
+ * FBNet-C
20
+ - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443
21
+
22
+ * Single-Path NAS Pixel1
23
+ - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877
24
+
25
+ * And likely more...
26
+
27
+ Hacked together by / Copyright 2020 Ross Wightman
28
+ """
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ from .config import layer_config_kwargs, is_scriptable
33
+ from .conv2d_layers import select_conv2d
34
+ from .helpers import load_pretrained
35
+ from .efficientnet_builder import *
36
+
37
+ __all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140',
38
+ 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small',
39
+ 'mobilenetv2_100', 'mobilenetv2_140', 'mobilenetv2_110d', 'mobilenetv2_120d',
40
+ 'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3',
41
+ 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8',
42
+ 'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el',
43
+ 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e',
44
+ 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4',
45
+ 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3',
46
+ 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8',
47
+ 'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap',
48
+ 'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap',
49
+ 'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns',
50
+ 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns',
51
+ 'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475',
52
+ 'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el',
53
+ 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e',
54
+ 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3',
55
+ 'tf_efficientnet_lite4',
56
+ 'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l']
57
+
58
+
59
+ model_urls = {
60
+ 'mnasnet_050': None,
61
+ 'mnasnet_075': None,
62
+ 'mnasnet_100':
63
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
64
+ 'mnasnet_140': None,
65
+ 'mnasnet_small': None,
66
+
67
+ 'semnasnet_050': None,
68
+ 'semnasnet_075': None,
69
+ 'semnasnet_100':
70
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
71
+ 'semnasnet_140': None,
72
+
73
+ 'mobilenetv2_100':
74
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth',
75
+ 'mobilenetv2_110d':
76
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth',
77
+ 'mobilenetv2_120d':
78
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth',
79
+ 'mobilenetv2_140':
80
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth',
81
+
82
+ 'fbnetc_100':
83
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
84
+ 'spnasnet_100':
85
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
86
+
87
+ 'efficientnet_b0':
88
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth',
89
+ 'efficientnet_b1':
90
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
91
+ 'efficientnet_b2':
92
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth',
93
+ 'efficientnet_b3':
94
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth',
95
+ 'efficientnet_b4': None,
96
+ 'efficientnet_b5': None,
97
+ 'efficientnet_b6': None,
98
+ 'efficientnet_b7': None,
99
+ 'efficientnet_b8': None,
100
+ 'efficientnet_l2': None,
101
+
102
+ 'efficientnet_es':
103
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
104
+ 'efficientnet_em': None,
105
+ 'efficientnet_el': None,
106
+
107
+ 'efficientnet_cc_b0_4e': None,
108
+ 'efficientnet_cc_b0_8e': None,
109
+ 'efficientnet_cc_b1_8e': None,
110
+
111
+ 'efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth',
112
+ 'efficientnet_lite1': None,
113
+ 'efficientnet_lite2': None,
114
+ 'efficientnet_lite3': None,
115
+ 'efficientnet_lite4': None,
116
+
117
+ 'tf_efficientnet_b0':
118
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
119
+ 'tf_efficientnet_b1':
120
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
121
+ 'tf_efficientnet_b2':
122
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
123
+ 'tf_efficientnet_b3':
124
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
125
+ 'tf_efficientnet_b4':
126
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
127
+ 'tf_efficientnet_b5':
128
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
129
+ 'tf_efficientnet_b6':
130
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
131
+ 'tf_efficientnet_b7':
132
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
133
+ 'tf_efficientnet_b8':
134
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
135
+
136
+ 'tf_efficientnet_b0_ap':
137
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
138
+ 'tf_efficientnet_b1_ap':
139
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth',
140
+ 'tf_efficientnet_b2_ap':
141
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth',
142
+ 'tf_efficientnet_b3_ap':
143
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth',
144
+ 'tf_efficientnet_b4_ap':
145
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth',
146
+ 'tf_efficientnet_b5_ap':
147
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth',
148
+ 'tf_efficientnet_b6_ap':
149
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth',
150
+ 'tf_efficientnet_b7_ap':
151
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth',
152
+ 'tf_efficientnet_b8_ap':
153
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
154
+
155
+ 'tf_efficientnet_b0_ns':
156
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
157
+ 'tf_efficientnet_b1_ns':
158
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
159
+ 'tf_efficientnet_b2_ns':
160
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
161
+ 'tf_efficientnet_b3_ns':
162
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
163
+ 'tf_efficientnet_b4_ns':
164
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
165
+ 'tf_efficientnet_b5_ns':
166
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
167
+ 'tf_efficientnet_b6_ns':
168
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
169
+ 'tf_efficientnet_b7_ns':
170
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
171
+ 'tf_efficientnet_l2_ns_475':
172
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
173
+ 'tf_efficientnet_l2_ns':
174
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
175
+
176
+ 'tf_efficientnet_es':
177
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
178
+ 'tf_efficientnet_em':
179
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
180
+ 'tf_efficientnet_el':
181
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
182
+
183
+ 'tf_efficientnet_cc_b0_4e':
184
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth',
185
+ 'tf_efficientnet_cc_b0_8e':
186
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth',
187
+ 'tf_efficientnet_cc_b1_8e':
188
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',
189
+
190
+ 'tf_efficientnet_lite0':
191
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
192
+ 'tf_efficientnet_lite1':
193
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
194
+ 'tf_efficientnet_lite2':
195
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
196
+ 'tf_efficientnet_lite3':
197
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
198
+ 'tf_efficientnet_lite4':
199
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',
200
+
201
+ 'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth',
202
+ 'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth',
203
+ 'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth',
204
+ 'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth',
205
+
206
+ 'tf_mixnet_s':
207
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth',
208
+ 'tf_mixnet_m':
209
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth',
210
+ 'tf_mixnet_l':
211
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth',
212
+ }
213
+
214
+
215
+ class GenEfficientNet(nn.Module):
216
+ """ Generic EfficientNets
217
+
218
+ An implementation of mobile optimized networks that covers:
219
+ * EfficientNet (B0-B8, L2, CondConv, EdgeTPU)
220
+ * MixNet (Small, Medium, and Large, XL)
221
+ * MNASNet A1, B1, and small
222
+ * FBNet C
223
+ * Single-Path NAS Pixel1
224
+ """
225
+
226
+ def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False,
227
+ channel_multiplier=1.0, channel_divisor=8, channel_min=None,
228
+ pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
229
+ se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
230
+ weight_init='goog'):
231
+ super(GenEfficientNet, self).__init__()
232
+ self.drop_rate = drop_rate
233
+
234
+ if not fix_stem:
235
+ stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
236
+ self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
237
+ self.bn1 = norm_layer(stem_size, **norm_kwargs)
238
+ self.act1 = act_layer(inplace=True)
239
+ in_chs = stem_size
240
+
241
+ builder = EfficientNetBuilder(
242
+ channel_multiplier, channel_divisor, channel_min,
243
+ pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate)
244
+ self.blocks = nn.Sequential(*builder(in_chs, block_args))
245
+ in_chs = builder.in_chs
246
+
247
+ self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type)
248
+ self.bn2 = norm_layer(num_features, **norm_kwargs)
249
+ self.act2 = act_layer(inplace=True)
250
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
251
+ self.classifier = nn.Linear(num_features, num_classes)
252
+
253
+ for n, m in self.named_modules():
254
+ if weight_init == 'goog':
255
+ initialize_weight_goog(m, n)
256
+ else:
257
+ initialize_weight_default(m, n)
258
+
259
+ def features(self, x):
260
+ x = self.conv_stem(x)
261
+ x = self.bn1(x)
262
+ x = self.act1(x)
263
+ x = self.blocks(x)
264
+ x = self.conv_head(x)
265
+ x = self.bn2(x)
266
+ x = self.act2(x)
267
+ return x
268
+
269
+ def as_sequential(self):
270
+ layers = [self.conv_stem, self.bn1, self.act1]
271
+ layers.extend(self.blocks)
272
+ layers.extend([
273
+ self.conv_head, self.bn2, self.act2,
274
+ self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
275
+ return nn.Sequential(*layers)
276
+
277
+ def forward(self, x):
278
+ x = self.features(x)
279
+ x = self.global_pool(x)
280
+ x = x.flatten(1)
281
+ if self.drop_rate > 0.:
282
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
283
+ return self.classifier(x)
284
+
285
+
286
+ def _create_model(model_kwargs, variant, pretrained=False):
287
+ as_sequential = model_kwargs.pop('as_sequential', False)
288
+ model = GenEfficientNet(**model_kwargs)
289
+ if pretrained:
290
+ load_pretrained(model, model_urls[variant])
291
+ if as_sequential:
292
+ model = model.as_sequential()
293
+ return model
294
+
295
+
296
+ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
297
+ """Creates a mnasnet-a1 model.
298
+
299
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
300
+ Paper: https://arxiv.org/pdf/1807.11626.pdf.
301
+
302
+ Args:
303
+ channel_multiplier: multiplier to number of channels per layer.
304
+ """
305
+ arch_def = [
306
+ # stage 0, 112x112 in
307
+ ['ds_r1_k3_s1_e1_c16_noskip'],
308
+ # stage 1, 112x112 in
309
+ ['ir_r2_k3_s2_e6_c24'],
310
+ # stage 2, 56x56 in
311
+ ['ir_r3_k5_s2_e3_c40_se0.25'],
312
+ # stage 3, 28x28 in
313
+ ['ir_r4_k3_s2_e6_c80'],
314
+ # stage 4, 14x14in
315
+ ['ir_r2_k3_s1_e6_c112_se0.25'],
316
+ # stage 5, 14x14in
317
+ ['ir_r3_k5_s2_e6_c160_se0.25'],
318
+ # stage 6, 7x7 in
319
+ ['ir_r1_k3_s1_e6_c320'],
320
+ ]
321
+ with layer_config_kwargs(kwargs):
322
+ model_kwargs = dict(
323
+ block_args=decode_arch_def(arch_def),
324
+ stem_size=32,
325
+ channel_multiplier=channel_multiplier,
326
+ act_layer=resolve_act_layer(kwargs, 'relu'),
327
+ norm_kwargs=resolve_bn_args(kwargs),
328
+ **kwargs
329
+ )
330
+ model = _create_model(model_kwargs, variant, pretrained)
331
+ return model
332
+
333
+
334
+ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
335
+ """Creates a mnasnet-b1 model.
336
+
337
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
338
+ Paper: https://arxiv.org/pdf/1807.11626.pdf.
339
+
340
+ Args:
341
+ channel_multiplier: multiplier to number of channels per layer.
342
+ """
343
+ arch_def = [
344
+ # stage 0, 112x112 in
345
+ ['ds_r1_k3_s1_c16_noskip'],
346
+ # stage 1, 112x112 in
347
+ ['ir_r3_k3_s2_e3_c24'],
348
+ # stage 2, 56x56 in
349
+ ['ir_r3_k5_s2_e3_c40'],
350
+ # stage 3, 28x28 in
351
+ ['ir_r3_k5_s2_e6_c80'],
352
+ # stage 4, 14x14in
353
+ ['ir_r2_k3_s1_e6_c96'],
354
+ # stage 5, 14x14in
355
+ ['ir_r4_k5_s2_e6_c192'],
356
+ # stage 6, 7x7 in
357
+ ['ir_r1_k3_s1_e6_c320_noskip']
358
+ ]
359
+ with layer_config_kwargs(kwargs):
360
+ model_kwargs = dict(
361
+ block_args=decode_arch_def(arch_def),
362
+ stem_size=32,
363
+ channel_multiplier=channel_multiplier,
364
+ act_layer=resolve_act_layer(kwargs, 'relu'),
365
+ norm_kwargs=resolve_bn_args(kwargs),
366
+ **kwargs
367
+ )
368
+ model = _create_model(model_kwargs, variant, pretrained)
369
+ return model
370
+
371
+
372
+ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
373
+ """Creates a mnasnet-b1 model.
374
+
375
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
376
+ Paper: https://arxiv.org/pdf/1807.11626.pdf.
377
+
378
+ Args:
379
+ channel_multiplier: multiplier to number of channels per layer.
380
+ """
381
+ arch_def = [
382
+ ['ds_r1_k3_s1_c8'],
383
+ ['ir_r1_k3_s2_e3_c16'],
384
+ ['ir_r2_k3_s2_e6_c16'],
385
+ ['ir_r4_k5_s2_e6_c32_se0.25'],
386
+ ['ir_r3_k3_s1_e6_c32_se0.25'],
387
+ ['ir_r3_k5_s2_e6_c88_se0.25'],
388
+ ['ir_r1_k3_s1_e6_c144']
389
+ ]
390
+ with layer_config_kwargs(kwargs):
391
+ model_kwargs = dict(
392
+ block_args=decode_arch_def(arch_def),
393
+ stem_size=8,
394
+ channel_multiplier=channel_multiplier,
395
+ act_layer=resolve_act_layer(kwargs, 'relu'),
396
+ norm_kwargs=resolve_bn_args(kwargs),
397
+ **kwargs
398
+ )
399
+ model = _create_model(model_kwargs, variant, pretrained)
400
+ return model
401
+
402
+
403
+ def _gen_mobilenet_v2(
404
+ variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
405
+ """ Generate MobileNet-V2 network
406
+ Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
407
+ Paper: https://arxiv.org/abs/1801.04381
408
+ """
409
+ arch_def = [
410
+ ['ds_r1_k3_s1_c16'],
411
+ ['ir_r2_k3_s2_e6_c24'],
412
+ ['ir_r3_k3_s2_e6_c32'],
413
+ ['ir_r4_k3_s2_e6_c64'],
414
+ ['ir_r3_k3_s1_e6_c96'],
415
+ ['ir_r3_k3_s2_e6_c160'],
416
+ ['ir_r1_k3_s1_e6_c320'],
417
+ ]
418
+ with layer_config_kwargs(kwargs):
419
+ model_kwargs = dict(
420
+ block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
421
+ num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None),
422
+ stem_size=32,
423
+ fix_stem=fix_stem_head,
424
+ channel_multiplier=channel_multiplier,
425
+ norm_kwargs=resolve_bn_args(kwargs),
426
+ act_layer=nn.ReLU6,
427
+ **kwargs
428
+ )
429
+ model = _create_model(model_kwargs, variant, pretrained)
430
+ return model
431
+
432
+
433
+ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
434
+ """ FBNet-C
435
+
436
+ Paper: https://arxiv.org/abs/1812.03443
437
+ Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
438
+
439
+ NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper,
440
+ it was used to confirm some building block details
441
+ """
442
+ arch_def = [
443
+ ['ir_r1_k3_s1_e1_c16'],
444
+ ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'],
445
+ ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'],
446
+ ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'],
447
+ ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'],
448
+ ['ir_r4_k5_s2_e6_c184'],
449
+ ['ir_r1_k3_s1_e6_c352'],
450
+ ]
451
+ with layer_config_kwargs(kwargs):
452
+ model_kwargs = dict(
453
+ block_args=decode_arch_def(arch_def),
454
+ stem_size=16,
455
+ num_features=1984, # paper suggests this, but is not 100% clear
456
+ channel_multiplier=channel_multiplier,
457
+ act_layer=resolve_act_layer(kwargs, 'relu'),
458
+ norm_kwargs=resolve_bn_args(kwargs),
459
+ **kwargs
460
+ )
461
+ model = _create_model(model_kwargs, variant, pretrained)
462
+ return model
463
+
464
+
465
+ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
466
+ """Creates the Single-Path NAS model from search targeted for Pixel1 phone.
467
+
468
+ Paper: https://arxiv.org/abs/1904.02877
469
+
470
+ Args:
471
+ channel_multiplier: multiplier to number of channels per layer.
472
+ """
473
+ arch_def = [
474
+ # stage 0, 112x112 in
475
+ ['ds_r1_k3_s1_c16_noskip'],
476
+ # stage 1, 112x112 in
477
+ ['ir_r3_k3_s2_e3_c24'],
478
+ # stage 2, 56x56 in
479
+ ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
480
+ # stage 3, 28x28 in
481
+ ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
482
+ # stage 4, 14x14in
483
+ ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
484
+ # stage 5, 14x14in
485
+ ['ir_r4_k5_s2_e6_c192'],
486
+ # stage 6, 7x7 in
487
+ ['ir_r1_k3_s1_e6_c320_noskip']
488
+ ]
489
+ with layer_config_kwargs(kwargs):
490
+ model_kwargs = dict(
491
+ block_args=decode_arch_def(arch_def),
492
+ stem_size=32,
493
+ channel_multiplier=channel_multiplier,
494
+ act_layer=resolve_act_layer(kwargs, 'relu'),
495
+ norm_kwargs=resolve_bn_args(kwargs),
496
+ **kwargs
497
+ )
498
+ model = _create_model(model_kwargs, variant, pretrained)
499
+ return model
500
+
501
+
502
+ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
503
+ """Creates an EfficientNet model.
504
+
505
+ Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
506
+ Paper: https://arxiv.org/abs/1905.11946
507
+
508
+ EfficientNet params
509
+ name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
510
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
511
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
512
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
513
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
514
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
515
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
516
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
517
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
518
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
519
+
520
+ Args:
521
+ channel_multiplier: multiplier to number of channels per layer
522
+ depth_multiplier: multiplier to number of repeats per stage
523
+
524
+ """
525
+ arch_def = [
526
+ ['ds_r1_k3_s1_e1_c16_se0.25'],
527
+ ['ir_r2_k3_s2_e6_c24_se0.25'],
528
+ ['ir_r2_k5_s2_e6_c40_se0.25'],
529
+ ['ir_r3_k3_s2_e6_c80_se0.25'],
530
+ ['ir_r3_k5_s1_e6_c112_se0.25'],
531
+ ['ir_r4_k5_s2_e6_c192_se0.25'],
532
+ ['ir_r1_k3_s1_e6_c320_se0.25'],
533
+ ]
534
+ with layer_config_kwargs(kwargs):
535
+ model_kwargs = dict(
536
+ block_args=decode_arch_def(arch_def, depth_multiplier),
537
+ num_features=round_channels(1280, channel_multiplier, 8, None),
538
+ stem_size=32,
539
+ channel_multiplier=channel_multiplier,
540
+ act_layer=resolve_act_layer(kwargs, 'swish'),
541
+ norm_kwargs=resolve_bn_args(kwargs),
542
+ **kwargs,
543
+ )
544
+ model = _create_model(model_kwargs, variant, pretrained)
545
+ return model
546
+
547
+
548
+ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
549
+ arch_def = [
550
+ # NOTE `fc` is present to override a mismatch between stem channels and in chs not
551
+ # present in other models
552
+ ['er_r1_k3_s1_e4_c24_fc24_noskip'],
553
+ ['er_r2_k3_s2_e8_c32'],
554
+ ['er_r4_k3_s2_e8_c48'],
555
+ ['ir_r5_k5_s2_e8_c96'],
556
+ ['ir_r4_k5_s1_e8_c144'],
557
+ ['ir_r2_k5_s2_e8_c192'],
558
+ ]
559
+ with layer_config_kwargs(kwargs):
560
+ model_kwargs = dict(
561
+ block_args=decode_arch_def(arch_def, depth_multiplier),
562
+ num_features=round_channels(1280, channel_multiplier, 8, None),
563
+ stem_size=32,
564
+ channel_multiplier=channel_multiplier,
565
+ act_layer=resolve_act_layer(kwargs, 'relu'),
566
+ norm_kwargs=resolve_bn_args(kwargs),
567
+ **kwargs,
568
+ )
569
+ model = _create_model(model_kwargs, variant, pretrained)
570
+ return model
571
+
572
+
573
+ def _gen_efficientnet_condconv(
574
+ variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
575
+ """Creates an efficientnet-condconv model."""
576
+ arch_def = [
577
+ ['ds_r1_k3_s1_e1_c16_se0.25'],
578
+ ['ir_r2_k3_s2_e6_c24_se0.25'],
579
+ ['ir_r2_k5_s2_e6_c40_se0.25'],
580
+ ['ir_r3_k3_s2_e6_c80_se0.25'],
581
+ ['ir_r3_k5_s1_e6_c112_se0.25_cc4'],
582
+ ['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
583
+ ['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
584
+ ]
585
+ with layer_config_kwargs(kwargs):
586
+ model_kwargs = dict(
587
+ block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
588
+ num_features=round_channels(1280, channel_multiplier, 8, None),
589
+ stem_size=32,
590
+ channel_multiplier=channel_multiplier,
591
+ act_layer=resolve_act_layer(kwargs, 'swish'),
592
+ norm_kwargs=resolve_bn_args(kwargs),
593
+ **kwargs,
594
+ )
595
+ model = _create_model(model_kwargs, variant, pretrained)
596
+ return model
597
+
598
+
599
+ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
600
+ """Creates an EfficientNet-Lite model.
601
+
602
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
603
+ Paper: https://arxiv.org/abs/1905.11946
604
+
605
+ EfficientNet params
606
+ name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
607
+ 'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
608
+ 'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
609
+ 'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
610
+ 'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
611
+ 'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
612
+
613
+ Args:
614
+ channel_multiplier: multiplier to number of channels per layer
615
+ depth_multiplier: multiplier to number of repeats per stage
616
+ """
617
+ arch_def = [
618
+ ['ds_r1_k3_s1_e1_c16'],
619
+ ['ir_r2_k3_s2_e6_c24'],
620
+ ['ir_r2_k5_s2_e6_c40'],
621
+ ['ir_r3_k3_s2_e6_c80'],
622
+ ['ir_r3_k5_s1_e6_c112'],
623
+ ['ir_r4_k5_s2_e6_c192'],
624
+ ['ir_r1_k3_s1_e6_c320'],
625
+ ]
626
+ with layer_config_kwargs(kwargs):
627
+ model_kwargs = dict(
628
+ block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
629
+ num_features=1280,
630
+ stem_size=32,
631
+ fix_stem=True,
632
+ channel_multiplier=channel_multiplier,
633
+ act_layer=nn.ReLU6,
634
+ norm_kwargs=resolve_bn_args(kwargs),
635
+ **kwargs,
636
+ )
637
+ model = _create_model(model_kwargs, variant, pretrained)
638
+ return model
639
+
640
+
641
+ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
642
+ """Creates a MixNet Small model.
643
+
644
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
645
+ Paper: https://arxiv.org/abs/1907.09595
646
+ """
647
+ arch_def = [
648
+ # stage 0, 112x112 in
649
+ ['ds_r1_k3_s1_e1_c16'], # relu
650
+ # stage 1, 112x112 in
651
+ ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu
652
+ # stage 2, 56x56 in
653
+ ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
654
+ # stage 3, 28x28 in
655
+ ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish
656
+ # stage 4, 14x14in
657
+ ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
658
+ # stage 5, 14x14in
659
+ ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
660
+ # 7x7
661
+ ]
662
+ with layer_config_kwargs(kwargs):
663
+ model_kwargs = dict(
664
+ block_args=decode_arch_def(arch_def),
665
+ num_features=1536,
666
+ stem_size=16,
667
+ channel_multiplier=channel_multiplier,
668
+ act_layer=resolve_act_layer(kwargs, 'relu'),
669
+ norm_kwargs=resolve_bn_args(kwargs),
670
+ **kwargs
671
+ )
672
+ model = _create_model(model_kwargs, variant, pretrained)
673
+ return model
674
+
675
+
676
+ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
677
+ """Creates a MixNet Medium-Large model.
678
+
679
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
680
+ Paper: https://arxiv.org/abs/1907.09595
681
+ """
682
+ arch_def = [
683
+ # stage 0, 112x112 in
684
+ ['ds_r1_k3_s1_e1_c24'], # relu
685
+ # stage 1, 112x112 in
686
+ ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu
687
+ # stage 2, 56x56 in
688
+ ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
689
+ # stage 3, 28x28 in
690
+ ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish
691
+ # stage 4, 14x14in
692
+ ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
693
+ # stage 5, 14x14in
694
+ ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
695
+ # 7x7
696
+ ]
697
+ with layer_config_kwargs(kwargs):
698
+ model_kwargs = dict(
699
+ block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
700
+ num_features=1536,
701
+ stem_size=24,
702
+ channel_multiplier=channel_multiplier,
703
+ act_layer=resolve_act_layer(kwargs, 'relu'),
704
+ norm_kwargs=resolve_bn_args(kwargs),
705
+ **kwargs
706
+ )
707
+ model = _create_model(model_kwargs, variant, pretrained)
708
+ return model
709
+
710
+
711
+ def mnasnet_050(pretrained=False, **kwargs):
712
+ """ MNASNet B1, depth multiplier of 0.5. """
713
+ model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs)
714
+ return model
715
+
716
+
717
+ def mnasnet_075(pretrained=False, **kwargs):
718
+ """ MNASNet B1, depth multiplier of 0.75. """
719
+ model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs)
720
+ return model
721
+
722
+
723
+ def mnasnet_100(pretrained=False, **kwargs):
724
+ """ MNASNet B1, depth multiplier of 1.0. """
725
+ model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs)
726
+ return model
727
+
728
+
729
+ def mnasnet_b1(pretrained=False, **kwargs):
730
+ """ MNASNet B1, depth multiplier of 1.0. """
731
+ return mnasnet_100(pretrained, **kwargs)
732
+
733
+
734
+ def mnasnet_140(pretrained=False, **kwargs):
735
+ """ MNASNet B1, depth multiplier of 1.4 """
736
+ model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs)
737
+ return model
738
+
739
+
740
+ def semnasnet_050(pretrained=False, **kwargs):
741
+ """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """
742
+ model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs)
743
+ return model
744
+
745
+
746
+ def semnasnet_075(pretrained=False, **kwargs):
747
+ """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """
748
+ model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs)
749
+ return model
750
+
751
+
752
+ def semnasnet_100(pretrained=False, **kwargs):
753
+ """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
754
+ model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs)
755
+ return model
756
+
757
+
758
+ def mnasnet_a1(pretrained=False, **kwargs):
759
+ """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
760
+ return semnasnet_100(pretrained, **kwargs)
761
+
762
+
763
+ def semnasnet_140(pretrained=False, **kwargs):
764
+ """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """
765
+ model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs)
766
+ return model
767
+
768
+
769
+ def mnasnet_small(pretrained=False, **kwargs):
770
+ """ MNASNet Small, depth multiplier of 1.0. """
771
+ model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs)
772
+ return model
773
+
774
+
775
+ def mobilenetv2_100(pretrained=False, **kwargs):
776
+ """ MobileNet V2 w/ 1.0 channel multiplier """
777
+ model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs)
778
+ return model
779
+
780
+
781
+ def mobilenetv2_140(pretrained=False, **kwargs):
782
+ """ MobileNet V2 w/ 1.4 channel multiplier """
783
+ model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs)
784
+ return model
785
+
786
+
787
+ def mobilenetv2_110d(pretrained=False, **kwargs):
788
+ """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers"""
789
+ model = _gen_mobilenet_v2(
790
+ 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs)
791
+ return model
792
+
793
+
794
+ def mobilenetv2_120d(pretrained=False, **kwargs):
795
+ """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """
796
+ model = _gen_mobilenet_v2(
797
+ 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs)
798
+ return model
799
+
800
+
801
+ def fbnetc_100(pretrained=False, **kwargs):
802
+ """ FBNet-C """
803
+ if pretrained:
804
+ # pretrained model trained with non-default BN epsilon
805
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
806
+ model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs)
807
+ return model
808
+
809
+
810
+ def spnasnet_100(pretrained=False, **kwargs):
811
+ """ Single-Path NAS Pixel1"""
812
+ model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs)
813
+ return model
814
+
815
+
816
+ def efficientnet_b0(pretrained=False, **kwargs):
817
+ """ EfficientNet-B0 """
818
+ # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
819
+ model = _gen_efficientnet(
820
+ 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
821
+ return model
822
+
823
+
824
+ def efficientnet_b1(pretrained=False, **kwargs):
825
+ """ EfficientNet-B1 """
826
+ # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
827
+ model = _gen_efficientnet(
828
+ 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
829
+ return model
830
+
831
+
832
+ def efficientnet_b2(pretrained=False, **kwargs):
833
+ """ EfficientNet-B2 """
834
+ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
835
+ model = _gen_efficientnet(
836
+ 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
837
+ return model
838
+
839
+
840
+ def efficientnet_b3(pretrained=False, **kwargs):
841
+ """ EfficientNet-B3 """
842
+ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
843
+ model = _gen_efficientnet(
844
+ 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
845
+ return model
846
+
847
+
848
+ def efficientnet_b4(pretrained=False, **kwargs):
849
+ """ EfficientNet-B4 """
850
+ # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
851
+ model = _gen_efficientnet(
852
+ 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
853
+ return model
854
+
855
+
856
+ def efficientnet_b5(pretrained=False, **kwargs):
857
+ """ EfficientNet-B5 """
858
+ # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
859
+ model = _gen_efficientnet(
860
+ 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
861
+ return model
862
+
863
+
864
+ def efficientnet_b6(pretrained=False, **kwargs):
865
+ """ EfficientNet-B6 """
866
+ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
867
+ model = _gen_efficientnet(
868
+ 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
869
+ return model
870
+
871
+
872
+ def efficientnet_b7(pretrained=False, **kwargs):
873
+ """ EfficientNet-B7 """
874
+ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
875
+ model = _gen_efficientnet(
876
+ 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
877
+ return model
878
+
879
+
880
+ def efficientnet_b8(pretrained=False, **kwargs):
881
+ """ EfficientNet-B8 """
882
+ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
883
+ model = _gen_efficientnet(
884
+ 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
885
+ return model
886
+
887
+
888
+ def efficientnet_l2(pretrained=False, **kwargs):
889
+ """ EfficientNet-L2. """
890
+ # NOTE for train, drop_rate should be 0.5
891
+ model = _gen_efficientnet(
892
+ 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
893
+ return model
894
+
895
+
896
+ def efficientnet_es(pretrained=False, **kwargs):
897
+ """ EfficientNet-Edge Small. """
898
+ model = _gen_efficientnet_edge(
899
+ 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
900
+ return model
901
+
902
+
903
+ def efficientnet_em(pretrained=False, **kwargs):
904
+ """ EfficientNet-Edge-Medium. """
905
+ model = _gen_efficientnet_edge(
906
+ 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
907
+ return model
908
+
909
+
910
+ def efficientnet_el(pretrained=False, **kwargs):
911
+ """ EfficientNet-Edge-Large. """
912
+ model = _gen_efficientnet_edge(
913
+ 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
914
+ return model
915
+
916
+
917
+ def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
918
+ """ EfficientNet-CondConv-B0 w/ 8 Experts """
919
+ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
920
+ model = _gen_efficientnet_condconv(
921
+ 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
922
+ return model
923
+
924
+
925
+ def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
926
+ """ EfficientNet-CondConv-B0 w/ 8 Experts """
927
+ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
928
+ model = _gen_efficientnet_condconv(
929
+ 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
930
+ pretrained=pretrained, **kwargs)
931
+ return model
932
+
933
+
934
+ def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
935
+ """ EfficientNet-CondConv-B1 w/ 8 Experts """
936
+ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
937
+ model = _gen_efficientnet_condconv(
938
+ 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
939
+ pretrained=pretrained, **kwargs)
940
+ return model
941
+
942
+
943
+ def efficientnet_lite0(pretrained=False, **kwargs):
944
+ """ EfficientNet-Lite0 """
945
+ model = _gen_efficientnet_lite(
946
+ 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
947
+ return model
948
+
949
+
950
+ def efficientnet_lite1(pretrained=False, **kwargs):
951
+ """ EfficientNet-Lite1 """
952
+ model = _gen_efficientnet_lite(
953
+ 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
954
+ return model
955
+
956
+
957
+ def efficientnet_lite2(pretrained=False, **kwargs):
958
+ """ EfficientNet-Lite2 """
959
+ model = _gen_efficientnet_lite(
960
+ 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
961
+ return model
962
+
963
+
964
+ def efficientnet_lite3(pretrained=False, **kwargs):
965
+ """ EfficientNet-Lite3 """
966
+ model = _gen_efficientnet_lite(
967
+ 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
968
+ return model
969
+
970
+
971
+ def efficientnet_lite4(pretrained=False, **kwargs):
972
+ """ EfficientNet-Lite4 """
973
+ model = _gen_efficientnet_lite(
974
+ 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
975
+ return model
976
+
977
+
978
+ def tf_efficientnet_b0(pretrained=False, **kwargs):
979
+ """ EfficientNet-B0 AutoAug. Tensorflow compatible variant """
980
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
981
+ kwargs['pad_type'] = 'same'
982
+ model = _gen_efficientnet(
983
+ 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
984
+ return model
985
+
986
+
987
+ def tf_efficientnet_b1(pretrained=False, **kwargs):
988
+ """ EfficientNet-B1 AutoAug. Tensorflow compatible variant """
989
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
990
+ kwargs['pad_type'] = 'same'
991
+ model = _gen_efficientnet(
992
+ 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
993
+ return model
994
+
995
+
996
+ def tf_efficientnet_b2(pretrained=False, **kwargs):
997
+ """ EfficientNet-B2 AutoAug. Tensorflow compatible variant """
998
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
999
+ kwargs['pad_type'] = 'same'
1000
+ model = _gen_efficientnet(
1001
+ 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
1002
+ return model
1003
+
1004
+
1005
+ def tf_efficientnet_b3(pretrained=False, **kwargs):
1006
+ """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """
1007
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1008
+ kwargs['pad_type'] = 'same'
1009
+ model = _gen_efficientnet(
1010
+ 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
1011
+ return model
1012
+
1013
+
1014
+ def tf_efficientnet_b4(pretrained=False, **kwargs):
1015
+ """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """
1016
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1017
+ kwargs['pad_type'] = 'same'
1018
+ model = _gen_efficientnet(
1019
+ 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
1020
+ return model
1021
+
1022
+
1023
+ def tf_efficientnet_b5(pretrained=False, **kwargs):
1024
+ """ EfficientNet-B5 RandAug. Tensorflow compatible variant """
1025
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1026
+ kwargs['pad_type'] = 'same'
1027
+ model = _gen_efficientnet(
1028
+ 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
1029
+ return model
1030
+
1031
+
1032
+ def tf_efficientnet_b6(pretrained=False, **kwargs):
1033
+ """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """
1034
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1035
+ kwargs['pad_type'] = 'same'
1036
+ model = _gen_efficientnet(
1037
+ 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
1038
+ return model
1039
+
1040
+
1041
+ def tf_efficientnet_b7(pretrained=False, **kwargs):
1042
+ """ EfficientNet-B7 RandAug. Tensorflow compatible variant """
1043
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1044
+ kwargs['pad_type'] = 'same'
1045
+ model = _gen_efficientnet(
1046
+ 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
1047
+ return model
1048
+
1049
+
1050
+ def tf_efficientnet_b8(pretrained=False, **kwargs):
1051
+ """ EfficientNet-B8 RandAug. Tensorflow compatible variant """
1052
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1053
+ kwargs['pad_type'] = 'same'
1054
+ model = _gen_efficientnet(
1055
+ 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
1056
+ return model
1057
+
1058
+
1059
+ def tf_efficientnet_b0_ap(pretrained=False, **kwargs):
1060
+ """ EfficientNet-B0 AdvProp. Tensorflow compatible variant
1061
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
1062
+ """
1063
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1064
+ kwargs['pad_type'] = 'same'
1065
+ model = _gen_efficientnet(
1066
+ 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1067
+ return model
1068
+
1069
+
1070
+ def tf_efficientnet_b1_ap(pretrained=False, **kwargs):
1071
+ """ EfficientNet-B1 AdvProp. Tensorflow compatible variant
1072
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
1073
+ """
1074
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1075
+ kwargs['pad_type'] = 'same'
1076
+ model = _gen_efficientnet(
1077
+ 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
1078
+ return model
1079
+
1080
+
1081
+ def tf_efficientnet_b2_ap(pretrained=False, **kwargs):
1082
+ """ EfficientNet-B2 AdvProp. Tensorflow compatible variant
1083
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
1084
+ """
1085
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1086
+ kwargs['pad_type'] = 'same'
1087
+ model = _gen_efficientnet(
1088
+ 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
1089
+ return model
1090
+
1091
+
1092
+ def tf_efficientnet_b3_ap(pretrained=False, **kwargs):
1093
+ """ EfficientNet-B3 AdvProp. Tensorflow compatible variant
1094
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
1095
+ """
1096
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1097
+ kwargs['pad_type'] = 'same'
1098
+ model = _gen_efficientnet(
1099
+ 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
1100
+ return model
1101
+
1102
+
1103
+ def tf_efficientnet_b4_ap(pretrained=False, **kwargs):
1104
+ """ EfficientNet-B4 AdvProp. Tensorflow compatible variant
1105
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
1106
+ """
1107
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1108
+ kwargs['pad_type'] = 'same'
1109
+ model = _gen_efficientnet(
1110
+ 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
1111
+ return model
1112
+
1113
+
1114
+ def tf_efficientnet_b5_ap(pretrained=False, **kwargs):
1115
+ """ EfficientNet-B5 AdvProp. Tensorflow compatible variant
1116
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
1117
+ """
1118
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1119
+ kwargs['pad_type'] = 'same'
1120
+ model = _gen_efficientnet(
1121
+ 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
1122
+ return model
1123
+
1124
+
1125
+ def tf_efficientnet_b6_ap(pretrained=False, **kwargs):
1126
+ """ EfficientNet-B6 AdvProp. Tensorflow compatible variant
1127
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
1128
+ """
1129
+ # NOTE for train, drop_rate should be 0.5
1130
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1131
+ kwargs['pad_type'] = 'same'
1132
+ model = _gen_efficientnet(
1133
+ 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
1134
+ return model
1135
+
1136
+
1137
+ def tf_efficientnet_b7_ap(pretrained=False, **kwargs):
1138
+ """ EfficientNet-B7 AdvProp. Tensorflow compatible variant
1139
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
1140
+ """
1141
+ # NOTE for train, drop_rate should be 0.5
1142
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1143
+ kwargs['pad_type'] = 'same'
1144
+ model = _gen_efficientnet(
1145
+ 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
1146
+ return model
1147
+
1148
+
1149
+ def tf_efficientnet_b8_ap(pretrained=False, **kwargs):
1150
+ """ EfficientNet-B8 AdvProp. Tensorflow compatible variant
1151
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
1152
+ """
1153
+ # NOTE for train, drop_rate should be 0.5
1154
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1155
+ kwargs['pad_type'] = 'same'
1156
+ model = _gen_efficientnet(
1157
+ 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
1158
+ return model
1159
+
1160
+
1161
+ def tf_efficientnet_b0_ns(pretrained=False, **kwargs):
1162
+ """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant
1163
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1164
+ """
1165
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1166
+ kwargs['pad_type'] = 'same'
1167
+ model = _gen_efficientnet(
1168
+ 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1169
+ return model
1170
+
1171
+
1172
+ def tf_efficientnet_b1_ns(pretrained=False, **kwargs):
1173
+ """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant
1174
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1175
+ """
1176
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1177
+ kwargs['pad_type'] = 'same'
1178
+ model = _gen_efficientnet(
1179
+ 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
1180
+ return model
1181
+
1182
+
1183
+ def tf_efficientnet_b2_ns(pretrained=False, **kwargs):
1184
+ """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant
1185
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1186
+ """
1187
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1188
+ kwargs['pad_type'] = 'same'
1189
+ model = _gen_efficientnet(
1190
+ 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
1191
+ return model
1192
+
1193
+
1194
+ def tf_efficientnet_b3_ns(pretrained=False, **kwargs):
1195
+ """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant
1196
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1197
+ """
1198
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1199
+ kwargs['pad_type'] = 'same'
1200
+ model = _gen_efficientnet(
1201
+ 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
1202
+ return model
1203
+
1204
+
1205
+ def tf_efficientnet_b4_ns(pretrained=False, **kwargs):
1206
+ """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant
1207
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1208
+ """
1209
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1210
+ kwargs['pad_type'] = 'same'
1211
+ model = _gen_efficientnet(
1212
+ 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
1213
+ return model
1214
+
1215
+
1216
+ def tf_efficientnet_b5_ns(pretrained=False, **kwargs):
1217
+ """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant
1218
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1219
+ """
1220
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1221
+ kwargs['pad_type'] = 'same'
1222
+ model = _gen_efficientnet(
1223
+ 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
1224
+ return model
1225
+
1226
+
1227
+ def tf_efficientnet_b6_ns(pretrained=False, **kwargs):
1228
+ """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant
1229
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1230
+ """
1231
+ # NOTE for train, drop_rate should be 0.5
1232
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1233
+ kwargs['pad_type'] = 'same'
1234
+ model = _gen_efficientnet(
1235
+ 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
1236
+ return model
1237
+
1238
+
1239
+ def tf_efficientnet_b7_ns(pretrained=False, **kwargs):
1240
+ """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant
1241
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1242
+ """
1243
+ # NOTE for train, drop_rate should be 0.5
1244
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1245
+ kwargs['pad_type'] = 'same'
1246
+ model = _gen_efficientnet(
1247
+ 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
1248
+ return model
1249
+
1250
+
1251
+ def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs):
1252
+ """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant
1253
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1254
+ """
1255
+ # NOTE for train, drop_rate should be 0.5
1256
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1257
+ kwargs['pad_type'] = 'same'
1258
+ model = _gen_efficientnet(
1259
+ 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
1260
+ return model
1261
+
1262
+
1263
+ def tf_efficientnet_l2_ns(pretrained=False, **kwargs):
1264
+ """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant
1265
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
1266
+ """
1267
+ # NOTE for train, drop_rate should be 0.5
1268
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1269
+ kwargs['pad_type'] = 'same'
1270
+ model = _gen_efficientnet(
1271
+ 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
1272
+ return model
1273
+
1274
+
1275
+ def tf_efficientnet_es(pretrained=False, **kwargs):
1276
+ """ EfficientNet-Edge Small. Tensorflow compatible variant """
1277
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1278
+ kwargs['pad_type'] = 'same'
1279
+ model = _gen_efficientnet_edge(
1280
+ 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1281
+ return model
1282
+
1283
+
1284
+ def tf_efficientnet_em(pretrained=False, **kwargs):
1285
+ """ EfficientNet-Edge-Medium. Tensorflow compatible variant """
1286
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1287
+ kwargs['pad_type'] = 'same'
1288
+ model = _gen_efficientnet_edge(
1289
+ 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
1290
+ return model
1291
+
1292
+
1293
+ def tf_efficientnet_el(pretrained=False, **kwargs):
1294
+ """ EfficientNet-Edge-Large. Tensorflow compatible variant """
1295
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1296
+ kwargs['pad_type'] = 'same'
1297
+ model = _gen_efficientnet_edge(
1298
+ 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
1299
+ return model
1300
+
1301
+
1302
+ def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
1303
+ """ EfficientNet-CondConv-B0 w/ 4 Experts """
1304
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1305
+ kwargs['pad_type'] = 'same'
1306
+ model = _gen_efficientnet_condconv(
1307
+ 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1308
+ return model
1309
+
1310
+
1311
+ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
1312
+ """ EfficientNet-CondConv-B0 w/ 8 Experts """
1313
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1314
+ kwargs['pad_type'] = 'same'
1315
+ model = _gen_efficientnet_condconv(
1316
+ 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
1317
+ pretrained=pretrained, **kwargs)
1318
+ return model
1319
+
1320
+
1321
+ def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
1322
+ """ EfficientNet-CondConv-B1 w/ 8 Experts """
1323
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1324
+ kwargs['pad_type'] = 'same'
1325
+ model = _gen_efficientnet_condconv(
1326
+ 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
1327
+ pretrained=pretrained, **kwargs)
1328
+ return model
1329
+
1330
+
1331
+ def tf_efficientnet_lite0(pretrained=False, **kwargs):
1332
+ """ EfficientNet-Lite0. Tensorflow compatible variant """
1333
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1334
+ kwargs['pad_type'] = 'same'
1335
+ model = _gen_efficientnet_lite(
1336
+ 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
1337
+ return model
1338
+
1339
+
1340
+ def tf_efficientnet_lite1(pretrained=False, **kwargs):
1341
+ """ EfficientNet-Lite1. Tensorflow compatible variant """
1342
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1343
+ kwargs['pad_type'] = 'same'
1344
+ model = _gen_efficientnet_lite(
1345
+ 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
1346
+ return model
1347
+
1348
+
1349
+ def tf_efficientnet_lite2(pretrained=False, **kwargs):
1350
+ """ EfficientNet-Lite2. Tensorflow compatible variant """
1351
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1352
+ kwargs['pad_type'] = 'same'
1353
+ model = _gen_efficientnet_lite(
1354
+ 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
1355
+ return model
1356
+
1357
+
1358
+ def tf_efficientnet_lite3(pretrained=False, **kwargs):
1359
+ """ EfficientNet-Lite3. Tensorflow compatible variant """
1360
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1361
+ kwargs['pad_type'] = 'same'
1362
+ model = _gen_efficientnet_lite(
1363
+ 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
1364
+ return model
1365
+
1366
+
1367
+ def tf_efficientnet_lite4(pretrained=False, **kwargs):
1368
+ """ EfficientNet-Lite4. Tensorflow compatible variant """
1369
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1370
+ kwargs['pad_type'] = 'same'
1371
+ model = _gen_efficientnet_lite(
1372
+ 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
1373
+ return model
1374
+
1375
+
1376
+ def mixnet_s(pretrained=False, **kwargs):
1377
+ """Creates a MixNet Small model.
1378
+ """
1379
+ # NOTE for train set drop_rate=0.2
1380
+ model = _gen_mixnet_s(
1381
+ 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
1382
+ return model
1383
+
1384
+
1385
+ def mixnet_m(pretrained=False, **kwargs):
1386
+ """Creates a MixNet Medium model.
1387
+ """
1388
+ # NOTE for train set drop_rate=0.25
1389
+ model = _gen_mixnet_m(
1390
+ 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
1391
+ return model
1392
+
1393
+
1394
+ def mixnet_l(pretrained=False, **kwargs):
1395
+ """Creates a MixNet Large model.
1396
+ """
1397
+ # NOTE for train set drop_rate=0.25
1398
+ model = _gen_mixnet_m(
1399
+ 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
1400
+ return model
1401
+
1402
+
1403
+ def mixnet_xl(pretrained=False, **kwargs):
1404
+ """Creates a MixNet Extra-Large model.
1405
+ Not a paper spec, experimental def by RW w/ depth scaling.
1406
+ """
1407
+ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
1408
+ model = _gen_mixnet_m(
1409
+ 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
1410
+ return model
1411
+
1412
+
1413
+ def mixnet_xxl(pretrained=False, **kwargs):
1414
+ """Creates a MixNet Double Extra Large model.
1415
+ Not a paper spec, experimental def by RW w/ depth scaling.
1416
+ """
1417
+ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
1418
+ model = _gen_mixnet_m(
1419
+ 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs)
1420
+ return model
1421
+
1422
+
1423
+ def tf_mixnet_s(pretrained=False, **kwargs):
1424
+ """Creates a MixNet Small model. Tensorflow compatible variant
1425
+ """
1426
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1427
+ kwargs['pad_type'] = 'same'
1428
+ model = _gen_mixnet_s(
1429
+ 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
1430
+ return model
1431
+
1432
+
1433
+ def tf_mixnet_m(pretrained=False, **kwargs):
1434
+ """Creates a MixNet Medium model. Tensorflow compatible variant
1435
+ """
1436
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1437
+ kwargs['pad_type'] = 'same'
1438
+ model = _gen_mixnet_m(
1439
+ 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
1440
+ return model
1441
+
1442
+
1443
+ def tf_mixnet_l(pretrained=False, **kwargs):
1444
+ """Creates a MixNet Large model. Tensorflow compatible variant
1445
+ """
1446
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1447
+ kwargs['pad_type'] = 'same'
1448
+ model = _gen_mixnet_m(
1449
+ 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
1450
+ return model
geffnet/helpers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Checkpoint loading / state_dict helpers
2
+ Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import os
6
+ from collections import OrderedDict
7
+ try:
8
+ from torch.hub import load_state_dict_from_url
9
+ except ImportError:
10
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
11
+
12
+
13
+ def load_checkpoint(model, checkpoint_path):
14
+ if checkpoint_path and os.path.isfile(checkpoint_path):
15
+ print("=> Loading checkpoint '{}'".format(checkpoint_path))
16
+ checkpoint = torch.load(checkpoint_path)
17
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
18
+ new_state_dict = OrderedDict()
19
+ for k, v in checkpoint['state_dict'].items():
20
+ if k.startswith('module'):
21
+ name = k[7:] # remove `module.`
22
+ else:
23
+ name = k
24
+ new_state_dict[name] = v
25
+ model.load_state_dict(new_state_dict)
26
+ else:
27
+ model.load_state_dict(checkpoint)
28
+ print("=> Loaded checkpoint '{}'".format(checkpoint_path))
29
+ else:
30
+ print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
31
+ raise FileNotFoundError()
32
+
33
+
34
+ def load_pretrained(model, url, filter_fn=None, strict=True):
35
+ if not url:
36
+ print("=> Warning: Pretrained model URL is empty, using random initialization.")
37
+ return
38
+
39
+ state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
40
+
41
+ input_conv = 'conv_stem'
42
+ classifier = 'classifier'
43
+ in_chans = getattr(model, input_conv).weight.shape[1]
44
+ num_classes = getattr(model, classifier).weight.shape[0]
45
+
46
+ input_conv_weight = input_conv + '.weight'
47
+ pretrained_in_chans = state_dict[input_conv_weight].shape[1]
48
+ if in_chans != pretrained_in_chans:
49
+ if in_chans == 1:
50
+ print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
51
+ input_conv_weight, pretrained_in_chans))
52
+ conv1_weight = state_dict[input_conv_weight]
53
+ state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
54
+ else:
55
+ print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
56
+ input_conv_weight, pretrained_in_chans))
57
+ del state_dict[input_conv_weight]
58
+ strict = False
59
+
60
+ classifier_weight = classifier + '.weight'
61
+ pretrained_num_classes = state_dict[classifier_weight].shape[0]
62
+ if num_classes != pretrained_num_classes:
63
+ print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
64
+ del state_dict[classifier_weight]
65
+ del state_dict[classifier + '.bias']
66
+ strict = False
67
+
68
+ if filter_fn is not None:
69
+ state_dict = filter_fn(state_dict)
70
+
71
+ model.load_state_dict(state_dict, strict=strict)
geffnet/mobilenetv3.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ MobileNet-V3
2
+
3
+ A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
4
+
5
+ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
6
+
7
+ Hacked together by / Copyright 2020 Ross Wightman
8
+ """
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from .activations import get_act_fn, get_act_layer, HardSwish
13
+ from .config import layer_config_kwargs
14
+ from .conv2d_layers import select_conv2d
15
+ from .helpers import load_pretrained
16
+ from .efficientnet_builder import *
17
+
18
+ __all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
19
+ 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',
20
+ 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100',
21
+ 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100']
22
+
23
+ model_urls = {
24
+ 'mobilenetv3_rw':
25
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
26
+ 'mobilenetv3_large_075': None,
27
+ 'mobilenetv3_large_100':
28
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
29
+ 'mobilenetv3_large_minimal_100': None,
30
+ 'mobilenetv3_small_075': None,
31
+ 'mobilenetv3_small_100': None,
32
+ 'mobilenetv3_small_minimal_100': None,
33
+ 'tf_mobilenetv3_large_075':
34
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
35
+ 'tf_mobilenetv3_large_100':
36
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
37
+ 'tf_mobilenetv3_large_minimal_100':
38
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
39
+ 'tf_mobilenetv3_small_075':
40
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
41
+ 'tf_mobilenetv3_small_100':
42
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
43
+ 'tf_mobilenetv3_small_minimal_100':
44
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
45
+ }
46
+
47
+
48
+ class MobileNetV3(nn.Module):
49
+ """ MobileNet-V3
50
+
51
+ A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the
52
+ head convolution without a final batch-norm layer before the classifier.
53
+
54
+ Paper: https://arxiv.org/abs/1905.02244
55
+ """
56
+
57
+ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
58
+ channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0.,
59
+ se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
60
+ super(MobileNetV3, self).__init__()
61
+ self.drop_rate = drop_rate
62
+
63
+ stem_size = round_channels(stem_size, channel_multiplier)
64
+ self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
65
+ self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs)
66
+ self.act1 = act_layer(inplace=True)
67
+ in_chs = stem_size
68
+
69
+ builder = EfficientNetBuilder(
70
+ channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs,
71
+ norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate)
72
+ self.blocks = nn.Sequential(*builder(in_chs, block_args))
73
+ in_chs = builder.in_chs
74
+
75
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
76
+ self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias)
77
+ self.act2 = act_layer(inplace=True)
78
+ self.classifier = nn.Linear(num_features, num_classes)
79
+
80
+ for m in self.modules():
81
+ if weight_init == 'goog':
82
+ initialize_weight_goog(m)
83
+ else:
84
+ initialize_weight_default(m)
85
+
86
+ def as_sequential(self):
87
+ layers = [self.conv_stem, self.bn1, self.act1]
88
+ layers.extend(self.blocks)
89
+ layers.extend([
90
+ self.global_pool, self.conv_head, self.act2,
91
+ nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
92
+ return nn.Sequential(*layers)
93
+
94
+ def features(self, x):
95
+ x = self.conv_stem(x)
96
+ x = self.bn1(x)
97
+ x = self.act1(x)
98
+ x = self.blocks(x)
99
+ x = self.global_pool(x)
100
+ x = self.conv_head(x)
101
+ x = self.act2(x)
102
+ return x
103
+
104
+ def forward(self, x):
105
+ x = self.features(x)
106
+ x = x.flatten(1)
107
+ if self.drop_rate > 0.:
108
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
109
+ return self.classifier(x)
110
+
111
+
112
+ def _create_model(model_kwargs, variant, pretrained=False):
113
+ as_sequential = model_kwargs.pop('as_sequential', False)
114
+ model = MobileNetV3(**model_kwargs)
115
+ if pretrained and model_urls[variant]:
116
+ load_pretrained(model, model_urls[variant])
117
+ if as_sequential:
118
+ model = model.as_sequential()
119
+ return model
120
+
121
+
122
+ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
123
+ """Creates a MobileNet-V3 model (RW variant).
124
+
125
+ Paper: https://arxiv.org/abs/1905.02244
126
+
127
+ This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the
128
+ eventual Tensorflow reference impl but has a few differences:
129
+ 1. This model has no bias on the head convolution
130
+ 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet
131
+ 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer
132
+ from their parent block
133
+ 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count
134
+
135
+ Overall the changes are fairly minor and result in a very small parameter count difference and no
136
+ top-1/5
137
+
138
+ Args:
139
+ channel_multiplier: multiplier to number of channels per layer.
140
+ """
141
+ arch_def = [
142
+ # stage 0, 112x112 in
143
+ ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
144
+ # stage 1, 112x112 in
145
+ ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
146
+ # stage 2, 56x56 in
147
+ ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
148
+ # stage 3, 28x28 in
149
+ ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
150
+ # stage 4, 14x14in
151
+ ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
152
+ # stage 5, 14x14in
153
+ ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
154
+ # stage 6, 7x7 in
155
+ ['cn_r1_k1_s1_c960'], # hard-swish
156
+ ]
157
+ with layer_config_kwargs(kwargs):
158
+ model_kwargs = dict(
159
+ block_args=decode_arch_def(arch_def),
160
+ head_bias=False, # one of my mistakes
161
+ channel_multiplier=channel_multiplier,
162
+ act_layer=resolve_act_layer(kwargs, 'hard_swish'),
163
+ se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True),
164
+ norm_kwargs=resolve_bn_args(kwargs),
165
+ **kwargs,
166
+ )
167
+ model = _create_model(model_kwargs, variant, pretrained)
168
+ return model
169
+
170
+
171
+ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
172
+ """Creates a MobileNet-V3 large/small/minimal models.
173
+
174
+ Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
175
+ Paper: https://arxiv.org/abs/1905.02244
176
+
177
+ Args:
178
+ channel_multiplier: multiplier to number of channels per layer.
179
+ """
180
+ if 'small' in variant:
181
+ num_features = 1024
182
+ if 'minimal' in variant:
183
+ act_layer = 'relu'
184
+ arch_def = [
185
+ # stage 0, 112x112 in
186
+ ['ds_r1_k3_s2_e1_c16'],
187
+ # stage 1, 56x56 in
188
+ ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
189
+ # stage 2, 28x28 in
190
+ ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
191
+ # stage 3, 14x14 in
192
+ ['ir_r2_k3_s1_e3_c48'],
193
+ # stage 4, 14x14in
194
+ ['ir_r3_k3_s2_e6_c96'],
195
+ # stage 6, 7x7 in
196
+ ['cn_r1_k1_s1_c576'],
197
+ ]
198
+ else:
199
+ act_layer = 'hard_swish'
200
+ arch_def = [
201
+ # stage 0, 112x112 in
202
+ ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
203
+ # stage 1, 56x56 in
204
+ ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
205
+ # stage 2, 28x28 in
206
+ ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
207
+ # stage 3, 14x14 in
208
+ ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
209
+ # stage 4, 14x14in
210
+ ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
211
+ # stage 6, 7x7 in
212
+ ['cn_r1_k1_s1_c576'], # hard-swish
213
+ ]
214
+ else:
215
+ num_features = 1280
216
+ if 'minimal' in variant:
217
+ act_layer = 'relu'
218
+ arch_def = [
219
+ # stage 0, 112x112 in
220
+ ['ds_r1_k3_s1_e1_c16'],
221
+ # stage 1, 112x112 in
222
+ ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
223
+ # stage 2, 56x56 in
224
+ ['ir_r3_k3_s2_e3_c40'],
225
+ # stage 3, 28x28 in
226
+ ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
227
+ # stage 4, 14x14in
228
+ ['ir_r2_k3_s1_e6_c112'],
229
+ # stage 5, 14x14in
230
+ ['ir_r3_k3_s2_e6_c160'],
231
+ # stage 6, 7x7 in
232
+ ['cn_r1_k1_s1_c960'],
233
+ ]
234
+ else:
235
+ act_layer = 'hard_swish'
236
+ arch_def = [
237
+ # stage 0, 112x112 in
238
+ ['ds_r1_k3_s1_e1_c16_nre'], # relu
239
+ # stage 1, 112x112 in
240
+ ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
241
+ # stage 2, 56x56 in
242
+ ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
243
+ # stage 3, 28x28 in
244
+ ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
245
+ # stage 4, 14x14in
246
+ ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
247
+ # stage 5, 14x14in
248
+ ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
249
+ # stage 6, 7x7 in
250
+ ['cn_r1_k1_s1_c960'], # hard-swish
251
+ ]
252
+ with layer_config_kwargs(kwargs):
253
+ model_kwargs = dict(
254
+ block_args=decode_arch_def(arch_def),
255
+ num_features=num_features,
256
+ stem_size=16,
257
+ channel_multiplier=channel_multiplier,
258
+ act_layer=resolve_act_layer(kwargs, act_layer),
259
+ se_kwargs=dict(
260
+ act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8),
261
+ norm_kwargs=resolve_bn_args(kwargs),
262
+ **kwargs,
263
+ )
264
+ model = _create_model(model_kwargs, variant, pretrained)
265
+ return model
266
+
267
+
268
+ def mobilenetv3_rw(pretrained=False, **kwargs):
269
+ """ MobileNet-V3 RW
270
+ Attn: See note in gen function for this variant.
271
+ """
272
+ # NOTE for train set drop_rate=0.2
273
+ if pretrained:
274
+ # pretrained model trained with non-default BN epsilon
275
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
276
+ model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
277
+ return model
278
+
279
+
280
+ def mobilenetv3_large_075(pretrained=False, **kwargs):
281
+ """ MobileNet V3 Large 0.75"""
282
+ # NOTE for train set drop_rate=0.2
283
+ model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
284
+ return model
285
+
286
+
287
+ def mobilenetv3_large_100(pretrained=False, **kwargs):
288
+ """ MobileNet V3 Large 1.0 """
289
+ # NOTE for train set drop_rate=0.2
290
+ model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
291
+ return model
292
+
293
+
294
+ def mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
295
+ """ MobileNet V3 Large (Minimalistic) 1.0 """
296
+ # NOTE for train set drop_rate=0.2
297
+ model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
298
+ return model
299
+
300
+
301
+ def mobilenetv3_small_075(pretrained=False, **kwargs):
302
+ """ MobileNet V3 Small 0.75 """
303
+ model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
304
+ return model
305
+
306
+
307
+ def mobilenetv3_small_100(pretrained=False, **kwargs):
308
+ """ MobileNet V3 Small 1.0 """
309
+ model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
310
+ return model
311
+
312
+
313
+ def mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
314
+ """ MobileNet V3 Small (Minimalistic) 1.0 """
315
+ model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
316
+ return model
317
+
318
+
319
+ def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
320
+ """ MobileNet V3 Large 0.75. Tensorflow compat variant. """
321
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
322
+ kwargs['pad_type'] = 'same'
323
+ model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
324
+ return model
325
+
326
+
327
+ def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
328
+ """ MobileNet V3 Large 1.0. Tensorflow compat variant. """
329
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
330
+ kwargs['pad_type'] = 'same'
331
+ model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
332
+ return model
333
+
334
+
335
+ def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
336
+ """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """
337
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
338
+ kwargs['pad_type'] = 'same'
339
+ model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
340
+ return model
341
+
342
+
343
+ def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
344
+ """ MobileNet V3 Small 0.75. Tensorflow compat variant. """
345
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
346
+ kwargs['pad_type'] = 'same'
347
+ model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
348
+ return model
349
+
350
+
351
+ def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
352
+ """ MobileNet V3 Small 1.0. Tensorflow compat variant."""
353
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
354
+ kwargs['pad_type'] = 'same'
355
+ model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
356
+ return model
357
+
358
+
359
+ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
360
+ """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """
361
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
362
+ kwargs['pad_type'] = 'same'
363
+ model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
364
+ return model
geffnet/model_factory.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import set_layer_config
2
+ from .helpers import load_checkpoint
3
+
4
+ from .gen_efficientnet import *
5
+ from .mobilenetv3 import *
6
+
7
+
8
+ def create_model(
9
+ model_name='mnasnet_100',
10
+ pretrained=None,
11
+ num_classes=1000,
12
+ in_chans=3,
13
+ checkpoint_path='',
14
+ **kwargs):
15
+
16
+ model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs)
17
+
18
+ if model_name in globals():
19
+ create_fn = globals()[model_name]
20
+ model = create_fn(**model_kwargs)
21
+ else:
22
+ raise RuntimeError('Unknown model (%s)' % model_name)
23
+
24
+ if checkpoint_path and not pretrained:
25
+ load_checkpoint(model, checkpoint_path)
26
+
27
+ return model
geffnet/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '1.0.2'
requirements.txt CHANGED
@@ -109,12 +109,13 @@ sympy==1.12.1
109
  tokenizers==0.15.2
110
  tomlkit==0.12.0
111
  toolz==0.12.1
112
- torch==2.0.1
113
- torchvision==v0.15.2
 
 
114
  tqdm==4.66.4
115
  transformers==4.36.1
116
  trimesh==4.0.5
117
- triton==2.0.0
118
  typer==0.12.3
119
  typing-inspect==0.9.0
120
  typing_extensions==4.11.0
@@ -126,10 +127,8 @@ uvloop==0.19.0
126
  watchfiles==0.22.0
127
  websockets==11.0.3
128
  wrapt==1.16.0
129
- xformers==0.0.20
130
  xxhash==3.4.1
131
  yarl==1.9.4
132
  zipp==3.19.1
133
  einops==0.7.0
134
- opencv-python-headless==4.8.1.78
135
- geffnet==1.0.2
 
109
  tokenizers==0.15.2
110
  tomlkit==0.12.0
111
  toolz==0.12.1
112
+ torch==2.2.0
113
+ torchvision==0.18.0
114
+ xformers==0.0.24
115
+ triton==2.2.0
116
  tqdm==4.66.4
117
  transformers==4.36.1
118
  trimesh==4.0.5
 
119
  typer==0.12.3
120
  typing-inspect==0.9.0
121
  typing_extensions==4.11.0
 
127
  watchfiles==0.22.0
128
  websockets==11.0.3
129
  wrapt==1.16.0
 
130
  xxhash==3.4.1
131
  yarl==1.9.4
132
  zipp==3.19.1
133
  einops==0.7.0
134
+ opencv-python-headless==4.8.1.78