Spaces:
Running
Running
File size: 5,033 Bytes
20239f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import math
from functools import reduce
import torch
import numpy as np
import os
from pathlib import Path
def factors(n):
return reduce(list.__add__,
([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0))
def file_line_count(filename: str) -> int:
"""Count the number of lines in a file"""
with open(filename, 'rb') as f:
return sum(1 for _ in f)
def compute_attention(qkv, scale=None):
"""
Compute attention matrix (same as in the pytorch scaled dot product attention)
Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
:param qkv: Query, key and value tensors concatenated along the first dimension
:param scale: Scale factor for the attention computation
:return:
"""
if isinstance(qkv, torch.Tensor):
query, key, value = qkv.unbind(0)
else:
query, key, value = qkv
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
L, S = query.size(-2), key.size(-2)
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_out = attn_weight @ value
return attn_weight, attn_out
def compute_dot_product_similarity(a, b):
scores = a @ b.transpose(-1, -2)
return scores
def compute_cross_entropy(p, q):
q = torch.nn.functional.log_softmax(q, dim=-1)
loss = torch.sum(p * q, dim=-1)
return - loss.mean()
def rollout(attentions, discard_ratio=0.9, head_fusion="max", device=torch.device("cuda")):
"""
Perform attention rollout,
Ref: https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L9C1-L42C16
Parameters
----------
attentions : list
List of attention matrices, one for each transformer layer
discard_ratio : float
Ratio of lowest attention values to discard
head_fusion : str
Type of fusion to use for attention heads. One of "mean", "max", "min"
device : torch.device
Device to use for computation
Returns
-------
mask : np.ndarray
Mask of shape (width, width), where width is the square root of the number of patches
"""
result = torch.eye(attentions[0].size(-1), device=device)
attentions = [attention.to(device) for attention in attentions]
with torch.no_grad():
for attention in attentions:
if head_fusion == "mean":
attention_heads_fused = attention.mean(axis=1)
elif head_fusion == "max":
attention_heads_fused = attention.max(axis=1).values
elif head_fusion == "min":
attention_heads_fused = attention.min(axis=1).values
else:
raise "Attention head fusion type Not supported"
# Drop the lowest attentions, but
# don't drop the class token
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
indices = indices[indices != 0]
flat[0, indices] = 0
I = torch.eye(attention_heads_fused.size(-1), device=device)
a = (attention_heads_fused + 1.0 * I) / 2
a = a / a.sum(dim=-1)
result = torch.matmul(a, result)
# Normalize the result by max value in each row
result = result / result.max(dim=-1, keepdim=True)[0]
return result
def sync_bn_conversion(model: torch.nn.Module):
"""
Convert BatchNorm to SyncBatchNorm (used for DDP)
:param model: PyTorch model
:return:
model: PyTorch model with SyncBatchNorm layers
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return model
def check_snapshot(args):
"""
Create directory to save training checkpoints, otherwise load the existing checkpoint.
Additionally, if it is an array training job, create a new directory for each training job.
:param args: Arguments from the argument parser
:return:
"""
# Check if it is an array training job (i.e. training with multiple random seeds on the same settings)
if args.array_training_job and not args.resume_training:
args.snapshot_dir = os.path.join(args.snapshot_dir, str(args.seed))
if not os.path.exists(args.snapshot_dir):
save_dir = Path(args.snapshot_dir)
save_dir.mkdir(parents=True, exist_ok=True)
else:
# Create directory to save training checkpoints, otherwise load the existing checkpoint
if not os.path.exists(args.snapshot_dir):
if ".pt" not in args.snapshot_dir or ".pth" not in args.snapshot_dir:
save_dir = Path(args.snapshot_dir)
save_dir.mkdir(parents=True, exist_ok=True)
else:
raise ValueError('Snapshot checkpoint does not exist.')
|