|
import copy |
|
import logging |
|
import os |
|
import os.path as osp |
|
from os.path import join |
|
|
|
import torch |
|
from torch.utils.data import ConcatDataset, DataLoader |
|
|
|
from utils.optimizer import create_optimizer |
|
from utils.scheduler import create_scheduler |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_media_types(datasources): |
|
"""get the media types for for all the dataloaders. |
|
|
|
Args: |
|
datasources (List): List of dataloaders or datasets. |
|
|
|
Returns: List. The media_types. |
|
|
|
""" |
|
if isinstance(datasources[0], DataLoader): |
|
datasets = [dataloader.dataset for dataloader in datasources] |
|
else: |
|
datasets = datasources |
|
media_types = [ |
|
dataset.datasets[0].media_type |
|
if isinstance(dataset, ConcatDataset) |
|
else dataset.media_type |
|
for dataset in datasets |
|
] |
|
|
|
return media_types |
|
|