File size: 3,994 Bytes
fcd5579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import pickle
from functools import partial
from pathlib import Path

import cv2
import numpy as np

from core.interact import interact as io
from core.leras import nn


class XSegNet(object):
    VERSION = 1

    def __init__ (self, name, 
                        resolution=256, 
                        load_weights=True, 
                        weights_file_root=None, 
                        training=False, 
                        place_model_on_cpu=False, 
                        run_on_cpu=False, 
                        optimizer=None, 
                        data_format="NHWC",
                        raise_on_no_model_files=False):
                
        self.resolution = resolution
        self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent
        
        nn.initialize(data_format=data_format)
        tf = nn.tf
        
        model_name = f'{name}_{resolution}'
        self.model_filename_list = []
        
        with tf.device ('/CPU:0'):
            #Place holders on CPU
            self.input_t  = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
            self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) )

        # Initializing model classes
        with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name):
            self.model = nn.XSeg(3, 32, 1, name=name)
            self.model_weights = self.model.get_weights()
            if training:
                if optimizer is None:
                    raise ValueError("Optimizer should be provided for training mode.")                
                self.opt = optimizer              
                self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)                    
                self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
                
        
        self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]

        if not training:
            with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name):
                _, pred = self.model(self.input_t)

            def net_run(input_np):
                return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
            self.net_run = net_run

        self.initialized = True
        # Loading/initializing all models/optimizers weights
        for model, filename in self.model_filename_list:
            do_init = not load_weights

            if not do_init:
                model_file_path = self.weights_file_root / filename
                do_init = not model.load_weights( model_file_path )
                if do_init:
                    if raise_on_no_model_files:
                        raise Exception(f'{model_file_path} does not exists.')
                    if not training:
                        self.initialized = False
                        break

            if do_init:
                model.init_weights()
        
    def get_resolution(self):
        return self.resolution
        
    def flow(self, x, pretrain=False):
        return self.model(x, pretrain=pretrain)

    def get_weights(self):
        return self.model_weights

    def save_weights(self):
        for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving", leave=False):
            model.save_weights( self.weights_file_root / filename )

    def extract (self, input_image):
        if not self.initialized:
            return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype )
            
        input_shape_len = len(input_image.shape)            
        if input_shape_len == 3:
            input_image = input_image[None,...]

        result = np.clip ( self.net_run(input_image), 0, 1.0 )
        result[result < 0.1] = 0 #get rid of noise

        if input_shape_len == 3:
            result = result[0]

        return result