Spaces:
Runtime error
Runtime error
File size: 1,040 Bytes
128757a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
from torch import nn
class MixedOperationRandom(nn.Module):
def __init__(self, search_ops):
super(MixedOperationRandom, self).__init__()
self.ops = nn.ModuleList(search_ops)
self.num_ops = len(search_ops)
def forward(self, x, x_path=None):
if x_path is None:
output = sum(op(x) for op in self.ops) / self.num_ops
else:
assert isinstance(x_path, (int, float)) and 0 <= x_path < self.num_ops or isinstance(x_path, torch.Tensor)
if isinstance(x_path, (int, float)):
x_path = int(x_path)
assert 0 <= x_path < self.num_ops
output = self.ops[x_path](x)
elif isinstance(x_path, torch.Tensor):
assert x_path.size(0) == x.size(0), 'batch_size should match length of y_idx'
output = torch.cat([self.ops[int(x_path[i].item())](x.narrow(0, i, 1))
for i in range(x.size(0))], dim=0)
return output |