lwm-v1.1 / utils.py
wi-lab's picture
Upload the pre-trained model and pre-training, inference, downstream, and utility scripts
8920c6e verified
raw
history blame contribute delete
10 kB
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
import torch
import numpy as np
#%%
def create_dataloader(grouped_data, batch_size, shuffle):
dataloaders = {}
for seq_length, group in grouped_data.items():
print(f"dataloader in progress ...\nkey: {seq_length}")
## Uncomment the following line if you run out of memory during pre-training
# batch_size = batch_size // 8 if seq_length >= 5 else batch_size
# Unpack samples for the current group
input_ids, masked_tokens, masked_pos = zip(*group)
# Convert to tensors
input_ids_tensor = torch.tensor(input_ids, dtype=torch.float32)
masked_tokens_tensor = torch.tensor(masked_tokens, dtype=torch.float32)
masked_pos_tensor = torch.tensor(masked_pos, dtype=torch.long)
# Create TensorDataset and DataLoader
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
dataloaders[seq_length] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True)
return dataloaders
#%%
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
#%%
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap
def visualize_embeddings(embeddings, labels, method="pca", label=None):
"""
Visualize embeddings using PCA, UMAP, or t-SNE with color-coded labels.
Args:
embeddings (torch.Tensor or np.ndarray): Embeddings to visualize, shape (n_samples, n_features).
labels (torch.Tensor or np.ndarray): Class labels corresponding to embeddings, shape (n_samples,).
method (str): Dimensionality reduction method ('pca', 'umap', or 'tsne').
title (str): Title of the plot.
"""
# Convert to numpy if input is a torch.Tensor
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
if isinstance(labels, torch.Tensor):
labels = labels.cpu().numpy()
# Apply the selected dimensionality reduction method
if method.lower() == "pca":
reducer = PCA(n_components=2)
elif method.lower() == "umap":
reducer = umap.UMAP(n_components=2, n_neighbors=16, random_state=42)
elif method.lower() == "tsne":
reducer = TSNE(n_components=2, random_state=42, init="random")
else:
raise ValueError("Invalid method. Choose from 'pca', 'umap', or 'tsne'.")
reduced_embeddings = reducer.fit_transform(embeddings)
# Create a scatter plot with color-coding based on labels
plt.figure(figsize=(10, 8))
num_classes = len(np.unique(labels))
colors = plt.cm.get_cmap("tab10", num_classes)
for class_idx in range(num_classes):
class_points = reduced_embeddings[labels == class_idx]
plt.scatter(
class_points[:, 0], class_points[:, 1],
label=f"Class {class_idx}",
alpha=0.6
)
# Customize the plot
plt.title(f"{label} ({method.upper()})")
plt.xlabel("Component 1")
plt.ylabel("Component 2")
plt.legend()
plt.show()
#%%
def generate_gaussian_noise(data, snr_db):
"""
Generate Gaussian noise given an SNR and apply it to the data.
Args:
data (torch.Tensor): Input data tensor of shape (n_samples, seq_len, feature_dim).
snr_db (float): Signal-to-Noise Ratio in decibels (dB).
Returns:
torch.Tensor: Data with Gaussian noise applied.
"""
# Separate the input data to exclude the first channel
a = data[:, 1:, :] # Shape: (n_samples, seq_len-1, feature_dim)
flat_data = a.view(a.size(0), -1) # Flatten data to calculate power
signal_power = torch.mean(flat_data**2, dim=1, keepdim=True) # Shape: (n_samples, 1)
snr_linear = 10 ** (snr_db / 10)
noise_power = signal_power / snr_linear
noise = torch.randn_like(flat_data) * torch.sqrt(noise_power)
noise = noise.view_as(a)
noise = torch.cat((torch.zeros_like(data[:, :1, :]), noise), dim=1) # Add zero noise for the first channel
return noise
#%%
def plot_coverage(rxs, cov_map, dpi=200, figsize=(6,4), cbar_title=None, title=False,
scat_sz=.5, tx_pos=None, tx_ori=None, legend=False, lims=None,
proj_3D=False, equal_aspect=False, tight=True, cmap='tab20'):
plt_params = {'cmap': cmap}
if lims:
plt_params['vmin'], plt_params['vmax'] = lims[0], lims[1]
n = 3 if proj_3D else 2 # n coordinates to consider 2 = xy | 3 = xyz
xyz = {'x': rxs[:,0], 'y': rxs[:,1]}
if proj_3D:
xyz['zs'] = rxs[:,2]
fig, ax = plt.subplots(dpi=dpi, figsize=figsize,
subplot_kw={'projection': '3d'} if proj_3D else {})
im = plt.scatter(**xyz, c=cov_map, s=scat_sz, marker='s', **plt_params)
cbar = plt.colorbar(im, label='' if not cbar_title else cbar_title)
plt.xlabel('x (m)')
plt.ylabel('y (m)')
# TX position
if tx_pos is not None:
ax.scatter(*tx_pos[:n], marker='P', c='r', label='TX')
# TX orientation
if tx_ori is not None and tx_pos is not None: # ori = [azi, el]
# positive azimuths point left (like positive angles in a unit circle)
# positive elevations point up
r = 30 # ref size of pointing direction
tx_lookat = np.copy(tx_pos)
tx_lookat[:2] += r * np.array([np.cos(tx_ori[2]), np.sin(tx_ori[2])]) # azimuth
tx_lookat[2] += r * np.sin(tx_ori[1]) # elevation
line_components = [[tx_pos[i], tx_lookat[i]] for i in range(n)]
line = {key:val for key,val in zip(['xs', 'ys', 'zs'], line_components)}
if n == 2:
ax.plot(line_components[0], line_components[1], c='k', alpha=.5, zorder=3)
else:
ax.plot(**line, c='k', alpha=.5, zorder=3)
if title:
ax.set_title(title)
if legend:
plt.legend(loc='upper center', ncols=10, framealpha=.5)
if tight:
s = 1
mins, maxs = np.min(rxs, axis=0)-s, np.max(rxs, axis=0)+s
if not proj_3D:
plt.xlim([mins[0], maxs[0]])
plt.ylim([mins[1], maxs[1]])
else:
ax.axes.set_xlim3d([mins[0], maxs[0]])
ax.axes.set_ylim3d([mins[1], maxs[1]])
if tx_pos is None:
ax.axes.set_zlim3d([mins[2], maxs[2]])
else:
ax.axes.set_zlim3d([np.min([mins[2], tx_pos[2]]),
np.max([mins[2], tx_pos[2]])])
if equal_aspect and not proj_3D: # disrups the plot
plt.axis('scaled')
return fig, ax, cbar
#%%
def prepare_loaders(
preprocessed_data,
labels=None,
selected_patches_idxs=None,
input_type="raw",
task_type="classification",
feature_selection=False,
train_ratio=0.8,
batch_size=64,
seed=42 # Default seed for reproducibility
):
"""
Prepares datasets and data loaders for training and validation.
Args:
preprocessed_data (torch.Tensor): The input data, either raw or preprocessed.
labels (torch.Tensor, optional): The labels for classification tasks.
selected_patches_idxs (torch.Tensor, optional): Indices of selected patches for feature selection.
input_type (str): "raw" or "processed" to specify input data type.
task_type (str): "classification" or "regression".
feature_selection (bool): Whether to perform feature selection based on selected_patches_idxs.
train_ratio (float): Proportion of data to use for training (remaining for validation).
batch_size (int): Batch size for data loaders.
seed (int): Random seed for reproducibility.
Returns:
tuple: (train_loader, val_loader)
"""
# Set random seed for reproducibility
torch.manual_seed(seed)
# Prepare samples
if input_type == "raw":
if feature_selection and selected_patches_idxs is not None:
batch_indices = torch.arange(preprocessed_data.size(0)).unsqueeze(1) # Shape: [batch_size, 1]
samples = torch.tensor(preprocessed_data[batch_indices, selected_patches_idxs], dtype=torch.float32)
else:
samples = torch.tensor(preprocessed_data[:, 1:], dtype=torch.float32) # raw_chs
else:
samples = torch.tensor(preprocessed_data, dtype=torch.float32)
# Prepare dataset
if task_type == "classification":
if labels is None:
raise ValueError("Labels are required for classification tasks.")
labels = torch.tensor(labels, dtype=torch.long)
dataset = TensorDataset(samples, labels)
target = 0 # REVISE if needed
elif task_type == "regression":
target = samples[:, 1:, :].view(samples.size(0), -1) # Reshape for regression targets
dataset = TensorDataset(samples, target)
else:
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
# Set random seed for reproducibility
generator = torch.Generator().manual_seed(seed)
# Split dataset into training and validation
n_samples = len(dataset)
train_size = int(train_ratio * n_samples)
val_size = n_samples - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=generator)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")
return train_loader, val_loader, samples, target