Spaces:
Build error
Build error
import copy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import comet.src.data.config as cfg | |
import comet.src.train.utils as train_utils | |
import comet.src.models.utils as model_utils | |
import comet.src.evaluate.utils as eval_utils | |
import comet.utils.utils as utils | |
from IPython import embed | |
############################################################################## | |
# BATCH | |
############################################################################## | |
def batch_atomic_generate(opt, nums, losses, batch_variables, eval_mode=False): | |
data_loader = batch_variables["data"] | |
model = batch_variables["model"] | |
split = batch_variables["split"] | |
batch, reset = data_loader.sample_batch(split, bs=opt.train.dynamic.bs) | |
input_ = model_utils.prepare_position_embeddings( | |
opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1)) | |
attention_mask = batch["attention_mask"] | |
loss_mask = batch["loss_mask"] | |
targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1) | |
loss, dist = mle_steps( | |
opt.net.model, model, input_[:, :-1, :], targets, | |
attention_mask[:, :-1], loss_reduction="none") | |
# Set loss name | |
micro_name = "total_micro" | |
macro_name = "total_macro" | |
length = loss_mask.sum(1) | |
bs = input_.size(0) | |
final_loss = (loss * loss_mask).sum(1) | |
update_generation_losses(losses, nums, micro_name, macro_name, bs, | |
length, (loss * loss_mask).sum(1), split) | |
final_loss = final_loss / length | |
outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset} | |
return outputs | |
def batch_conceptnet_generate(opt, nums, losses, batch_variables, | |
eval_mode=False, tracking_mode=False): | |
data_loader = batch_variables["data"] | |
model = batch_variables["model"] | |
split = batch_variables["split"] | |
category = batch_variables["category"] | |
batch, reset = data_loader.sample_batch( | |
split, bs=opt.train.dynamic.bs, cat=category) | |
input_ = model_utils.prepare_position_embeddings( | |
opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1)) | |
attention_mask = batch["attention_mask"] | |
loss_mask = batch["loss_mask"] | |
targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1) | |
loss, dist = mle_steps( | |
opt.net.model, model, input_[:, :-1, :], targets, | |
attention_mask[:, :-1], loss_reduction="none") | |
# Set loss name | |
if not eval_mode or batch_variables["category"] == "positive": | |
micro_name = "total_micro" | |
macro_name = "total_macro" | |
else: | |
micro_name = "negative_micro" | |
macro_name = "negative_macro" | |
length = loss_mask.sum(1) | |
bs = input_.size(0) | |
final_loss = (loss * loss_mask).sum(1) | |
update_generation_losses(losses, nums, micro_name, macro_name, bs, | |
length, (loss * loss_mask).sum(1), split) | |
final_loss = final_loss / length | |
outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset} | |
if tracking_mode: | |
outputs["tracking"] = final_loss.squeeze().tolist() | |
return outputs | |
def mle_steps(key, model, input_, targets, attention_mask, | |
loss_reduction="mean", i=None): | |
word_acts = decode(model, input_.unsqueeze(1), | |
attention_mask, i) | |
word_dist = train_utils.modify_output_for_loss_fn( | |
"nll", word_acts, dim=-1) | |
# Compute losses | |
loss = F.nll_loss( | |
word_dist.view(-1, word_dist.size(-1)), | |
targets, reduction=loss_reduction) | |
if loss_reduction != "mean": | |
return loss.view(word_dist.size(0), -1), word_dist | |
else: | |
return loss, word_dist | |
def decode(model, input_, attention_mask, i=None): | |
return model(input_, sequence_mask=attention_mask) | |
def update_generation_losses(losses, nums, micro, macro, bs, | |
length, loss, split): | |
if split == "train": | |
train_utils.update_generation_losses( | |
losses, nums, micro, macro, bs, length, loss) | |
else: | |
eval_utils.update_generation_losses( | |
losses, nums, micro, macro, bs, length, loss) | |