ashutosh-pathak's picture
Restructure + Add Gradio Interface
8444121
raw
history blame
1.19 kB
import numpy as np
from keras.callbacks import ModelCheckpoint
from src.model.model import get_unet
from src.data.data_processing import load_and_preprocess_train_data
def train_model():
print('-'*30)
print('Loading and preprocessing train data...')
print('-'*30)
imgs_train, imgs_mask_train = load_and_preprocess_train_data()
imgs_train = imgs_train.astype('float32')
mean = np.mean(imgs_train)
std = np.std(imgs_train)
# save mean and std
np.save('mean.npy', mean)
np.save('std.npy', std)
imgs_train -= mean
imgs_train /= std
imgs_mask_train = imgs_mask_train.astype('float32')
print('-'*30)
print('Creating and compiling model...')
print('-'*30)
model = get_unet()
model_checkpoint = ModelCheckpoint('../../models/weights.h5', monitor='val_loss', save_best_only=True)
print('-'*30)
print('Fitting model...')
print('-'*30)
history = model.fit(imgs_train, imgs_mask_train, batch_size=10, epochs=20, verbose=1, shuffle=True,
validation_split=0.2, callbacks=[model_checkpoint])
return model, history, mean, std
if __name__ == '__main__':
train_model()