File size: 1,732 Bytes
9123ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.multiprocessing as mp

class DEFAULTDataset(Dataset):
    def __init__(self, root_dir: str, internal_resolution: int):
        super().__init__()
        self.root_dir = root_dir
        self.internal_resolution = internal_resolution
        self.file_paths = self.get_data_files()

    def get_data_files(self):
        if not os.path.exists(self.root_dir):
            raise FileNotFoundError(f"Directory '{self.root_dir}' does not exist.")
        
        npy_file_names = os.listdir(self.root_dir)
        folder_names = [os.path.join(self.root_dir, npy_file_name) 
                        for npy_file_name in npy_file_names if npy_file_name.endswith('.npy')]
        return folder_names

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx: int):
        filename = self.file_paths[idx]
        try:
            numpy_file = np.load(filename)
            torch_np = torch.from_numpy(numpy_file)
            torch_np = torch_np.unsqueeze(0).unsqueeze(0).float()  # Convert to float and move to appropriate device
            interpolated_data = F.interpolate(input=torch_np, size=(self.internal_resolution, self.internal_resolution, self.internal_resolution), mode='trilinear')

            # Apply tanh and log operations
            interpolated_data_tanh = torch.tanh(interpolated_data)
            interpolated_data_log = torch.log(interpolated_data + 1).squeeze(0)  # Adding 1 to avoid log(0)

            return {'data': interpolated_data_log}
        except Exception as e:
            print(f"Error loading file '{filename}': {e}")
            return None