File size: 4,200 Bytes
a446b0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

import comet.src.data.config as cfg
import comet.src.train.utils as train_utils
import comet.src.models.utils as model_utils
import comet.src.evaluate.utils as eval_utils
import comet.utils.utils as utils
from IPython import embed


##############################################################################
#                                       BATCH
##############################################################################


def batch_atomic_generate(opt, nums, losses, batch_variables, eval_mode=False):
    data_loader = batch_variables["data"]
    model = batch_variables["model"]
    split = batch_variables["split"]

    batch, reset = data_loader.sample_batch(split, bs=opt.train.dynamic.bs)

    input_ = model_utils.prepare_position_embeddings(
        opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
    attention_mask = batch["attention_mask"]
    loss_mask = batch["loss_mask"]

    targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)

    loss, dist = mle_steps(
        opt.net.model, model, input_[:, :-1, :], targets,
        attention_mask[:, :-1], loss_reduction="none")

    # Set loss name
    micro_name = "total_micro"
    macro_name = "total_macro"

    length = loss_mask.sum(1)
    bs = input_.size(0)

    final_loss = (loss * loss_mask).sum(1)

    update_generation_losses(losses, nums, micro_name, macro_name, bs,
                             length, (loss * loss_mask).sum(1), split)

    final_loss = final_loss / length

    outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}

    return outputs


def batch_conceptnet_generate(opt, nums, losses, batch_variables,
                              eval_mode=False, tracking_mode=False):
    data_loader = batch_variables["data"]
    model = batch_variables["model"]
    split = batch_variables["split"]
    category = batch_variables["category"]

    batch, reset = data_loader.sample_batch(
        split, bs=opt.train.dynamic.bs, cat=category)

    input_ = model_utils.prepare_position_embeddings(
        opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
    attention_mask = batch["attention_mask"]
    loss_mask = batch["loss_mask"]

    targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)

    loss, dist = mle_steps(
        opt.net.model, model, input_[:, :-1, :], targets,
        attention_mask[:, :-1], loss_reduction="none")

    # Set loss name
    if not eval_mode or batch_variables["category"] == "positive":
        micro_name = "total_micro"
        macro_name = "total_macro"
    else:
        micro_name = "negative_micro"
        macro_name = "negative_macro"

    length = loss_mask.sum(1)
    bs = input_.size(0)

    final_loss = (loss * loss_mask).sum(1)

    update_generation_losses(losses, nums, micro_name, macro_name, bs,
                             length, (loss * loss_mask).sum(1), split)

    final_loss = final_loss / length

    outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}

    if tracking_mode:
        outputs["tracking"] = final_loss.squeeze().tolist()

    return outputs


def mle_steps(key, model, input_, targets, attention_mask,
              loss_reduction="mean", i=None):
    word_acts = decode(model, input_.unsqueeze(1),
                       attention_mask, i)

    word_dist = train_utils.modify_output_for_loss_fn(
        "nll", word_acts, dim=-1)

    # Compute losses
    loss = F.nll_loss(
        word_dist.view(-1, word_dist.size(-1)),
        targets, reduction=loss_reduction)

    if loss_reduction != "mean":
        return loss.view(word_dist.size(0), -1), word_dist
    else:
        return loss, word_dist


def decode(model, input_, attention_mask, i=None):
    return model(input_, sequence_mask=attention_mask)


def update_generation_losses(losses, nums, micro, macro, bs,
                             length, loss, split):
    if split == "train":
        train_utils.update_generation_losses(
            losses, nums, micro, macro, bs, length, loss)
    else:
        eval_utils.update_generation_losses(
            losses, nums, micro, macro, bs, length, loss)