Spaces:
Running
Running
Commit
·
fdc4786
1
Parent(s):
4cda815
Upload 6 files
Browse files- loss/contrastive_loss.py +88 -0
- loss/focal_loss.py +28 -0
- loss/label_smoothing.py +21 -0
- loss/rl_loss.py +122 -0
- loss/similarity_loss.py +70 -0
- loss/triplet_loss.py +103 -0
loss/contrastive_loss.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2022/03/23 14:50
|
| 3 |
+
# @Author : Jianing Wang
|
| 4 |
+
# @Email : [email protected]
|
| 5 |
+
# @File : ContrastiveLoss.py
|
| 6 |
+
# !/usr/bin/env python
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
|
| 9 |
+
from enum import Enum
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import nn, Tensor
|
| 13 |
+
from transformers.models.bert.modeling_bert import BertModel
|
| 14 |
+
from transformers import BertTokenizer, BertConfig
|
| 15 |
+
|
| 16 |
+
class SiameseDistanceMetric(Enum):
|
| 17 |
+
"""
|
| 18 |
+
The metric for the contrastive loss
|
| 19 |
+
"""
|
| 20 |
+
EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
|
| 21 |
+
MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
|
| 22 |
+
COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ContrastiveLoss(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
|
| 28 |
+
two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
|
| 29 |
+
|
| 30 |
+
@:param distance_metric: The distance metric function
|
| 31 |
+
@:param margin: (float) The margin distance
|
| 32 |
+
@:param size_average: (bool) Whether to get averaged loss
|
| 33 |
+
|
| 34 |
+
Input example of forward function:
|
| 35 |
+
rep_anchor: [[0.2, -0.1, ..., 0.6], [0.2, -0.1, ..., 0.6], ..., [0.2, -0.1, ..., 0.6]]
|
| 36 |
+
rep_candidate: [[0.3, 0.1, ...m -0.3], [-0.8, 1.2, ..., 0.7], ..., [-0.9, 0.1, ..., 0.4]]
|
| 37 |
+
label: [0, 1, ..., 1]
|
| 38 |
+
|
| 39 |
+
Return example of forward function:
|
| 40 |
+
0.015 (averged)
|
| 41 |
+
2.672 (sum)
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = False):
|
| 45 |
+
super(ContrastiveLoss, self).__init__()
|
| 46 |
+
self.distance_metric = distance_metric
|
| 47 |
+
self.margin = margin
|
| 48 |
+
self.size_average = size_average
|
| 49 |
+
|
| 50 |
+
def forward(self, rep_anchor, rep_candidate, label: Tensor):
|
| 51 |
+
# rep_anchor: [batch_size, hidden_dim] denotes the representations of anchors
|
| 52 |
+
# rep_candidate: [batch_size, hidden_dim] denotes the representations of positive / negative
|
| 53 |
+
# label: [batch_size, hidden_dim] denotes the label of each anchor - candidate pair
|
| 54 |
+
|
| 55 |
+
distances = self.distance_metric(rep_anchor, rep_candidate)
|
| 56 |
+
losses = 0.5 * (label.float() * distances.pow(2) + (1 - label).float() * F.relu(self.margin - distances).pow(2))
|
| 57 |
+
return losses.mean() if self.size_average else losses.sum()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
# configure for huggingface pre-trained language models
|
| 62 |
+
config = BertConfig.from_pretrained("bert-base-cased")
|
| 63 |
+
# tokenizer for huggingface pre-trained language models
|
| 64 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
| 65 |
+
# pytorch_model.bin for huggingface pre-trained language models
|
| 66 |
+
model = BertModel.from_pretrained("bert-base-cased")
|
| 67 |
+
# obtain two batch of examples, each corresponding example is a pair
|
| 68 |
+
examples1 = ["This is the sentence anchor 1.", "It is the second sentence in this article named Section D."]
|
| 69 |
+
examples2 = ["It is the same as anchor 1.", "I think it is different with Section D."]
|
| 70 |
+
label = [1, 0]
|
| 71 |
+
# convert each example for feature
|
| 72 |
+
# {"input_ids": xxx, "attention_mask": xxx, "token_tuype_ids": xxx}
|
| 73 |
+
features1 = tokenizer(examples1, add_special_tokens=True, padding=True)
|
| 74 |
+
features2 = tokenizer(examples2, add_special_tokens=True, padding=True)
|
| 75 |
+
# padding and convert to feature batch
|
| 76 |
+
max_seq_lem = 16
|
| 77 |
+
features1 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features1.items()}
|
| 78 |
+
features2 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features2.items()}
|
| 79 |
+
label = torch.Tensor(label).long()
|
| 80 |
+
# obtain sentence embedding by averaged pooling
|
| 81 |
+
rep_anchor = model(**features1)[0] # [batch_size, max_seq_len, hidden_dim]
|
| 82 |
+
rep_candidate = model(**features2)[0] # [batch_size, max_seq_len, hidden_dim]
|
| 83 |
+
rep_anchor = torch.mean(rep_anchor, -1) # [batch_size, hidden_dim]
|
| 84 |
+
rep_candidate = torch.mean(rep_candidate, -1) # [batch_size, hidden_dim]
|
| 85 |
+
# obtain contrastive loss
|
| 86 |
+
loss_fn = ContrastiveLoss()
|
| 87 |
+
loss = loss_fn(rep_anchor=rep_anchor, rep_candidate=rep_candidate, label=label)
|
| 88 |
+
print(loss) # tensor(0.0869, grad_fn=<SumBackward0>)
|
loss/focal_loss.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2022/2/17 6:05 下午
|
| 3 |
+
# @Author : JianingWang
|
| 4 |
+
# @File : loss
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FocalLoss(nn.Module):
|
| 11 |
+
"""Multi-class Focal loss implementation"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, gamma=2, weight=None, ignore_index=-100):
|
| 14 |
+
super(FocalLoss, self).__init__()
|
| 15 |
+
self.gamma = gamma
|
| 16 |
+
self.weight = weight
|
| 17 |
+
self.ignore_index = ignore_index
|
| 18 |
+
|
| 19 |
+
def forward(self, input, target):
|
| 20 |
+
"""
|
| 21 |
+
input: [N, C]
|
| 22 |
+
target: [N, ]
|
| 23 |
+
"""
|
| 24 |
+
logpt = F.log_softmax(input, dim=1)
|
| 25 |
+
pt = torch.exp(logpt)
|
| 26 |
+
logpt = (1 - pt) ** self.gamma * logpt
|
| 27 |
+
loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index)
|
| 28 |
+
return loss
|
loss/label_smoothing.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
class LabelSmoothingCrossEntropy(nn.Module):
|
| 5 |
+
def __init__(self, eps=0.1, reduction="mean",ignore_index=-100):
|
| 6 |
+
super(LabelSmoothingCrossEntropy, self).__init__()
|
| 7 |
+
self.eps = eps
|
| 8 |
+
self.reduction = reduction
|
| 9 |
+
self.ignore_index = ignore_index
|
| 10 |
+
|
| 11 |
+
def forward(self, output, target):
|
| 12 |
+
c = output.size()[-1]
|
| 13 |
+
log_preds = F.log_softmax(output, dim=-1)
|
| 14 |
+
if self.reduction=="sum":
|
| 15 |
+
loss = -log_preds.sum()
|
| 16 |
+
else:
|
| 17 |
+
loss = -log_preds.sum(dim=-1)
|
| 18 |
+
if self.reduction=="mean":
|
| 19 |
+
loss = loss.mean()
|
| 20 |
+
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction,
|
| 21 |
+
ignore_index=self.ignore_index)
|
loss/rl_loss.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
| 7 |
+
tensor = tensor * mask
|
| 8 |
+
tensor = tensor.sum(dim=dim)
|
| 9 |
+
mask_sum = mask.sum(dim=dim)
|
| 10 |
+
mean = tensor / (mask_sum + 1e-8)
|
| 11 |
+
return mean
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GPTLMLoss(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
GPT Language Model Loss
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.loss = nn.CrossEntropyLoss()
|
| 22 |
+
|
| 23 |
+
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 25 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 26 |
+
# Flatten the tokens
|
| 27 |
+
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PolicyLoss(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
Policy Loss for PPO
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, clip_eps: float = 0.2) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.clip_eps = clip_eps
|
| 38 |
+
|
| 39 |
+
def forward(self,
|
| 40 |
+
log_probs: torch.Tensor,
|
| 41 |
+
old_log_probs: torch.Tensor,
|
| 42 |
+
advantages: torch.Tensor,
|
| 43 |
+
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 44 |
+
ratio = (log_probs - old_log_probs).exp()
|
| 45 |
+
surr1 = ratio * advantages
|
| 46 |
+
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
| 47 |
+
loss = -torch.min(surr1, surr2)
|
| 48 |
+
if action_mask is not None:
|
| 49 |
+
loss = masked_mean(loss, action_mask)
|
| 50 |
+
loss = loss.mean()
|
| 51 |
+
return loss
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ValueLoss(nn.Module):
|
| 55 |
+
"""
|
| 56 |
+
Value Loss for PPO
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, clip_eps: float = 0.4) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.clip_eps = clip_eps
|
| 62 |
+
|
| 63 |
+
def forward(self,
|
| 64 |
+
values: torch.Tensor,
|
| 65 |
+
old_values: torch.Tensor,
|
| 66 |
+
reward: torch.Tensor,
|
| 67 |
+
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 68 |
+
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
| 69 |
+
surr1 = (values_clipped - reward)**2
|
| 70 |
+
surr2 = (values - reward)**2
|
| 71 |
+
loss = torch.max(surr1, surr2)
|
| 72 |
+
loss = loss.mean()
|
| 73 |
+
return 0.5 * loss
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class PPOPtxActorLoss(nn.Module):
|
| 77 |
+
"""
|
| 78 |
+
To Do:
|
| 79 |
+
|
| 80 |
+
PPO-ptx Actor Loss
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.pretrain_coef = pretrain_coef
|
| 86 |
+
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
|
| 87 |
+
self.pretrain_loss_fn = pretrain_loss_fn
|
| 88 |
+
|
| 89 |
+
def forward(self,
|
| 90 |
+
log_probs: torch.Tensor,
|
| 91 |
+
old_log_probs: torch.Tensor,
|
| 92 |
+
advantages: torch.Tensor,
|
| 93 |
+
lm_logits: torch.Tensor,
|
| 94 |
+
lm_input_ids: torch.Tensor,
|
| 95 |
+
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 96 |
+
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
|
| 97 |
+
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
|
| 98 |
+
return policy_loss + self.pretrain_coef * lm_loss
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class LogSigLoss(nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
Pairwise Loss for Reward Model
|
| 104 |
+
Details: https://arxiv.org/abs/2203.02155
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
probs = torch.sigmoid(chosen_reward - reject_reward)
|
| 109 |
+
log_probs = torch.log(probs)
|
| 110 |
+
loss = -log_probs.mean()
|
| 111 |
+
return loss
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class LogExpLoss(nn.Module):
|
| 115 |
+
"""
|
| 116 |
+
Pairwise Loss for Reward Model
|
| 117 |
+
Details: https://arxiv.org/abs/2204.05862
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
| 122 |
+
return loss
|
loss/similarity_loss.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2022/03/23 16:55
|
| 3 |
+
# @Author : Jianing Wang
|
| 4 |
+
# @Email : [email protected]
|
| 5 |
+
# @File : SimilarityLoss.py
|
| 6 |
+
# !/usr/bin/env python
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn, Tensor
|
| 11 |
+
from transformers.models.bert.modeling_bert import BertModel
|
| 12 |
+
from transformers import BertTokenizer, BertConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CosineSimilarityLoss(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
CosineSimilarityLoss expects, that the InputExamples consists of two texts and a float label.
|
| 18 |
+
|
| 19 |
+
It computes the vectors u = model(input_text[0]) and v = model(input_text[1]) and measures the cosine-similarity between the two.
|
| 20 |
+
By default, it minimizes the following loss: ||input_label - cos_score_transformation(cosine_sim(u,v))||_2.
|
| 21 |
+
|
| 22 |
+
:param loss_fct: Which pytorch loss function should be used to compare the cosine_similartiy(u,v) with the input_label? By default, MSE: ||input_label - cosine_sim(u,v)||_2
|
| 23 |
+
:param cos_score_transformation: The cos_score_transformation function is applied on top of cosine_similarity. By default, the identify function is used (i.e. no change).
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, loss_fct = nn.MSELoss(), cos_score_transformation=nn.Identity()):
|
| 28 |
+
super(CosineSimilarityLoss, self).__init__()
|
| 29 |
+
self.loss_fct = loss_fct
|
| 30 |
+
self.cos_score_transformation = cos_score_transformation
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def forward(self, rep_a, rep_b, label: Tensor):
|
| 34 |
+
# rep_a: [batch_size, hidden_dim]
|
| 35 |
+
# rep_b: [batch_size, hidden_dim]
|
| 36 |
+
output = self.cos_score_transformation(torch.cosine_similarity(rep_a, rep_b))
|
| 37 |
+
# print(output) # tensor([0.9925, 0.5846], grad_fn=<DivBackward0>), tensor(0.1709, grad_fn=<MseLossBackward0>)
|
| 38 |
+
return self.loss_fct(output, label.view(-1))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
# configure for huggingface pre-trained language models
|
| 44 |
+
config = BertConfig.from_pretrained("bert-base-cased")
|
| 45 |
+
# tokenizer for huggingface pre-trained language models
|
| 46 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
| 47 |
+
# pytorch_model.bin for huggingface pre-trained language models
|
| 48 |
+
model = BertModel.from_pretrained("bert-base-cased")
|
| 49 |
+
# obtain two batch of examples, each corresponding example is a pair
|
| 50 |
+
examples1 = ["Beijing is one of the biggest city in China.", "Disney film is well seeing for us."]
|
| 51 |
+
examples2 = ["Shanghai is the largest city in east of China.", "ACL 2021 will be held in line due to COVID-19."]
|
| 52 |
+
label = [1, 0]
|
| 53 |
+
# convert each example for feature
|
| 54 |
+
# {"input_ids": xxx, "attention_mask": xxx, "token_tuype_ids": xxx}
|
| 55 |
+
features1 = tokenizer(examples1, add_special_tokens=True, padding=True)
|
| 56 |
+
features2 = tokenizer(examples2, add_special_tokens=True, padding=True)
|
| 57 |
+
# padding and convert to feature batch
|
| 58 |
+
max_seq_lem = 24
|
| 59 |
+
features1 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features1.items()}
|
| 60 |
+
features2 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features2.items()}
|
| 61 |
+
label = torch.Tensor(label).long()
|
| 62 |
+
# obtain sentence embedding by averaged pooling
|
| 63 |
+
rep_a = model(**features1)[0] # [batch_size, max_seq_len, hidden_dim]
|
| 64 |
+
rep_b = model(**features2)[0] # [batch_size, max_seq_len, hidden_dim]
|
| 65 |
+
rep_a = torch.mean(rep_a, -1) # [batch_size, hidden_dim]
|
| 66 |
+
rep_b = torch.mean(rep_b, -1) # [batch_size, hidden_dim]
|
| 67 |
+
# obtain contrastive loss
|
| 68 |
+
loss_fn = CosineSimilarityLoss()
|
| 69 |
+
loss = loss_fn(rep_a=rep_a, rep_b=rep_b, label=label)
|
| 70 |
+
print(loss) # tensor(0.1709, grad_fn=<SumBackward0>)
|
loss/triplet_loss.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2022/03/23 15:25
|
| 3 |
+
# @Author : Jianing Wang
|
| 4 |
+
# @Email : [email protected]
|
| 5 |
+
# @File : TripletLoss.py
|
| 6 |
+
# !/usr/bin/env python
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
|
| 9 |
+
from enum import Enum
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from transformers.models.bert.modeling_bert import BertModel
|
| 14 |
+
from transformers import BertTokenizer, BertConfig
|
| 15 |
+
|
| 16 |
+
class TripletDistanceMetric(Enum):
|
| 17 |
+
"""
|
| 18 |
+
The metric for the triplet loss
|
| 19 |
+
"""
|
| 20 |
+
COSINE = lambda x, y: 1 - F.cosine_similarity(x, y)
|
| 21 |
+
EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
|
| 22 |
+
MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
|
| 23 |
+
|
| 24 |
+
class TripletLoss(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
This class implements triplet loss. Given a triplet of (anchor, positive, negative),
|
| 27 |
+
the loss minimizes the distance between anchor and positive while it maximizes the distance
|
| 28 |
+
between anchor and negative. It compute the following loss function:
|
| 29 |
+
|
| 30 |
+
loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0).
|
| 31 |
+
|
| 32 |
+
Margin is an important hyperparameter and needs to be tuned respectively.
|
| 33 |
+
|
| 34 |
+
@:param distance_metric: The distance metric function
|
| 35 |
+
@:param triplet_margin: (float) The margin distance
|
| 36 |
+
|
| 37 |
+
Input example of forward function:
|
| 38 |
+
rep_anchor: [[0.2, -0.1, ..., 0.6], [0.2, -0.1, ..., 0.6], ..., [0.2, -0.1, ..., 0.6]]
|
| 39 |
+
rep_candidate: [[0.3, 0.1, ...m -0.3], [-0.8, 1.2, ..., 0.7], ..., [-0.9, 0.1, ..., 0.4]]
|
| 40 |
+
label: [0, 1, ..., 1]
|
| 41 |
+
|
| 42 |
+
Return example of forward function:
|
| 43 |
+
0.015 (averged)
|
| 44 |
+
2.672 (sum)
|
| 45 |
+
|
| 46 |
+
"""
|
| 47 |
+
def __init__(self, distance_metric=TripletDistanceMetric.EUCLIDEAN, triplet_margin: float = 0.5):
|
| 48 |
+
super(TripletLoss, self).__init__()
|
| 49 |
+
self.distance_metric = distance_metric
|
| 50 |
+
self.triplet_margin = triplet_margin
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def forward(self, rep_anchor, rep_positive, rep_negative):
|
| 54 |
+
# rep_anchor: [batch_size, hidden_dim] denotes the representations of anchors
|
| 55 |
+
# rep_positive: [batch_size, hidden_dim] denotes the representations of positive, sometimes, it canbe dropout
|
| 56 |
+
# rep_negative: [batch_size, hidden_dim] denotes the representations of negative
|
| 57 |
+
# label: [batch_size, hidden_dim] denotes the label of each anchor - candidate pair
|
| 58 |
+
distance_pos = self.distance_metric(rep_anchor, rep_positive)
|
| 59 |
+
distance_neg = self.distance_metric(rep_anchor, rep_negative)
|
| 60 |
+
|
| 61 |
+
losses = F.relu(distance_pos - distance_neg + self.triplet_margin)
|
| 62 |
+
return losses.mean()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
# configure for huggingface pre-trained language models
|
| 67 |
+
config = BertConfig.from_pretrained("bert-base-cased")
|
| 68 |
+
# tokenizer for huggingface pre-trained language models
|
| 69 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
| 70 |
+
# pytorch_model.bin for huggingface pre-trained language models
|
| 71 |
+
model = BertModel.from_pretrained("bert-base-cased")
|
| 72 |
+
# obtain two batch of examples, each corresponding example is a pair
|
| 73 |
+
anchor_example = ["I am an anchor, which is the source example sampled from corpora."] # anchor sentence
|
| 74 |
+
positive_example = [
|
| 75 |
+
"I am an anchor, which is the source example.",
|
| 76 |
+
"I am the source example sampled from corpora."
|
| 77 |
+
] # positive, which randomly dropout or noise from anchor
|
| 78 |
+
negative_example = [
|
| 79 |
+
"It is different with the anchor.",
|
| 80 |
+
"My name is Jianing Wang, please give me some stars, thank you!"
|
| 81 |
+
] # negative, which randomly sampled from corpora
|
| 82 |
+
# convert each example for feature
|
| 83 |
+
# {"input_ids": xxx, "attention_mask": xxx, "token_tuype_ids": xxx}
|
| 84 |
+
anchor_feature = tokenizer(anchor_example, add_special_tokens=True, padding=True)
|
| 85 |
+
positive_feature = tokenizer(positive_example, add_special_tokens=True, padding=True)
|
| 86 |
+
negative_feature = tokenizer(negative_example, add_special_tokens=True, padding=True)
|
| 87 |
+
# padding and convert to feature batch
|
| 88 |
+
max_seq_lem = 24
|
| 89 |
+
anchor_feature = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in anchor_feature.items()}
|
| 90 |
+
positive_feature = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in positive_feature.items()}
|
| 91 |
+
negative_feature = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in negative_feature.items()}
|
| 92 |
+
# obtain sentence embedding by averaged pooling
|
| 93 |
+
rep_anchor = model(**anchor_feature)[0] # [1, max_seq_len, hidden_dim]
|
| 94 |
+
rep_positive = model(**positive_feature)[0] # [batch_size, max_seq_len, hidden_dim]
|
| 95 |
+
rep_negative = model(**negative_feature)[0] # [batch_size, max_seq_len, hidden_dim]
|
| 96 |
+
# repeat
|
| 97 |
+
rep_anchor = torch.mean(rep_anchor, -1) # [1, hidden_dim]
|
| 98 |
+
rep_positive = torch.mean(rep_positive, -1) # [batch_size, hidden_dim]
|
| 99 |
+
rep_negative = torch.mean(rep_negative, -1) # [batch_size, hidden_dim]
|
| 100 |
+
# obtain contrastive loss
|
| 101 |
+
loss_fn = TripletLoss()
|
| 102 |
+
loss = loss_fn(rep_anchor=rep_anchor, rep_positive=rep_positive, rep_negative=rep_negative)
|
| 103 |
+
print(loss) # tensor(0.5001, grad_fn=<MeanBackward0>)
|