Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from collections import OrderedDict | |
import glob | |
class Saver(object): | |
def __init__(self, args): | |
self.args = args | |
self.directory = os.path.join('run', args.train_dataset, args.checkname) | |
self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) | |
run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 | |
self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) | |
if not os.path.exists(self.experiment_dir): | |
os.makedirs(self.experiment_dir) | |
def save_checkpoint(self, state, filename='checkpoint.pth.tar'): | |
"""Saves checkpoint to disk""" | |
filename = os.path.join(self.experiment_dir, filename) | |
torch.save(state, filename) | |
def save_experiment_config(self): | |
logfile = os.path.join(self.experiment_dir, 'parameters.txt') | |
log_file = open(logfile, 'w') | |
p = OrderedDict() | |
p['train_dataset'] = self.args.train_dataset | |
p['lr'] = self.args.lr | |
p['epoch'] = self.args.epochs | |
for key, val in p.items(): | |
log_file.write(key + ':' + str(val) + '\n') | |
log_file.close() | |