File size: 1,661 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from torch import nn
from risk_biased.models.mlp import MLP


def pool(x, dim):
    x, _ = x.max(dim)
    return x


class ContextGating(nn.Module):
    """Inspired by Multi-Path++ https://arxiv.org/pdf/2111.14973v3.pdf (but not the same)

    Args:
        d_model: input dimension of the model
        d: hidden dimension of the model
        num_layers: number of layers of the MLP blocks
        is_mlp_residual: whether to use residual connections in the MLP blocks
    """

    def __init__(self, d_model, d, num_layers, is_mlp_residual):
        super().__init__()

        self.w_s = MLP(d_model, d, int((d_model + d) / 2), num_layers, is_mlp_residual)
        self.w_c_cross = MLP(
            d_model, d, int((d_model + d) / 2), num_layers, is_mlp_residual
        )
        self.w_c_global = MLP(d, d, d, num_layers, is_mlp_residual)

        self.output_layer = nn.Linear(d, d_model)

    def forward(self, s, c_cross, c_global):
        """context gating forward function

        Args:

        s: (batch, agents, features) tensor of agent encoded states
        c_cross: (batch, objects, features) tensor of objects encoded states
        c_global: (batch, d) tensor of global context

        Returns:

        s: (batch, agents, features) updated tensor of agent encoded states
        c_global: updated tensor of global context

        """
        s = self.w_s(s)
        c_cross = self.w_c_cross(c_cross)
        c_global = pool(c_cross, -2) * self.w_c_global(c_global)
        # b: batch, a: agents, k: features
        s = torch.einsum("bak,bk->bak", [s, c_global])
        s = self.output_layer(s)
        return s, c_global