Spaces:
Sleeping
Sleeping
File size: 954 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .task import Task
class MILNCETask(Task):
def reshape_subsample(self, sample):
if (
hasattr(self.config.dataset, "subsampling")
and self.config.dataset.subsampling is not None
and self.config.dataset.subsampling > 1
):
for key in sample:
if torch.is_tensor(sample[key]):
tensor = self.flat_subsample(sample[key])
if key in ["caps", "cmasks"]:
size = tensor.size()
batch_size = size[0] * size[1]
expanded_size = (batch_size,) + size[2:]
tensor = tensor.view(expanded_size)
sample[key] = tensor
return sample
|