Spaces:
Sleeping
Sleeping
import numpy as np | |
from keras.callbacks import ModelCheckpoint | |
from src.model.model import get_unet | |
from src.data.data_processing import load_and_preprocess_train_data | |
def train_model(): | |
print('-'*30) | |
print('Loading and preprocessing train data...') | |
print('-'*30) | |
imgs_train, imgs_mask_train = load_and_preprocess_train_data() | |
imgs_train = imgs_train.astype('float32') | |
mean = np.mean(imgs_train) | |
std = np.std(imgs_train) | |
# save mean and std | |
np.save('mean.npy', mean) | |
np.save('std.npy', std) | |
imgs_train -= mean | |
imgs_train /= std | |
imgs_mask_train = imgs_mask_train.astype('float32') | |
print('-'*30) | |
print('Creating and compiling model...') | |
print('-'*30) | |
model = get_unet() | |
model_checkpoint = ModelCheckpoint('../../models/weights.h5', monitor='val_loss', save_best_only=True) | |
print('-'*30) | |
print('Fitting model...') | |
print('-'*30) | |
history = model.fit(imgs_train, imgs_mask_train, batch_size=10, epochs=20, verbose=1, shuffle=True, | |
validation_split=0.2, callbacks=[model_checkpoint]) | |
return model, history, mean, std | |
if __name__ == '__main__': | |
train_model() |