File size: 1,343 Bytes
bd3a23c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
import torch
import random
import numpy as np
# Function to recursively find all .pt files in a directory
def find_pt_files(root_dir):
pt_files = []
for dirpath, _, filenames in os.walk(root_dir):
for file in filenames:
if file.endswith('.pt'):
pt_files.append(os.path.join(dirpath, file))
return pt_files
# Function to compute statistics for a given tensor list
def compute_statistics(tensor_list):
all_data = torch.cat(tensor_list)
mean = torch.mean(all_data).item()
std = torch.std(all_data).item()
max_val = torch.max(all_data).item()
min_val = torch.min(all_data).item()
return mean, std, max_val, min_val
# Root directory containing .pt files in subfolders
root_dir = "spk"
# Find all .pt files
pt_files = find_pt_files(root_dir)
# Randomly sample 1000 .pt files (or fewer if less than 1000 files are available)
sampled_files = random.sample(pt_files, min(1000, len(pt_files)))
# Load tensors from sampled files
tensor_list = []
for file in sampled_files:
tensor = torch.load(file)
tensor_list.append(tensor.view(-1)) # Flatten the tensor
# Compute statistics
mean, std, max_val, min_val = compute_statistics(tensor_list)
# Print the results
print(f"Mean: {mean}")
print(f"Std: {std}")
print(f"Max: {max_val}")
print(f"Min: {min_val}") |