Haaribo's picture
Add application file
9b896f5
raw
history blame contribute delete
357 Bytes
import torch
class SequentialWithArgs(torch.nn.Sequential):
def forward(self, input, *args, **kwargs):
vs = list(self._modules.values())
l = len(vs)
for i in range(l):
if i == l-1:
input = vs[i](input, *args, **kwargs)
else:
input = vs[i](input)
return input