|
import math |
|
import os |
|
import json |
|
import re |
|
import cv2 |
|
from dataclasses import dataclass, field |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
from step1x3d_geometry import register |
|
from step1x3d_geometry.utils.typing import * |
|
from step1x3d_geometry.utils.config import parse_structured |
|
|
|
from streaming import StreamingDataLoader |
|
from .base import BaseDataModuleConfig, BaseDataset |
|
|
|
|
|
@dataclass |
|
class ObjaverseDataModuleConfig(BaseDataModuleConfig): |
|
pass |
|
|
|
|
|
class ObjaverseDataset(BaseDataset): |
|
pass |
|
|
|
|
|
@register("Objaverse-datamodule") |
|
class ObjaverseDataModule(pl.LightningDataModule): |
|
cfg: ObjaverseDataModuleConfig |
|
|
|
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: |
|
super().__init__() |
|
self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg) |
|
|
|
def setup(self, stage=None) -> None: |
|
if stage in [None, "fit"]: |
|
self.train_dataset = ObjaverseDataset(self.cfg, "train") |
|
if stage in [None, "fit", "validate"]: |
|
self.val_dataset = ObjaverseDataset(self.cfg, "val") |
|
if stage in [None, "test", "predict"]: |
|
self.test_dataset = ObjaverseDataset(self.cfg, "test") |
|
|
|
def prepare_data(self): |
|
pass |
|
|
|
def general_loader( |
|
self, dataset, batch_size, collate_fn=None, num_workers=0 |
|
) -> DataLoader: |
|
return DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
collate_fn=collate_fn, |
|
num_workers=num_workers, |
|
) |
|
|
|
def train_dataloader(self) -> DataLoader: |
|
return self.general_loader( |
|
self.train_dataset, |
|
batch_size=self.cfg.batch_size, |
|
collate_fn=self.train_dataset.collate, |
|
num_workers=self.cfg.num_workers, |
|
) |
|
|
|
def val_dataloader(self) -> DataLoader: |
|
return self.general_loader(self.val_dataset, batch_size=1) |
|
|
|
def test_dataloader(self) -> DataLoader: |
|
return self.general_loader(self.test_dataset, batch_size=1) |
|
|
|
def predict_dataloader(self) -> DataLoader: |
|
return self.general_loader(self.test_dataset, batch_size=1) |
|
|