Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023-2024, Zexin He | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from abc import ABC, abstractmethod | |
import traceback | |
import json | |
import numpy as np | |
import torch | |
from PIL import Image | |
from typing import Optional, Union | |
from megfile import smart_open, smart_path_join, smart_exists | |
class BaseDataset(torch.utils.data.Dataset, ABC): | |
def __init__(self, root_dirs: str, meta_path: Optional[Union[list, str]]): | |
super().__init__() | |
self.root_dirs = root_dirs | |
self.uids = self._load_uids(meta_path) | |
def __len__(self): | |
return len(self.uids) | |
def inner_get_item(self, idx): | |
pass | |
def __getitem__(self, idx): | |
try: | |
return self.inner_get_item(idx) | |
except Exception as e: | |
traceback.print_exc() | |
print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}") | |
# raise e | |
return self.__getitem__((idx + 1) % self.__len__()) | |
def _load_uids(meta_path: Optional[Union[list, str]]): | |
# meta_path is a json file | |
if isinstance(meta_path, str): | |
with open(meta_path, 'r') as f: | |
uids = json.load(f) | |
else: | |
uids_lst = [] | |
max_total = 0 | |
for pth, weight in meta_path: | |
with open(pth, 'r') as f: | |
uids = json.load(f) | |
max_total = max(len(uids) / weight, max_total) | |
uids_lst.append([uids, weight, pth]) | |
merged_uids = [] | |
for uids, weight, pth in uids_lst: | |
repeat = 1 | |
if len(uids) < int(weight * max_total): | |
repeat = int(weight * max_total) // len(uids) | |
cur_uids = uids * repeat | |
merged_uids += cur_uids | |
print("Data Path:", pth, "Repeat:", repeat, "Final Length:", len(cur_uids)) | |
uids = merged_uids | |
print("Total UIDs:", len(uids)) | |
return uids | |
def _load_rgba_image(file_path, bg_color: float = 1.0): | |
''' Load and blend RGBA image to RGB with certain background, 0-1 scaled ''' | |
rgba = np.array(Image.open(smart_open(file_path, 'rb'))) | |
rgba = torch.from_numpy(rgba).float() / 255.0 | |
rgba = rgba.permute(2, 0, 1).unsqueeze(0) | |
rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :]) | |
rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...]) | |
return rgb | |
def _locate_datadir(root_dirs, uid, locator: str): | |
for root_dir in root_dirs: | |
datadir = smart_path_join(root_dir, uid, locator) | |
if smart_exists(datadir): | |
return root_dir | |
raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}") | |