Upload modeling_bert.py
Browse files- modeling_bert.py +14 -0
modeling_bert.py
CHANGED
@@ -21,6 +21,7 @@ import warnings
|
|
21 |
from dataclasses import dataclass
|
22 |
from typing import List, Optional, Tuple, Union
|
23 |
from functools import partial
|
|
|
24 |
import torch
|
25 |
import torch.utils.checkpoint
|
26 |
from packaging import version
|
@@ -56,6 +57,18 @@ from transformers.utils import (
|
|
56 |
)
|
57 |
from .configuration_bert import BertConfig
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def softmax_n_shifted_zeros(input: torch.Tensor, n: int, dim=-1) -> torch.Tensor:
|
61 |
"""
|
@@ -91,6 +104,7 @@ def clipped_softmax(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
|
91 |
return torch.clip(stretched_out, 0, 1)
|
92 |
|
93 |
|
|
|
94 |
def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
95 |
sm_out = softmax_1(data, dim=dim, **kw)
|
96 |
stretched_out = sm_out * (eta - gamma) + gamma
|
|
|
21 |
from dataclasses import dataclass
|
22 |
from typing import List, Optional, Tuple, Union
|
23 |
from functools import partial
|
24 |
+
from enum import Flag, auto
|
25 |
import torch
|
26 |
import torch.utils.checkpoint
|
27 |
from packaging import version
|
|
|
57 |
)
|
58 |
from .configuration_bert import BertConfig
|
59 |
|
60 |
+
class BaseEnumOptions(Flag):
|
61 |
+
def __str__(self):
|
62 |
+
return self.name
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def list_names(cls):
|
66 |
+
return [m.name for m in cls]
|
67 |
+
class AttentionGateType(BaseEnumOptions):
|
68 |
+
none = 0
|
69 |
+
unconditional_per_head = 1
|
70 |
+
conditional_per_head = 2
|
71 |
+
conditional_per_token = 3
|
72 |
|
73 |
def softmax_n_shifted_zeros(input: torch.Tensor, n: int, dim=-1) -> torch.Tensor:
|
74 |
"""
|
|
|
104 |
return torch.clip(stretched_out, 0, 1)
|
105 |
|
106 |
|
107 |
+
|
108 |
def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
109 |
sm_out = softmax_1(data, dim=dim, **kw)
|
110 |
stretched_out = sm_out * (eta - gamma) + gamma
|