Minh Q. Le
Pushed COSMIC code
a446b0b
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import comet.src.data.config as cfg
import comet.src.train.utils as train_utils
import comet.src.models.utils as model_utils
import comet.src.evaluate.utils as eval_utils
import comet.utils.utils as utils
from IPython import embed
##############################################################################
# BATCH
##############################################################################
def batch_atomic_generate(opt, nums, losses, batch_variables, eval_mode=False):
data_loader = batch_variables["data"]
model = batch_variables["model"]
split = batch_variables["split"]
batch, reset = data_loader.sample_batch(split, bs=opt.train.dynamic.bs)
input_ = model_utils.prepare_position_embeddings(
opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
attention_mask = batch["attention_mask"]
loss_mask = batch["loss_mask"]
targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)
loss, dist = mle_steps(
opt.net.model, model, input_[:, :-1, :], targets,
attention_mask[:, :-1], loss_reduction="none")
# Set loss name
micro_name = "total_micro"
macro_name = "total_macro"
length = loss_mask.sum(1)
bs = input_.size(0)
final_loss = (loss * loss_mask).sum(1)
update_generation_losses(losses, nums, micro_name, macro_name, bs,
length, (loss * loss_mask).sum(1), split)
final_loss = final_loss / length
outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}
return outputs
def batch_conceptnet_generate(opt, nums, losses, batch_variables,
eval_mode=False, tracking_mode=False):
data_loader = batch_variables["data"]
model = batch_variables["model"]
split = batch_variables["split"]
category = batch_variables["category"]
batch, reset = data_loader.sample_batch(
split, bs=opt.train.dynamic.bs, cat=category)
input_ = model_utils.prepare_position_embeddings(
opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
attention_mask = batch["attention_mask"]
loss_mask = batch["loss_mask"]
targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)
loss, dist = mle_steps(
opt.net.model, model, input_[:, :-1, :], targets,
attention_mask[:, :-1], loss_reduction="none")
# Set loss name
if not eval_mode or batch_variables["category"] == "positive":
micro_name = "total_micro"
macro_name = "total_macro"
else:
micro_name = "negative_micro"
macro_name = "negative_macro"
length = loss_mask.sum(1)
bs = input_.size(0)
final_loss = (loss * loss_mask).sum(1)
update_generation_losses(losses, nums, micro_name, macro_name, bs,
length, (loss * loss_mask).sum(1), split)
final_loss = final_loss / length
outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}
if tracking_mode:
outputs["tracking"] = final_loss.squeeze().tolist()
return outputs
def mle_steps(key, model, input_, targets, attention_mask,
loss_reduction="mean", i=None):
word_acts = decode(model, input_.unsqueeze(1),
attention_mask, i)
word_dist = train_utils.modify_output_for_loss_fn(
"nll", word_acts, dim=-1)
# Compute losses
loss = F.nll_loss(
word_dist.view(-1, word_dist.size(-1)),
targets, reduction=loss_reduction)
if loss_reduction != "mean":
return loss.view(word_dist.size(0), -1), word_dist
else:
return loss, word_dist
def decode(model, input_, attention_mask, i=None):
return model(input_, sequence_mask=attention_mask)
def update_generation_losses(losses, nums, micro, macro, bs,
length, loss, split):
if split == "train":
train_utils.update_generation_losses(
losses, nums, micro, macro, bs, length, loss)
else:
eval_utils.update_generation_losses(
losses, nums, micro, macro, bs, length, loss)