|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
import traceback |
|
import json |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from typing import Optional, Union |
|
from megfile import smart_open, smart_path_join, smart_exists |
|
|
|
|
|
class BaseDataset(torch.utils.data.Dataset, ABC): |
|
def __init__(self, root_dirs: str, meta_path: Optional[Union[list, str]]): |
|
super().__init__() |
|
self.root_dirs = root_dirs |
|
self.uids = self._load_uids(meta_path) |
|
|
|
def __len__(self): |
|
return len(self.uids) |
|
|
|
@abstractmethod |
|
def inner_get_item(self, idx): |
|
pass |
|
|
|
def __getitem__(self, idx): |
|
try: |
|
return self.inner_get_item(idx) |
|
except Exception as e: |
|
traceback.print_exc() |
|
print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}") |
|
|
|
return self.__getitem__((idx + 1) % self.__len__()) |
|
|
|
@staticmethod |
|
def _load_uids(meta_path: Optional[Union[list, str]]): |
|
|
|
if isinstance(meta_path, str): |
|
with open(meta_path, 'r') as f: |
|
uids = json.load(f) |
|
else: |
|
uids_lst = [] |
|
max_total = 0 |
|
for pth, weight in meta_path: |
|
with open(pth, 'r') as f: |
|
uids = json.load(f) |
|
max_total = max(len(uids) / weight, max_total) |
|
uids_lst.append([uids, weight, pth]) |
|
merged_uids = [] |
|
for uids, weight, pth in uids_lst: |
|
repeat = 1 |
|
if len(uids) < int(weight * max_total): |
|
repeat = int(weight * max_total) // len(uids) |
|
cur_uids = uids * repeat |
|
merged_uids += cur_uids |
|
print("Data Path:", pth, "Repeat:", repeat, "Final Length:", len(cur_uids)) |
|
uids = merged_uids |
|
print("Total UIDs:", len(uids)) |
|
return uids |
|
|
|
@staticmethod |
|
def _load_rgba_image(file_path, bg_color: float = 1.0): |
|
''' Load and blend RGBA image to RGB with certain background, 0-1 scaled ''' |
|
rgba = np.array(Image.open(smart_open(file_path, 'rb'))) |
|
rgba = torch.from_numpy(rgba).float() / 255.0 |
|
rgba = rgba.permute(2, 0, 1).unsqueeze(0) |
|
rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :]) |
|
rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...]) |
|
return rgb |
|
|
|
@staticmethod |
|
def _locate_datadir(root_dirs, uid, locator: str): |
|
for root_dir in root_dirs: |
|
datadir = smart_path_join(root_dir, uid, locator) |
|
if smart_exists(datadir): |
|
return root_dir |
|
raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}") |
|
|