Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,050 Bytes
600759a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
#!/usr/bin/env python3
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import pytorch_lightning as pl
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size=8,
num_workers=4,
train=None,
validation=None,
test=None,
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_configs = dict()
if train is not None:
self.dataset_configs["train"] = train
if validation is not None:
self.dataset_configs["validation"] = validation
if test is not None:
self.dataset_configs["test"] = test
def setup(self, stage):
from src.utils.train_util import instantiate_from_config
if stage in ["fit"]:
dataset_dict = {}
for k in self.dataset_configs:
dataset_dict[k] = []
for loader in self.dataset_configs[k]:
dataset_dict[k].append(instantiate_from_config(loader))
self.datasets = dataset_dict
print(self.datasets)
else:
raise NotImplementedError
def train_dataloader(self):
datasets = ConcatDataset(self.datasets["train"])
sampler = DistributedSampler(datasets)
return DataLoader(
datasets,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
sampler=sampler,
prefetch_factor=2,
pin_memory=True,
)
def val_dataloader(self):
datasets = ConcatDataset(self.datasets["validation"])
sampler = DistributedSampler(datasets)
return DataLoader(datasets, batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def test_dataloader(self):
datasets = ConcatDataset(self.datasets["test"])
return DataLoader(datasets, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|