Spaces:
Running
on
Zero
Running
on
Zero
Fix environment dependency
Browse files- app.py +1 -2
- geffnet/__init__.py +5 -0
- geffnet/activations/__init__.py +137 -0
- geffnet/activations/activations.py +102 -0
- geffnet/activations/activations_jit.py +79 -0
- geffnet/activations/activations_me.py +170 -0
- geffnet/config.py +123 -0
- geffnet/conv2d_layers.py +304 -0
- geffnet/efficientnet_builder.py +683 -0
- geffnet/gen_efficientnet.py +1450 -0
- geffnet/helpers.py +71 -0
- geffnet/mobilenetv3.py +364 -0
- geffnet/model_factory.py +27 -0
- geffnet/version.py +1 -0
- requirements.txt +5 -6
app.py
CHANGED
@@ -375,10 +375,9 @@ def run_demo_server(pipe):
|
|
375 |
def main():
|
376 |
|
377 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
378 |
-
|
379 |
marigold_pipe = Marigold()
|
380 |
geowizard_pipe = Geowizard()
|
381 |
-
dsine_pipe = DSINE()
|
382 |
our_pipe = StableNormal()
|
383 |
|
384 |
run_demo_server([dsine_pipe, marigold_pipe, geowizard_pipe, our_pipe])
|
|
|
375 |
def main():
|
376 |
|
377 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
378 |
+
dsine_pipe = DSINE()
|
379 |
marigold_pipe = Marigold()
|
380 |
geowizard_pipe = Geowizard()
|
|
|
381 |
our_pipe = StableNormal()
|
382 |
|
383 |
run_demo_server([dsine_pipe, marigold_pipe, geowizard_pipe, our_pipe])
|
geffnet/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .gen_efficientnet import *
|
2 |
+
from .mobilenetv3 import *
|
3 |
+
from .model_factory import create_model
|
4 |
+
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
|
5 |
+
from .activations import *
|
geffnet/activations/__init__.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from geffnet import config
|
2 |
+
from geffnet.activations.activations_me import *
|
3 |
+
from geffnet.activations.activations_jit import *
|
4 |
+
from geffnet.activations.activations import *
|
5 |
+
import torch
|
6 |
+
|
7 |
+
_has_silu = 'silu' in dir(torch.nn.functional)
|
8 |
+
|
9 |
+
_ACT_FN_DEFAULT = dict(
|
10 |
+
silu=F.silu if _has_silu else swish,
|
11 |
+
swish=F.silu if _has_silu else swish,
|
12 |
+
mish=mish,
|
13 |
+
relu=F.relu,
|
14 |
+
relu6=F.relu6,
|
15 |
+
sigmoid=sigmoid,
|
16 |
+
tanh=tanh,
|
17 |
+
hard_sigmoid=hard_sigmoid,
|
18 |
+
hard_swish=hard_swish,
|
19 |
+
)
|
20 |
+
|
21 |
+
_ACT_FN_JIT = dict(
|
22 |
+
silu=F.silu if _has_silu else swish_jit,
|
23 |
+
swish=F.silu if _has_silu else swish_jit,
|
24 |
+
mish=mish_jit,
|
25 |
+
)
|
26 |
+
|
27 |
+
_ACT_FN_ME = dict(
|
28 |
+
silu=F.silu if _has_silu else swish_me,
|
29 |
+
swish=F.silu if _has_silu else swish_me,
|
30 |
+
mish=mish_me,
|
31 |
+
hard_swish=hard_swish_me,
|
32 |
+
hard_sigmoid_jit=hard_sigmoid_me,
|
33 |
+
)
|
34 |
+
|
35 |
+
_ACT_LAYER_DEFAULT = dict(
|
36 |
+
silu=nn.SiLU if _has_silu else Swish,
|
37 |
+
swish=nn.SiLU if _has_silu else Swish,
|
38 |
+
mish=Mish,
|
39 |
+
relu=nn.ReLU,
|
40 |
+
relu6=nn.ReLU6,
|
41 |
+
sigmoid=Sigmoid,
|
42 |
+
tanh=Tanh,
|
43 |
+
hard_sigmoid=HardSigmoid,
|
44 |
+
hard_swish=HardSwish,
|
45 |
+
)
|
46 |
+
|
47 |
+
_ACT_LAYER_JIT = dict(
|
48 |
+
silu=nn.SiLU if _has_silu else SwishJit,
|
49 |
+
swish=nn.SiLU if _has_silu else SwishJit,
|
50 |
+
mish=MishJit,
|
51 |
+
)
|
52 |
+
|
53 |
+
_ACT_LAYER_ME = dict(
|
54 |
+
silu=nn.SiLU if _has_silu else SwishMe,
|
55 |
+
swish=nn.SiLU if _has_silu else SwishMe,
|
56 |
+
mish=MishMe,
|
57 |
+
hard_swish=HardSwishMe,
|
58 |
+
hard_sigmoid=HardSigmoidMe
|
59 |
+
)
|
60 |
+
|
61 |
+
_OVERRIDE_FN = dict()
|
62 |
+
_OVERRIDE_LAYER = dict()
|
63 |
+
|
64 |
+
|
65 |
+
def add_override_act_fn(name, fn):
|
66 |
+
global _OVERRIDE_FN
|
67 |
+
_OVERRIDE_FN[name] = fn
|
68 |
+
|
69 |
+
|
70 |
+
def update_override_act_fn(overrides):
|
71 |
+
assert isinstance(overrides, dict)
|
72 |
+
global _OVERRIDE_FN
|
73 |
+
_OVERRIDE_FN.update(overrides)
|
74 |
+
|
75 |
+
|
76 |
+
def clear_override_act_fn():
|
77 |
+
global _OVERRIDE_FN
|
78 |
+
_OVERRIDE_FN = dict()
|
79 |
+
|
80 |
+
|
81 |
+
def add_override_act_layer(name, fn):
|
82 |
+
_OVERRIDE_LAYER[name] = fn
|
83 |
+
|
84 |
+
|
85 |
+
def update_override_act_layer(overrides):
|
86 |
+
assert isinstance(overrides, dict)
|
87 |
+
global _OVERRIDE_LAYER
|
88 |
+
_OVERRIDE_LAYER.update(overrides)
|
89 |
+
|
90 |
+
|
91 |
+
def clear_override_act_layer():
|
92 |
+
global _OVERRIDE_LAYER
|
93 |
+
_OVERRIDE_LAYER = dict()
|
94 |
+
|
95 |
+
|
96 |
+
def get_act_fn(name='relu'):
|
97 |
+
""" Activation Function Factory
|
98 |
+
Fetching activation fns by name with this function allows export or torch script friendly
|
99 |
+
functions to be returned dynamically based on current config.
|
100 |
+
"""
|
101 |
+
if name in _OVERRIDE_FN:
|
102 |
+
return _OVERRIDE_FN[name]
|
103 |
+
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
|
104 |
+
if use_me and name in _ACT_FN_ME:
|
105 |
+
# If not exporting or scripting the model, first look for a memory optimized version
|
106 |
+
# activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin
|
107 |
+
return _ACT_FN_ME[name]
|
108 |
+
if config.is_exportable() and name in ('silu', 'swish'):
|
109 |
+
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
110 |
+
return swish
|
111 |
+
use_jit = not (config.is_exportable() or config.is_no_jit())
|
112 |
+
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
|
113 |
+
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
|
114 |
+
return _ACT_FN_JIT[name]
|
115 |
+
return _ACT_FN_DEFAULT[name]
|
116 |
+
|
117 |
+
|
118 |
+
def get_act_layer(name='relu'):
|
119 |
+
""" Activation Layer Factory
|
120 |
+
Fetching activation layers by name with this function allows export or torch script friendly
|
121 |
+
functions to be returned dynamically based on current config.
|
122 |
+
"""
|
123 |
+
if name in _OVERRIDE_LAYER:
|
124 |
+
return _OVERRIDE_LAYER[name]
|
125 |
+
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
|
126 |
+
if use_me and name in _ACT_LAYER_ME:
|
127 |
+
return _ACT_LAYER_ME[name]
|
128 |
+
if config.is_exportable() and name in ('silu', 'swish'):
|
129 |
+
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
130 |
+
return Swish
|
131 |
+
use_jit = not (config.is_exportable() or config.is_no_jit())
|
132 |
+
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
|
133 |
+
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
|
134 |
+
return _ACT_LAYER_JIT[name]
|
135 |
+
return _ACT_LAYER_DEFAULT[name]
|
136 |
+
|
137 |
+
|
geffnet/activations/activations.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Activations
|
2 |
+
|
3 |
+
A collection of activations fn and modules with a common interface so that they can
|
4 |
+
easily be swapped. All have an `inplace` arg even if not used.
|
5 |
+
|
6 |
+
Copyright 2020 Ross Wightman
|
7 |
+
"""
|
8 |
+
from torch import nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def swish(x, inplace: bool = False):
|
13 |
+
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
14 |
+
and also as Swish (https://arxiv.org/abs/1710.05941).
|
15 |
+
|
16 |
+
TODO Rename to SiLU with addition to PyTorch
|
17 |
+
"""
|
18 |
+
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
|
19 |
+
|
20 |
+
|
21 |
+
class Swish(nn.Module):
|
22 |
+
def __init__(self, inplace: bool = False):
|
23 |
+
super(Swish, self).__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return swish(x, self.inplace)
|
28 |
+
|
29 |
+
|
30 |
+
def mish(x, inplace: bool = False):
|
31 |
+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
32 |
+
"""
|
33 |
+
return x.mul(F.softplus(x).tanh())
|
34 |
+
|
35 |
+
|
36 |
+
class Mish(nn.Module):
|
37 |
+
def __init__(self, inplace: bool = False):
|
38 |
+
super(Mish, self).__init__()
|
39 |
+
self.inplace = inplace
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
return mish(x, self.inplace)
|
43 |
+
|
44 |
+
|
45 |
+
def sigmoid(x, inplace: bool = False):
|
46 |
+
return x.sigmoid_() if inplace else x.sigmoid()
|
47 |
+
|
48 |
+
|
49 |
+
# PyTorch has this, but not with a consistent inplace argmument interface
|
50 |
+
class Sigmoid(nn.Module):
|
51 |
+
def __init__(self, inplace: bool = False):
|
52 |
+
super(Sigmoid, self).__init__()
|
53 |
+
self.inplace = inplace
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return x.sigmoid_() if self.inplace else x.sigmoid()
|
57 |
+
|
58 |
+
|
59 |
+
def tanh(x, inplace: bool = False):
|
60 |
+
return x.tanh_() if inplace else x.tanh()
|
61 |
+
|
62 |
+
|
63 |
+
# PyTorch has this, but not with a consistent inplace argmument interface
|
64 |
+
class Tanh(nn.Module):
|
65 |
+
def __init__(self, inplace: bool = False):
|
66 |
+
super(Tanh, self).__init__()
|
67 |
+
self.inplace = inplace
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return x.tanh_() if self.inplace else x.tanh()
|
71 |
+
|
72 |
+
|
73 |
+
def hard_swish(x, inplace: bool = False):
|
74 |
+
inner = F.relu6(x + 3.).div_(6.)
|
75 |
+
return x.mul_(inner) if inplace else x.mul(inner)
|
76 |
+
|
77 |
+
|
78 |
+
class HardSwish(nn.Module):
|
79 |
+
def __init__(self, inplace: bool = False):
|
80 |
+
super(HardSwish, self).__init__()
|
81 |
+
self.inplace = inplace
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
return hard_swish(x, self.inplace)
|
85 |
+
|
86 |
+
|
87 |
+
def hard_sigmoid(x, inplace: bool = False):
|
88 |
+
if inplace:
|
89 |
+
return x.add_(3.).clamp_(0., 6.).div_(6.)
|
90 |
+
else:
|
91 |
+
return F.relu6(x + 3.) / 6.
|
92 |
+
|
93 |
+
|
94 |
+
class HardSigmoid(nn.Module):
|
95 |
+
def __init__(self, inplace: bool = False):
|
96 |
+
super(HardSigmoid, self).__init__()
|
97 |
+
self.inplace = inplace
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
return hard_sigmoid(x, self.inplace)
|
101 |
+
|
102 |
+
|
geffnet/activations/activations_jit.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Activations (jit)
|
2 |
+
|
3 |
+
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
4 |
+
easily be swapped. All have an `inplace` arg even if not used.
|
5 |
+
|
6 |
+
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
7 |
+
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
8 |
+
versions if they contain in-place ops.
|
9 |
+
|
10 |
+
Copyright 2020 Ross Wightman
|
11 |
+
"""
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn as nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
|
17 |
+
__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit',
|
18 |
+
'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit']
|
19 |
+
|
20 |
+
|
21 |
+
@torch.jit.script
|
22 |
+
def swish_jit(x, inplace: bool = False):
|
23 |
+
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
24 |
+
and also as Swish (https://arxiv.org/abs/1710.05941).
|
25 |
+
|
26 |
+
TODO Rename to SiLU with addition to PyTorch
|
27 |
+
"""
|
28 |
+
return x.mul(x.sigmoid())
|
29 |
+
|
30 |
+
|
31 |
+
@torch.jit.script
|
32 |
+
def mish_jit(x, _inplace: bool = False):
|
33 |
+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
34 |
+
"""
|
35 |
+
return x.mul(F.softplus(x).tanh())
|
36 |
+
|
37 |
+
|
38 |
+
class SwishJit(nn.Module):
|
39 |
+
def __init__(self, inplace: bool = False):
|
40 |
+
super(SwishJit, self).__init__()
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return swish_jit(x)
|
44 |
+
|
45 |
+
|
46 |
+
class MishJit(nn.Module):
|
47 |
+
def __init__(self, inplace: bool = False):
|
48 |
+
super(MishJit, self).__init__()
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
return mish_jit(x)
|
52 |
+
|
53 |
+
|
54 |
+
@torch.jit.script
|
55 |
+
def hard_sigmoid_jit(x, inplace: bool = False):
|
56 |
+
# return F.relu6(x + 3.) / 6.
|
57 |
+
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
58 |
+
|
59 |
+
|
60 |
+
class HardSigmoidJit(nn.Module):
|
61 |
+
def __init__(self, inplace: bool = False):
|
62 |
+
super(HardSigmoidJit, self).__init__()
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return hard_sigmoid_jit(x)
|
66 |
+
|
67 |
+
|
68 |
+
@torch.jit.script
|
69 |
+
def hard_swish_jit(x, inplace: bool = False):
|
70 |
+
# return x * (F.relu6(x + 3.) / 6)
|
71 |
+
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
72 |
+
|
73 |
+
|
74 |
+
class HardSwishJit(nn.Module):
|
75 |
+
def __init__(self, inplace: bool = False):
|
76 |
+
super(HardSwishJit, self).__init__()
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
return hard_swish_jit(x)
|
geffnet/activations/activations_me.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Activations (memory-efficient w/ custom autograd)
|
2 |
+
|
3 |
+
A collection of activations fn and modules with a common interface so that they can
|
4 |
+
easily be swapped. All have an `inplace` arg even if not used.
|
5 |
+
|
6 |
+
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
7 |
+
the JIT or basic versions of the activations.
|
8 |
+
|
9 |
+
Copyright 2020 Ross Wightman
|
10 |
+
"""
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn as nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
|
17 |
+
__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe',
|
18 |
+
'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe']
|
19 |
+
|
20 |
+
|
21 |
+
@torch.jit.script
|
22 |
+
def swish_jit_fwd(x):
|
23 |
+
return x.mul(torch.sigmoid(x))
|
24 |
+
|
25 |
+
|
26 |
+
@torch.jit.script
|
27 |
+
def swish_jit_bwd(x, grad_output):
|
28 |
+
x_sigmoid = torch.sigmoid(x)
|
29 |
+
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
30 |
+
|
31 |
+
|
32 |
+
class SwishJitAutoFn(torch.autograd.Function):
|
33 |
+
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
34 |
+
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
35 |
+
https://twitter.com/jeremyphoward/status/1188251041835315200
|
36 |
+
|
37 |
+
Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
38 |
+
and also as Swish (https://arxiv.org/abs/1710.05941).
|
39 |
+
|
40 |
+
TODO Rename to SiLU with addition to PyTorch
|
41 |
+
"""
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def forward(ctx, x):
|
45 |
+
ctx.save_for_backward(x)
|
46 |
+
return swish_jit_fwd(x)
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def backward(ctx, grad_output):
|
50 |
+
x = ctx.saved_tensors[0]
|
51 |
+
return swish_jit_bwd(x, grad_output)
|
52 |
+
|
53 |
+
|
54 |
+
def swish_me(x, inplace=False):
|
55 |
+
return SwishJitAutoFn.apply(x)
|
56 |
+
|
57 |
+
|
58 |
+
class SwishMe(nn.Module):
|
59 |
+
def __init__(self, inplace: bool = False):
|
60 |
+
super(SwishMe, self).__init__()
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return SwishJitAutoFn.apply(x)
|
64 |
+
|
65 |
+
|
66 |
+
@torch.jit.script
|
67 |
+
def mish_jit_fwd(x):
|
68 |
+
return x.mul(torch.tanh(F.softplus(x)))
|
69 |
+
|
70 |
+
|
71 |
+
@torch.jit.script
|
72 |
+
def mish_jit_bwd(x, grad_output):
|
73 |
+
x_sigmoid = torch.sigmoid(x)
|
74 |
+
x_tanh_sp = F.softplus(x).tanh()
|
75 |
+
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
76 |
+
|
77 |
+
|
78 |
+
class MishJitAutoFn(torch.autograd.Function):
|
79 |
+
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
80 |
+
A memory efficient, jit scripted variant of Mish
|
81 |
+
"""
|
82 |
+
@staticmethod
|
83 |
+
def forward(ctx, x):
|
84 |
+
ctx.save_for_backward(x)
|
85 |
+
return mish_jit_fwd(x)
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def backward(ctx, grad_output):
|
89 |
+
x = ctx.saved_tensors[0]
|
90 |
+
return mish_jit_bwd(x, grad_output)
|
91 |
+
|
92 |
+
|
93 |
+
def mish_me(x, inplace=False):
|
94 |
+
return MishJitAutoFn.apply(x)
|
95 |
+
|
96 |
+
|
97 |
+
class MishMe(nn.Module):
|
98 |
+
def __init__(self, inplace: bool = False):
|
99 |
+
super(MishMe, self).__init__()
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
return MishJitAutoFn.apply(x)
|
103 |
+
|
104 |
+
|
105 |
+
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
106 |
+
return (x + 3).clamp(min=0, max=6).div(6.)
|
107 |
+
|
108 |
+
|
109 |
+
def hard_sigmoid_jit_bwd(x, grad_output):
|
110 |
+
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
111 |
+
return grad_output * m
|
112 |
+
|
113 |
+
|
114 |
+
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
115 |
+
@staticmethod
|
116 |
+
def forward(ctx, x):
|
117 |
+
ctx.save_for_backward(x)
|
118 |
+
return hard_sigmoid_jit_fwd(x)
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def backward(ctx, grad_output):
|
122 |
+
x = ctx.saved_tensors[0]
|
123 |
+
return hard_sigmoid_jit_bwd(x, grad_output)
|
124 |
+
|
125 |
+
|
126 |
+
def hard_sigmoid_me(x, inplace: bool = False):
|
127 |
+
return HardSigmoidJitAutoFn.apply(x)
|
128 |
+
|
129 |
+
|
130 |
+
class HardSigmoidMe(nn.Module):
|
131 |
+
def __init__(self, inplace: bool = False):
|
132 |
+
super(HardSigmoidMe, self).__init__()
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
return HardSigmoidJitAutoFn.apply(x)
|
136 |
+
|
137 |
+
|
138 |
+
def hard_swish_jit_fwd(x):
|
139 |
+
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
140 |
+
|
141 |
+
|
142 |
+
def hard_swish_jit_bwd(x, grad_output):
|
143 |
+
m = torch.ones_like(x) * (x >= 3.)
|
144 |
+
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
145 |
+
return grad_output * m
|
146 |
+
|
147 |
+
|
148 |
+
class HardSwishJitAutoFn(torch.autograd.Function):
|
149 |
+
"""A memory efficient, jit-scripted HardSwish activation"""
|
150 |
+
@staticmethod
|
151 |
+
def forward(ctx, x):
|
152 |
+
ctx.save_for_backward(x)
|
153 |
+
return hard_swish_jit_fwd(x)
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def backward(ctx, grad_output):
|
157 |
+
x = ctx.saved_tensors[0]
|
158 |
+
return hard_swish_jit_bwd(x, grad_output)
|
159 |
+
|
160 |
+
|
161 |
+
def hard_swish_me(x, inplace=False):
|
162 |
+
return HardSwishJitAutoFn.apply(x)
|
163 |
+
|
164 |
+
|
165 |
+
class HardSwishMe(nn.Module):
|
166 |
+
def __init__(self, inplace: bool = False):
|
167 |
+
super(HardSwishMe, self).__init__()
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
return HardSwishJitAutoFn.apply(x)
|
geffnet/config.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Global layer config state
|
2 |
+
"""
|
3 |
+
from typing import Any, Optional
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs',
|
7 |
+
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
|
8 |
+
]
|
9 |
+
|
10 |
+
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
11 |
+
_NO_JIT = False
|
12 |
+
|
13 |
+
# Set to True if prefer to have activation layers with no jit optimization
|
14 |
+
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
|
15 |
+
# the jit flags so far are activations. This will change as more layers are updated and/or added.
|
16 |
+
_NO_ACTIVATION_JIT = False
|
17 |
+
|
18 |
+
# Set to True if exporting a model with Same padding via ONNX
|
19 |
+
_EXPORTABLE = False
|
20 |
+
|
21 |
+
# Set to True if wanting to use torch.jit.script on a model
|
22 |
+
_SCRIPTABLE = False
|
23 |
+
|
24 |
+
|
25 |
+
def is_no_jit():
|
26 |
+
return _NO_JIT
|
27 |
+
|
28 |
+
|
29 |
+
class set_no_jit:
|
30 |
+
def __init__(self, mode: bool) -> None:
|
31 |
+
global _NO_JIT
|
32 |
+
self.prev = _NO_JIT
|
33 |
+
_NO_JIT = mode
|
34 |
+
|
35 |
+
def __enter__(self) -> None:
|
36 |
+
pass
|
37 |
+
|
38 |
+
def __exit__(self, *args: Any) -> bool:
|
39 |
+
global _NO_JIT
|
40 |
+
_NO_JIT = self.prev
|
41 |
+
return False
|
42 |
+
|
43 |
+
|
44 |
+
def is_exportable():
|
45 |
+
return _EXPORTABLE
|
46 |
+
|
47 |
+
|
48 |
+
class set_exportable:
|
49 |
+
def __init__(self, mode: bool) -> None:
|
50 |
+
global _EXPORTABLE
|
51 |
+
self.prev = _EXPORTABLE
|
52 |
+
_EXPORTABLE = mode
|
53 |
+
|
54 |
+
def __enter__(self) -> None:
|
55 |
+
pass
|
56 |
+
|
57 |
+
def __exit__(self, *args: Any) -> bool:
|
58 |
+
global _EXPORTABLE
|
59 |
+
_EXPORTABLE = self.prev
|
60 |
+
return False
|
61 |
+
|
62 |
+
|
63 |
+
def is_scriptable():
|
64 |
+
return _SCRIPTABLE
|
65 |
+
|
66 |
+
|
67 |
+
class set_scriptable:
|
68 |
+
def __init__(self, mode: bool) -> None:
|
69 |
+
global _SCRIPTABLE
|
70 |
+
self.prev = _SCRIPTABLE
|
71 |
+
_SCRIPTABLE = mode
|
72 |
+
|
73 |
+
def __enter__(self) -> None:
|
74 |
+
pass
|
75 |
+
|
76 |
+
def __exit__(self, *args: Any) -> bool:
|
77 |
+
global _SCRIPTABLE
|
78 |
+
_SCRIPTABLE = self.prev
|
79 |
+
return False
|
80 |
+
|
81 |
+
|
82 |
+
class set_layer_config:
|
83 |
+
""" Layer config context manager that allows setting all layer config flags at once.
|
84 |
+
If a flag arg is None, it will not change the current value.
|
85 |
+
"""
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
scriptable: Optional[bool] = None,
|
89 |
+
exportable: Optional[bool] = None,
|
90 |
+
no_jit: Optional[bool] = None,
|
91 |
+
no_activation_jit: Optional[bool] = None):
|
92 |
+
global _SCRIPTABLE
|
93 |
+
global _EXPORTABLE
|
94 |
+
global _NO_JIT
|
95 |
+
global _NO_ACTIVATION_JIT
|
96 |
+
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
|
97 |
+
if scriptable is not None:
|
98 |
+
_SCRIPTABLE = scriptable
|
99 |
+
if exportable is not None:
|
100 |
+
_EXPORTABLE = exportable
|
101 |
+
if no_jit is not None:
|
102 |
+
_NO_JIT = no_jit
|
103 |
+
if no_activation_jit is not None:
|
104 |
+
_NO_ACTIVATION_JIT = no_activation_jit
|
105 |
+
|
106 |
+
def __enter__(self) -> None:
|
107 |
+
pass
|
108 |
+
|
109 |
+
def __exit__(self, *args: Any) -> bool:
|
110 |
+
global _SCRIPTABLE
|
111 |
+
global _EXPORTABLE
|
112 |
+
global _NO_JIT
|
113 |
+
global _NO_ACTIVATION_JIT
|
114 |
+
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
|
115 |
+
return False
|
116 |
+
|
117 |
+
|
118 |
+
def layer_config_kwargs(kwargs):
|
119 |
+
""" Consume config kwargs and return contextmgr obj """
|
120 |
+
return set_layer_config(
|
121 |
+
scriptable=kwargs.pop('scriptable', None),
|
122 |
+
exportable=kwargs.pop('exportable', None),
|
123 |
+
no_jit=kwargs.pop('no_jit', None))
|
geffnet/conv2d_layers.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Conv2D w/ SAME padding, CondConv, MixedConv
|
2 |
+
|
3 |
+
A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and
|
4 |
+
MobileNetV3 models that maintain weight compatibility with original Tensorflow models.
|
5 |
+
|
6 |
+
Copyright 2020 Ross Wightman
|
7 |
+
"""
|
8 |
+
import collections.abc
|
9 |
+
import math
|
10 |
+
from functools import partial
|
11 |
+
from itertools import repeat
|
12 |
+
from typing import Tuple, Optional
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
from .config import *
|
20 |
+
|
21 |
+
|
22 |
+
# From PyTorch internals
|
23 |
+
def _ntuple(n):
|
24 |
+
def parse(x):
|
25 |
+
if isinstance(x, collections.abc.Iterable):
|
26 |
+
return x
|
27 |
+
return tuple(repeat(x, n))
|
28 |
+
return parse
|
29 |
+
|
30 |
+
|
31 |
+
_single = _ntuple(1)
|
32 |
+
_pair = _ntuple(2)
|
33 |
+
_triple = _ntuple(3)
|
34 |
+
_quadruple = _ntuple(4)
|
35 |
+
|
36 |
+
|
37 |
+
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
38 |
+
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
39 |
+
|
40 |
+
|
41 |
+
def _get_padding(kernel_size, stride=1, dilation=1, **_):
|
42 |
+
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
43 |
+
return padding
|
44 |
+
|
45 |
+
|
46 |
+
def _calc_same_pad(i: int, k: int, s: int, d: int):
|
47 |
+
return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
48 |
+
|
49 |
+
|
50 |
+
def _same_pad_arg(input_size, kernel_size, stride, dilation):
|
51 |
+
ih, iw = input_size
|
52 |
+
kh, kw = kernel_size
|
53 |
+
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
54 |
+
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
55 |
+
return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
|
56 |
+
|
57 |
+
|
58 |
+
def _split_channels(num_chan, num_groups):
|
59 |
+
split = [num_chan // num_groups for _ in range(num_groups)]
|
60 |
+
split[0] += num_chan - sum(split)
|
61 |
+
return split
|
62 |
+
|
63 |
+
|
64 |
+
def conv2d_same(
|
65 |
+
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
|
66 |
+
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
|
67 |
+
ih, iw = x.size()[-2:]
|
68 |
+
kh, kw = weight.size()[-2:]
|
69 |
+
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
70 |
+
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
71 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
72 |
+
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
73 |
+
|
74 |
+
|
75 |
+
class Conv2dSame(nn.Conv2d):
|
76 |
+
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
77 |
+
"""
|
78 |
+
|
79 |
+
# pylint: disable=unused-argument
|
80 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
81 |
+
padding=0, dilation=1, groups=1, bias=True):
|
82 |
+
super(Conv2dSame, self).__init__(
|
83 |
+
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
87 |
+
|
88 |
+
|
89 |
+
class Conv2dSameExport(nn.Conv2d):
|
90 |
+
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
91 |
+
|
92 |
+
NOTE: This does not currently work with torch.jit.script
|
93 |
+
"""
|
94 |
+
|
95 |
+
# pylint: disable=unused-argument
|
96 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
97 |
+
padding=0, dilation=1, groups=1, bias=True):
|
98 |
+
super(Conv2dSameExport, self).__init__(
|
99 |
+
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
100 |
+
self.pad = None
|
101 |
+
self.pad_input_size = (0, 0)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
input_size = x.size()[-2:]
|
105 |
+
if self.pad is None:
|
106 |
+
pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
|
107 |
+
self.pad = nn.ZeroPad2d(pad_arg)
|
108 |
+
self.pad_input_size = input_size
|
109 |
+
|
110 |
+
if self.pad is not None:
|
111 |
+
x = self.pad(x)
|
112 |
+
return F.conv2d(
|
113 |
+
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
114 |
+
|
115 |
+
|
116 |
+
def get_padding_value(padding, kernel_size, **kwargs):
|
117 |
+
dynamic = False
|
118 |
+
if isinstance(padding, str):
|
119 |
+
# for any string padding, the padding will be calculated for you, one of three ways
|
120 |
+
padding = padding.lower()
|
121 |
+
if padding == 'same':
|
122 |
+
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
123 |
+
if _is_static_pad(kernel_size, **kwargs):
|
124 |
+
# static case, no extra overhead
|
125 |
+
padding = _get_padding(kernel_size, **kwargs)
|
126 |
+
else:
|
127 |
+
# dynamic padding
|
128 |
+
padding = 0
|
129 |
+
dynamic = True
|
130 |
+
elif padding == 'valid':
|
131 |
+
# 'VALID' padding, same as padding=0
|
132 |
+
padding = 0
|
133 |
+
else:
|
134 |
+
# Default to PyTorch style 'same'-ish symmetric padding
|
135 |
+
padding = _get_padding(kernel_size, **kwargs)
|
136 |
+
return padding, dynamic
|
137 |
+
|
138 |
+
|
139 |
+
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
140 |
+
padding = kwargs.pop('padding', '')
|
141 |
+
kwargs.setdefault('bias', False)
|
142 |
+
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
143 |
+
if is_dynamic:
|
144 |
+
if is_exportable():
|
145 |
+
assert not is_scriptable()
|
146 |
+
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
|
147 |
+
else:
|
148 |
+
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
149 |
+
else:
|
150 |
+
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
151 |
+
|
152 |
+
|
153 |
+
class MixedConv2d(nn.ModuleDict):
|
154 |
+
""" Mixed Grouped Convolution
|
155 |
+
Based on MDConv and GroupedConv in MixNet impl:
|
156 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
160 |
+
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
161 |
+
super(MixedConv2d, self).__init__()
|
162 |
+
|
163 |
+
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
164 |
+
num_groups = len(kernel_size)
|
165 |
+
in_splits = _split_channels(in_channels, num_groups)
|
166 |
+
out_splits = _split_channels(out_channels, num_groups)
|
167 |
+
self.in_channels = sum(in_splits)
|
168 |
+
self.out_channels = sum(out_splits)
|
169 |
+
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
170 |
+
conv_groups = out_ch if depthwise else 1
|
171 |
+
self.add_module(
|
172 |
+
str(idx),
|
173 |
+
create_conv2d_pad(
|
174 |
+
in_ch, out_ch, k, stride=stride,
|
175 |
+
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
|
176 |
+
)
|
177 |
+
self.splits = in_splits
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
x_split = torch.split(x, self.splits, 1)
|
181 |
+
x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
|
182 |
+
x = torch.cat(x_out, 1)
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
187 |
+
def condconv_initializer(weight):
|
188 |
+
"""CondConv initializer function."""
|
189 |
+
num_params = np.prod(expert_shape)
|
190 |
+
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
|
191 |
+
weight.shape[1] != num_params):
|
192 |
+
raise (ValueError(
|
193 |
+
'CondConv variables must have shape [num_experts, num_params]'))
|
194 |
+
for i in range(num_experts):
|
195 |
+
initializer(weight[i].view(expert_shape))
|
196 |
+
return condconv_initializer
|
197 |
+
|
198 |
+
|
199 |
+
class CondConv2d(nn.Module):
|
200 |
+
""" Conditional Convolution
|
201 |
+
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
202 |
+
|
203 |
+
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
204 |
+
https://github.com/pytorch/pytorch/issues/17983
|
205 |
+
"""
|
206 |
+
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
|
207 |
+
|
208 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
209 |
+
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
210 |
+
super(CondConv2d, self).__init__()
|
211 |
+
|
212 |
+
self.in_channels = in_channels
|
213 |
+
self.out_channels = out_channels
|
214 |
+
self.kernel_size = _pair(kernel_size)
|
215 |
+
self.stride = _pair(stride)
|
216 |
+
padding_val, is_padding_dynamic = get_padding_value(
|
217 |
+
padding, kernel_size, stride=stride, dilation=dilation)
|
218 |
+
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
|
219 |
+
self.padding = _pair(padding_val)
|
220 |
+
self.dilation = _pair(dilation)
|
221 |
+
self.groups = groups
|
222 |
+
self.num_experts = num_experts
|
223 |
+
|
224 |
+
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
225 |
+
weight_num_param = 1
|
226 |
+
for wd in self.weight_shape:
|
227 |
+
weight_num_param *= wd
|
228 |
+
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
|
229 |
+
|
230 |
+
if bias:
|
231 |
+
self.bias_shape = (self.out_channels,)
|
232 |
+
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
|
233 |
+
else:
|
234 |
+
self.register_parameter('bias', None)
|
235 |
+
|
236 |
+
self.reset_parameters()
|
237 |
+
|
238 |
+
def reset_parameters(self):
|
239 |
+
init_weight = get_condconv_initializer(
|
240 |
+
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
|
241 |
+
init_weight(self.weight)
|
242 |
+
if self.bias is not None:
|
243 |
+
fan_in = np.prod(self.weight_shape[1:])
|
244 |
+
bound = 1 / math.sqrt(fan_in)
|
245 |
+
init_bias = get_condconv_initializer(
|
246 |
+
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
|
247 |
+
init_bias(self.bias)
|
248 |
+
|
249 |
+
def forward(self, x, routing_weights):
|
250 |
+
B, C, H, W = x.shape
|
251 |
+
weight = torch.matmul(routing_weights, self.weight)
|
252 |
+
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
253 |
+
weight = weight.view(new_weight_shape)
|
254 |
+
bias = None
|
255 |
+
if self.bias is not None:
|
256 |
+
bias = torch.matmul(routing_weights, self.bias)
|
257 |
+
bias = bias.view(B * self.out_channels)
|
258 |
+
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
|
259 |
+
x = x.view(1, B * C, H, W)
|
260 |
+
if self.dynamic_padding:
|
261 |
+
out = conv2d_same(
|
262 |
+
x, weight, bias, stride=self.stride, padding=self.padding,
|
263 |
+
dilation=self.dilation, groups=self.groups * B)
|
264 |
+
else:
|
265 |
+
out = F.conv2d(
|
266 |
+
x, weight, bias, stride=self.stride, padding=self.padding,
|
267 |
+
dilation=self.dilation, groups=self.groups * B)
|
268 |
+
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
|
269 |
+
|
270 |
+
# Literal port (from TF definition)
|
271 |
+
# x = torch.split(x, 1, 0)
|
272 |
+
# weight = torch.split(weight, 1, 0)
|
273 |
+
# if self.bias is not None:
|
274 |
+
# bias = torch.matmul(routing_weights, self.bias)
|
275 |
+
# bias = torch.split(bias, 1, 0)
|
276 |
+
# else:
|
277 |
+
# bias = [None] * B
|
278 |
+
# out = []
|
279 |
+
# for xi, wi, bi in zip(x, weight, bias):
|
280 |
+
# wi = wi.view(*self.weight_shape)
|
281 |
+
# if bi is not None:
|
282 |
+
# bi = bi.view(*self.bias_shape)
|
283 |
+
# out.append(self.conv_fn(
|
284 |
+
# xi, wi, bi, stride=self.stride, padding=self.padding,
|
285 |
+
# dilation=self.dilation, groups=self.groups))
|
286 |
+
# out = torch.cat(out, 0)
|
287 |
+
return out
|
288 |
+
|
289 |
+
|
290 |
+
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
291 |
+
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
292 |
+
if isinstance(kernel_size, list):
|
293 |
+
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
|
294 |
+
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
295 |
+
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
296 |
+
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
297 |
+
else:
|
298 |
+
depthwise = kwargs.pop('depthwise', False)
|
299 |
+
groups = out_chs if depthwise else 1
|
300 |
+
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
301 |
+
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
302 |
+
else:
|
303 |
+
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
304 |
+
return m
|
geffnet/efficientnet_builder.py
ADDED
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" EfficientNet / MobileNetV3 Blocks and Builder
|
2 |
+
|
3 |
+
Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
|
8 |
+
from .conv2d_layers import *
|
9 |
+
from geffnet.activations import *
|
10 |
+
|
11 |
+
__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible',
|
12 |
+
'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv',
|
13 |
+
'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def',
|
14 |
+
'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'
|
15 |
+
]
|
16 |
+
|
17 |
+
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
18 |
+
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
19 |
+
# NOTE: momentum varies btw .99 and .9997 depending on source
|
20 |
+
# .99 in official TF TPU impl
|
21 |
+
# .9997 (/w .999 in search space) for paper
|
22 |
+
#
|
23 |
+
# PyTorch defaults are momentum = .1, eps = 1e-5
|
24 |
+
#
|
25 |
+
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
26 |
+
BN_EPS_TF_DEFAULT = 1e-3
|
27 |
+
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
28 |
+
|
29 |
+
|
30 |
+
def get_bn_args_tf():
|
31 |
+
return _BN_ARGS_TF.copy()
|
32 |
+
|
33 |
+
|
34 |
+
def resolve_bn_args(kwargs):
|
35 |
+
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
|
36 |
+
bn_momentum = kwargs.pop('bn_momentum', None)
|
37 |
+
if bn_momentum is not None:
|
38 |
+
bn_args['momentum'] = bn_momentum
|
39 |
+
bn_eps = kwargs.pop('bn_eps', None)
|
40 |
+
if bn_eps is not None:
|
41 |
+
bn_args['eps'] = bn_eps
|
42 |
+
return bn_args
|
43 |
+
|
44 |
+
|
45 |
+
_SE_ARGS_DEFAULT = dict(
|
46 |
+
gate_fn=sigmoid,
|
47 |
+
act_layer=None, # None == use containing block's activation layer
|
48 |
+
reduce_mid=False,
|
49 |
+
divisor=1)
|
50 |
+
|
51 |
+
|
52 |
+
def resolve_se_args(kwargs, in_chs, act_layer=None):
|
53 |
+
se_kwargs = kwargs.copy() if kwargs is not None else {}
|
54 |
+
# fill in args that aren't specified with the defaults
|
55 |
+
for k, v in _SE_ARGS_DEFAULT.items():
|
56 |
+
se_kwargs.setdefault(k, v)
|
57 |
+
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
58 |
+
if not se_kwargs.pop('reduce_mid'):
|
59 |
+
se_kwargs['reduced_base_chs'] = in_chs
|
60 |
+
# act_layer override, if it remains None, the containing block's act_layer will be used
|
61 |
+
if se_kwargs['act_layer'] is None:
|
62 |
+
assert act_layer is not None
|
63 |
+
se_kwargs['act_layer'] = act_layer
|
64 |
+
return se_kwargs
|
65 |
+
|
66 |
+
|
67 |
+
def resolve_act_layer(kwargs, default='relu'):
|
68 |
+
act_layer = kwargs.pop('act_layer', default)
|
69 |
+
if isinstance(act_layer, str):
|
70 |
+
act_layer = get_act_layer(act_layer)
|
71 |
+
return act_layer
|
72 |
+
|
73 |
+
|
74 |
+
def make_divisible(v: int, divisor: int = 8, min_value: int = None):
|
75 |
+
min_value = min_value or divisor
|
76 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
77 |
+
if new_v < 0.9 * v: # ensure round down does not go down by more than 10%.
|
78 |
+
new_v += divisor
|
79 |
+
return new_v
|
80 |
+
|
81 |
+
|
82 |
+
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
83 |
+
"""Round number of filters based on depth multiplier."""
|
84 |
+
if not multiplier:
|
85 |
+
return channels
|
86 |
+
channels *= multiplier
|
87 |
+
return make_divisible(channels, divisor, channel_min)
|
88 |
+
|
89 |
+
|
90 |
+
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
|
91 |
+
"""Apply drop connect."""
|
92 |
+
if not training:
|
93 |
+
return inputs
|
94 |
+
|
95 |
+
keep_prob = 1 - drop_connect_rate
|
96 |
+
random_tensor = keep_prob + torch.rand(
|
97 |
+
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
|
98 |
+
random_tensor.floor_() # binarize
|
99 |
+
output = inputs.div(keep_prob) * random_tensor
|
100 |
+
return output
|
101 |
+
|
102 |
+
|
103 |
+
class SqueezeExcite(nn.Module):
|
104 |
+
|
105 |
+
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1):
|
106 |
+
super(SqueezeExcite, self).__init__()
|
107 |
+
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
|
108 |
+
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
109 |
+
self.act1 = act_layer(inplace=True)
|
110 |
+
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
111 |
+
self.gate_fn = gate_fn
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
x_se = x.mean((2, 3), keepdim=True)
|
115 |
+
x_se = self.conv_reduce(x_se)
|
116 |
+
x_se = self.act1(x_se)
|
117 |
+
x_se = self.conv_expand(x_se)
|
118 |
+
x = x * self.gate_fn(x_se)
|
119 |
+
return x
|
120 |
+
|
121 |
+
|
122 |
+
class ConvBnAct(nn.Module):
|
123 |
+
def __init__(self, in_chs, out_chs, kernel_size,
|
124 |
+
stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
125 |
+
super(ConvBnAct, self).__init__()
|
126 |
+
assert stride in [1, 2]
|
127 |
+
norm_kwargs = norm_kwargs or {}
|
128 |
+
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
|
129 |
+
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
130 |
+
self.act1 = act_layer(inplace=True)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
x = self.conv(x)
|
134 |
+
x = self.bn1(x)
|
135 |
+
x = self.act1(x)
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class DepthwiseSeparableConv(nn.Module):
|
140 |
+
""" DepthwiseSeparable block
|
141 |
+
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
|
142 |
+
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
|
143 |
+
"""
|
144 |
+
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
145 |
+
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
146 |
+
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
|
147 |
+
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
148 |
+
super(DepthwiseSeparableConv, self).__init__()
|
149 |
+
assert stride in [1, 2]
|
150 |
+
norm_kwargs = norm_kwargs or {}
|
151 |
+
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
152 |
+
self.drop_connect_rate = drop_connect_rate
|
153 |
+
|
154 |
+
self.conv_dw = select_conv2d(
|
155 |
+
in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
|
156 |
+
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
157 |
+
self.act1 = act_layer(inplace=True)
|
158 |
+
|
159 |
+
# Squeeze-and-excitation
|
160 |
+
if se_ratio is not None and se_ratio > 0.:
|
161 |
+
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
162 |
+
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
|
163 |
+
else:
|
164 |
+
self.se = nn.Identity()
|
165 |
+
|
166 |
+
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
167 |
+
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
168 |
+
self.act2 = act_layer(inplace=True) if pw_act else nn.Identity()
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
residual = x
|
172 |
+
|
173 |
+
x = self.conv_dw(x)
|
174 |
+
x = self.bn1(x)
|
175 |
+
x = self.act1(x)
|
176 |
+
|
177 |
+
x = self.se(x)
|
178 |
+
|
179 |
+
x = self.conv_pw(x)
|
180 |
+
x = self.bn2(x)
|
181 |
+
x = self.act2(x)
|
182 |
+
|
183 |
+
if self.has_residual:
|
184 |
+
if self.drop_connect_rate > 0.:
|
185 |
+
x = drop_connect(x, self.training, self.drop_connect_rate)
|
186 |
+
x += residual
|
187 |
+
return x
|
188 |
+
|
189 |
+
|
190 |
+
class InvertedResidual(nn.Module):
|
191 |
+
""" Inverted residual block w/ optional SE"""
|
192 |
+
|
193 |
+
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
194 |
+
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
195 |
+
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
196 |
+
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
197 |
+
conv_kwargs=None, drop_connect_rate=0.):
|
198 |
+
super(InvertedResidual, self).__init__()
|
199 |
+
norm_kwargs = norm_kwargs or {}
|
200 |
+
conv_kwargs = conv_kwargs or {}
|
201 |
+
mid_chs: int = make_divisible(in_chs * exp_ratio)
|
202 |
+
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
203 |
+
self.drop_connect_rate = drop_connect_rate
|
204 |
+
|
205 |
+
# Point-wise expansion
|
206 |
+
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
207 |
+
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
208 |
+
self.act1 = act_layer(inplace=True)
|
209 |
+
|
210 |
+
# Depth-wise convolution
|
211 |
+
self.conv_dw = select_conv2d(
|
212 |
+
mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs)
|
213 |
+
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
214 |
+
self.act2 = act_layer(inplace=True)
|
215 |
+
|
216 |
+
# Squeeze-and-excitation
|
217 |
+
if se_ratio is not None and se_ratio > 0.:
|
218 |
+
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
219 |
+
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
220 |
+
else:
|
221 |
+
self.se = nn.Identity() # for jit.script compat
|
222 |
+
|
223 |
+
# Point-wise linear projection
|
224 |
+
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
225 |
+
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
residual = x
|
229 |
+
|
230 |
+
# Point-wise expansion
|
231 |
+
x = self.conv_pw(x)
|
232 |
+
x = self.bn1(x)
|
233 |
+
x = self.act1(x)
|
234 |
+
|
235 |
+
# Depth-wise convolution
|
236 |
+
x = self.conv_dw(x)
|
237 |
+
x = self.bn2(x)
|
238 |
+
x = self.act2(x)
|
239 |
+
|
240 |
+
# Squeeze-and-excitation
|
241 |
+
x = self.se(x)
|
242 |
+
|
243 |
+
# Point-wise linear projection
|
244 |
+
x = self.conv_pwl(x)
|
245 |
+
x = self.bn3(x)
|
246 |
+
|
247 |
+
if self.has_residual:
|
248 |
+
if self.drop_connect_rate > 0.:
|
249 |
+
x = drop_connect(x, self.training, self.drop_connect_rate)
|
250 |
+
x += residual
|
251 |
+
return x
|
252 |
+
|
253 |
+
|
254 |
+
class CondConvResidual(InvertedResidual):
|
255 |
+
""" Inverted residual block w/ CondConv routing"""
|
256 |
+
|
257 |
+
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
258 |
+
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
259 |
+
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
260 |
+
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
261 |
+
num_experts=0, drop_connect_rate=0.):
|
262 |
+
|
263 |
+
self.num_experts = num_experts
|
264 |
+
conv_kwargs = dict(num_experts=self.num_experts)
|
265 |
+
|
266 |
+
super(CondConvResidual, self).__init__(
|
267 |
+
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type,
|
268 |
+
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
269 |
+
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
|
270 |
+
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
271 |
+
drop_connect_rate=drop_connect_rate)
|
272 |
+
|
273 |
+
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
residual = x
|
277 |
+
|
278 |
+
# CondConv routing
|
279 |
+
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
|
280 |
+
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
|
281 |
+
|
282 |
+
# Point-wise expansion
|
283 |
+
x = self.conv_pw(x, routing_weights)
|
284 |
+
x = self.bn1(x)
|
285 |
+
x = self.act1(x)
|
286 |
+
|
287 |
+
# Depth-wise convolution
|
288 |
+
x = self.conv_dw(x, routing_weights)
|
289 |
+
x = self.bn2(x)
|
290 |
+
x = self.act2(x)
|
291 |
+
|
292 |
+
# Squeeze-and-excitation
|
293 |
+
x = self.se(x)
|
294 |
+
|
295 |
+
# Point-wise linear projection
|
296 |
+
x = self.conv_pwl(x, routing_weights)
|
297 |
+
x = self.bn3(x)
|
298 |
+
|
299 |
+
if self.has_residual:
|
300 |
+
if self.drop_connect_rate > 0.:
|
301 |
+
x = drop_connect(x, self.training, self.drop_connect_rate)
|
302 |
+
x += residual
|
303 |
+
return x
|
304 |
+
|
305 |
+
|
306 |
+
class EdgeResidual(nn.Module):
|
307 |
+
""" EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride"""
|
308 |
+
|
309 |
+
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
310 |
+
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
311 |
+
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
312 |
+
super(EdgeResidual, self).__init__()
|
313 |
+
norm_kwargs = norm_kwargs or {}
|
314 |
+
mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio)
|
315 |
+
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
316 |
+
self.drop_connect_rate = drop_connect_rate
|
317 |
+
|
318 |
+
# Expansion convolution
|
319 |
+
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
320 |
+
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
321 |
+
self.act1 = act_layer(inplace=True)
|
322 |
+
|
323 |
+
# Squeeze-and-excitation
|
324 |
+
if se_ratio is not None and se_ratio > 0.:
|
325 |
+
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
326 |
+
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
327 |
+
else:
|
328 |
+
self.se = nn.Identity()
|
329 |
+
|
330 |
+
# Point-wise linear projection
|
331 |
+
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
|
332 |
+
self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs)
|
333 |
+
|
334 |
+
def forward(self, x):
|
335 |
+
residual = x
|
336 |
+
|
337 |
+
# Expansion convolution
|
338 |
+
x = self.conv_exp(x)
|
339 |
+
x = self.bn1(x)
|
340 |
+
x = self.act1(x)
|
341 |
+
|
342 |
+
# Squeeze-and-excitation
|
343 |
+
x = self.se(x)
|
344 |
+
|
345 |
+
# Point-wise linear projection
|
346 |
+
x = self.conv_pwl(x)
|
347 |
+
x = self.bn2(x)
|
348 |
+
|
349 |
+
if self.has_residual:
|
350 |
+
if self.drop_connect_rate > 0.:
|
351 |
+
x = drop_connect(x, self.training, self.drop_connect_rate)
|
352 |
+
x += residual
|
353 |
+
|
354 |
+
return x
|
355 |
+
|
356 |
+
|
357 |
+
class EfficientNetBuilder:
|
358 |
+
""" Build Trunk Blocks for Efficient/Mobile Networks
|
359 |
+
|
360 |
+
This ended up being somewhat of a cross between
|
361 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
362 |
+
and
|
363 |
+
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
364 |
+
|
365 |
+
"""
|
366 |
+
|
367 |
+
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
368 |
+
pad_type='', act_layer=None, se_kwargs=None,
|
369 |
+
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
370 |
+
self.channel_multiplier = channel_multiplier
|
371 |
+
self.channel_divisor = channel_divisor
|
372 |
+
self.channel_min = channel_min
|
373 |
+
self.pad_type = pad_type
|
374 |
+
self.act_layer = act_layer
|
375 |
+
self.se_kwargs = se_kwargs
|
376 |
+
self.norm_layer = norm_layer
|
377 |
+
self.norm_kwargs = norm_kwargs
|
378 |
+
self.drop_connect_rate = drop_connect_rate
|
379 |
+
|
380 |
+
# updated during build
|
381 |
+
self.in_chs = None
|
382 |
+
self.block_idx = 0
|
383 |
+
self.block_count = 0
|
384 |
+
|
385 |
+
def _round_channels(self, chs):
|
386 |
+
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
387 |
+
|
388 |
+
def _make_block(self, ba):
|
389 |
+
bt = ba.pop('block_type')
|
390 |
+
ba['in_chs'] = self.in_chs
|
391 |
+
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
392 |
+
if 'fake_in_chs' in ba and ba['fake_in_chs']:
|
393 |
+
# FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU
|
394 |
+
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
|
395 |
+
ba['norm_layer'] = self.norm_layer
|
396 |
+
ba['norm_kwargs'] = self.norm_kwargs
|
397 |
+
ba['pad_type'] = self.pad_type
|
398 |
+
# block act fn overrides the model default
|
399 |
+
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
400 |
+
assert ba['act_layer'] is not None
|
401 |
+
if bt == 'ir':
|
402 |
+
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
403 |
+
ba['se_kwargs'] = self.se_kwargs
|
404 |
+
if ba.get('num_experts', 0) > 0:
|
405 |
+
block = CondConvResidual(**ba)
|
406 |
+
else:
|
407 |
+
block = InvertedResidual(**ba)
|
408 |
+
elif bt == 'ds' or bt == 'dsa':
|
409 |
+
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
410 |
+
ba['se_kwargs'] = self.se_kwargs
|
411 |
+
block = DepthwiseSeparableConv(**ba)
|
412 |
+
elif bt == 'er':
|
413 |
+
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
|
414 |
+
ba['se_kwargs'] = self.se_kwargs
|
415 |
+
block = EdgeResidual(**ba)
|
416 |
+
elif bt == 'cn':
|
417 |
+
block = ConvBnAct(**ba)
|
418 |
+
else:
|
419 |
+
assert False, 'Uknkown block type (%s) while building model.' % bt
|
420 |
+
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
421 |
+
return block
|
422 |
+
|
423 |
+
def _make_stack(self, stack_args):
|
424 |
+
blocks = []
|
425 |
+
# each stack (stage) contains a list of block arguments
|
426 |
+
for i, ba in enumerate(stack_args):
|
427 |
+
if i >= 1:
|
428 |
+
# only the first block in any stack can have a stride > 1
|
429 |
+
ba['stride'] = 1
|
430 |
+
block = self._make_block(ba)
|
431 |
+
blocks.append(block)
|
432 |
+
self.block_idx += 1 # incr global idx (across all stacks)
|
433 |
+
return nn.Sequential(*blocks)
|
434 |
+
|
435 |
+
def __call__(self, in_chs, block_args):
|
436 |
+
""" Build the blocks
|
437 |
+
Args:
|
438 |
+
in_chs: Number of input-channels passed to first block
|
439 |
+
block_args: A list of lists, outer list defines stages, inner
|
440 |
+
list contains strings defining block configuration(s)
|
441 |
+
Return:
|
442 |
+
List of block stacks (each stack wrapped in nn.Sequential)
|
443 |
+
"""
|
444 |
+
self.in_chs = in_chs
|
445 |
+
self.block_count = sum([len(x) for x in block_args])
|
446 |
+
self.block_idx = 0
|
447 |
+
blocks = []
|
448 |
+
# outer list of block_args defines the stacks ('stages' by some conventions)
|
449 |
+
for stack_idx, stack in enumerate(block_args):
|
450 |
+
assert isinstance(stack, list)
|
451 |
+
stack = self._make_stack(stack)
|
452 |
+
blocks.append(stack)
|
453 |
+
return blocks
|
454 |
+
|
455 |
+
|
456 |
+
def _parse_ksize(ss):
|
457 |
+
if ss.isdigit():
|
458 |
+
return int(ss)
|
459 |
+
else:
|
460 |
+
return [int(k) for k in ss.split('.')]
|
461 |
+
|
462 |
+
|
463 |
+
def _decode_block_str(block_str):
|
464 |
+
""" Decode block definition string
|
465 |
+
|
466 |
+
Gets a list of block arg (dicts) through a string notation of arguments.
|
467 |
+
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
|
468 |
+
|
469 |
+
All args can exist in any order with the exception of the leading string which
|
470 |
+
is assumed to indicate the block type.
|
471 |
+
|
472 |
+
leading string - block type (
|
473 |
+
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
|
474 |
+
r - number of repeat blocks,
|
475 |
+
k - kernel size,
|
476 |
+
s - strides (1-9),
|
477 |
+
e - expansion ratio,
|
478 |
+
c - output channels,
|
479 |
+
se - squeeze/excitation ratio
|
480 |
+
n - activation fn ('re', 'r6', 'hs', or 'sw')
|
481 |
+
Args:
|
482 |
+
block_str: a string representation of block arguments.
|
483 |
+
Returns:
|
484 |
+
A list of block args (dicts)
|
485 |
+
Raises:
|
486 |
+
ValueError: if the string def not properly specified (TODO)
|
487 |
+
"""
|
488 |
+
assert isinstance(block_str, str)
|
489 |
+
ops = block_str.split('_')
|
490 |
+
block_type = ops[0] # take the block type off the front
|
491 |
+
ops = ops[1:]
|
492 |
+
options = {}
|
493 |
+
noskip = False
|
494 |
+
for op in ops:
|
495 |
+
# string options being checked on individual basis, combine if they grow
|
496 |
+
if op == 'noskip':
|
497 |
+
noskip = True
|
498 |
+
elif op.startswith('n'):
|
499 |
+
# activation fn
|
500 |
+
key = op[0]
|
501 |
+
v = op[1:]
|
502 |
+
if v == 're':
|
503 |
+
value = get_act_layer('relu')
|
504 |
+
elif v == 'r6':
|
505 |
+
value = get_act_layer('relu6')
|
506 |
+
elif v == 'hs':
|
507 |
+
value = get_act_layer('hard_swish')
|
508 |
+
elif v == 'sw':
|
509 |
+
value = get_act_layer('swish')
|
510 |
+
else:
|
511 |
+
continue
|
512 |
+
options[key] = value
|
513 |
+
else:
|
514 |
+
# all numeric options
|
515 |
+
splits = re.split(r'(\d.*)', op)
|
516 |
+
if len(splits) >= 2:
|
517 |
+
key, value = splits[:2]
|
518 |
+
options[key] = value
|
519 |
+
|
520 |
+
# if act_layer is None, the model default (passed to model init) will be used
|
521 |
+
act_layer = options['n'] if 'n' in options else None
|
522 |
+
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
523 |
+
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
524 |
+
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
525 |
+
|
526 |
+
num_repeat = int(options['r'])
|
527 |
+
# each type of block has different valid arguments, fill accordingly
|
528 |
+
if block_type == 'ir':
|
529 |
+
block_args = dict(
|
530 |
+
block_type=block_type,
|
531 |
+
dw_kernel_size=_parse_ksize(options['k']),
|
532 |
+
exp_kernel_size=exp_kernel_size,
|
533 |
+
pw_kernel_size=pw_kernel_size,
|
534 |
+
out_chs=int(options['c']),
|
535 |
+
exp_ratio=float(options['e']),
|
536 |
+
se_ratio=float(options['se']) if 'se' in options else None,
|
537 |
+
stride=int(options['s']),
|
538 |
+
act_layer=act_layer,
|
539 |
+
noskip=noskip,
|
540 |
+
)
|
541 |
+
if 'cc' in options:
|
542 |
+
block_args['num_experts'] = int(options['cc'])
|
543 |
+
elif block_type == 'ds' or block_type == 'dsa':
|
544 |
+
block_args = dict(
|
545 |
+
block_type=block_type,
|
546 |
+
dw_kernel_size=_parse_ksize(options['k']),
|
547 |
+
pw_kernel_size=pw_kernel_size,
|
548 |
+
out_chs=int(options['c']),
|
549 |
+
se_ratio=float(options['se']) if 'se' in options else None,
|
550 |
+
stride=int(options['s']),
|
551 |
+
act_layer=act_layer,
|
552 |
+
pw_act=block_type == 'dsa',
|
553 |
+
noskip=block_type == 'dsa' or noskip,
|
554 |
+
)
|
555 |
+
elif block_type == 'er':
|
556 |
+
block_args = dict(
|
557 |
+
block_type=block_type,
|
558 |
+
exp_kernel_size=_parse_ksize(options['k']),
|
559 |
+
pw_kernel_size=pw_kernel_size,
|
560 |
+
out_chs=int(options['c']),
|
561 |
+
exp_ratio=float(options['e']),
|
562 |
+
fake_in_chs=fake_in_chs,
|
563 |
+
se_ratio=float(options['se']) if 'se' in options else None,
|
564 |
+
stride=int(options['s']),
|
565 |
+
act_layer=act_layer,
|
566 |
+
noskip=noskip,
|
567 |
+
)
|
568 |
+
elif block_type == 'cn':
|
569 |
+
block_args = dict(
|
570 |
+
block_type=block_type,
|
571 |
+
kernel_size=int(options['k']),
|
572 |
+
out_chs=int(options['c']),
|
573 |
+
stride=int(options['s']),
|
574 |
+
act_layer=act_layer,
|
575 |
+
)
|
576 |
+
else:
|
577 |
+
assert False, 'Unknown block type (%s)' % block_type
|
578 |
+
|
579 |
+
return block_args, num_repeat
|
580 |
+
|
581 |
+
|
582 |
+
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
583 |
+
""" Per-stage depth scaling
|
584 |
+
Scales the block repeats in each stage. This depth scaling impl maintains
|
585 |
+
compatibility with the EfficientNet scaling method, while allowing sensible
|
586 |
+
scaling for other models that may have multiple block arg definitions in each stage.
|
587 |
+
"""
|
588 |
+
|
589 |
+
# We scale the total repeat count for each stage, there may be multiple
|
590 |
+
# block arg defs per stage so we need to sum.
|
591 |
+
num_repeat = sum(repeats)
|
592 |
+
if depth_trunc == 'round':
|
593 |
+
# Truncating to int by rounding allows stages with few repeats to remain
|
594 |
+
# proportionally smaller for longer. This is a good choice when stage definitions
|
595 |
+
# include single repeat stages that we'd prefer to keep that way as long as possible
|
596 |
+
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
597 |
+
else:
|
598 |
+
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
599 |
+
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
600 |
+
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
601 |
+
|
602 |
+
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
603 |
+
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
604 |
+
# The first block makes less sense to repeat in most of the arch definitions.
|
605 |
+
repeats_scaled = []
|
606 |
+
for r in repeats[::-1]:
|
607 |
+
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
608 |
+
repeats_scaled.append(rs)
|
609 |
+
num_repeat -= r
|
610 |
+
num_repeat_scaled -= rs
|
611 |
+
repeats_scaled = repeats_scaled[::-1]
|
612 |
+
|
613 |
+
# Apply the calculated scaling to each block arg in the stage
|
614 |
+
sa_scaled = []
|
615 |
+
for ba, rep in zip(stack_args, repeats_scaled):
|
616 |
+
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
617 |
+
return sa_scaled
|
618 |
+
|
619 |
+
|
620 |
+
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
|
621 |
+
arch_args = []
|
622 |
+
for stack_idx, block_strings in enumerate(arch_def):
|
623 |
+
assert isinstance(block_strings, list)
|
624 |
+
stack_args = []
|
625 |
+
repeats = []
|
626 |
+
for block_str in block_strings:
|
627 |
+
assert isinstance(block_str, str)
|
628 |
+
ba, rep = _decode_block_str(block_str)
|
629 |
+
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
|
630 |
+
ba['num_experts'] *= experts_multiplier
|
631 |
+
stack_args.append(ba)
|
632 |
+
repeats.append(rep)
|
633 |
+
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
|
634 |
+
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
|
635 |
+
else:
|
636 |
+
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
637 |
+
return arch_args
|
638 |
+
|
639 |
+
|
640 |
+
def initialize_weight_goog(m, n='', fix_group_fanout=True):
|
641 |
+
# weight init as per Tensorflow Official impl
|
642 |
+
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
643 |
+
if isinstance(m, CondConv2d):
|
644 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
645 |
+
if fix_group_fanout:
|
646 |
+
fan_out //= m.groups
|
647 |
+
init_weight_fn = get_condconv_initializer(
|
648 |
+
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
649 |
+
init_weight_fn(m.weight)
|
650 |
+
if m.bias is not None:
|
651 |
+
m.bias.data.zero_()
|
652 |
+
elif isinstance(m, nn.Conv2d):
|
653 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
654 |
+
if fix_group_fanout:
|
655 |
+
fan_out //= m.groups
|
656 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
657 |
+
if m.bias is not None:
|
658 |
+
m.bias.data.zero_()
|
659 |
+
elif isinstance(m, nn.BatchNorm2d):
|
660 |
+
m.weight.data.fill_(1.0)
|
661 |
+
m.bias.data.zero_()
|
662 |
+
elif isinstance(m, nn.Linear):
|
663 |
+
fan_out = m.weight.size(0) # fan-out
|
664 |
+
fan_in = 0
|
665 |
+
if 'routing_fn' in n:
|
666 |
+
fan_in = m.weight.size(1)
|
667 |
+
init_range = 1.0 / math.sqrt(fan_in + fan_out)
|
668 |
+
m.weight.data.uniform_(-init_range, init_range)
|
669 |
+
m.bias.data.zero_()
|
670 |
+
|
671 |
+
|
672 |
+
def initialize_weight_default(m, n=''):
|
673 |
+
if isinstance(m, CondConv2d):
|
674 |
+
init_fn = get_condconv_initializer(partial(
|
675 |
+
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
|
676 |
+
init_fn(m.weight)
|
677 |
+
elif isinstance(m, nn.Conv2d):
|
678 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
679 |
+
elif isinstance(m, nn.BatchNorm2d):
|
680 |
+
m.weight.data.fill_(1.0)
|
681 |
+
m.bias.data.zero_()
|
682 |
+
elif isinstance(m, nn.Linear):
|
683 |
+
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
|
geffnet/gen_efficientnet.py
ADDED
@@ -0,0 +1,1450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Generic Efficient Networks
|
2 |
+
|
3 |
+
A generic MobileNet class with building blocks to support a variety of models:
|
4 |
+
|
5 |
+
* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports)
|
6 |
+
- EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
|
7 |
+
- CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
|
8 |
+
- Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
|
9 |
+
- Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252
|
10 |
+
|
11 |
+
* EfficientNet-Lite
|
12 |
+
|
13 |
+
* MixNet (Small, Medium, and Large)
|
14 |
+
- MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595
|
15 |
+
|
16 |
+
* MNasNet B1, A1 (SE), Small
|
17 |
+
- MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626
|
18 |
+
|
19 |
+
* FBNet-C
|
20 |
+
- FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443
|
21 |
+
|
22 |
+
* Single-Path NAS Pixel1
|
23 |
+
- Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877
|
24 |
+
|
25 |
+
* And likely more...
|
26 |
+
|
27 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
28 |
+
"""
|
29 |
+
import torch.nn as nn
|
30 |
+
import torch.nn.functional as F
|
31 |
+
|
32 |
+
from .config import layer_config_kwargs, is_scriptable
|
33 |
+
from .conv2d_layers import select_conv2d
|
34 |
+
from .helpers import load_pretrained
|
35 |
+
from .efficientnet_builder import *
|
36 |
+
|
37 |
+
__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140',
|
38 |
+
'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small',
|
39 |
+
'mobilenetv2_100', 'mobilenetv2_140', 'mobilenetv2_110d', 'mobilenetv2_120d',
|
40 |
+
'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3',
|
41 |
+
'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8',
|
42 |
+
'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el',
|
43 |
+
'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e',
|
44 |
+
'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4',
|
45 |
+
'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3',
|
46 |
+
'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8',
|
47 |
+
'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap',
|
48 |
+
'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap',
|
49 |
+
'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns',
|
50 |
+
'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns',
|
51 |
+
'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475',
|
52 |
+
'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el',
|
53 |
+
'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e',
|
54 |
+
'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3',
|
55 |
+
'tf_efficientnet_lite4',
|
56 |
+
'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l']
|
57 |
+
|
58 |
+
|
59 |
+
model_urls = {
|
60 |
+
'mnasnet_050': None,
|
61 |
+
'mnasnet_075': None,
|
62 |
+
'mnasnet_100':
|
63 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
|
64 |
+
'mnasnet_140': None,
|
65 |
+
'mnasnet_small': None,
|
66 |
+
|
67 |
+
'semnasnet_050': None,
|
68 |
+
'semnasnet_075': None,
|
69 |
+
'semnasnet_100':
|
70 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
|
71 |
+
'semnasnet_140': None,
|
72 |
+
|
73 |
+
'mobilenetv2_100':
|
74 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth',
|
75 |
+
'mobilenetv2_110d':
|
76 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth',
|
77 |
+
'mobilenetv2_120d':
|
78 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth',
|
79 |
+
'mobilenetv2_140':
|
80 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth',
|
81 |
+
|
82 |
+
'fbnetc_100':
|
83 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
|
84 |
+
'spnasnet_100':
|
85 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
|
86 |
+
|
87 |
+
'efficientnet_b0':
|
88 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth',
|
89 |
+
'efficientnet_b1':
|
90 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
|
91 |
+
'efficientnet_b2':
|
92 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth',
|
93 |
+
'efficientnet_b3':
|
94 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth',
|
95 |
+
'efficientnet_b4': None,
|
96 |
+
'efficientnet_b5': None,
|
97 |
+
'efficientnet_b6': None,
|
98 |
+
'efficientnet_b7': None,
|
99 |
+
'efficientnet_b8': None,
|
100 |
+
'efficientnet_l2': None,
|
101 |
+
|
102 |
+
'efficientnet_es':
|
103 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
|
104 |
+
'efficientnet_em': None,
|
105 |
+
'efficientnet_el': None,
|
106 |
+
|
107 |
+
'efficientnet_cc_b0_4e': None,
|
108 |
+
'efficientnet_cc_b0_8e': None,
|
109 |
+
'efficientnet_cc_b1_8e': None,
|
110 |
+
|
111 |
+
'efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth',
|
112 |
+
'efficientnet_lite1': None,
|
113 |
+
'efficientnet_lite2': None,
|
114 |
+
'efficientnet_lite3': None,
|
115 |
+
'efficientnet_lite4': None,
|
116 |
+
|
117 |
+
'tf_efficientnet_b0':
|
118 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
119 |
+
'tf_efficientnet_b1':
|
120 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
|
121 |
+
'tf_efficientnet_b2':
|
122 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
|
123 |
+
'tf_efficientnet_b3':
|
124 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
|
125 |
+
'tf_efficientnet_b4':
|
126 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
|
127 |
+
'tf_efficientnet_b5':
|
128 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
|
129 |
+
'tf_efficientnet_b6':
|
130 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
|
131 |
+
'tf_efficientnet_b7':
|
132 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
|
133 |
+
'tf_efficientnet_b8':
|
134 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
|
135 |
+
|
136 |
+
'tf_efficientnet_b0_ap':
|
137 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
|
138 |
+
'tf_efficientnet_b1_ap':
|
139 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth',
|
140 |
+
'tf_efficientnet_b2_ap':
|
141 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth',
|
142 |
+
'tf_efficientnet_b3_ap':
|
143 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth',
|
144 |
+
'tf_efficientnet_b4_ap':
|
145 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth',
|
146 |
+
'tf_efficientnet_b5_ap':
|
147 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth',
|
148 |
+
'tf_efficientnet_b6_ap':
|
149 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth',
|
150 |
+
'tf_efficientnet_b7_ap':
|
151 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth',
|
152 |
+
'tf_efficientnet_b8_ap':
|
153 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
|
154 |
+
|
155 |
+
'tf_efficientnet_b0_ns':
|
156 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
|
157 |
+
'tf_efficientnet_b1_ns':
|
158 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
|
159 |
+
'tf_efficientnet_b2_ns':
|
160 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
|
161 |
+
'tf_efficientnet_b3_ns':
|
162 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
|
163 |
+
'tf_efficientnet_b4_ns':
|
164 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
|
165 |
+
'tf_efficientnet_b5_ns':
|
166 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
|
167 |
+
'tf_efficientnet_b6_ns':
|
168 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
|
169 |
+
'tf_efficientnet_b7_ns':
|
170 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
|
171 |
+
'tf_efficientnet_l2_ns_475':
|
172 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
|
173 |
+
'tf_efficientnet_l2_ns':
|
174 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
|
175 |
+
|
176 |
+
'tf_efficientnet_es':
|
177 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
|
178 |
+
'tf_efficientnet_em':
|
179 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
|
180 |
+
'tf_efficientnet_el':
|
181 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
|
182 |
+
|
183 |
+
'tf_efficientnet_cc_b0_4e':
|
184 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth',
|
185 |
+
'tf_efficientnet_cc_b0_8e':
|
186 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth',
|
187 |
+
'tf_efficientnet_cc_b1_8e':
|
188 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',
|
189 |
+
|
190 |
+
'tf_efficientnet_lite0':
|
191 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
|
192 |
+
'tf_efficientnet_lite1':
|
193 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
|
194 |
+
'tf_efficientnet_lite2':
|
195 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
|
196 |
+
'tf_efficientnet_lite3':
|
197 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
|
198 |
+
'tf_efficientnet_lite4':
|
199 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',
|
200 |
+
|
201 |
+
'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth',
|
202 |
+
'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth',
|
203 |
+
'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth',
|
204 |
+
'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth',
|
205 |
+
|
206 |
+
'tf_mixnet_s':
|
207 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth',
|
208 |
+
'tf_mixnet_m':
|
209 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth',
|
210 |
+
'tf_mixnet_l':
|
211 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth',
|
212 |
+
}
|
213 |
+
|
214 |
+
|
215 |
+
class GenEfficientNet(nn.Module):
|
216 |
+
""" Generic EfficientNets
|
217 |
+
|
218 |
+
An implementation of mobile optimized networks that covers:
|
219 |
+
* EfficientNet (B0-B8, L2, CondConv, EdgeTPU)
|
220 |
+
* MixNet (Small, Medium, and Large, XL)
|
221 |
+
* MNASNet A1, B1, and small
|
222 |
+
* FBNet C
|
223 |
+
* Single-Path NAS Pixel1
|
224 |
+
"""
|
225 |
+
|
226 |
+
def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False,
|
227 |
+
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
228 |
+
pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
|
229 |
+
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
230 |
+
weight_init='goog'):
|
231 |
+
super(GenEfficientNet, self).__init__()
|
232 |
+
self.drop_rate = drop_rate
|
233 |
+
|
234 |
+
if not fix_stem:
|
235 |
+
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
236 |
+
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
237 |
+
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
238 |
+
self.act1 = act_layer(inplace=True)
|
239 |
+
in_chs = stem_size
|
240 |
+
|
241 |
+
builder = EfficientNetBuilder(
|
242 |
+
channel_multiplier, channel_divisor, channel_min,
|
243 |
+
pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate)
|
244 |
+
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
245 |
+
in_chs = builder.in_chs
|
246 |
+
|
247 |
+
self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type)
|
248 |
+
self.bn2 = norm_layer(num_features, **norm_kwargs)
|
249 |
+
self.act2 = act_layer(inplace=True)
|
250 |
+
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
251 |
+
self.classifier = nn.Linear(num_features, num_classes)
|
252 |
+
|
253 |
+
for n, m in self.named_modules():
|
254 |
+
if weight_init == 'goog':
|
255 |
+
initialize_weight_goog(m, n)
|
256 |
+
else:
|
257 |
+
initialize_weight_default(m, n)
|
258 |
+
|
259 |
+
def features(self, x):
|
260 |
+
x = self.conv_stem(x)
|
261 |
+
x = self.bn1(x)
|
262 |
+
x = self.act1(x)
|
263 |
+
x = self.blocks(x)
|
264 |
+
x = self.conv_head(x)
|
265 |
+
x = self.bn2(x)
|
266 |
+
x = self.act2(x)
|
267 |
+
return x
|
268 |
+
|
269 |
+
def as_sequential(self):
|
270 |
+
layers = [self.conv_stem, self.bn1, self.act1]
|
271 |
+
layers.extend(self.blocks)
|
272 |
+
layers.extend([
|
273 |
+
self.conv_head, self.bn2, self.act2,
|
274 |
+
self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
|
275 |
+
return nn.Sequential(*layers)
|
276 |
+
|
277 |
+
def forward(self, x):
|
278 |
+
x = self.features(x)
|
279 |
+
x = self.global_pool(x)
|
280 |
+
x = x.flatten(1)
|
281 |
+
if self.drop_rate > 0.:
|
282 |
+
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
283 |
+
return self.classifier(x)
|
284 |
+
|
285 |
+
|
286 |
+
def _create_model(model_kwargs, variant, pretrained=False):
|
287 |
+
as_sequential = model_kwargs.pop('as_sequential', False)
|
288 |
+
model = GenEfficientNet(**model_kwargs)
|
289 |
+
if pretrained:
|
290 |
+
load_pretrained(model, model_urls[variant])
|
291 |
+
if as_sequential:
|
292 |
+
model = model.as_sequential()
|
293 |
+
return model
|
294 |
+
|
295 |
+
|
296 |
+
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
297 |
+
"""Creates a mnasnet-a1 model.
|
298 |
+
|
299 |
+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
300 |
+
Paper: https://arxiv.org/pdf/1807.11626.pdf.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
channel_multiplier: multiplier to number of channels per layer.
|
304 |
+
"""
|
305 |
+
arch_def = [
|
306 |
+
# stage 0, 112x112 in
|
307 |
+
['ds_r1_k3_s1_e1_c16_noskip'],
|
308 |
+
# stage 1, 112x112 in
|
309 |
+
['ir_r2_k3_s2_e6_c24'],
|
310 |
+
# stage 2, 56x56 in
|
311 |
+
['ir_r3_k5_s2_e3_c40_se0.25'],
|
312 |
+
# stage 3, 28x28 in
|
313 |
+
['ir_r4_k3_s2_e6_c80'],
|
314 |
+
# stage 4, 14x14in
|
315 |
+
['ir_r2_k3_s1_e6_c112_se0.25'],
|
316 |
+
# stage 5, 14x14in
|
317 |
+
['ir_r3_k5_s2_e6_c160_se0.25'],
|
318 |
+
# stage 6, 7x7 in
|
319 |
+
['ir_r1_k3_s1_e6_c320'],
|
320 |
+
]
|
321 |
+
with layer_config_kwargs(kwargs):
|
322 |
+
model_kwargs = dict(
|
323 |
+
block_args=decode_arch_def(arch_def),
|
324 |
+
stem_size=32,
|
325 |
+
channel_multiplier=channel_multiplier,
|
326 |
+
act_layer=resolve_act_layer(kwargs, 'relu'),
|
327 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
328 |
+
**kwargs
|
329 |
+
)
|
330 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
331 |
+
return model
|
332 |
+
|
333 |
+
|
334 |
+
def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
335 |
+
"""Creates a mnasnet-b1 model.
|
336 |
+
|
337 |
+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
338 |
+
Paper: https://arxiv.org/pdf/1807.11626.pdf.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
channel_multiplier: multiplier to number of channels per layer.
|
342 |
+
"""
|
343 |
+
arch_def = [
|
344 |
+
# stage 0, 112x112 in
|
345 |
+
['ds_r1_k3_s1_c16_noskip'],
|
346 |
+
# stage 1, 112x112 in
|
347 |
+
['ir_r3_k3_s2_e3_c24'],
|
348 |
+
# stage 2, 56x56 in
|
349 |
+
['ir_r3_k5_s2_e3_c40'],
|
350 |
+
# stage 3, 28x28 in
|
351 |
+
['ir_r3_k5_s2_e6_c80'],
|
352 |
+
# stage 4, 14x14in
|
353 |
+
['ir_r2_k3_s1_e6_c96'],
|
354 |
+
# stage 5, 14x14in
|
355 |
+
['ir_r4_k5_s2_e6_c192'],
|
356 |
+
# stage 6, 7x7 in
|
357 |
+
['ir_r1_k3_s1_e6_c320_noskip']
|
358 |
+
]
|
359 |
+
with layer_config_kwargs(kwargs):
|
360 |
+
model_kwargs = dict(
|
361 |
+
block_args=decode_arch_def(arch_def),
|
362 |
+
stem_size=32,
|
363 |
+
channel_multiplier=channel_multiplier,
|
364 |
+
act_layer=resolve_act_layer(kwargs, 'relu'),
|
365 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
366 |
+
**kwargs
|
367 |
+
)
|
368 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
369 |
+
return model
|
370 |
+
|
371 |
+
|
372 |
+
def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
373 |
+
"""Creates a mnasnet-b1 model.
|
374 |
+
|
375 |
+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
376 |
+
Paper: https://arxiv.org/pdf/1807.11626.pdf.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
channel_multiplier: multiplier to number of channels per layer.
|
380 |
+
"""
|
381 |
+
arch_def = [
|
382 |
+
['ds_r1_k3_s1_c8'],
|
383 |
+
['ir_r1_k3_s2_e3_c16'],
|
384 |
+
['ir_r2_k3_s2_e6_c16'],
|
385 |
+
['ir_r4_k5_s2_e6_c32_se0.25'],
|
386 |
+
['ir_r3_k3_s1_e6_c32_se0.25'],
|
387 |
+
['ir_r3_k5_s2_e6_c88_se0.25'],
|
388 |
+
['ir_r1_k3_s1_e6_c144']
|
389 |
+
]
|
390 |
+
with layer_config_kwargs(kwargs):
|
391 |
+
model_kwargs = dict(
|
392 |
+
block_args=decode_arch_def(arch_def),
|
393 |
+
stem_size=8,
|
394 |
+
channel_multiplier=channel_multiplier,
|
395 |
+
act_layer=resolve_act_layer(kwargs, 'relu'),
|
396 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
397 |
+
**kwargs
|
398 |
+
)
|
399 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
400 |
+
return model
|
401 |
+
|
402 |
+
|
403 |
+
def _gen_mobilenet_v2(
|
404 |
+
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
|
405 |
+
""" Generate MobileNet-V2 network
|
406 |
+
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
407 |
+
Paper: https://arxiv.org/abs/1801.04381
|
408 |
+
"""
|
409 |
+
arch_def = [
|
410 |
+
['ds_r1_k3_s1_c16'],
|
411 |
+
['ir_r2_k3_s2_e6_c24'],
|
412 |
+
['ir_r3_k3_s2_e6_c32'],
|
413 |
+
['ir_r4_k3_s2_e6_c64'],
|
414 |
+
['ir_r3_k3_s1_e6_c96'],
|
415 |
+
['ir_r3_k3_s2_e6_c160'],
|
416 |
+
['ir_r1_k3_s1_e6_c320'],
|
417 |
+
]
|
418 |
+
with layer_config_kwargs(kwargs):
|
419 |
+
model_kwargs = dict(
|
420 |
+
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
|
421 |
+
num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None),
|
422 |
+
stem_size=32,
|
423 |
+
fix_stem=fix_stem_head,
|
424 |
+
channel_multiplier=channel_multiplier,
|
425 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
426 |
+
act_layer=nn.ReLU6,
|
427 |
+
**kwargs
|
428 |
+
)
|
429 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
430 |
+
return model
|
431 |
+
|
432 |
+
|
433 |
+
def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
434 |
+
""" FBNet-C
|
435 |
+
|
436 |
+
Paper: https://arxiv.org/abs/1812.03443
|
437 |
+
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
|
438 |
+
|
439 |
+
NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper,
|
440 |
+
it was used to confirm some building block details
|
441 |
+
"""
|
442 |
+
arch_def = [
|
443 |
+
['ir_r1_k3_s1_e1_c16'],
|
444 |
+
['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'],
|
445 |
+
['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'],
|
446 |
+
['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'],
|
447 |
+
['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'],
|
448 |
+
['ir_r4_k5_s2_e6_c184'],
|
449 |
+
['ir_r1_k3_s1_e6_c352'],
|
450 |
+
]
|
451 |
+
with layer_config_kwargs(kwargs):
|
452 |
+
model_kwargs = dict(
|
453 |
+
block_args=decode_arch_def(arch_def),
|
454 |
+
stem_size=16,
|
455 |
+
num_features=1984, # paper suggests this, but is not 100% clear
|
456 |
+
channel_multiplier=channel_multiplier,
|
457 |
+
act_layer=resolve_act_layer(kwargs, 'relu'),
|
458 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
459 |
+
**kwargs
|
460 |
+
)
|
461 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
462 |
+
return model
|
463 |
+
|
464 |
+
|
465 |
+
def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
466 |
+
"""Creates the Single-Path NAS model from search targeted for Pixel1 phone.
|
467 |
+
|
468 |
+
Paper: https://arxiv.org/abs/1904.02877
|
469 |
+
|
470 |
+
Args:
|
471 |
+
channel_multiplier: multiplier to number of channels per layer.
|
472 |
+
"""
|
473 |
+
arch_def = [
|
474 |
+
# stage 0, 112x112 in
|
475 |
+
['ds_r1_k3_s1_c16_noskip'],
|
476 |
+
# stage 1, 112x112 in
|
477 |
+
['ir_r3_k3_s2_e3_c24'],
|
478 |
+
# stage 2, 56x56 in
|
479 |
+
['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
|
480 |
+
# stage 3, 28x28 in
|
481 |
+
['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
|
482 |
+
# stage 4, 14x14in
|
483 |
+
['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
|
484 |
+
# stage 5, 14x14in
|
485 |
+
['ir_r4_k5_s2_e6_c192'],
|
486 |
+
# stage 6, 7x7 in
|
487 |
+
['ir_r1_k3_s1_e6_c320_noskip']
|
488 |
+
]
|
489 |
+
with layer_config_kwargs(kwargs):
|
490 |
+
model_kwargs = dict(
|
491 |
+
block_args=decode_arch_def(arch_def),
|
492 |
+
stem_size=32,
|
493 |
+
channel_multiplier=channel_multiplier,
|
494 |
+
act_layer=resolve_act_layer(kwargs, 'relu'),
|
495 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
496 |
+
**kwargs
|
497 |
+
)
|
498 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
499 |
+
return model
|
500 |
+
|
501 |
+
|
502 |
+
def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
503 |
+
"""Creates an EfficientNet model.
|
504 |
+
|
505 |
+
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
506 |
+
Paper: https://arxiv.org/abs/1905.11946
|
507 |
+
|
508 |
+
EfficientNet params
|
509 |
+
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
510 |
+
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
511 |
+
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
512 |
+
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
513 |
+
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
514 |
+
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
515 |
+
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
516 |
+
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
517 |
+
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
518 |
+
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
519 |
+
|
520 |
+
Args:
|
521 |
+
channel_multiplier: multiplier to number of channels per layer
|
522 |
+
depth_multiplier: multiplier to number of repeats per stage
|
523 |
+
|
524 |
+
"""
|
525 |
+
arch_def = [
|
526 |
+
['ds_r1_k3_s1_e1_c16_se0.25'],
|
527 |
+
['ir_r2_k3_s2_e6_c24_se0.25'],
|
528 |
+
['ir_r2_k5_s2_e6_c40_se0.25'],
|
529 |
+
['ir_r3_k3_s2_e6_c80_se0.25'],
|
530 |
+
['ir_r3_k5_s1_e6_c112_se0.25'],
|
531 |
+
['ir_r4_k5_s2_e6_c192_se0.25'],
|
532 |
+
['ir_r1_k3_s1_e6_c320_se0.25'],
|
533 |
+
]
|
534 |
+
with layer_config_kwargs(kwargs):
|
535 |
+
model_kwargs = dict(
|
536 |
+
block_args=decode_arch_def(arch_def, depth_multiplier),
|
537 |
+
num_features=round_channels(1280, channel_multiplier, 8, None),
|
538 |
+
stem_size=32,
|
539 |
+
channel_multiplier=channel_multiplier,
|
540 |
+
act_layer=resolve_act_layer(kwargs, 'swish'),
|
541 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
542 |
+
**kwargs,
|
543 |
+
)
|
544 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
545 |
+
return model
|
546 |
+
|
547 |
+
|
548 |
+
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
549 |
+
arch_def = [
|
550 |
+
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
|
551 |
+
# present in other models
|
552 |
+
['er_r1_k3_s1_e4_c24_fc24_noskip'],
|
553 |
+
['er_r2_k3_s2_e8_c32'],
|
554 |
+
['er_r4_k3_s2_e8_c48'],
|
555 |
+
['ir_r5_k5_s2_e8_c96'],
|
556 |
+
['ir_r4_k5_s1_e8_c144'],
|
557 |
+
['ir_r2_k5_s2_e8_c192'],
|
558 |
+
]
|
559 |
+
with layer_config_kwargs(kwargs):
|
560 |
+
model_kwargs = dict(
|
561 |
+
block_args=decode_arch_def(arch_def, depth_multiplier),
|
562 |
+
num_features=round_channels(1280, channel_multiplier, 8, None),
|
563 |
+
stem_size=32,
|
564 |
+
channel_multiplier=channel_multiplier,
|
565 |
+
act_layer=resolve_act_layer(kwargs, 'relu'),
|
566 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
567 |
+
**kwargs,
|
568 |
+
)
|
569 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
570 |
+
return model
|
571 |
+
|
572 |
+
|
573 |
+
def _gen_efficientnet_condconv(
|
574 |
+
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
|
575 |
+
"""Creates an efficientnet-condconv model."""
|
576 |
+
arch_def = [
|
577 |
+
['ds_r1_k3_s1_e1_c16_se0.25'],
|
578 |
+
['ir_r2_k3_s2_e6_c24_se0.25'],
|
579 |
+
['ir_r2_k5_s2_e6_c40_se0.25'],
|
580 |
+
['ir_r3_k3_s2_e6_c80_se0.25'],
|
581 |
+
['ir_r3_k5_s1_e6_c112_se0.25_cc4'],
|
582 |
+
['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
|
583 |
+
['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
|
584 |
+
]
|
585 |
+
with layer_config_kwargs(kwargs):
|
586 |
+
model_kwargs = dict(
|
587 |
+
block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
|
588 |
+
num_features=round_channels(1280, channel_multiplier, 8, None),
|
589 |
+
stem_size=32,
|
590 |
+
channel_multiplier=channel_multiplier,
|
591 |
+
act_layer=resolve_act_layer(kwargs, 'swish'),
|
592 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
593 |
+
**kwargs,
|
594 |
+
)
|
595 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
596 |
+
return model
|
597 |
+
|
598 |
+
|
599 |
+
def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
600 |
+
"""Creates an EfficientNet-Lite model.
|
601 |
+
|
602 |
+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
|
603 |
+
Paper: https://arxiv.org/abs/1905.11946
|
604 |
+
|
605 |
+
EfficientNet params
|
606 |
+
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
607 |
+
'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
|
608 |
+
'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
|
609 |
+
'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
|
610 |
+
'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
|
611 |
+
'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
|
612 |
+
|
613 |
+
Args:
|
614 |
+
channel_multiplier: multiplier to number of channels per layer
|
615 |
+
depth_multiplier: multiplier to number of repeats per stage
|
616 |
+
"""
|
617 |
+
arch_def = [
|
618 |
+
['ds_r1_k3_s1_e1_c16'],
|
619 |
+
['ir_r2_k3_s2_e6_c24'],
|
620 |
+
['ir_r2_k5_s2_e6_c40'],
|
621 |
+
['ir_r3_k3_s2_e6_c80'],
|
622 |
+
['ir_r3_k5_s1_e6_c112'],
|
623 |
+
['ir_r4_k5_s2_e6_c192'],
|
624 |
+
['ir_r1_k3_s1_e6_c320'],
|
625 |
+
]
|
626 |
+
with layer_config_kwargs(kwargs):
|
627 |
+
model_kwargs = dict(
|
628 |
+
block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
|
629 |
+
num_features=1280,
|
630 |
+
stem_size=32,
|
631 |
+
fix_stem=True,
|
632 |
+
channel_multiplier=channel_multiplier,
|
633 |
+
act_layer=nn.ReLU6,
|
634 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
635 |
+
**kwargs,
|
636 |
+
)
|
637 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
638 |
+
return model
|
639 |
+
|
640 |
+
|
641 |
+
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
642 |
+
"""Creates a MixNet Small model.
|
643 |
+
|
644 |
+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
|
645 |
+
Paper: https://arxiv.org/abs/1907.09595
|
646 |
+
"""
|
647 |
+
arch_def = [
|
648 |
+
# stage 0, 112x112 in
|
649 |
+
['ds_r1_k3_s1_e1_c16'], # relu
|
650 |
+
# stage 1, 112x112 in
|
651 |
+
['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu
|
652 |
+
# stage 2, 56x56 in
|
653 |
+
['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
|
654 |
+
# stage 3, 28x28 in
|
655 |
+
['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish
|
656 |
+
# stage 4, 14x14in
|
657 |
+
['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
|
658 |
+
# stage 5, 14x14in
|
659 |
+
['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
|
660 |
+
# 7x7
|
661 |
+
]
|
662 |
+
with layer_config_kwargs(kwargs):
|
663 |
+
model_kwargs = dict(
|
664 |
+
block_args=decode_arch_def(arch_def),
|
665 |
+
num_features=1536,
|
666 |
+
stem_size=16,
|
667 |
+
channel_multiplier=channel_multiplier,
|
668 |
+
act_layer=resolve_act_layer(kwargs, 'relu'),
|
669 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
670 |
+
**kwargs
|
671 |
+
)
|
672 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
673 |
+
return model
|
674 |
+
|
675 |
+
|
676 |
+
def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
677 |
+
"""Creates a MixNet Medium-Large model.
|
678 |
+
|
679 |
+
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
|
680 |
+
Paper: https://arxiv.org/abs/1907.09595
|
681 |
+
"""
|
682 |
+
arch_def = [
|
683 |
+
# stage 0, 112x112 in
|
684 |
+
['ds_r1_k3_s1_e1_c24'], # relu
|
685 |
+
# stage 1, 112x112 in
|
686 |
+
['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu
|
687 |
+
# stage 2, 56x56 in
|
688 |
+
['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
|
689 |
+
# stage 3, 28x28 in
|
690 |
+
['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish
|
691 |
+
# stage 4, 14x14in
|
692 |
+
['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
|
693 |
+
# stage 5, 14x14in
|
694 |
+
['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
|
695 |
+
# 7x7
|
696 |
+
]
|
697 |
+
with layer_config_kwargs(kwargs):
|
698 |
+
model_kwargs = dict(
|
699 |
+
block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
|
700 |
+
num_features=1536,
|
701 |
+
stem_size=24,
|
702 |
+
channel_multiplier=channel_multiplier,
|
703 |
+
act_layer=resolve_act_layer(kwargs, 'relu'),
|
704 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
705 |
+
**kwargs
|
706 |
+
)
|
707 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
708 |
+
return model
|
709 |
+
|
710 |
+
|
711 |
+
def mnasnet_050(pretrained=False, **kwargs):
|
712 |
+
""" MNASNet B1, depth multiplier of 0.5. """
|
713 |
+
model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs)
|
714 |
+
return model
|
715 |
+
|
716 |
+
|
717 |
+
def mnasnet_075(pretrained=False, **kwargs):
|
718 |
+
""" MNASNet B1, depth multiplier of 0.75. """
|
719 |
+
model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs)
|
720 |
+
return model
|
721 |
+
|
722 |
+
|
723 |
+
def mnasnet_100(pretrained=False, **kwargs):
|
724 |
+
""" MNASNet B1, depth multiplier of 1.0. """
|
725 |
+
model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs)
|
726 |
+
return model
|
727 |
+
|
728 |
+
|
729 |
+
def mnasnet_b1(pretrained=False, **kwargs):
|
730 |
+
""" MNASNet B1, depth multiplier of 1.0. """
|
731 |
+
return mnasnet_100(pretrained, **kwargs)
|
732 |
+
|
733 |
+
|
734 |
+
def mnasnet_140(pretrained=False, **kwargs):
|
735 |
+
""" MNASNet B1, depth multiplier of 1.4 """
|
736 |
+
model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs)
|
737 |
+
return model
|
738 |
+
|
739 |
+
|
740 |
+
def semnasnet_050(pretrained=False, **kwargs):
|
741 |
+
""" MNASNet A1 (w/ SE), depth multiplier of 0.5 """
|
742 |
+
model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs)
|
743 |
+
return model
|
744 |
+
|
745 |
+
|
746 |
+
def semnasnet_075(pretrained=False, **kwargs):
|
747 |
+
""" MNASNet A1 (w/ SE), depth multiplier of 0.75. """
|
748 |
+
model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs)
|
749 |
+
return model
|
750 |
+
|
751 |
+
|
752 |
+
def semnasnet_100(pretrained=False, **kwargs):
|
753 |
+
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
|
754 |
+
model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs)
|
755 |
+
return model
|
756 |
+
|
757 |
+
|
758 |
+
def mnasnet_a1(pretrained=False, **kwargs):
|
759 |
+
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
|
760 |
+
return semnasnet_100(pretrained, **kwargs)
|
761 |
+
|
762 |
+
|
763 |
+
def semnasnet_140(pretrained=False, **kwargs):
|
764 |
+
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
|
765 |
+
model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs)
|
766 |
+
return model
|
767 |
+
|
768 |
+
|
769 |
+
def mnasnet_small(pretrained=False, **kwargs):
|
770 |
+
""" MNASNet Small, depth multiplier of 1.0. """
|
771 |
+
model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs)
|
772 |
+
return model
|
773 |
+
|
774 |
+
|
775 |
+
def mobilenetv2_100(pretrained=False, **kwargs):
|
776 |
+
""" MobileNet V2 w/ 1.0 channel multiplier """
|
777 |
+
model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs)
|
778 |
+
return model
|
779 |
+
|
780 |
+
|
781 |
+
def mobilenetv2_140(pretrained=False, **kwargs):
|
782 |
+
""" MobileNet V2 w/ 1.4 channel multiplier """
|
783 |
+
model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs)
|
784 |
+
return model
|
785 |
+
|
786 |
+
|
787 |
+
def mobilenetv2_110d(pretrained=False, **kwargs):
|
788 |
+
""" MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers"""
|
789 |
+
model = _gen_mobilenet_v2(
|
790 |
+
'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs)
|
791 |
+
return model
|
792 |
+
|
793 |
+
|
794 |
+
def mobilenetv2_120d(pretrained=False, **kwargs):
|
795 |
+
""" MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """
|
796 |
+
model = _gen_mobilenet_v2(
|
797 |
+
'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs)
|
798 |
+
return model
|
799 |
+
|
800 |
+
|
801 |
+
def fbnetc_100(pretrained=False, **kwargs):
|
802 |
+
""" FBNet-C """
|
803 |
+
if pretrained:
|
804 |
+
# pretrained model trained with non-default BN epsilon
|
805 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
806 |
+
model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs)
|
807 |
+
return model
|
808 |
+
|
809 |
+
|
810 |
+
def spnasnet_100(pretrained=False, **kwargs):
|
811 |
+
""" Single-Path NAS Pixel1"""
|
812 |
+
model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs)
|
813 |
+
return model
|
814 |
+
|
815 |
+
|
816 |
+
def efficientnet_b0(pretrained=False, **kwargs):
|
817 |
+
""" EfficientNet-B0 """
|
818 |
+
# NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
|
819 |
+
model = _gen_efficientnet(
|
820 |
+
'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
821 |
+
return model
|
822 |
+
|
823 |
+
|
824 |
+
def efficientnet_b1(pretrained=False, **kwargs):
|
825 |
+
""" EfficientNet-B1 """
|
826 |
+
# NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
|
827 |
+
model = _gen_efficientnet(
|
828 |
+
'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
829 |
+
return model
|
830 |
+
|
831 |
+
|
832 |
+
def efficientnet_b2(pretrained=False, **kwargs):
|
833 |
+
""" EfficientNet-B2 """
|
834 |
+
# NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
|
835 |
+
model = _gen_efficientnet(
|
836 |
+
'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
837 |
+
return model
|
838 |
+
|
839 |
+
|
840 |
+
def efficientnet_b3(pretrained=False, **kwargs):
|
841 |
+
""" EfficientNet-B3 """
|
842 |
+
# NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
|
843 |
+
model = _gen_efficientnet(
|
844 |
+
'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
845 |
+
return model
|
846 |
+
|
847 |
+
|
848 |
+
def efficientnet_b4(pretrained=False, **kwargs):
|
849 |
+
""" EfficientNet-B4 """
|
850 |
+
# NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
|
851 |
+
model = _gen_efficientnet(
|
852 |
+
'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
|
853 |
+
return model
|
854 |
+
|
855 |
+
|
856 |
+
def efficientnet_b5(pretrained=False, **kwargs):
|
857 |
+
""" EfficientNet-B5 """
|
858 |
+
# NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
|
859 |
+
model = _gen_efficientnet(
|
860 |
+
'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
|
861 |
+
return model
|
862 |
+
|
863 |
+
|
864 |
+
def efficientnet_b6(pretrained=False, **kwargs):
|
865 |
+
""" EfficientNet-B6 """
|
866 |
+
# NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
|
867 |
+
model = _gen_efficientnet(
|
868 |
+
'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
|
869 |
+
return model
|
870 |
+
|
871 |
+
|
872 |
+
def efficientnet_b7(pretrained=False, **kwargs):
|
873 |
+
""" EfficientNet-B7 """
|
874 |
+
# NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
|
875 |
+
model = _gen_efficientnet(
|
876 |
+
'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
|
877 |
+
return model
|
878 |
+
|
879 |
+
|
880 |
+
def efficientnet_b8(pretrained=False, **kwargs):
|
881 |
+
""" EfficientNet-B8 """
|
882 |
+
# NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
|
883 |
+
model = _gen_efficientnet(
|
884 |
+
'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
|
885 |
+
return model
|
886 |
+
|
887 |
+
|
888 |
+
def efficientnet_l2(pretrained=False, **kwargs):
|
889 |
+
""" EfficientNet-L2. """
|
890 |
+
# NOTE for train, drop_rate should be 0.5
|
891 |
+
model = _gen_efficientnet(
|
892 |
+
'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
|
893 |
+
return model
|
894 |
+
|
895 |
+
|
896 |
+
def efficientnet_es(pretrained=False, **kwargs):
|
897 |
+
""" EfficientNet-Edge Small. """
|
898 |
+
model = _gen_efficientnet_edge(
|
899 |
+
'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
900 |
+
return model
|
901 |
+
|
902 |
+
|
903 |
+
def efficientnet_em(pretrained=False, **kwargs):
|
904 |
+
""" EfficientNet-Edge-Medium. """
|
905 |
+
model = _gen_efficientnet_edge(
|
906 |
+
'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
907 |
+
return model
|
908 |
+
|
909 |
+
|
910 |
+
def efficientnet_el(pretrained=False, **kwargs):
|
911 |
+
""" EfficientNet-Edge-Large. """
|
912 |
+
model = _gen_efficientnet_edge(
|
913 |
+
'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
914 |
+
return model
|
915 |
+
|
916 |
+
|
917 |
+
def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
|
918 |
+
""" EfficientNet-CondConv-B0 w/ 8 Experts """
|
919 |
+
# NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
|
920 |
+
model = _gen_efficientnet_condconv(
|
921 |
+
'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
922 |
+
return model
|
923 |
+
|
924 |
+
|
925 |
+
def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
926 |
+
""" EfficientNet-CondConv-B0 w/ 8 Experts """
|
927 |
+
# NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
|
928 |
+
model = _gen_efficientnet_condconv(
|
929 |
+
'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
|
930 |
+
pretrained=pretrained, **kwargs)
|
931 |
+
return model
|
932 |
+
|
933 |
+
|
934 |
+
def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
935 |
+
""" EfficientNet-CondConv-B1 w/ 8 Experts """
|
936 |
+
# NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
|
937 |
+
model = _gen_efficientnet_condconv(
|
938 |
+
'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
|
939 |
+
pretrained=pretrained, **kwargs)
|
940 |
+
return model
|
941 |
+
|
942 |
+
|
943 |
+
def efficientnet_lite0(pretrained=False, **kwargs):
|
944 |
+
""" EfficientNet-Lite0 """
|
945 |
+
model = _gen_efficientnet_lite(
|
946 |
+
'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
947 |
+
return model
|
948 |
+
|
949 |
+
|
950 |
+
def efficientnet_lite1(pretrained=False, **kwargs):
|
951 |
+
""" EfficientNet-Lite1 """
|
952 |
+
model = _gen_efficientnet_lite(
|
953 |
+
'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
954 |
+
return model
|
955 |
+
|
956 |
+
|
957 |
+
def efficientnet_lite2(pretrained=False, **kwargs):
|
958 |
+
""" EfficientNet-Lite2 """
|
959 |
+
model = _gen_efficientnet_lite(
|
960 |
+
'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
961 |
+
return model
|
962 |
+
|
963 |
+
|
964 |
+
def efficientnet_lite3(pretrained=False, **kwargs):
|
965 |
+
""" EfficientNet-Lite3 """
|
966 |
+
model = _gen_efficientnet_lite(
|
967 |
+
'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
968 |
+
return model
|
969 |
+
|
970 |
+
|
971 |
+
def efficientnet_lite4(pretrained=False, **kwargs):
|
972 |
+
""" EfficientNet-Lite4 """
|
973 |
+
model = _gen_efficientnet_lite(
|
974 |
+
'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
|
975 |
+
return model
|
976 |
+
|
977 |
+
|
978 |
+
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
979 |
+
""" EfficientNet-B0 AutoAug. Tensorflow compatible variant """
|
980 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
981 |
+
kwargs['pad_type'] = 'same'
|
982 |
+
model = _gen_efficientnet(
|
983 |
+
'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
984 |
+
return model
|
985 |
+
|
986 |
+
|
987 |
+
def tf_efficientnet_b1(pretrained=False, **kwargs):
|
988 |
+
""" EfficientNet-B1 AutoAug. Tensorflow compatible variant """
|
989 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
990 |
+
kwargs['pad_type'] = 'same'
|
991 |
+
model = _gen_efficientnet(
|
992 |
+
'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
993 |
+
return model
|
994 |
+
|
995 |
+
|
996 |
+
def tf_efficientnet_b2(pretrained=False, **kwargs):
|
997 |
+
""" EfficientNet-B2 AutoAug. Tensorflow compatible variant """
|
998 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
999 |
+
kwargs['pad_type'] = 'same'
|
1000 |
+
model = _gen_efficientnet(
|
1001 |
+
'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
1002 |
+
return model
|
1003 |
+
|
1004 |
+
|
1005 |
+
def tf_efficientnet_b3(pretrained=False, **kwargs):
|
1006 |
+
""" EfficientNet-B3 AutoAug. Tensorflow compatible variant """
|
1007 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1008 |
+
kwargs['pad_type'] = 'same'
|
1009 |
+
model = _gen_efficientnet(
|
1010 |
+
'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
1011 |
+
return model
|
1012 |
+
|
1013 |
+
|
1014 |
+
def tf_efficientnet_b4(pretrained=False, **kwargs):
|
1015 |
+
""" EfficientNet-B4 AutoAug. Tensorflow compatible variant """
|
1016 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1017 |
+
kwargs['pad_type'] = 'same'
|
1018 |
+
model = _gen_efficientnet(
|
1019 |
+
'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
|
1020 |
+
return model
|
1021 |
+
|
1022 |
+
|
1023 |
+
def tf_efficientnet_b5(pretrained=False, **kwargs):
|
1024 |
+
""" EfficientNet-B5 RandAug. Tensorflow compatible variant """
|
1025 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1026 |
+
kwargs['pad_type'] = 'same'
|
1027 |
+
model = _gen_efficientnet(
|
1028 |
+
'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
|
1029 |
+
return model
|
1030 |
+
|
1031 |
+
|
1032 |
+
def tf_efficientnet_b6(pretrained=False, **kwargs):
|
1033 |
+
""" EfficientNet-B6 AutoAug. Tensorflow compatible variant """
|
1034 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1035 |
+
kwargs['pad_type'] = 'same'
|
1036 |
+
model = _gen_efficientnet(
|
1037 |
+
'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
|
1038 |
+
return model
|
1039 |
+
|
1040 |
+
|
1041 |
+
def tf_efficientnet_b7(pretrained=False, **kwargs):
|
1042 |
+
""" EfficientNet-B7 RandAug. Tensorflow compatible variant """
|
1043 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1044 |
+
kwargs['pad_type'] = 'same'
|
1045 |
+
model = _gen_efficientnet(
|
1046 |
+
'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
|
1047 |
+
return model
|
1048 |
+
|
1049 |
+
|
1050 |
+
def tf_efficientnet_b8(pretrained=False, **kwargs):
|
1051 |
+
""" EfficientNet-B8 RandAug. Tensorflow compatible variant """
|
1052 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1053 |
+
kwargs['pad_type'] = 'same'
|
1054 |
+
model = _gen_efficientnet(
|
1055 |
+
'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
|
1056 |
+
return model
|
1057 |
+
|
1058 |
+
|
1059 |
+
def tf_efficientnet_b0_ap(pretrained=False, **kwargs):
|
1060 |
+
""" EfficientNet-B0 AdvProp. Tensorflow compatible variant
|
1061 |
+
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
|
1062 |
+
"""
|
1063 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1064 |
+
kwargs['pad_type'] = 'same'
|
1065 |
+
model = _gen_efficientnet(
|
1066 |
+
'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
1067 |
+
return model
|
1068 |
+
|
1069 |
+
|
1070 |
+
def tf_efficientnet_b1_ap(pretrained=False, **kwargs):
|
1071 |
+
""" EfficientNet-B1 AdvProp. Tensorflow compatible variant
|
1072 |
+
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
|
1073 |
+
"""
|
1074 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1075 |
+
kwargs['pad_type'] = 'same'
|
1076 |
+
model = _gen_efficientnet(
|
1077 |
+
'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
1078 |
+
return model
|
1079 |
+
|
1080 |
+
|
1081 |
+
def tf_efficientnet_b2_ap(pretrained=False, **kwargs):
|
1082 |
+
""" EfficientNet-B2 AdvProp. Tensorflow compatible variant
|
1083 |
+
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
|
1084 |
+
"""
|
1085 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1086 |
+
kwargs['pad_type'] = 'same'
|
1087 |
+
model = _gen_efficientnet(
|
1088 |
+
'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
1089 |
+
return model
|
1090 |
+
|
1091 |
+
|
1092 |
+
def tf_efficientnet_b3_ap(pretrained=False, **kwargs):
|
1093 |
+
""" EfficientNet-B3 AdvProp. Tensorflow compatible variant
|
1094 |
+
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
|
1095 |
+
"""
|
1096 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1097 |
+
kwargs['pad_type'] = 'same'
|
1098 |
+
model = _gen_efficientnet(
|
1099 |
+
'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
1100 |
+
return model
|
1101 |
+
|
1102 |
+
|
1103 |
+
def tf_efficientnet_b4_ap(pretrained=False, **kwargs):
|
1104 |
+
""" EfficientNet-B4 AdvProp. Tensorflow compatible variant
|
1105 |
+
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
|
1106 |
+
"""
|
1107 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1108 |
+
kwargs['pad_type'] = 'same'
|
1109 |
+
model = _gen_efficientnet(
|
1110 |
+
'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
|
1111 |
+
return model
|
1112 |
+
|
1113 |
+
|
1114 |
+
def tf_efficientnet_b5_ap(pretrained=False, **kwargs):
|
1115 |
+
""" EfficientNet-B5 AdvProp. Tensorflow compatible variant
|
1116 |
+
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
|
1117 |
+
"""
|
1118 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1119 |
+
kwargs['pad_type'] = 'same'
|
1120 |
+
model = _gen_efficientnet(
|
1121 |
+
'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
|
1122 |
+
return model
|
1123 |
+
|
1124 |
+
|
1125 |
+
def tf_efficientnet_b6_ap(pretrained=False, **kwargs):
|
1126 |
+
""" EfficientNet-B6 AdvProp. Tensorflow compatible variant
|
1127 |
+
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
|
1128 |
+
"""
|
1129 |
+
# NOTE for train, drop_rate should be 0.5
|
1130 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1131 |
+
kwargs['pad_type'] = 'same'
|
1132 |
+
model = _gen_efficientnet(
|
1133 |
+
'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
|
1134 |
+
return model
|
1135 |
+
|
1136 |
+
|
1137 |
+
def tf_efficientnet_b7_ap(pretrained=False, **kwargs):
|
1138 |
+
""" EfficientNet-B7 AdvProp. Tensorflow compatible variant
|
1139 |
+
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
|
1140 |
+
"""
|
1141 |
+
# NOTE for train, drop_rate should be 0.5
|
1142 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1143 |
+
kwargs['pad_type'] = 'same'
|
1144 |
+
model = _gen_efficientnet(
|
1145 |
+
'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
|
1146 |
+
return model
|
1147 |
+
|
1148 |
+
|
1149 |
+
def tf_efficientnet_b8_ap(pretrained=False, **kwargs):
|
1150 |
+
""" EfficientNet-B8 AdvProp. Tensorflow compatible variant
|
1151 |
+
Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
|
1152 |
+
"""
|
1153 |
+
# NOTE for train, drop_rate should be 0.5
|
1154 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1155 |
+
kwargs['pad_type'] = 'same'
|
1156 |
+
model = _gen_efficientnet(
|
1157 |
+
'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
|
1158 |
+
return model
|
1159 |
+
|
1160 |
+
|
1161 |
+
def tf_efficientnet_b0_ns(pretrained=False, **kwargs):
|
1162 |
+
""" EfficientNet-B0 NoisyStudent. Tensorflow compatible variant
|
1163 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1164 |
+
"""
|
1165 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1166 |
+
kwargs['pad_type'] = 'same'
|
1167 |
+
model = _gen_efficientnet(
|
1168 |
+
'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
1169 |
+
return model
|
1170 |
+
|
1171 |
+
|
1172 |
+
def tf_efficientnet_b1_ns(pretrained=False, **kwargs):
|
1173 |
+
""" EfficientNet-B1 NoisyStudent. Tensorflow compatible variant
|
1174 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1175 |
+
"""
|
1176 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1177 |
+
kwargs['pad_type'] = 'same'
|
1178 |
+
model = _gen_efficientnet(
|
1179 |
+
'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
1180 |
+
return model
|
1181 |
+
|
1182 |
+
|
1183 |
+
def tf_efficientnet_b2_ns(pretrained=False, **kwargs):
|
1184 |
+
""" EfficientNet-B2 NoisyStudent. Tensorflow compatible variant
|
1185 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1186 |
+
"""
|
1187 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1188 |
+
kwargs['pad_type'] = 'same'
|
1189 |
+
model = _gen_efficientnet(
|
1190 |
+
'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
1191 |
+
return model
|
1192 |
+
|
1193 |
+
|
1194 |
+
def tf_efficientnet_b3_ns(pretrained=False, **kwargs):
|
1195 |
+
""" EfficientNet-B3 NoisyStudent. Tensorflow compatible variant
|
1196 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1197 |
+
"""
|
1198 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1199 |
+
kwargs['pad_type'] = 'same'
|
1200 |
+
model = _gen_efficientnet(
|
1201 |
+
'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
1202 |
+
return model
|
1203 |
+
|
1204 |
+
|
1205 |
+
def tf_efficientnet_b4_ns(pretrained=False, **kwargs):
|
1206 |
+
""" EfficientNet-B4 NoisyStudent. Tensorflow compatible variant
|
1207 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1208 |
+
"""
|
1209 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1210 |
+
kwargs['pad_type'] = 'same'
|
1211 |
+
model = _gen_efficientnet(
|
1212 |
+
'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
|
1213 |
+
return model
|
1214 |
+
|
1215 |
+
|
1216 |
+
def tf_efficientnet_b5_ns(pretrained=False, **kwargs):
|
1217 |
+
""" EfficientNet-B5 NoisyStudent. Tensorflow compatible variant
|
1218 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1219 |
+
"""
|
1220 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1221 |
+
kwargs['pad_type'] = 'same'
|
1222 |
+
model = _gen_efficientnet(
|
1223 |
+
'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
|
1224 |
+
return model
|
1225 |
+
|
1226 |
+
|
1227 |
+
def tf_efficientnet_b6_ns(pretrained=False, **kwargs):
|
1228 |
+
""" EfficientNet-B6 NoisyStudent. Tensorflow compatible variant
|
1229 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1230 |
+
"""
|
1231 |
+
# NOTE for train, drop_rate should be 0.5
|
1232 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1233 |
+
kwargs['pad_type'] = 'same'
|
1234 |
+
model = _gen_efficientnet(
|
1235 |
+
'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
|
1236 |
+
return model
|
1237 |
+
|
1238 |
+
|
1239 |
+
def tf_efficientnet_b7_ns(pretrained=False, **kwargs):
|
1240 |
+
""" EfficientNet-B7 NoisyStudent. Tensorflow compatible variant
|
1241 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1242 |
+
"""
|
1243 |
+
# NOTE for train, drop_rate should be 0.5
|
1244 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1245 |
+
kwargs['pad_type'] = 'same'
|
1246 |
+
model = _gen_efficientnet(
|
1247 |
+
'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
|
1248 |
+
return model
|
1249 |
+
|
1250 |
+
|
1251 |
+
def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs):
|
1252 |
+
""" EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant
|
1253 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1254 |
+
"""
|
1255 |
+
# NOTE for train, drop_rate should be 0.5
|
1256 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1257 |
+
kwargs['pad_type'] = 'same'
|
1258 |
+
model = _gen_efficientnet(
|
1259 |
+
'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
|
1260 |
+
return model
|
1261 |
+
|
1262 |
+
|
1263 |
+
def tf_efficientnet_l2_ns(pretrained=False, **kwargs):
|
1264 |
+
""" EfficientNet-L2 NoisyStudent. Tensorflow compatible variant
|
1265 |
+
Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
|
1266 |
+
"""
|
1267 |
+
# NOTE for train, drop_rate should be 0.5
|
1268 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1269 |
+
kwargs['pad_type'] = 'same'
|
1270 |
+
model = _gen_efficientnet(
|
1271 |
+
'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
|
1272 |
+
return model
|
1273 |
+
|
1274 |
+
|
1275 |
+
def tf_efficientnet_es(pretrained=False, **kwargs):
|
1276 |
+
""" EfficientNet-Edge Small. Tensorflow compatible variant """
|
1277 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1278 |
+
kwargs['pad_type'] = 'same'
|
1279 |
+
model = _gen_efficientnet_edge(
|
1280 |
+
'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
1281 |
+
return model
|
1282 |
+
|
1283 |
+
|
1284 |
+
def tf_efficientnet_em(pretrained=False, **kwargs):
|
1285 |
+
""" EfficientNet-Edge-Medium. Tensorflow compatible variant """
|
1286 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1287 |
+
kwargs['pad_type'] = 'same'
|
1288 |
+
model = _gen_efficientnet_edge(
|
1289 |
+
'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
1290 |
+
return model
|
1291 |
+
|
1292 |
+
|
1293 |
+
def tf_efficientnet_el(pretrained=False, **kwargs):
|
1294 |
+
""" EfficientNet-Edge-Large. Tensorflow compatible variant """
|
1295 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1296 |
+
kwargs['pad_type'] = 'same'
|
1297 |
+
model = _gen_efficientnet_edge(
|
1298 |
+
'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
1299 |
+
return model
|
1300 |
+
|
1301 |
+
|
1302 |
+
def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
|
1303 |
+
""" EfficientNet-CondConv-B0 w/ 4 Experts """
|
1304 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1305 |
+
kwargs['pad_type'] = 'same'
|
1306 |
+
model = _gen_efficientnet_condconv(
|
1307 |
+
'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
1308 |
+
return model
|
1309 |
+
|
1310 |
+
|
1311 |
+
def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
1312 |
+
""" EfficientNet-CondConv-B0 w/ 8 Experts """
|
1313 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1314 |
+
kwargs['pad_type'] = 'same'
|
1315 |
+
model = _gen_efficientnet_condconv(
|
1316 |
+
'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
|
1317 |
+
pretrained=pretrained, **kwargs)
|
1318 |
+
return model
|
1319 |
+
|
1320 |
+
|
1321 |
+
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
1322 |
+
""" EfficientNet-CondConv-B1 w/ 8 Experts """
|
1323 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1324 |
+
kwargs['pad_type'] = 'same'
|
1325 |
+
model = _gen_efficientnet_condconv(
|
1326 |
+
'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
|
1327 |
+
pretrained=pretrained, **kwargs)
|
1328 |
+
return model
|
1329 |
+
|
1330 |
+
|
1331 |
+
def tf_efficientnet_lite0(pretrained=False, **kwargs):
|
1332 |
+
""" EfficientNet-Lite0. Tensorflow compatible variant """
|
1333 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1334 |
+
kwargs['pad_type'] = 'same'
|
1335 |
+
model = _gen_efficientnet_lite(
|
1336 |
+
'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
1337 |
+
return model
|
1338 |
+
|
1339 |
+
|
1340 |
+
def tf_efficientnet_lite1(pretrained=False, **kwargs):
|
1341 |
+
""" EfficientNet-Lite1. Tensorflow compatible variant """
|
1342 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1343 |
+
kwargs['pad_type'] = 'same'
|
1344 |
+
model = _gen_efficientnet_lite(
|
1345 |
+
'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
1346 |
+
return model
|
1347 |
+
|
1348 |
+
|
1349 |
+
def tf_efficientnet_lite2(pretrained=False, **kwargs):
|
1350 |
+
""" EfficientNet-Lite2. Tensorflow compatible variant """
|
1351 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1352 |
+
kwargs['pad_type'] = 'same'
|
1353 |
+
model = _gen_efficientnet_lite(
|
1354 |
+
'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
1355 |
+
return model
|
1356 |
+
|
1357 |
+
|
1358 |
+
def tf_efficientnet_lite3(pretrained=False, **kwargs):
|
1359 |
+
""" EfficientNet-Lite3. Tensorflow compatible variant """
|
1360 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1361 |
+
kwargs['pad_type'] = 'same'
|
1362 |
+
model = _gen_efficientnet_lite(
|
1363 |
+
'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
1364 |
+
return model
|
1365 |
+
|
1366 |
+
|
1367 |
+
def tf_efficientnet_lite4(pretrained=False, **kwargs):
|
1368 |
+
""" EfficientNet-Lite4. Tensorflow compatible variant """
|
1369 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1370 |
+
kwargs['pad_type'] = 'same'
|
1371 |
+
model = _gen_efficientnet_lite(
|
1372 |
+
'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
|
1373 |
+
return model
|
1374 |
+
|
1375 |
+
|
1376 |
+
def mixnet_s(pretrained=False, **kwargs):
|
1377 |
+
"""Creates a MixNet Small model.
|
1378 |
+
"""
|
1379 |
+
# NOTE for train set drop_rate=0.2
|
1380 |
+
model = _gen_mixnet_s(
|
1381 |
+
'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
|
1382 |
+
return model
|
1383 |
+
|
1384 |
+
|
1385 |
+
def mixnet_m(pretrained=False, **kwargs):
|
1386 |
+
"""Creates a MixNet Medium model.
|
1387 |
+
"""
|
1388 |
+
# NOTE for train set drop_rate=0.25
|
1389 |
+
model = _gen_mixnet_m(
|
1390 |
+
'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
|
1391 |
+
return model
|
1392 |
+
|
1393 |
+
|
1394 |
+
def mixnet_l(pretrained=False, **kwargs):
|
1395 |
+
"""Creates a MixNet Large model.
|
1396 |
+
"""
|
1397 |
+
# NOTE for train set drop_rate=0.25
|
1398 |
+
model = _gen_mixnet_m(
|
1399 |
+
'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
|
1400 |
+
return model
|
1401 |
+
|
1402 |
+
|
1403 |
+
def mixnet_xl(pretrained=False, **kwargs):
|
1404 |
+
"""Creates a MixNet Extra-Large model.
|
1405 |
+
Not a paper spec, experimental def by RW w/ depth scaling.
|
1406 |
+
"""
|
1407 |
+
# NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
|
1408 |
+
model = _gen_mixnet_m(
|
1409 |
+
'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
1410 |
+
return model
|
1411 |
+
|
1412 |
+
|
1413 |
+
def mixnet_xxl(pretrained=False, **kwargs):
|
1414 |
+
"""Creates a MixNet Double Extra Large model.
|
1415 |
+
Not a paper spec, experimental def by RW w/ depth scaling.
|
1416 |
+
"""
|
1417 |
+
# NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
|
1418 |
+
model = _gen_mixnet_m(
|
1419 |
+
'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs)
|
1420 |
+
return model
|
1421 |
+
|
1422 |
+
|
1423 |
+
def tf_mixnet_s(pretrained=False, **kwargs):
|
1424 |
+
"""Creates a MixNet Small model. Tensorflow compatible variant
|
1425 |
+
"""
|
1426 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1427 |
+
kwargs['pad_type'] = 'same'
|
1428 |
+
model = _gen_mixnet_s(
|
1429 |
+
'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
|
1430 |
+
return model
|
1431 |
+
|
1432 |
+
|
1433 |
+
def tf_mixnet_m(pretrained=False, **kwargs):
|
1434 |
+
"""Creates a MixNet Medium model. Tensorflow compatible variant
|
1435 |
+
"""
|
1436 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1437 |
+
kwargs['pad_type'] = 'same'
|
1438 |
+
model = _gen_mixnet_m(
|
1439 |
+
'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
|
1440 |
+
return model
|
1441 |
+
|
1442 |
+
|
1443 |
+
def tf_mixnet_l(pretrained=False, **kwargs):
|
1444 |
+
"""Creates a MixNet Large model. Tensorflow compatible variant
|
1445 |
+
"""
|
1446 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
1447 |
+
kwargs['pad_type'] = 'same'
|
1448 |
+
model = _gen_mixnet_m(
|
1449 |
+
'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
|
1450 |
+
return model
|
geffnet/helpers.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Checkpoint loading / state_dict helpers
|
2 |
+
Copyright 2020 Ross Wightman
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from collections import OrderedDict
|
7 |
+
try:
|
8 |
+
from torch.hub import load_state_dict_from_url
|
9 |
+
except ImportError:
|
10 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
11 |
+
|
12 |
+
|
13 |
+
def load_checkpoint(model, checkpoint_path):
|
14 |
+
if checkpoint_path and os.path.isfile(checkpoint_path):
|
15 |
+
print("=> Loading checkpoint '{}'".format(checkpoint_path))
|
16 |
+
checkpoint = torch.load(checkpoint_path)
|
17 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
18 |
+
new_state_dict = OrderedDict()
|
19 |
+
for k, v in checkpoint['state_dict'].items():
|
20 |
+
if k.startswith('module'):
|
21 |
+
name = k[7:] # remove `module.`
|
22 |
+
else:
|
23 |
+
name = k
|
24 |
+
new_state_dict[name] = v
|
25 |
+
model.load_state_dict(new_state_dict)
|
26 |
+
else:
|
27 |
+
model.load_state_dict(checkpoint)
|
28 |
+
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
|
29 |
+
else:
|
30 |
+
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
|
31 |
+
raise FileNotFoundError()
|
32 |
+
|
33 |
+
|
34 |
+
def load_pretrained(model, url, filter_fn=None, strict=True):
|
35 |
+
if not url:
|
36 |
+
print("=> Warning: Pretrained model URL is empty, using random initialization.")
|
37 |
+
return
|
38 |
+
|
39 |
+
state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
|
40 |
+
|
41 |
+
input_conv = 'conv_stem'
|
42 |
+
classifier = 'classifier'
|
43 |
+
in_chans = getattr(model, input_conv).weight.shape[1]
|
44 |
+
num_classes = getattr(model, classifier).weight.shape[0]
|
45 |
+
|
46 |
+
input_conv_weight = input_conv + '.weight'
|
47 |
+
pretrained_in_chans = state_dict[input_conv_weight].shape[1]
|
48 |
+
if in_chans != pretrained_in_chans:
|
49 |
+
if in_chans == 1:
|
50 |
+
print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
|
51 |
+
input_conv_weight, pretrained_in_chans))
|
52 |
+
conv1_weight = state_dict[input_conv_weight]
|
53 |
+
state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
|
54 |
+
else:
|
55 |
+
print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
|
56 |
+
input_conv_weight, pretrained_in_chans))
|
57 |
+
del state_dict[input_conv_weight]
|
58 |
+
strict = False
|
59 |
+
|
60 |
+
classifier_weight = classifier + '.weight'
|
61 |
+
pretrained_num_classes = state_dict[classifier_weight].shape[0]
|
62 |
+
if num_classes != pretrained_num_classes:
|
63 |
+
print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
|
64 |
+
del state_dict[classifier_weight]
|
65 |
+
del state_dict[classifier + '.bias']
|
66 |
+
strict = False
|
67 |
+
|
68 |
+
if filter_fn is not None:
|
69 |
+
state_dict = filter_fn(state_dict)
|
70 |
+
|
71 |
+
model.load_state_dict(state_dict, strict=strict)
|
geffnet/mobilenetv3.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" MobileNet-V3
|
2 |
+
|
3 |
+
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
|
4 |
+
|
5 |
+
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
6 |
+
|
7 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
8 |
+
"""
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from .activations import get_act_fn, get_act_layer, HardSwish
|
13 |
+
from .config import layer_config_kwargs
|
14 |
+
from .conv2d_layers import select_conv2d
|
15 |
+
from .helpers import load_pretrained
|
16 |
+
from .efficientnet_builder import *
|
17 |
+
|
18 |
+
__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
|
19 |
+
'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',
|
20 |
+
'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100',
|
21 |
+
'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100']
|
22 |
+
|
23 |
+
model_urls = {
|
24 |
+
'mobilenetv3_rw':
|
25 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
|
26 |
+
'mobilenetv3_large_075': None,
|
27 |
+
'mobilenetv3_large_100':
|
28 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
|
29 |
+
'mobilenetv3_large_minimal_100': None,
|
30 |
+
'mobilenetv3_small_075': None,
|
31 |
+
'mobilenetv3_small_100': None,
|
32 |
+
'mobilenetv3_small_minimal_100': None,
|
33 |
+
'tf_mobilenetv3_large_075':
|
34 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
|
35 |
+
'tf_mobilenetv3_large_100':
|
36 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
|
37 |
+
'tf_mobilenetv3_large_minimal_100':
|
38 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
|
39 |
+
'tf_mobilenetv3_small_075':
|
40 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
|
41 |
+
'tf_mobilenetv3_small_100':
|
42 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
|
43 |
+
'tf_mobilenetv3_small_minimal_100':
|
44 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
class MobileNetV3(nn.Module):
|
49 |
+
""" MobileNet-V3
|
50 |
+
|
51 |
+
A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the
|
52 |
+
head convolution without a final batch-norm layer before the classifier.
|
53 |
+
|
54 |
+
Paper: https://arxiv.org/abs/1905.02244
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
58 |
+
channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0.,
|
59 |
+
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
|
60 |
+
super(MobileNetV3, self).__init__()
|
61 |
+
self.drop_rate = drop_rate
|
62 |
+
|
63 |
+
stem_size = round_channels(stem_size, channel_multiplier)
|
64 |
+
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
65 |
+
self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs)
|
66 |
+
self.act1 = act_layer(inplace=True)
|
67 |
+
in_chs = stem_size
|
68 |
+
|
69 |
+
builder = EfficientNetBuilder(
|
70 |
+
channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs,
|
71 |
+
norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate)
|
72 |
+
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
73 |
+
in_chs = builder.in_chs
|
74 |
+
|
75 |
+
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
76 |
+
self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias)
|
77 |
+
self.act2 = act_layer(inplace=True)
|
78 |
+
self.classifier = nn.Linear(num_features, num_classes)
|
79 |
+
|
80 |
+
for m in self.modules():
|
81 |
+
if weight_init == 'goog':
|
82 |
+
initialize_weight_goog(m)
|
83 |
+
else:
|
84 |
+
initialize_weight_default(m)
|
85 |
+
|
86 |
+
def as_sequential(self):
|
87 |
+
layers = [self.conv_stem, self.bn1, self.act1]
|
88 |
+
layers.extend(self.blocks)
|
89 |
+
layers.extend([
|
90 |
+
self.global_pool, self.conv_head, self.act2,
|
91 |
+
nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
|
92 |
+
return nn.Sequential(*layers)
|
93 |
+
|
94 |
+
def features(self, x):
|
95 |
+
x = self.conv_stem(x)
|
96 |
+
x = self.bn1(x)
|
97 |
+
x = self.act1(x)
|
98 |
+
x = self.blocks(x)
|
99 |
+
x = self.global_pool(x)
|
100 |
+
x = self.conv_head(x)
|
101 |
+
x = self.act2(x)
|
102 |
+
return x
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
x = self.features(x)
|
106 |
+
x = x.flatten(1)
|
107 |
+
if self.drop_rate > 0.:
|
108 |
+
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
109 |
+
return self.classifier(x)
|
110 |
+
|
111 |
+
|
112 |
+
def _create_model(model_kwargs, variant, pretrained=False):
|
113 |
+
as_sequential = model_kwargs.pop('as_sequential', False)
|
114 |
+
model = MobileNetV3(**model_kwargs)
|
115 |
+
if pretrained and model_urls[variant]:
|
116 |
+
load_pretrained(model, model_urls[variant])
|
117 |
+
if as_sequential:
|
118 |
+
model = model.as_sequential()
|
119 |
+
return model
|
120 |
+
|
121 |
+
|
122 |
+
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
123 |
+
"""Creates a MobileNet-V3 model (RW variant).
|
124 |
+
|
125 |
+
Paper: https://arxiv.org/abs/1905.02244
|
126 |
+
|
127 |
+
This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the
|
128 |
+
eventual Tensorflow reference impl but has a few differences:
|
129 |
+
1. This model has no bias on the head convolution
|
130 |
+
2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet
|
131 |
+
3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer
|
132 |
+
from their parent block
|
133 |
+
4. This model does not enforce divisible by 8 limitation on the SE reduction channel count
|
134 |
+
|
135 |
+
Overall the changes are fairly minor and result in a very small parameter count difference and no
|
136 |
+
top-1/5
|
137 |
+
|
138 |
+
Args:
|
139 |
+
channel_multiplier: multiplier to number of channels per layer.
|
140 |
+
"""
|
141 |
+
arch_def = [
|
142 |
+
# stage 0, 112x112 in
|
143 |
+
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
|
144 |
+
# stage 1, 112x112 in
|
145 |
+
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
146 |
+
# stage 2, 56x56 in
|
147 |
+
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
148 |
+
# stage 3, 28x28 in
|
149 |
+
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
150 |
+
# stage 4, 14x14in
|
151 |
+
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
152 |
+
# stage 5, 14x14in
|
153 |
+
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
154 |
+
# stage 6, 7x7 in
|
155 |
+
['cn_r1_k1_s1_c960'], # hard-swish
|
156 |
+
]
|
157 |
+
with layer_config_kwargs(kwargs):
|
158 |
+
model_kwargs = dict(
|
159 |
+
block_args=decode_arch_def(arch_def),
|
160 |
+
head_bias=False, # one of my mistakes
|
161 |
+
channel_multiplier=channel_multiplier,
|
162 |
+
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
163 |
+
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True),
|
164 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
165 |
+
**kwargs,
|
166 |
+
)
|
167 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
168 |
+
return model
|
169 |
+
|
170 |
+
|
171 |
+
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
172 |
+
"""Creates a MobileNet-V3 large/small/minimal models.
|
173 |
+
|
174 |
+
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
|
175 |
+
Paper: https://arxiv.org/abs/1905.02244
|
176 |
+
|
177 |
+
Args:
|
178 |
+
channel_multiplier: multiplier to number of channels per layer.
|
179 |
+
"""
|
180 |
+
if 'small' in variant:
|
181 |
+
num_features = 1024
|
182 |
+
if 'minimal' in variant:
|
183 |
+
act_layer = 'relu'
|
184 |
+
arch_def = [
|
185 |
+
# stage 0, 112x112 in
|
186 |
+
['ds_r1_k3_s2_e1_c16'],
|
187 |
+
# stage 1, 56x56 in
|
188 |
+
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
|
189 |
+
# stage 2, 28x28 in
|
190 |
+
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
|
191 |
+
# stage 3, 14x14 in
|
192 |
+
['ir_r2_k3_s1_e3_c48'],
|
193 |
+
# stage 4, 14x14in
|
194 |
+
['ir_r3_k3_s2_e6_c96'],
|
195 |
+
# stage 6, 7x7 in
|
196 |
+
['cn_r1_k1_s1_c576'],
|
197 |
+
]
|
198 |
+
else:
|
199 |
+
act_layer = 'hard_swish'
|
200 |
+
arch_def = [
|
201 |
+
# stage 0, 112x112 in
|
202 |
+
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
|
203 |
+
# stage 1, 56x56 in
|
204 |
+
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
|
205 |
+
# stage 2, 28x28 in
|
206 |
+
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
|
207 |
+
# stage 3, 14x14 in
|
208 |
+
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
|
209 |
+
# stage 4, 14x14in
|
210 |
+
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
|
211 |
+
# stage 6, 7x7 in
|
212 |
+
['cn_r1_k1_s1_c576'], # hard-swish
|
213 |
+
]
|
214 |
+
else:
|
215 |
+
num_features = 1280
|
216 |
+
if 'minimal' in variant:
|
217 |
+
act_layer = 'relu'
|
218 |
+
arch_def = [
|
219 |
+
# stage 0, 112x112 in
|
220 |
+
['ds_r1_k3_s1_e1_c16'],
|
221 |
+
# stage 1, 112x112 in
|
222 |
+
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
|
223 |
+
# stage 2, 56x56 in
|
224 |
+
['ir_r3_k3_s2_e3_c40'],
|
225 |
+
# stage 3, 28x28 in
|
226 |
+
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
|
227 |
+
# stage 4, 14x14in
|
228 |
+
['ir_r2_k3_s1_e6_c112'],
|
229 |
+
# stage 5, 14x14in
|
230 |
+
['ir_r3_k3_s2_e6_c160'],
|
231 |
+
# stage 6, 7x7 in
|
232 |
+
['cn_r1_k1_s1_c960'],
|
233 |
+
]
|
234 |
+
else:
|
235 |
+
act_layer = 'hard_swish'
|
236 |
+
arch_def = [
|
237 |
+
# stage 0, 112x112 in
|
238 |
+
['ds_r1_k3_s1_e1_c16_nre'], # relu
|
239 |
+
# stage 1, 112x112 in
|
240 |
+
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
241 |
+
# stage 2, 56x56 in
|
242 |
+
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
243 |
+
# stage 3, 28x28 in
|
244 |
+
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
245 |
+
# stage 4, 14x14in
|
246 |
+
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
247 |
+
# stage 5, 14x14in
|
248 |
+
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
249 |
+
# stage 6, 7x7 in
|
250 |
+
['cn_r1_k1_s1_c960'], # hard-swish
|
251 |
+
]
|
252 |
+
with layer_config_kwargs(kwargs):
|
253 |
+
model_kwargs = dict(
|
254 |
+
block_args=decode_arch_def(arch_def),
|
255 |
+
num_features=num_features,
|
256 |
+
stem_size=16,
|
257 |
+
channel_multiplier=channel_multiplier,
|
258 |
+
act_layer=resolve_act_layer(kwargs, act_layer),
|
259 |
+
se_kwargs=dict(
|
260 |
+
act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8),
|
261 |
+
norm_kwargs=resolve_bn_args(kwargs),
|
262 |
+
**kwargs,
|
263 |
+
)
|
264 |
+
model = _create_model(model_kwargs, variant, pretrained)
|
265 |
+
return model
|
266 |
+
|
267 |
+
|
268 |
+
def mobilenetv3_rw(pretrained=False, **kwargs):
|
269 |
+
""" MobileNet-V3 RW
|
270 |
+
Attn: See note in gen function for this variant.
|
271 |
+
"""
|
272 |
+
# NOTE for train set drop_rate=0.2
|
273 |
+
if pretrained:
|
274 |
+
# pretrained model trained with non-default BN epsilon
|
275 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
276 |
+
model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
|
277 |
+
return model
|
278 |
+
|
279 |
+
|
280 |
+
def mobilenetv3_large_075(pretrained=False, **kwargs):
|
281 |
+
""" MobileNet V3 Large 0.75"""
|
282 |
+
# NOTE for train set drop_rate=0.2
|
283 |
+
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
284 |
+
return model
|
285 |
+
|
286 |
+
|
287 |
+
def mobilenetv3_large_100(pretrained=False, **kwargs):
|
288 |
+
""" MobileNet V3 Large 1.0 """
|
289 |
+
# NOTE for train set drop_rate=0.2
|
290 |
+
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
291 |
+
return model
|
292 |
+
|
293 |
+
|
294 |
+
def mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
|
295 |
+
""" MobileNet V3 Large (Minimalistic) 1.0 """
|
296 |
+
# NOTE for train set drop_rate=0.2
|
297 |
+
model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
298 |
+
return model
|
299 |
+
|
300 |
+
|
301 |
+
def mobilenetv3_small_075(pretrained=False, **kwargs):
|
302 |
+
""" MobileNet V3 Small 0.75 """
|
303 |
+
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
304 |
+
return model
|
305 |
+
|
306 |
+
|
307 |
+
def mobilenetv3_small_100(pretrained=False, **kwargs):
|
308 |
+
""" MobileNet V3 Small 1.0 """
|
309 |
+
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
310 |
+
return model
|
311 |
+
|
312 |
+
|
313 |
+
def mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
|
314 |
+
""" MobileNet V3 Small (Minimalistic) 1.0 """
|
315 |
+
model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
316 |
+
return model
|
317 |
+
|
318 |
+
|
319 |
+
def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
|
320 |
+
""" MobileNet V3 Large 0.75. Tensorflow compat variant. """
|
321 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
322 |
+
kwargs['pad_type'] = 'same'
|
323 |
+
model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
324 |
+
return model
|
325 |
+
|
326 |
+
|
327 |
+
def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
|
328 |
+
""" MobileNet V3 Large 1.0. Tensorflow compat variant. """
|
329 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
330 |
+
kwargs['pad_type'] = 'same'
|
331 |
+
model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
332 |
+
return model
|
333 |
+
|
334 |
+
|
335 |
+
def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
|
336 |
+
""" MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """
|
337 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
338 |
+
kwargs['pad_type'] = 'same'
|
339 |
+
model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
340 |
+
return model
|
341 |
+
|
342 |
+
|
343 |
+
def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
|
344 |
+
""" MobileNet V3 Small 0.75. Tensorflow compat variant. """
|
345 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
346 |
+
kwargs['pad_type'] = 'same'
|
347 |
+
model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
348 |
+
return model
|
349 |
+
|
350 |
+
|
351 |
+
def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
|
352 |
+
""" MobileNet V3 Small 1.0. Tensorflow compat variant."""
|
353 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
354 |
+
kwargs['pad_type'] = 'same'
|
355 |
+
model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
356 |
+
return model
|
357 |
+
|
358 |
+
|
359 |
+
def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
|
360 |
+
""" MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """
|
361 |
+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
362 |
+
kwargs['pad_type'] = 'same'
|
363 |
+
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
364 |
+
return model
|
geffnet/model_factory.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import set_layer_config
|
2 |
+
from .helpers import load_checkpoint
|
3 |
+
|
4 |
+
from .gen_efficientnet import *
|
5 |
+
from .mobilenetv3 import *
|
6 |
+
|
7 |
+
|
8 |
+
def create_model(
|
9 |
+
model_name='mnasnet_100',
|
10 |
+
pretrained=None,
|
11 |
+
num_classes=1000,
|
12 |
+
in_chans=3,
|
13 |
+
checkpoint_path='',
|
14 |
+
**kwargs):
|
15 |
+
|
16 |
+
model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs)
|
17 |
+
|
18 |
+
if model_name in globals():
|
19 |
+
create_fn = globals()[model_name]
|
20 |
+
model = create_fn(**model_kwargs)
|
21 |
+
else:
|
22 |
+
raise RuntimeError('Unknown model (%s)' % model_name)
|
23 |
+
|
24 |
+
if checkpoint_path and not pretrained:
|
25 |
+
load_checkpoint(model, checkpoint_path)
|
26 |
+
|
27 |
+
return model
|
geffnet/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = '1.0.2'
|
requirements.txt
CHANGED
@@ -109,12 +109,13 @@ sympy==1.12.1
|
|
109 |
tokenizers==0.15.2
|
110 |
tomlkit==0.12.0
|
111 |
toolz==0.12.1
|
112 |
-
torch==2.0
|
113 |
-
torchvision==
|
|
|
|
|
114 |
tqdm==4.66.4
|
115 |
transformers==4.36.1
|
116 |
trimesh==4.0.5
|
117 |
-
triton==2.0.0
|
118 |
typer==0.12.3
|
119 |
typing-inspect==0.9.0
|
120 |
typing_extensions==4.11.0
|
@@ -126,10 +127,8 @@ uvloop==0.19.0
|
|
126 |
watchfiles==0.22.0
|
127 |
websockets==11.0.3
|
128 |
wrapt==1.16.0
|
129 |
-
xformers==0.0.20
|
130 |
xxhash==3.4.1
|
131 |
yarl==1.9.4
|
132 |
zipp==3.19.1
|
133 |
einops==0.7.0
|
134 |
-
opencv-python-headless==4.8.1.78
|
135 |
-
geffnet==1.0.2
|
|
|
109 |
tokenizers==0.15.2
|
110 |
tomlkit==0.12.0
|
111 |
toolz==0.12.1
|
112 |
+
torch==2.2.0
|
113 |
+
torchvision==0.18.0
|
114 |
+
xformers==0.0.24
|
115 |
+
triton==2.2.0
|
116 |
tqdm==4.66.4
|
117 |
transformers==4.36.1
|
118 |
trimesh==4.0.5
|
|
|
119 |
typer==0.12.3
|
120 |
typing-inspect==0.9.0
|
121 |
typing_extensions==4.11.0
|
|
|
127 |
watchfiles==0.22.0
|
128 |
websockets==11.0.3
|
129 |
wrapt==1.16.0
|
|
|
130 |
xxhash==3.4.1
|
131 |
yarl==1.9.4
|
132 |
zipp==3.19.1
|
133 |
einops==0.7.0
|
134 |
+
opencv-python-headless==4.8.1.78
|
|