Minh Q. Le
Pushed COSMIC code
a446b0b
import torch
import torch.nn as nn
import torch.nn.functional as F
import comet.src.data.config as cfg
import comet.src.data.data as data
import comet.src.train.utils as train_utils
import comet.src.train.batch as batch
import comet.src.evaluate.evaluate as evaluate
import comet.src.evaluate.generate as gen
import comet.src.evaluate.sampler as sampling
import comet.utils.utils as utils
from tensorboardX import SummaryWriter
class Trainer(object):
def __init__(self, opt, meta, data_loader, model, optimizer):
self.optimizer = optimizer
self.model = model
if opt.trainer == "epoch":
self.epochs = meta.epochs
self.data_loader = data_loader
self.opt = opt
self.losses = {"dev": {}, "test": {}, "train": {}}
self.top_score = None
self.lrs = {}
self.batch_variables = {
"data": self.data_loader,
"model": self.model,
"split": "train"
}
self.do_gen = cfg.do_gen
self.samplers = {}
def decide_to_save(self):
to_save = cfg.save and not cfg.toy
to_save = to_save or cfg.test_save
print(cfg.save_strategy)
if cfg.save_strategy == "best":
if self.top_score[0] != self.opt.train.dynamic.epoch:
print("DOING IT RIGHT")
to_save = False
return to_save
def save_model(self, tracked_score):
lrs = {}
for i, param_group in enumerate(self.optimizer.param_groups):
lrs[i] = param_group['lr']
self.lrs[self.opt.train.dynamic.epoch] = lrs
to_save = self.decide_to_save()
if to_save:
data.save_step(
self.model, self.data_loader.vocab_encoder,
self.optimizer, self.opt,
self.opt.train.dynamic.epoch, self.lrs)
def log_losses(self, opt, losses):
if (not cfg.toy and cfg.save) or cfg.test_save:
data.save_eval_file(opt, losses["train"], "losses", split="train")
data.save_eval_file(opt, losses['dev'], "losses", split="dev")
data.save_eval_file(opt, losses['test'], "losses", split="test")
def set_logger(self):
if cfg.toy:
self.logger = SummaryWriter(utils.make_name(
self.opt, prefix="garbage/logs/", eval_=True, do_epoch=False))
else:
self.logger = SummaryWriter(utils.make_name(
self.opt, prefix="logs/", eval_=True, do_epoch=False))
print("Logging Tensorboard Files at: {}".format(self.logger.logdir))
def stop_logger(self):
self.logger.close()
def run(self):
self.set_logger()
self.count = 0
for epoch in range(self.epochs):
self.model.train()
self.opt.train.dynamic.epoch += 1
self.epoch()
self.stop_logger()
def epoch(self):
nums = self.reset_losses()
# Initialize progress bar
bar = utils.initialize_progress_bar(
self.data_loader.sequences["train"])
reset = False
while not reset:
loss, nums, reset = self.do_forward_pass(nums)
self.do_backward_pass(loss)
self.update_parameters()
bar.update(self.opt.train.dynamic.bs)
self.count += 1
for loss_name in self.losses["train"]:
self.logger.add_scalar(
"train/{}".format(loss_name),
loss.item() / self.opt.train.dynamic.bs,
self.count)
if cfg.toy and self.counter(nums) > 300:
break
with torch.no_grad():
self.run_evaluation_cycle()
self.log_losses(self.opt, self.losses)
self.update_top_score(self.opt)
self.save_model(self.get_tracked_score())
self.data_loader.reset_offsets("train")
def run_evaluation_cycle(self):
for split in ["dev", "test"]:
self.evaluator.validate(
self.opt.train.dynamic.epoch, split,
self.losses[split])
if self.do_gen:
gen.do_gen_run(
self.opt, self.generator, self.opt.train.dynamic.epoch,
split, self.losses[split])
iter_num = self.opt.train.dynamic.epoch
for loss_name in self.losses[split]:
self.logger.add_scalar(
"{}/{}".format(split, loss_name),
self.losses[split][loss_name][iter_num],
iter_num)
def clip_gradients(self):
if self.opt.train.static.clip:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.opt.train.static.clip)
def do_forward_pass(self, nums):
token_loss, nums, reset = self.batch(
self.opt, nums, self.losses["train"],
self.batch_variables)
return token_loss, nums, reset
def do_backward_pass(self, loss):
loss.backward()
def update_parameters(self):
if self.opt.model == "lstm":
self.clip_gradients()
self.optimizer.step()
self.optimizer.zero_grad()
def reset_losses(self):
loss_names = set([i.rstrip("maicro").rstrip("_") for
i in self.losses["train"].keys()])
return self.initialize_losses(list(loss_names))
class IteratorTrainer(Trainer):
def __init__(self, opt, meta, data_loader, model, optimizer):
super(IteratorTrainer, self).__init__(
opt, meta, data_loader, model, optimizer)
self.iters = meta.cycle
self.total_iters = meta.iterations
def run(self):
self.set_logger()
# Initialize progress bar
bar = utils.set_progress_bar(self.total_iters)
for cycle_num in range(int(self.total_iters / self.iters)):
self.model.train()
self.cycle(bar, cycle_num)
with torch.no_grad():
self.run_evaluation_cycle()
self.log_losses(self.opt, self.losses)
self.update_top_score(self.opt)
self.save_model(self.get_tracked_score())
self.stop_logger()
def cycle(self, bar, cycle_num):
nums = self.reset_losses()
print(self.losses["train"])
for i in range(1, self.iters + 1):
# self.model.zero_grad()
loss, nums, reset = self.do_forward_pass(nums)
self.do_backward_pass(loss)
self.update_parameters()
# print(loss)
# print(loss.item())
self.opt.train.dynamic.epoch += 1
for loss_name in self.losses["train"]:
self.logger.add_scalar(
"train/{}".format(loss_name),
loss.item() / self.opt.train.dynamic.bs,
self.opt.train.dynamic.epoch)
bar.update(1)
if cfg.toy and i > 10:
break
if reset:
self.data_loader.reset_offsets("train")