Spaces:
Starting
on
T4
Starting
on
T4
import torch | |
from torchvision import transforms | |
from math import pi | |
import torchvision.transforms.functional as TF | |
# Define helper functions | |
def exists(val): | |
"""Check if a variable exists""" | |
return val is not None | |
def uniq(arr): | |
return {el: True for el in arr}.keys() | |
def default(val, d): | |
"""If a value exists, return it; otherwise, return a default value""" | |
return val if exists(val) else d | |
def max_neg_value(t): | |
return -torch.finfo(t.dtype).max | |
def cast_tuple(val, depth=1): | |
if isinstance(val, list): | |
val = tuple(val) | |
return val if isinstance(val, tuple) else (val,) * depth | |
def is_empty(t): | |
"""Check if a tensor is empty""" | |
# Return True if the number of elements in the tensor is zero, else False | |
return t.nelement() == 0 | |
def masked_mean(t, mask, dim=1): | |
""" | |
Compute the mean of a tensor, masked by a given mask | |
Args: | |
t (torch.Tensor): input tensor of shape (batch_size, seq_len, hidden_dim) | |
mask (torch.Tensor): mask tensor of shape (batch_size, seq_len) | |
dim (int): dimension along which to compute the mean (default=1) | |
Returns: | |
torch.Tensor: masked mean tensor of shape (batch_size, hidden_dim) | |
""" | |
t = t.masked_fill(~mask[:, :, None], 0.0) | |
return t.sum(dim=1) / mask.sum(dim=1)[..., None] | |
def set_requires_grad(model, value): | |
""" | |
Set whether or not the model's parameters require gradients | |
Args: | |
model (torch.nn.Module): the PyTorch model to modify | |
value (bool): whether or not to require gradients | |
""" | |
for param in model.parameters(): | |
param.requires_grad = value | |
def eval_decorator(fn): | |
""" | |
Decorator function to evaluate a given function | |
Args: | |
fn (callable): function to evaluate | |
Returns: | |
callable: the decorated function | |
""" | |
def inner(model, *args, **kwargs): | |
was_training = model.training | |
model.eval() | |
out = fn(model, *args, **kwargs) | |
model.train(was_training) | |
return out | |
return inner | |
def log(t, eps=1e-20): | |
""" | |
Compute the natural logarithm of a tensor | |
Args: | |
t (torch.Tensor): input tensor | |
eps (float): small value to add to prevent taking the log of 0 (default=1e-20) | |
Returns: | |
torch.Tensor: the natural logarithm of the input tensor | |
""" | |
return torch.log(t + eps) | |
def gumbel_noise(t): | |
""" | |
Generate Gumbel noise | |
Args: | |
t (torch.Tensor): input tensor | |
Returns: | |
torch.Tensor: a tensor of Gumbel noise with the same shape as the input tensor | |
""" | |
noise = torch.zeros_like(t).uniform_(0, 1) | |
return -log(-log(noise)) | |
def gumbel_sample(t, temperature=0.9, dim=-1): | |
""" | |
Sample from a Gumbel-softmax distribution | |
Args: | |
t (torch.Tensor): input tensor of shape (batch_size, num_classes) | |
temperature (float): temperature for the Gumbel-softmax distribution (default=0.9) | |
dim (int): dimension along which to sample (default=-1) | |
Returns: | |
torch.Tensor: a tensor of samples from the Gumbel-softmax distribution with the same shape as the input tensor | |
""" | |
return (t / max(temperature, 1e-10)) + gumbel_noise(t) | |
def top_k(logits, thres=0.5): | |
""" | |
Return a tensor where all but the top k values are set to negative infinity | |
Args: | |
logits (torch.Tensor): input tensor of shape (batch_size, num_classes) | |
thres (float): threshold for the top k values (default=0.5) | |
Returns: | |
torch.Tensor: a tensor with the same shape as the input tensor, where all but the top k values are set to negative infinity | |
""" | |
num_logits = logits.shape[-1] | |
k = max(int((1 - thres) * num_logits), 1) | |
val, ind = torch.topk(logits, k) | |
probs = torch.full_like(logits, float("-inf")) | |
probs.scatter_(-1, ind, val) | |
return probs | |
def gamma_func(mode="cosine", scale=0.15): | |
"""Return a function that takes a single input r and returns a value based on the selected mode""" | |
# Define a different function based on the selected mode | |
if mode == "linear": | |
return lambda r: 1 - r | |
elif mode == "cosine": | |
return lambda r: torch.cos(r * pi / 2) | |
elif mode == "square": | |
return lambda r: 1 - r**2 | |
elif mode == "cubic": | |
return lambda r: 1 - r**3 | |
elif mode == "scaled-cosine": | |
return lambda r: scale * (torch.cos(r * pi / 2)) | |
else: | |
# Raise an error if the selected mode is not implemented | |
raise NotImplementedError | |
class always: | |
"""Helper class to always return a given value""" | |
def __init__(self, val): | |
self.val = val | |
def __call__(self, x, *args, **kwargs): | |
return self.val | |
class DivideMax(torch.nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
maxes = x.amax(dim=self.dim, keepdim=True).detach() | |
return x / maxes | |
def replace_outliers(image, percentile=0.0001): | |
lower_bound, upper_bound = torch.quantile(image, percentile), torch.quantile( | |
image, 1 - percentile | |
) | |
mask = (image <= upper_bound) & (image >= lower_bound) | |
valid_pixels = image[mask] | |
image[~mask] = torch.clip(image[~mask], min(valid_pixels), max(valid_pixels)) | |
return image | |
def process_image(image, dataset, image_type=None): | |
image = TF.to_tensor(image).unsqueeze(0) | |
if dataset == "HPA": | |
if image_type == 'nucleus': | |
normalize = (0.0655, 0.0650) | |
elif image_type == 'protein': | |
normalize = (0.1732, 0.1208) | |
elif dataset == "OpenCell": | |
if image_type == 'nucleus': | |
normalize = (0.0272, 0.0244) | |
elif image_type == 'protein': | |
normalize = (0.0486, 0.0671) | |
t_forms = [] | |
t_forms.append(transforms.RandomCrop(256)) | |
# t_forms.append(transforms.Normalize(normalize[0],normalize[1])) | |
image = transforms.Compose(t_forms)(image) | |
return image | |