Spaces:
Running
Running
from einops.layers.torch import Rearrange | |
from einops import rearrange, repeat | |
import torch | |
import torch.nn as nn | |
from risk_biased.models.multi_head_attention import MultiHeadAttention | |
from risk_biased.models.context_gating import ContextGating | |
from risk_biased.models.mlp import MLP | |
class SequenceEncoderMaskedLSTM(nn.Module): | |
"""MLP followed with a masked LSTM implementation with one layer. | |
Args: | |
input_dim : dimension of the input variable | |
h_dim : dimension of a hidden layer of MLP | |
""" | |
def __init__(self, input_dim: int, h_dim: int) -> None: | |
super().__init__() | |
self._group_objects = Rearrange("b o ... -> (b o) ...") | |
self._embed = nn.Linear(in_features=input_dim, out_features=h_dim) | |
self._lstm = nn.LSTMCell( | |
input_size=h_dim, hidden_size=h_dim | |
) # expects(batch,seq,features) | |
self.h0 = nn.parameter.Parameter(torch.zeros(1, h_dim)) | |
self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim)) | |
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor: | |
"""Forward function for MapEncoder | |
Args: | |
input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor | |
mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing) | |
Returns: | |
torch.Tensor: (batch_size, num_objects, output_dim) tensor | |
""" | |
batch_size, num_objects, seq_len, _ = input.shape | |
split_objects = Rearrange("(b o) f -> b o f", b=batch_size, o=num_objects) | |
input = self._group_objects(input) | |
mask_input = self._group_objects(mask_input) | |
embedded_input = self._embed(input) | |
# One to many encoding of the input sequence with masking for missing points | |
mask_input = mask_input.float() | |
h = mask_input[:, 0, None] * embedded_input[:, 0, :] + ( | |
1 - mask_input[:, 0, None] | |
) * repeat(self.h0, "b f -> (size b) f", size=batch_size * num_objects) | |
c = repeat(self.c0, "b f -> (size b) f", size=batch_size * num_objects) | |
for i in range(seq_len): | |
new_input = ( | |
mask_input[:, i, None] * embedded_input[:, i, :] | |
+ (1 - mask_input[:, i, None]) * h | |
) | |
h, c = self._lstm(new_input, (h, c)) | |
return split_objects(h) | |
class SequenceEncoderLSTM(nn.Module): | |
"""MLP followed with an LSTM with one layer. | |
Args: | |
input_dim : dimension of the input variable | |
h_dim : dimension of a hidden layer of MLP | |
""" | |
def __init__(self, input_dim: int, h_dim: int) -> None: | |
super().__init__() | |
self._group_objects = Rearrange("b o ... -> (b o) ...") | |
self._embed = nn.Linear(in_features=input_dim, out_features=h_dim) | |
self._lstm = nn.LSTM( | |
input_size=h_dim, | |
hidden_size=h_dim, | |
batch_first=True, | |
) # expects(batch,seq,features) | |
self.h0 = nn.parameter.Parameter(torch.zeros(1, h_dim)) | |
self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim)) | |
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor: | |
"""Forward function for MapEncoder | |
Args: | |
input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor | |
mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing) | |
Returns: | |
torch.Tensor: (batch_size, num_objects, output_dim) tensor | |
""" | |
batch_size, num_objects, seq_len, _ = input.shape | |
split_objects = Rearrange("(b o) f -> b o f", b=batch_size, o=num_objects) | |
input = self._group_objects(input) | |
mask_input = self._group_objects(mask_input) | |
embedded_input = self._embed(input) | |
# One to many encoding of the input sequence with masking for missing points | |
mask_input = mask_input.float() | |
h = ( | |
mask_input[:, 0, None] * embedded_input[:, 0, :] | |
+ (1 - mask_input[:, 0, None]) | |
* repeat( | |
self.h0, "one f -> one size f", size=batch_size * num_objects | |
).contiguous() | |
) | |
c = repeat( | |
self.c0, "one f -> one size f", size=batch_size * num_objects | |
).contiguous() | |
_, (h, _) = self._lstm(embedded_input, (h, c)) | |
# for i in range(seq_len): | |
# new_input = ( | |
# mask_input[:, i, None] * embedded_input[:, i, :] | |
# + (1 - mask_input[:, i, None]) * h | |
# ) | |
# h, c = self._lstm(new_input, (h, c)) | |
return split_objects(h.squeeze(0)) | |
class SequenceEncoderMLP(nn.Module): | |
"""MLP implementation. | |
Args: | |
input_dim : dimension of the input variable | |
h_dim : dimension of a hidden layer of MLP | |
num_layers: number of layers to use in the MLP | |
sequence_length: dimension of the input sequence | |
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP | |
""" | |
def __init__( | |
self, | |
input_dim: int, | |
h_dim: int, | |
num_layers: int, | |
sequence_length: int, | |
is_mlp_residual: bool, | |
) -> None: | |
super().__init__() | |
self._mlp = MLP( | |
input_dim * sequence_length, h_dim, h_dim, num_layers, is_mlp_residual | |
) | |
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor: | |
"""Forward function for MapEncoder | |
Args: | |
input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor | |
mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing) | |
Returns: | |
torch.Tensor: (batch_size, num_objects, output_dim) tensor | |
""" | |
batch_size, num_objects, _, _ = input.shape | |
input = input * mask_input.unsqueeze(-1) | |
h = rearrange(input, "b o s f -> (b o) (s f)") | |
mask_input = rearrange(mask_input, "b o s -> (b o) s") | |
if h.shape[-1] == 0: | |
h = h.view(batch_size, 0, h.shape[0]) | |
else: | |
h = self._mlp(h) | |
h = rearrange(h, "(b o) f -> b o f", b=batch_size, o=num_objects) | |
return h | |
class SequenceDecoderLSTM(nn.Module): | |
"""A one to many LSTM implementation with one layer. | |
Args: | |
h_dim : dimension of a hidden layer | |
""" | |
def __init__(self, h_dim: int) -> None: | |
super().__init__() | |
self._group_objects = Rearrange("b o f -> (b o) f") | |
self._lstm = nn.LSTM(input_size=h_dim, hidden_size=h_dim) | |
self._out_layer = nn.Linear(in_features=h_dim, out_features=h_dim) | |
self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim)) | |
def forward(self, input: torch.Tensor, sequence_length: int) -> torch.Tensor: | |
"""Forward function for MapEncoder | |
Args: | |
input (torch.Tensor): (batch_size, num_objects, input_dim) tensor | |
sequence_length: output sequence length to create | |
Returns: | |
torch.Tensor: (batch_size, num_objects, output_dim) tensor | |
""" | |
batch_size, num_objects, _ = input.shape | |
h = repeat(input, "b o f -> one (b o) f", one=1).contiguous() | |
c = repeat( | |
self.c0, "one f -> one size f", size=batch_size * num_objects | |
).contiguous() | |
seq_h = repeat(h, "one b f -> (one t) b f", t=sequence_length).contiguous() | |
h, (_, _) = self._lstm(seq_h, (h, c)) | |
h = rearrange(h, "t (b o) f -> b o t f", b=batch_size, o=num_objects) | |
return self._out_layer(h) | |
class SequenceDecoderMLP(nn.Module): | |
"""A one to many MLP implementation. | |
Args: | |
h_dim : dimension of a hidden layer | |
num_layers: number of layers to use in the MLP | |
sequence_length: output sequence length to return | |
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP | |
""" | |
def __init__( | |
self, h_dim: int, num_layers: int, sequence_length: int, is_mlp_residual: bool | |
) -> None: | |
super().__init__() | |
self._mlp = MLP( | |
h_dim, h_dim * sequence_length, h_dim, num_layers, is_mlp_residual | |
) | |
def forward(self, input: torch.Tensor, sequence_length: int) -> torch.Tensor: | |
"""Forward function for MapEncoder | |
Args: | |
input (torch.Tensor): (batch_size, num_objects, input_dim) tensor | |
sequence_length: output sequence length to create | |
Returns: | |
torch.Tensor: (batch_size, num_objects, output_dim) tensor | |
""" | |
batch_size, num_objects, _ = input.shape | |
h = rearrange(input, "b o f -> (b o) f") | |
h = self._mlp(h) | |
h = rearrange( | |
h, "(b o) (s f) -> b o s f", b=batch_size, o=num_objects, s=sequence_length | |
) | |
return h | |
class AttentionBlock(nn.Module): | |
"""Block performing agent-map cross attention->ReLU(linear)->+residual->layer_norm->agent-agent attention->ReLU(linear)->+residual->layer_norm | |
Args: | |
hidden_dim: feature dimension | |
num_attention_heads: number of attention heads to use | |
""" | |
def __init__(self, hidden_dim: int, num_attention_heads: int): | |
super().__init__() | |
self._num_attention_heads = num_attention_heads | |
self._agent_map_attention = MultiHeadAttention( | |
hidden_dim, num_attention_heads, hidden_dim, hidden_dim | |
) | |
self._lin1 = nn.Linear(hidden_dim, hidden_dim) | |
self._layer_norm1 = nn.LayerNorm(hidden_dim) | |
self._agent_agent_attention = MultiHeadAttention( | |
hidden_dim, num_attention_heads, hidden_dim, hidden_dim | |
) | |
self._lin2 = nn.Linear(hidden_dim, hidden_dim) | |
self._layer_norm2 = nn.LayerNorm(hidden_dim) | |
self._activation = nn.ReLU() | |
def forward( | |
self, | |
encoded_agents: torch.Tensor, | |
mask_agents: torch.Tensor, | |
encoded_absolute_agents: torch.Tensor, | |
encoded_map: torch.Tensor, | |
mask_map: torch.Tensor, | |
) -> torch.Tensor: | |
"""Forward function of the block, returning only the output (no attention matrix) | |
Args: | |
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks | |
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding | |
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions | |
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features | |
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding | |
""" | |
# Check if map_info is available. If not, don't compute cross-attention with it | |
if mask_map.any(): | |
mask_agent_map = torch.einsum("ba,bo->bao", mask_agents, mask_map) | |
h, _ = self._agent_map_attention( | |
encoded_agents + encoded_absolute_agents, | |
encoded_map, | |
encoded_map, | |
mask=mask_agent_map, | |
) | |
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0) | |
h = torch.sigmoid(self._lin1(h)) | |
h = self._layer_norm1(encoded_agents + h) | |
else: | |
h = self._layer_norm1(encoded_agents) | |
h_res = h.clone() | |
agent_agent_mask = torch.einsum("ba,be->bae", mask_agents, mask_agents) | |
h = h + encoded_absolute_agents | |
h, _ = self._agent_agent_attention(h, h, h, mask=agent_agent_mask) | |
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0) | |
h = self._activation(self._lin2(h)) | |
h = self._layer_norm2(h_res + h) | |
return h | |
class CG_block(nn.Module): | |
"""Block performing context gating agent-map | |
Args: | |
hidden_dim: feature dimension | |
dim_expansion: multiplicative factor on the hidden dimension for the global context representation | |
num_layers: number of layers to use in the MLP for context encoding | |
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP | |
""" | |
def __init__( | |
self, | |
hidden_dim: int, | |
dim_expansion: int, | |
num_layers: int, | |
is_mlp_residual: bool, | |
): | |
super().__init__() | |
self._agent_map = ContextGating( | |
hidden_dim, | |
hidden_dim * dim_expansion, | |
num_layers=num_layers, | |
is_mlp_residual=is_mlp_residual, | |
) | |
self._lin1 = nn.Linear(hidden_dim, hidden_dim) | |
self._layer_norm1 = nn.LayerNorm(hidden_dim) | |
self._agent_agent = ContextGating( | |
hidden_dim, hidden_dim * dim_expansion, num_layers, is_mlp_residual | |
) | |
self._lin2 = nn.Linear(hidden_dim, hidden_dim) | |
self._activation = nn.ReLU() | |
def forward( | |
self, | |
encoded_agents: torch.Tensor, | |
mask_agents: torch.Tensor, | |
encoded_absolute_agents: torch.Tensor, | |
encoded_map: torch.Tensor, | |
mask_map: torch.Tensor, | |
global_context: torch.Tensor, | |
) -> torch.Tensor: | |
"""Forward function of the block, returning the output and global context | |
Args: | |
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks | |
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding | |
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions | |
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features | |
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding | |
global_context: (batch_size, dim_context) tensor representing the global context | |
""" | |
# Check if map_info is available. If not, don't compute cross-interaction with it | |
if mask_map.any(): | |
s, global_context = self._agent_map( | |
encoded_agents + encoded_absolute_agents, encoded_map, global_context | |
) | |
s = s * mask_agents.unsqueeze(-1) | |
s = self._activation(self._lin1(s)) | |
s = self._layer_norm1(encoded_agents + s) | |
else: | |
s = self._layer_norm1(encoded_agents) | |
s = s + encoded_absolute_agents | |
s, global_context = self._agent_agent(s, s, global_context) | |
s = s * mask_agents.unsqueeze(-1) | |
s = self._lin2(s) | |
return s, global_context | |
class HybridBlock(nn.Module): | |
"""Block performing agent-map cross context_gating->ReLU(linear)->+residual->layer_norm->agent-agent attention->ReLU(linear)->+residual->layer_norm | |
Args: | |
hidden_dim: feature dimension | |
num_attention_heads: number of attention heads to use | |
dim_expansion: multiplicative factor on the hidden dimension for the global context representation | |
num_layers: number of layers to use in the MLP for context encoding | |
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP | |
""" | |
def __init__( | |
self, | |
hidden_dim: int, | |
num_attention_heads: int, | |
dim_expansion: int, | |
num_layers: int, | |
is_mlp_residual: bool, | |
): | |
super().__init__() | |
self._num_attention_heads = num_attention_heads | |
self._agent_map_cg = ContextGating( | |
hidden_dim, | |
hidden_dim * dim_expansion, | |
num_layers=num_layers, | |
is_mlp_residual=is_mlp_residual, | |
) | |
self._lin1 = nn.Linear(hidden_dim, hidden_dim) | |
self._layer_norm1 = nn.LayerNorm(hidden_dim) | |
self._agent_agent_attention = MultiHeadAttention( | |
hidden_dim, num_attention_heads, hidden_dim, hidden_dim | |
) | |
self._lin2 = nn.Linear(hidden_dim, hidden_dim) | |
self._layer_norm2 = nn.LayerNorm(hidden_dim) | |
self._activation = nn.ReLU() | |
def forward( | |
self, | |
encoded_agents: torch.Tensor, | |
mask_agents: torch.Tensor, | |
encoded_absolute_agents: torch.Tensor, | |
encoded_map: torch.Tensor, | |
mask_map: torch.Tensor, | |
global_context: torch.Tensor, | |
) -> torch.Tensor: | |
"""Forward function of the block, returning the output and the context (no attention matrix) | |
Args: | |
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks | |
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding | |
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions | |
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features | |
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding | |
global_context: (batch_size, dim_context) tensor representing the global context | |
""" | |
# Check if map_info is available. If not, don't compute cross-context gating with it | |
if mask_map.any(): | |
# mask_agent_map = torch.logical_not( | |
# torch.einsum("ba,bo->bao", mask_agents, mask_map) | |
# ) | |
h, global_context = self._agent_map_cg( | |
encoded_agents + encoded_absolute_agents, encoded_map, global_context | |
) | |
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0) | |
h = self._activation(self._lin1(h)) | |
h = self._layer_norm1(encoded_agents + h) | |
else: | |
h = self._layer_norm1(encoded_agents) | |
h_res = h.clone() | |
agent_agent_mask = torch.einsum("ba,be->bae", mask_agents, mask_agents) | |
h = h + encoded_absolute_agents | |
h, _ = self._agent_agent_attention(h, h, h, mask=agent_agent_mask) | |
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0) | |
h = self._activation(self._lin2(h)) | |
h = self._layer_norm2(h_res + h) | |
return h, global_context | |
class MCG(nn.Module): | |
"""Multiple context encoding blocks | |
Args: | |
hidden_dim: feature dimension | |
dim_expansion: multiplicative factor on the hidden dimension for the global context representation | |
num_layers: number of layers to use in the MLP for context encoding | |
num_blocks: number of successive context encoding blocks to use in the module | |
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP | |
""" | |
def __init__( | |
self, | |
hidden_dim: int, | |
dim_expansion: int, | |
num_layers: int, | |
num_blocks: int, | |
is_mlp_residual: bool, | |
): | |
super().__init__() | |
self.initial_global_context = nn.parameter.Parameter( | |
torch.ones(1, hidden_dim * dim_expansion) | |
) | |
list_cg = [] | |
for i in range(num_blocks): | |
list_cg.append( | |
CG_block(hidden_dim, dim_expansion, num_layers, is_mlp_residual) | |
) | |
self.mcg = nn.ModuleList(list_cg) | |
def forward( | |
self, | |
encoded_agents: torch.Tensor, | |
mask_agents: torch.Tensor, | |
encoded_absolute_agents: torch.Tensor, | |
encoded_map: torch.Tensor, | |
mask_map: torch.Tensor, | |
) -> torch.Tensor: | |
"""Forward function of the block, returning only the output (no context) | |
Args: | |
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks | |
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding | |
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions | |
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features | |
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding | |
""" | |
s = encoded_agents | |
c = self.initial_global_context | |
sum_s = s | |
sum_c = c | |
for i, cg in enumerate(self.mcg): | |
s_new, c_new = cg( | |
s, mask_agents, encoded_absolute_agents, encoded_map, mask_map, c | |
) | |
sum_s = sum_s + s_new | |
sum_c = sum_c + c_new | |
s = (sum_s / (i + 2)).clone() | |
c = (sum_c / (i + 2)).clone() | |
return s | |
class MAB(nn.Module): | |
"""Multiple Attention Blocks | |
Args: | |
hidden_dim: feature dimension | |
num_attention_heads: number of attention heads to use | |
num_blocks: number of successive blocks to use in the module. | |
""" | |
def __init__( | |
self, | |
hidden_dim: int, | |
num_attention_heads: int, | |
num_blocks: int, | |
): | |
super().__init__() | |
list_attention = [] | |
for i in range(num_blocks): | |
list_attention.append(AttentionBlock(hidden_dim, num_attention_heads)) | |
self.attention_blocks = nn.ModuleList(list_attention) | |
def forward( | |
self, | |
encoded_agents: torch.Tensor, | |
mask_agents: torch.Tensor, | |
encoded_absolute_agents: torch.Tensor, | |
encoded_map: torch.Tensor, | |
mask_map: torch.Tensor, | |
) -> torch.Tensor: | |
"""Forward function of the block, returning only the output (no attention matrix) | |
Args: | |
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks | |
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding | |
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions | |
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features | |
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding | |
""" | |
h = encoded_agents | |
sum_h = h | |
for i, attention in enumerate(self.attention_blocks): | |
h_new = attention( | |
h, mask_agents, encoded_absolute_agents, encoded_map, mask_map | |
) | |
sum_h = sum_h + h_new | |
h = (sum_h / (i + 2)).clone() | |
return h | |
class MHB(nn.Module): | |
"""Multiple Hybrid Blocks | |
Args: | |
hidden_dim: feature dimension | |
num_attention_heads: number of attention heads to use | |
dim_expansion: multiplicative factor on the hidden dimension for the global context representation | |
num_layers: number of layers to use in the MLP for context encoding | |
num_blocks: number of successive blocks to use in the module. | |
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP | |
""" | |
def __init__( | |
self, | |
hidden_dim: int, | |
num_attention_heads: int, | |
dim_expansion: int, | |
num_layers: int, | |
num_blocks: int, | |
is_mlp_residual: bool, | |
): | |
super().__init__() | |
self.initial_global_context = nn.parameter.Parameter( | |
torch.ones(1, hidden_dim * dim_expansion) | |
) | |
list_hb = [] | |
for i in range(num_blocks): | |
list_hb.append( | |
HybridBlock( | |
hidden_dim, | |
num_attention_heads, | |
dim_expansion, | |
num_layers, | |
is_mlp_residual, | |
) | |
) | |
self.hybrid_blocks = nn.ModuleList(list_hb) | |
def forward( | |
self, | |
encoded_agents: torch.Tensor, | |
mask_agents: torch.Tensor, | |
encoded_absolute_agents: torch.Tensor, | |
encoded_map: torch.Tensor, | |
mask_map: torch.Tensor, | |
) -> torch.Tensor: | |
"""Forward function of the block, returning only the output (no attention matrix nor context) | |
Args: | |
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks | |
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding | |
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions | |
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features | |
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding | |
""" | |
sum_h = encoded_agents | |
sum_c = self.initial_global_context | |
h = encoded_agents | |
c = self.initial_global_context | |
for i, hb in enumerate(self.hybrid_blocks): | |
h_new, c_new = hb( | |
h, mask_agents, encoded_absolute_agents, encoded_map, mask_map, c | |
) | |
sum_h = sum_h + h_new | |
sum_c = sum_c + c_new | |
h = (sum_h / (i + 2)).clone() | |
c = (sum_c / (i + 2)).clone() | |
return h | |