|
import torch |
|
import random |
|
import numpy as np |
|
|
|
def shuffle_tensor_along_axis(tensor, axis=0, seed=None): |
|
""" |
|
Shuffle a tensor along a specified axis without affecting the global random state. |
|
|
|
Args: |
|
tensor (torch.Tensor): The input tensor to shuffle |
|
axis (int, optional): The axis along which to shuffle. Defaults to 0. |
|
seed (int, optional): Random seed for reproducibility. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: The shuffled tensor |
|
""" |
|
|
|
shuffled_tensor = tensor.clone() |
|
|
|
|
|
torch_state = torch.get_rng_state() |
|
np_state = np.random.get_state() |
|
py_state = random.getstate() |
|
|
|
try: |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
|
|
dim_size = tensor.shape[axis] |
|
|
|
|
|
indices = torch.randperm(dim_size) |
|
|
|
|
|
slices = [slice(None)] * tensor.dim() |
|
slices[axis] = indices |
|
|
|
|
|
shuffled_tensor = tensor[slices] |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error during shuffling: {e}") |
|
|
|
finally: |
|
|
|
torch.set_rng_state(torch_state) |
|
np.random.set_state(np_state) |
|
random.setstate(py_state) |
|
|
|
return shuffled_tensor |