|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
def is_dist_initialized(): |
|
return torch.distributed.is_initialized() |
|
|
|
def get_world_size(): |
|
if is_dist_initialized(): |
|
return torch.distributed.get_world_size() |
|
return 1 |
|
|
|
def all_gather_grad(x): |
|
if get_world_size() > 1: |
|
all_x = [torch.zeros_like(x) for _ in range(get_world_size())] |
|
torch.distributed.all_gather(all_x, x) |
|
all_x[torch.distributed.get_rank()] = x |
|
x = torch.cat(all_x, dim=0) |
|
return x |
|
|
|
@torch.no_grad() |
|
def all_gather_nograd(tensor): |
|
|
|
""" |
|
Performs all_gather operation on the provided tensors. |
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
|
""" |
|
if get_world_size() > 1: |
|
tensors_gather = [torch.ones_like(tensor) |
|
for _ in range(torch.distributed.get_world_size())] |
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
|
|
|
tensor = torch.cat(tensors_gather, dim=0) |
|
return tensor |
|
|
|
def image_text_contrastive_loss(image_feat, text_feat, temperature, image_id=None, text_id=None): |
|
|
|
image_feat = all_gather_grad(image_feat) |
|
text_feat = all_gather_grad(text_feat) |
|
|
|
logits = torch.matmul(image_feat, text_feat.t()) |
|
logits /= temperature |
|
|
|
if image_id is None and text_id is None: |
|
gt = torch.arange(logits.shape[0], device=logits.device) |
|
loss1 = F.cross_entropy(logits, gt) |
|
loss2 = F.cross_entropy(logits.t(), gt) |
|
else: |
|
image_id = all_gather_grad(image_id) |
|
text_id = all_gather_grad(text_id) |
|
|
|
gt_image = image_id.reshape((-1, 1)) == image_id.reshape((1, -1)) |
|
gt_text = text_id.reshape((-1, 1)) == text_id.reshape((1, -1)) |
|
gt = torch.logical_or(gt_image, gt_text) |
|
|
|
loss1 = -torch.sum(gt * F.log_softmax(logits, dim=1)) / gt.sum() |
|
loss2 = -torch.sum(gt.t() * F.log_softmax(logits.t(), dim=1)) / gt.sum() |
|
|
|
return (loss1 + loss2) / 2 * get_world_size() |
|
|