File size: 2,727 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Optional

from tha3.nn.conv import create_conv7_block_from_block_args, create_conv3_block_from_block_args, \
    create_downsample_block_from_block_args, create_conv3
from tha3.nn.resnet_block import ResnetBlock
from tha3.nn.resnet_block_seperable import ResnetBlockSeparable
from tha3.nn.separable_conv import create_separable_conv7_block, create_separable_conv3_block, \
    create_separable_downsample_block, create_separable_conv3
from tha3.nn.util import BlockArgs


class ConvBlockFactory:
    def __init__(self,
                 block_args: BlockArgs,
                 use_separable_convolution: bool = False):
        self.use_separable_convolution = use_separable_convolution
        self.block_args = block_args

    def create_conv3(self,
                     in_channels: int,
                     out_channels: int,
                     bias: bool,
                     initialization_method: Optional[str] = None):
        if initialization_method is None:
            initialization_method = self.block_args.initialization_method
        if self.use_separable_convolution:
            return create_separable_conv3(
                in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm)
        else:
            return create_conv3(
                in_channels, out_channels, bias, initialization_method, self.block_args.use_spectral_norm)

    def create_conv7_block(self, in_channels: int, out_channels: int):
        if self.use_separable_convolution:
            return create_separable_conv7_block(in_channels, out_channels, self.block_args)
        else:
            return create_conv7_block_from_block_args(in_channels, out_channels, self.block_args)

    def create_conv3_block(self, in_channels: int, out_channels: int):
        if self.use_separable_convolution:
            return create_separable_conv3_block(in_channels, out_channels, self.block_args)
        else:
            return create_conv3_block_from_block_args(in_channels, out_channels, self.block_args)

    def create_downsample_block(self, in_channels: int, out_channels: int, is_output_1x1: bool):
        if self.use_separable_convolution:
            return create_separable_downsample_block(in_channels, out_channels, is_output_1x1, self.block_args)
        else:
            return create_downsample_block_from_block_args(in_channels, out_channels, is_output_1x1)

    def create_resnet_block(self, num_channels: int, is_1x1: bool):
        if self.use_separable_convolution:
            return ResnetBlockSeparable.create(num_channels, is_1x1, block_args=self.block_args)
        else:
            return ResnetBlock.create(num_channels, is_1x1, block_args=self.block_args)