File size: 3,994 Bytes
fcd5579 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import pickle
from functools import partial
from pathlib import Path
import cv2
import numpy as np
from core.interact import interact as io
from core.leras import nn
class XSegNet(object):
VERSION = 1
def __init__ (self, name,
resolution=256,
load_weights=True,
weights_file_root=None,
training=False,
place_model_on_cpu=False,
run_on_cpu=False,
optimizer=None,
data_format="NHWC",
raise_on_no_model_files=False):
self.resolution = resolution
self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent
nn.initialize(data_format=data_format)
tf = nn.tf
model_name = f'{name}_{resolution}'
self.model_filename_list = []
with tf.device ('/CPU:0'):
#Place holders on CPU
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) )
# Initializing model classes
with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name):
self.model = nn.XSeg(3, 32, 1, name=name)
self.model_weights = self.model.get_weights()
if training:
if optimizer is None:
raise ValueError("Optimizer should be provided for training mode.")
self.opt = optimizer
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]
if not training:
with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name):
_, pred = self.model(self.input_t)
def net_run(input_np):
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
self.net_run = net_run
self.initialized = True
# Loading/initializing all models/optimizers weights
for model, filename in self.model_filename_list:
do_init = not load_weights
if not do_init:
model_file_path = self.weights_file_root / filename
do_init = not model.load_weights( model_file_path )
if do_init:
if raise_on_no_model_files:
raise Exception(f'{model_file_path} does not exists.')
if not training:
self.initialized = False
break
if do_init:
model.init_weights()
def get_resolution(self):
return self.resolution
def flow(self, x, pretrain=False):
return self.model(x, pretrain=pretrain)
def get_weights(self):
return self.model_weights
def save_weights(self):
for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving", leave=False):
model.save_weights( self.weights_file_root / filename )
def extract (self, input_image):
if not self.initialized:
return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype )
input_shape_len = len(input_image.shape)
if input_shape_len == 3:
input_image = input_image[None,...]
result = np.clip ( self.net_run(input_image), 0, 1.0 )
result[result < 0.1] = 0 #get rid of noise
if input_shape_len == 3:
result = result[0]
return result |