File size: 3,802 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
from torch.nn.modules.loss import _Loss


class LatentLayersKLLoss(_Loss):
    def __init__(self, args):
        super().__init__()
        self.args = args

    def forward(self, layer_samples, lang_idx, update_num, sample_size):
        prior = self.args.prior
        samples = layer_samples[lang_idx]
        eps = 1e-7
        if prior == "uniform":
            # uniform prior
            kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
        elif prior == "agged_posterior":
            # aggregated posterior
            y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
            agged_q = torch.sum(y_t, dim=0)
            row_norm = agged_q.sum(-1)
            normed_agg_q = agged_q / row_norm
            kl_loss = (
                samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
            ).sum(-1)
        else:
            raise NotImplementedError("The specified prior is not implemented.")

        # normalized by number of layers
        kl_loss /= layer_samples[0].size()[0]
        kl_weight = min(
            self.args.sparsity_weight,
            (update_num - self.args.soft_update)
            * self.args.sparsity_weight
            / self.args.anneal_updates,
        )
        kl_loss *= kl_weight * sample_size
        return kl_loss


class LatentLayersSparsityLoss(_Loss):
    def __init__(self, args):
        super().__init__()
        self.args = args

    def is_valid(self, update_num):
        if self.args.target_layers <= 0:
            return False
        return update_num > (self.args.soft_update + self.args.anneal_updates)

    def forward(self, layer_samples_list, update_num, sample_size):
        batch_loss = 0
        share_loss = 0
        global_sparsity_loss = 0
        layer_samples = torch.stack(layer_samples_list, dim=0)
        if (
            self.args.target_layers > 0 or self.args.share_weight > 0
        ) and update_num > (self.args.soft_update + self.args.anneal_updates):
            # anneal sparsity weight
            if update_num < (self.args.anneal_updates + self.args.soft_update):
                weight_anneal = 0
            elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
                weight_anneal = (
                    (update_num - self.args.soft_update - self.args.anneal_updates)
                    * self.args.share_weight
                    / self.args.anneal_updates
                )
            else:
                weight_anneal = 1
            # compute ratio among languages
            layer_utilization = torch.sum(layer_samples, dim=0)
            layer_utilization /= layer_samples.size()[0]
            if self.args.share_weight > 0:
                # encouraging sharing across languages
                share_loss = sum(
                    -1.0 * v * math.log(v) for v in layer_utilization if v > 0
                )
                batch_loss += (
                    weight_anneal * self.args.share_weight * sample_size * share_loss
                )
            if self.args.target_layers > 0:
                # computed expected number of layers selected
                expeted_layers = sum(layer_utilization)
                # compute l2 loss wrt target number of layers
                global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
                batch_loss += (
                    weight_anneal
                    * self.args.share_weight
                    * sample_size
                    * global_sparsity_loss
                )
        return batch_loss