Spaces:
Build error
Build error
Commit
·
394811b
0
Parent(s):
Duplicate from nguyenvulebinh/spoken-norm-taggen
Browse filesCo-authored-by: Binh Nguyen <[email protected]>
- .gitattributes +27 -0
- README.md +38 -0
- app.py +25 -0
- attentions.py +466 -0
- data_handling.py +336 -0
- infer.py +374 -0
- model_config_handling.py +90 -0
- model_handling.py +763 -0
- requirements.txt +8 -0
- utils.py +271 -0
.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Spoken Norm
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
duplicated_from: nguyenvulebinh/spoken-norm-taggen
|
10 |
+
---
|
11 |
+
|
12 |
+
# Configuration
|
13 |
+
|
14 |
+
`title`: _string_
|
15 |
+
Display title for the Space
|
16 |
+
|
17 |
+
`emoji`: _string_
|
18 |
+
Space emoji (emoji-only character allowed)
|
19 |
+
|
20 |
+
`colorFrom`: _string_
|
21 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
22 |
+
|
23 |
+
`colorTo`: _string_
|
24 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
25 |
+
|
26 |
+
`sdk`: _string_
|
27 |
+
Can be either `gradio` or `streamlit`
|
28 |
+
|
29 |
+
`sdk_version` : _string_
|
30 |
+
Only applicable for `streamlit` SDK.
|
31 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
32 |
+
|
33 |
+
`app_file`: _string_
|
34 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
35 |
+
Path is relative to the root of the repository.
|
36 |
+
|
37 |
+
`pinned`: _boolean_
|
38 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from infer import infer
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
def format_text(text_input, list_bias_input):
|
7 |
+
print('{}\n{}\n\n'.format(text_input, list_bias_input))
|
8 |
+
bias_list = list_bias_input.strip().split('\n')
|
9 |
+
norm_result = infer([text_input], bias_list)
|
10 |
+
return norm_result[0]
|
11 |
+
|
12 |
+
|
13 |
+
title = "Transformation spoken text to written text"
|
14 |
+
|
15 |
+
iface = gr.Interface(format_text,
|
16 |
+
[
|
17 |
+
gr.inputs.Textbox(
|
18 |
+
lines=1,
|
19 |
+
default="ngày hai tám tháng tư cô vít bùng phát ở xì cút len chiếm tám mươi phần trăm là biến chủng đen ta và bê ta và ô mi cờ ron"),
|
20 |
+
gr.inputs.Textbox(
|
21 |
+
lines=5, default='covid\ndelta\nbeta\nomicron | ô mi cờ ron\nscotland | sờ cốt lờn | xì cút len'),
|
22 |
+
],
|
23 |
+
outputs="text",
|
24 |
+
title=title)
|
25 |
+
iface.launch()
|
attentions.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import Tensor
|
6 |
+
import numpy as np
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
|
9 |
+
|
10 |
+
class ScaledDotProductAttention(nn.Module):
|
11 |
+
"""
|
12 |
+
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
|
13 |
+
Compute the dot products of the query with all keys, divide each by sqrt(dim),
|
14 |
+
and apply a softmax function to obtain the weights on the values
|
15 |
+
|
16 |
+
Args: dim, mask
|
17 |
+
dim (int): dimention of attention
|
18 |
+
mask (torch.Tensor): tensor containing indices to be masked
|
19 |
+
|
20 |
+
Inputs: query, key, value, mask
|
21 |
+
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
|
22 |
+
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
|
23 |
+
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
|
24 |
+
- **mask** (-): tensor containing indices to be masked
|
25 |
+
|
26 |
+
Returns: context, attn
|
27 |
+
- **context**: tensor containing the context vector from attention mechanism.
|
28 |
+
- **attn**: tensor containing the attention (alignment) from the encoder outputs.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, dim: int):
|
32 |
+
super(ScaledDotProductAttention, self).__init__()
|
33 |
+
self.sqrt_dim = np.sqrt(dim)
|
34 |
+
|
35 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[
|
36 |
+
Tensor, Tensor]:
|
37 |
+
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
|
38 |
+
|
39 |
+
if mask is not None:
|
40 |
+
score.masked_fill_(mask.view(score.size()), -float('Inf'))
|
41 |
+
|
42 |
+
attn = F.softmax(score, -1)
|
43 |
+
context = torch.bmm(attn, value)
|
44 |
+
return context, attn
|
45 |
+
|
46 |
+
|
47 |
+
class DotProductAttention(nn.Module):
|
48 |
+
"""
|
49 |
+
Compute the dot products of the query with all values and apply a softmax function to obtain the weights on the values
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, hidden_dim):
|
53 |
+
super(DotProductAttention, self).__init__()
|
54 |
+
self.normalize = nn.LayerNorm(hidden_dim)
|
55 |
+
|
56 |
+
def forward(self, query: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
|
57 |
+
batch_size, hidden_dim, input_size = query.size(0), query.size(2), value.size(1)
|
58 |
+
|
59 |
+
score = torch.bmm(query, value.transpose(1, 2))
|
60 |
+
attn = F.softmax(score.view(-1, input_size), dim=1).view(batch_size, -1, input_size)
|
61 |
+
context = torch.bmm(attn, value)
|
62 |
+
|
63 |
+
return context, attn
|
64 |
+
|
65 |
+
|
66 |
+
class AdditiveAttention(nn.Module):
|
67 |
+
"""
|
68 |
+
Applies a additive attention (bahdanau) mechanism on the output features from the decoder.
|
69 |
+
Additive attention proposed in "Neural Machine Translation by Jointly Learning to Align and Translate" paper.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
hidden_dim (int): dimesion of hidden state vector
|
73 |
+
|
74 |
+
Inputs: query, value
|
75 |
+
- **query** (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder.
|
76 |
+
- **value** (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence.
|
77 |
+
|
78 |
+
Returns: context, attn
|
79 |
+
- **context**: tensor containing the context vector from attention mechanism.
|
80 |
+
- **attn**: tensor containing the alignment from the encoder outputs.
|
81 |
+
|
82 |
+
Reference:
|
83 |
+
- **Neural Machine Translation by Jointly Learning to Align and Translate**: https://arxiv.org/abs/1409.0473
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, hidden_dim: int) -> None:
|
87 |
+
super(AdditiveAttention, self).__init__()
|
88 |
+
self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
89 |
+
self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
90 |
+
self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
|
91 |
+
self.score_proj = nn.Linear(hidden_dim, 1)
|
92 |
+
|
93 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
|
94 |
+
score = self.score_proj(torch.tanh(self.key_proj(key) + self.query_proj(query) + self.bias)).squeeze(-1)
|
95 |
+
attn = F.softmax(score, dim=-1)
|
96 |
+
context = torch.bmm(attn.unsqueeze(1), value)
|
97 |
+
return context, attn
|
98 |
+
|
99 |
+
|
100 |
+
class LocationAwareAttention(nn.Module):
|
101 |
+
"""
|
102 |
+
Applies a location-aware attention mechanism on the output features from the decoder.
|
103 |
+
Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
|
104 |
+
The location-aware attention mechanism is performing well in speech recognition tasks.
|
105 |
+
We refer to implementation of ClovaCall Attention style.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
hidden_dim (int): dimesion of hidden state vector
|
109 |
+
smoothing (bool): flag indication whether to use smoothing or not.
|
110 |
+
|
111 |
+
Inputs: query, value, last_attn, smoothing
|
112 |
+
- **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
|
113 |
+
- **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
|
114 |
+
- **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
|
115 |
+
|
116 |
+
Returns: output, attn
|
117 |
+
- **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
|
118 |
+
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
|
119 |
+
|
120 |
+
Reference:
|
121 |
+
- **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
|
122 |
+
- **ClovaCall**: https://github.com/clovaai/ClovaCall/blob/master/las.pytorch/models/attention.py
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, hidden_dim: int, smoothing: bool = True) -> None:
|
126 |
+
super(LocationAwareAttention, self).__init__()
|
127 |
+
self.hidden_dim = hidden_dim
|
128 |
+
self.conv1d = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1)
|
129 |
+
self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
130 |
+
self.value_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
131 |
+
self.score_proj = nn.Linear(hidden_dim, 1, bias=True)
|
132 |
+
self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
|
133 |
+
self.smoothing = smoothing
|
134 |
+
|
135 |
+
def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
|
136 |
+
batch_size, hidden_dim, seq_len = query.size(0), query.size(2), value.size(1)
|
137 |
+
|
138 |
+
# Initialize previous attention (alignment) to zeros
|
139 |
+
if last_attn is None:
|
140 |
+
last_attn = value.new_zeros(batch_size, seq_len)
|
141 |
+
|
142 |
+
conv_attn = torch.transpose(self.conv1d(last_attn.unsqueeze(1)), 1, 2)
|
143 |
+
score = self.score_proj(torch.tanh(
|
144 |
+
self.query_proj(query.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
|
145 |
+
+ self.value_proj(value.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
|
146 |
+
+ conv_attn
|
147 |
+
+ self.bias
|
148 |
+
)).squeeze(dim=-1)
|
149 |
+
|
150 |
+
if self.smoothing:
|
151 |
+
score = torch.sigmoid(score)
|
152 |
+
attn = torch.div(score, score.sum(dim=-1).unsqueeze(dim=-1))
|
153 |
+
else:
|
154 |
+
attn = F.softmax(score, dim=-1)
|
155 |
+
|
156 |
+
context = torch.bmm(attn.unsqueeze(dim=1), value).squeeze(dim=1) # Bx1xT X BxTxD => Bx1xD => BxD
|
157 |
+
|
158 |
+
return context, attn
|
159 |
+
|
160 |
+
|
161 |
+
class MultiHeadLocationAwareAttention(nn.Module):
|
162 |
+
"""
|
163 |
+
Applies a multi-headed location-aware attention mechanism on the output features from the decoder.
|
164 |
+
Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
|
165 |
+
The location-aware attention mechanism is performing well in speech recognition tasks.
|
166 |
+
In the above paper applied a signle head, but we applied multi head concept.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
hidden_dim (int): The number of expected features in the output
|
170 |
+
num_heads (int): The number of heads. (default: )
|
171 |
+
conv_out_channel (int): The number of out channel in convolution
|
172 |
+
|
173 |
+
Inputs: query, value, prev_attn
|
174 |
+
- **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
|
175 |
+
- **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
|
176 |
+
- **prev_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
|
177 |
+
|
178 |
+
Returns: output, attn
|
179 |
+
- **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
|
180 |
+
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
|
181 |
+
|
182 |
+
Reference:
|
183 |
+
- **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
|
184 |
+
- **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, hidden_dim: int, num_heads: int = 8, conv_out_channel: int = 10) -> None:
|
188 |
+
super(MultiHeadLocationAwareAttention, self).__init__()
|
189 |
+
self.hidden_dim = hidden_dim
|
190 |
+
self.num_heads = num_heads
|
191 |
+
self.dim = int(hidden_dim / num_heads)
|
192 |
+
self.conv1d = nn.Conv1d(num_heads, conv_out_channel, kernel_size=3, padding=1)
|
193 |
+
self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
|
194 |
+
self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
|
195 |
+
self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
|
196 |
+
self.score_proj = nn.Linear(self.dim, 1, bias=True)
|
197 |
+
self.bias = nn.Parameter(torch.rand(self.dim).uniform_(-0.1, 0.1))
|
198 |
+
|
199 |
+
def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
|
200 |
+
batch_size, seq_len = value.size(0), value.size(1)
|
201 |
+
|
202 |
+
if last_attn is None:
|
203 |
+
last_attn = value.new_zeros(batch_size, self.num_heads, seq_len)
|
204 |
+
|
205 |
+
loc_energy = torch.tanh(self.loc_proj(self.conv1d(last_attn).transpose(1, 2)))
|
206 |
+
loc_energy = loc_energy.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(-1, seq_len, self.dim)
|
207 |
+
|
208 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
|
209 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
|
210 |
+
query = query.contiguous().view(-1, 1, self.dim)
|
211 |
+
value = value.contiguous().view(-1, seq_len, self.dim)
|
212 |
+
|
213 |
+
score = self.score_proj(torch.tanh(value + query + loc_energy + self.bias)).squeeze(2)
|
214 |
+
attn = F.softmax(score, dim=1)
|
215 |
+
|
216 |
+
value = value.view(batch_size, seq_len, self.num_heads, self.dim).permute(0, 2, 1, 3)
|
217 |
+
value = value.contiguous().view(-1, seq_len, self.dim)
|
218 |
+
|
219 |
+
context = torch.bmm(attn.unsqueeze(1), value).view(batch_size, -1, self.num_heads * self.dim)
|
220 |
+
attn = attn.view(batch_size, self.num_heads, -1)
|
221 |
+
|
222 |
+
return context, attn
|
223 |
+
|
224 |
+
|
225 |
+
class MultiHeadAttention(nn.Module):
|
226 |
+
"""
|
227 |
+
Multi-Head Attention proposed in "Attention Is All You Need"
|
228 |
+
Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
|
229 |
+
project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
|
230 |
+
These are concatenated and once again projected, resulting in the final values.
|
231 |
+
Multi-head attention allows the model to jointly attend to information from different representation
|
232 |
+
subspaces at different positions.
|
233 |
+
|
234 |
+
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
|
235 |
+
where head_i = Attention(Q · W_q, K · W_k, V · W_v)
|
236 |
+
|
237 |
+
Args:
|
238 |
+
d_model (int): The dimension of keys / values / quries (default: 512)
|
239 |
+
num_heads (int): The number of attention heads. (default: 8)
|
240 |
+
|
241 |
+
Inputs: query, key, value, mask
|
242 |
+
- **query** (batch, q_len, d_model): In transformer, three different ways:
|
243 |
+
Case 1: come from previoys decoder layer
|
244 |
+
Case 2: come from the input embedding
|
245 |
+
Case 3: come from the output embedding (masked)
|
246 |
+
|
247 |
+
- **key** (batch, k_len, d_model): In transformer, three different ways:
|
248 |
+
Case 1: come from the output of the encoder
|
249 |
+
Case 2: come from the input embeddings
|
250 |
+
Case 3: come from the output embedding (masked)
|
251 |
+
|
252 |
+
- **value** (batch, v_len, d_model): In transformer, three different ways:
|
253 |
+
Case 1: come from the output of the encoder
|
254 |
+
Case 2: come from the input embeddings
|
255 |
+
Case 3: come from the output embedding (masked)
|
256 |
+
|
257 |
+
- **mask** (-): tensor containing indices to be masked
|
258 |
+
|
259 |
+
Returns: output, attn
|
260 |
+
- **output** (batch, output_len, dimensions): tensor containing the attended output features.
|
261 |
+
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(self, d_model: int = 512, num_heads: int = 8):
|
265 |
+
super(MultiHeadAttention, self).__init__()
|
266 |
+
|
267 |
+
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
|
268 |
+
|
269 |
+
self.d_head = int(d_model / num_heads)
|
270 |
+
self.num_heads = num_heads
|
271 |
+
self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
|
272 |
+
self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
|
273 |
+
self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
|
274 |
+
self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
|
275 |
+
|
276 |
+
def forward(
|
277 |
+
self,
|
278 |
+
query: Tensor,
|
279 |
+
key: Tensor,
|
280 |
+
value: Tensor,
|
281 |
+
mask: Optional[Tensor] = None
|
282 |
+
) -> Tuple[Tensor, Tensor]:
|
283 |
+
batch_size = value.size(0)
|
284 |
+
|
285 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
|
286 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
|
287 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
|
288 |
+
|
289 |
+
query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
|
290 |
+
key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
|
291 |
+
value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
|
292 |
+
|
293 |
+
if mask is not None:
|
294 |
+
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN
|
295 |
+
|
296 |
+
context, attn = self.scaled_dot_attn(query, key, value, mask)
|
297 |
+
|
298 |
+
context = context.view(self.num_heads, batch_size, -1, self.d_head)
|
299 |
+
context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
|
300 |
+
|
301 |
+
return context, attn
|
302 |
+
|
303 |
+
|
304 |
+
class RelativeMultiHeadAttention(nn.Module):
|
305 |
+
"""
|
306 |
+
Multi-head attention with relative positional encoding.
|
307 |
+
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
308 |
+
|
309 |
+
Args:
|
310 |
+
d_model (int): The dimension of model
|
311 |
+
num_heads (int): The number of attention heads.
|
312 |
+
dropout_p (float): probability of dropout
|
313 |
+
|
314 |
+
Inputs: query, key, value, pos_embedding, mask
|
315 |
+
- **query** (batch, time, dim): Tensor containing query vector
|
316 |
+
- **key** (batch, time, dim): Tensor containing key vector
|
317 |
+
- **value** (batch, time, dim): Tensor containing value vector
|
318 |
+
- **pos_embedding** (batch, time, dim): Positional embedding tensor
|
319 |
+
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
- **outputs**: Tensor produces by relative multi head attention module.
|
323 |
+
"""
|
324 |
+
|
325 |
+
def __init__(
|
326 |
+
self,
|
327 |
+
d_model: int = 512,
|
328 |
+
num_heads: int = 16,
|
329 |
+
dropout_p: float = 0.1,
|
330 |
+
):
|
331 |
+
super(RelativeMultiHeadAttention, self).__init__()
|
332 |
+
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
|
333 |
+
self.d_model = d_model
|
334 |
+
self.d_head = int(d_model / num_heads)
|
335 |
+
self.num_heads = num_heads
|
336 |
+
self.sqrt_dim = math.sqrt(d_model)
|
337 |
+
|
338 |
+
self.query_proj = nn.Linear(d_model, d_model)
|
339 |
+
self.key_proj = nn.Linear(d_model, d_model)
|
340 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
341 |
+
self.pos_proj = nn.Linear(d_model, d_model, bias=False)
|
342 |
+
|
343 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
344 |
+
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
345 |
+
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
346 |
+
torch.nn.init.xavier_uniform_(self.u_bias)
|
347 |
+
torch.nn.init.xavier_uniform_(self.v_bias)
|
348 |
+
|
349 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
350 |
+
|
351 |
+
def forward(
|
352 |
+
self,
|
353 |
+
query: Tensor,
|
354 |
+
key: Tensor,
|
355 |
+
value: Tensor,
|
356 |
+
pos_embedding: Tensor,
|
357 |
+
mask: Optional[Tensor] = None,
|
358 |
+
) -> Tensor:
|
359 |
+
batch_size = value.size(0)
|
360 |
+
|
361 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
362 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
363 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
364 |
+
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
365 |
+
|
366 |
+
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
|
367 |
+
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
368 |
+
pos_score = self._compute_relative_positional_encoding(pos_score)
|
369 |
+
|
370 |
+
score = (content_score + pos_score) / self.sqrt_dim
|
371 |
+
|
372 |
+
if mask is not None:
|
373 |
+
mask = mask.unsqueeze(1)
|
374 |
+
score.masked_fill_(mask, -1e9)
|
375 |
+
|
376 |
+
attn = F.softmax(score, -1)
|
377 |
+
attn = self.dropout(attn)
|
378 |
+
|
379 |
+
context = torch.matmul(attn, value).transpose(1, 2)
|
380 |
+
context = context.contiguous().view(batch_size, -1, self.d_model)
|
381 |
+
|
382 |
+
return self.out_proj(context)
|
383 |
+
|
384 |
+
def _compute_relative_positional_encoding(self, pos_score: Tensor) -> Tensor:
|
385 |
+
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
386 |
+
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
387 |
+
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
388 |
+
|
389 |
+
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
390 |
+
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
391 |
+
|
392 |
+
return pos_score
|
393 |
+
|
394 |
+
|
395 |
+
class CustomizingAttention(nn.Module):
|
396 |
+
r"""
|
397 |
+
Customizing Attention
|
398 |
+
|
399 |
+
Applies a multi-head + location-aware attention mechanism on the output features from the decoder.
|
400 |
+
Multi-head attention proposed in "Attention Is All You Need" paper.
|
401 |
+
Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
|
402 |
+
I combined these two attention mechanisms as custom.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
hidden_dim (int): The number of expected features in the output
|
406 |
+
num_heads (int): The number of heads. (default: )
|
407 |
+
conv_out_channel (int): The dimension of convolution
|
408 |
+
|
409 |
+
Inputs: query, value, last_attn
|
410 |
+
- **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
|
411 |
+
- **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
|
412 |
+
- **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s alignment
|
413 |
+
|
414 |
+
Returns: output, attn
|
415 |
+
- **output** (batch, output_len, dimensions): tensor containing the attended output features from the decoder.
|
416 |
+
- **attn** (batch * num_heads, v_len): tensor containing the alignment from the encoder outputs.
|
417 |
+
|
418 |
+
Reference:
|
419 |
+
- **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
|
420 |
+
- **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
|
421 |
+
"""
|
422 |
+
|
423 |
+
def __init__(self, hidden_dim: int, num_heads: int = 4, conv_out_channel: int = 10) -> None:
|
424 |
+
super(CustomizingAttention, self).__init__()
|
425 |
+
self.hidden_dim = hidden_dim
|
426 |
+
self.num_heads = num_heads
|
427 |
+
self.dim = int(hidden_dim / num_heads)
|
428 |
+
self.scaled_dot_attn = ScaledDotProductAttention(self.dim)
|
429 |
+
self.conv1d = nn.Conv1d(1, conv_out_channel, kernel_size=3, padding=1)
|
430 |
+
self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=True)
|
431 |
+
self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
|
432 |
+
self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
|
433 |
+
self.bias = nn.Parameter(torch.rand(self.dim * num_heads).uniform_(-0.1, 0.1))
|
434 |
+
|
435 |
+
def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
|
436 |
+
batch_size, q_len, v_len = value.size(0), query.size(1), value.size(1)
|
437 |
+
|
438 |
+
if last_attn is None:
|
439 |
+
last_attn = value.new_zeros(batch_size * self.num_heads, v_len)
|
440 |
+
|
441 |
+
loc_energy = self.get_loc_energy(last_attn, batch_size, v_len) # get location energy
|
442 |
+
|
443 |
+
query = self.query_proj(query).view(batch_size, q_len, self.num_heads * self.dim)
|
444 |
+
value = self.value_proj(value).view(batch_size, v_len, self.num_heads * self.dim) + loc_energy + self.bias
|
445 |
+
|
446 |
+
query = query.view(batch_size, q_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
|
447 |
+
value = value.view(batch_size, v_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
|
448 |
+
query = query.contiguous().view(-1, q_len, self.dim)
|
449 |
+
value = value.contiguous().view(-1, v_len, self.dim)
|
450 |
+
|
451 |
+
context, attn = self.scaled_dot_attn(query, value)
|
452 |
+
attn = attn.squeeze()
|
453 |
+
|
454 |
+
context = context.view(self.num_heads, batch_size, q_len, self.dim).permute(1, 2, 0, 3)
|
455 |
+
context = context.contiguous().view(batch_size, q_len, -1)
|
456 |
+
|
457 |
+
return context, attn
|
458 |
+
|
459 |
+
def get_loc_energy(self, last_attn: Tensor, batch_size: int, v_len: int) -> Tensor:
|
460 |
+
conv_feat = self.conv1d(last_attn.unsqueeze(1))
|
461 |
+
conv_feat = conv_feat.view(batch_size, self.num_heads, -1, v_len).permute(0, 1, 3, 2)
|
462 |
+
|
463 |
+
loc_energy = self.loc_proj(conv_feat).view(batch_size, self.num_heads, v_len, self.dim)
|
464 |
+
loc_energy = loc_energy.permute(0, 2, 1, 3).reshape(batch_size, v_len, self.num_heads * self.dim)
|
465 |
+
|
466 |
+
return loc_energy
|
data_handling.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import model_handling
|
3 |
+
from transformers import PreTrainedTokenizerBase
|
4 |
+
from typing import Optional, Union, Any
|
5 |
+
from transformers.file_utils import PaddingStrategy
|
6 |
+
import re
|
7 |
+
import os
|
8 |
+
from tqdm import tqdm
|
9 |
+
# import time
|
10 |
+
import json
|
11 |
+
import random
|
12 |
+
import regtag
|
13 |
+
from dataclasses import dataclass
|
14 |
+
import validators
|
15 |
+
|
16 |
+
import utils
|
17 |
+
|
18 |
+
regexp = re.compile(r"\d{4}[\-/]\d{2}[\-/]\d{2}t\d{2}:\d{2}:\d{2}")
|
19 |
+
target_bias_words = set(regtag.get_general_en_word())
|
20 |
+
tokenizer = None
|
21 |
+
|
22 |
+
|
23 |
+
def get_bias_words():
|
24 |
+
regtag.augment.get_random_oov()
|
25 |
+
return list(regtag.augment.oov_dict.keys())
|
26 |
+
|
27 |
+
|
28 |
+
def check_common_phrase(word):
|
29 |
+
if validators.email(word.replace(' @', '@')):
|
30 |
+
return True
|
31 |
+
if validators.domain(word):
|
32 |
+
return True
|
33 |
+
if validators.url(word):
|
34 |
+
return True
|
35 |
+
if word in regtag.get_general_en_word():
|
36 |
+
return True
|
37 |
+
return False
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class DataCollatorForNormSeq2Seq:
|
42 |
+
tokenizer: PreTrainedTokenizerBase
|
43 |
+
model: Optional[Any] = None
|
44 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
45 |
+
max_length: Optional[int] = None
|
46 |
+
pad_to_multiple_of: Optional[int] = None
|
47 |
+
label_pad_token_id: int = -100
|
48 |
+
return_tensors: str = "pt"
|
49 |
+
|
50 |
+
def bias_phrases_extractor(self, features, max_bias_per_sample=30):
|
51 |
+
# src_ids, src_length, tgt_ids, tgt_length
|
52 |
+
phrase_candidate = []
|
53 |
+
sample_output_words = []
|
54 |
+
bias_labels = []
|
55 |
+
|
56 |
+
for sample in features:
|
57 |
+
words = []
|
58 |
+
for idx, (src_word_len, tgt_word_len) in enumerate(zip(sample['inputs_length'], sample['outputs_length'])):
|
59 |
+
src_start_idx = sum(sample['inputs_length'][:idx])
|
60 |
+
tgt_start_idx = sum(sample['outputs_length'][:idx])
|
61 |
+
word_input = self.tokenizer.decode(sample['input_ids'][src_start_idx: src_start_idx + src_word_len])
|
62 |
+
word_output = self.tokenizer.decode(sample['outputs'][tgt_start_idx: tgt_start_idx + tgt_word_len])
|
63 |
+
words.append(word_output)
|
64 |
+
if word_input != word_output and not any(map(str.isdigit, word_output)):
|
65 |
+
phrase_candidate.append(word_output)
|
66 |
+
sample_output_words.append(words)
|
67 |
+
|
68 |
+
phrase_candidate = list(set(phrase_candidate))
|
69 |
+
phrase_candidate_revised = []
|
70 |
+
phrase_candidate_common = []
|
71 |
+
raw_phrase_candidate = []
|
72 |
+
for item in phrase_candidate:
|
73 |
+
raw_item = self.tokenizer.sp_model.DecodePieces(item.split())
|
74 |
+
if check_common_phrase(raw_item):
|
75 |
+
phrase_candidate_common.append(raw_item)
|
76 |
+
else:
|
77 |
+
phrase_candidate_revised.append(item)
|
78 |
+
raw_phrase_candidate.append(raw_item)
|
79 |
+
|
80 |
+
remain_phrase = max(0, max_bias_per_sample * len(features) - len(phrase_candidate_revised))
|
81 |
+
|
82 |
+
if remain_phrase > 0:
|
83 |
+
words_candidate = list(
|
84 |
+
set(get_bias_words()) - set(raw_phrase_candidate))
|
85 |
+
random.shuffle(words_candidate)
|
86 |
+
phrase_candidate_revised += [' '.join(self.tokenizer.sp_model.EncodeAsPieces(item)[:5]) for item in
|
87 |
+
words_candidate[:remain_phrase]]
|
88 |
+
|
89 |
+
for i in range(len(features)):
|
90 |
+
sample_bias_lables = []
|
91 |
+
for w_idx, w in enumerate(sample_output_words[i]):
|
92 |
+
try:
|
93 |
+
sample_bias_lables.extend(
|
94 |
+
[phrase_candidate_revised.index(w) + 1] * features[i]['outputs_length'][w_idx])
|
95 |
+
except:
|
96 |
+
# random ignore 0 label
|
97 |
+
if random.random() < 0.5:
|
98 |
+
sample_bias_lables.extend([0] * features[i]['outputs_length'][w_idx])
|
99 |
+
else:
|
100 |
+
sample_bias_lables.extend([self.label_pad_token_id] * features[i]['outputs_length'][w_idx])
|
101 |
+
bias_labels.append(sample_bias_lables)
|
102 |
+
assert len(sample_bias_lables) == len(features[i]['outputs']), "{} vs {}".format(sample_bias_lables,
|
103 |
+
features[i]['outputs'])
|
104 |
+
|
105 |
+
# phrase_candidate_ids = [self.tokenizer.encode(item) for item in phrase_candidate]
|
106 |
+
phrase_candidate_ids = [self.tokenizer.encode(self.tokenizer.sp_model.DecodePieces(item.split())) for item in
|
107 |
+
phrase_candidate_revised]
|
108 |
+
phrase_candidate_mask = [[self.tokenizer.pad_token_id] * len(item) for item in phrase_candidate_ids]
|
109 |
+
|
110 |
+
return phrase_candidate_ids, phrase_candidate_mask, bias_labels
|
111 |
+
# pass
|
112 |
+
|
113 |
+
def encode_list_string(self, list_text):
|
114 |
+
text_tokenized = self.tokenizer(list_text)
|
115 |
+
return self.tokenizer.pad(
|
116 |
+
text_tokenized,
|
117 |
+
padding=self.padding,
|
118 |
+
max_length=self.max_length,
|
119 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
120 |
+
return_tensors='pt',
|
121 |
+
)
|
122 |
+
|
123 |
+
def __call__(self, features, return_tensors=None):
|
124 |
+
# start_time = time.time()
|
125 |
+
batch_src, batch_tgt = [], []
|
126 |
+
for item in features:
|
127 |
+
src_spans, tgt_spans = utils.make_spoken(item['text'])
|
128 |
+
batch_src.append(src_spans)
|
129 |
+
batch_tgt.append(tgt_spans)
|
130 |
+
# print("Make src-tgt {}s".format(time.time() - start_time))
|
131 |
+
# start_time = time.time()
|
132 |
+
|
133 |
+
features = preprocess_function({"src": batch_src, "tgt": batch_tgt})
|
134 |
+
|
135 |
+
|
136 |
+
# print("Make feature {}s".format(time.time() - start_time))
|
137 |
+
# start_time = time.time()
|
138 |
+
|
139 |
+
phrase_candidate_ids, phrase_candidate_mask, samples_bias_labels = self.bias_phrases_extractor(features)
|
140 |
+
# print("Make bias {}s".format(time.time() - start_time))
|
141 |
+
# start_time = time.time()
|
142 |
+
|
143 |
+
if return_tensors is None:
|
144 |
+
return_tensors = self.return_tensors
|
145 |
+
labels = [feature["outputs"] for feature in features] if "outputs" in features[0].keys() else None
|
146 |
+
spoken_labels = [feature["spoken_label"] for feature in features] if "spoken_label" in features[0].keys() else None
|
147 |
+
spoken_idx = [feature["src_spoken_idx"] for feature in features] if "src_spoken_idx" in features[0].keys() else None
|
148 |
+
|
149 |
+
word_src_lengths = [feature["inputs_length"] for feature in features] if "inputs_length" in features[0].keys() else None
|
150 |
+
word_tgt_lengths = [feature["outputs_length"] for feature in features] if "outputs_length" in features[0].keys() else None
|
151 |
+
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
152 |
+
# same length to return tensors.
|
153 |
+
if labels is not None:
|
154 |
+
max_label_length = max(len(l) for l in labels)
|
155 |
+
max_src_length = max(len(l) for l in spoken_labels)
|
156 |
+
max_spoken_idx_length = max(len(l) for l in spoken_idx)
|
157 |
+
max_word_src_length = max(len(l) for l in word_src_lengths)
|
158 |
+
max_word_tgt_length = max(len(l) for l in word_tgt_lengths)
|
159 |
+
|
160 |
+
padding_side = self.tokenizer.padding_side
|
161 |
+
for feature, bias_labels in zip(features, samples_bias_labels):
|
162 |
+
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["outputs"]))
|
163 |
+
remainder_word_tgt_length = [0] * (max_word_tgt_length - len(feature["outputs_length"]))
|
164 |
+
remainder_spoken = [self.label_pad_token_id] * (max_src_length - len(feature["spoken_label"]))
|
165 |
+
remainder_spoken_idx = [self.label_pad_token_id] * (max_spoken_idx_length - len(feature["src_spoken_idx"]))
|
166 |
+
remainder_word_src_length = [0] * (max_word_src_length - len(feature["inputs_length"]))
|
167 |
+
|
168 |
+
feature["labels"] = (
|
169 |
+
feature["outputs"] + [
|
170 |
+
self.tokenizer.eos_token_id] + remainder if padding_side == "right" else remainder + feature[
|
171 |
+
"outputs"] + [self.tokenizer.eos_token_id]
|
172 |
+
)
|
173 |
+
feature["labels_bias"] = (
|
174 |
+
bias_labels + [0] + remainder if padding_side == "right" else remainder + bias_labels + [0]
|
175 |
+
)
|
176 |
+
|
177 |
+
feature["spoken_label"] = [self.label_pad_token_id] + feature["spoken_label"] + [self.label_pad_token_id]
|
178 |
+
feature["spoken_label"] = feature["spoken_label"] + remainder_spoken if padding_side == "right" else remainder_spoken + feature["spoken_label"]
|
179 |
+
feature["src_spoken_idx"] = feature["src_spoken_idx"] + remainder_spoken_idx
|
180 |
+
|
181 |
+
feature['inputs_length'] = [1] + feature['inputs_length'] + [1]
|
182 |
+
feature['outputs_length'] = feature['outputs_length'] + [1]
|
183 |
+
|
184 |
+
feature["inputs_length"] = feature["inputs_length"] + remainder_word_src_length
|
185 |
+
feature["outputs_length"] = feature["outputs_length"] + remainder_word_tgt_length
|
186 |
+
|
187 |
+
|
188 |
+
features_inputs = [{
|
189 |
+
"input_ids": [self.tokenizer.bos_token_id] + item["input_ids"] + [self.tokenizer.eos_token_id],
|
190 |
+
"attention_mask": [self.tokenizer.pad_token_id] + item["attention_mask"] + [self.tokenizer.pad_token_id]
|
191 |
+
} for item in features]
|
192 |
+
features_inputs = self.tokenizer.pad(
|
193 |
+
features_inputs,
|
194 |
+
padding=self.padding,
|
195 |
+
max_length=self.max_length,
|
196 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
197 |
+
return_tensors=return_tensors,
|
198 |
+
)
|
199 |
+
|
200 |
+
bias_phrases_inputs = [{
|
201 |
+
"input_ids": ids,
|
202 |
+
"attention_mask": mask
|
203 |
+
} for ids, mask in zip(phrase_candidate_ids, phrase_candidate_mask)]
|
204 |
+
bias_phrases_inputs = self.tokenizer.pad(
|
205 |
+
bias_phrases_inputs,
|
206 |
+
padding=self.padding,
|
207 |
+
max_length=self.max_length,
|
208 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
209 |
+
return_tensors=return_tensors,
|
210 |
+
)
|
211 |
+
|
212 |
+
outputs = self.tokenizer.pad({"input_ids": [feature["labels"] for feature in features]},
|
213 |
+
return_tensors=return_tensors)['input_ids']
|
214 |
+
outputs_bias = self.tokenizer.pad({"input_ids": [feature["labels_bias"] for feature in features]},
|
215 |
+
return_tensors=return_tensors)['input_ids']
|
216 |
+
spoken_label = self.tokenizer.pad({"input_ids": [feature["spoken_label"] for feature in features]},
|
217 |
+
return_tensors=return_tensors)['input_ids']
|
218 |
+
spoken_idx = self.tokenizer.pad({"input_ids": [feature["src_spoken_idx"] for feature in features]},
|
219 |
+
return_tensors=return_tensors)['input_ids'] + 1 # 1 for bos token
|
220 |
+
word_src_lengths = self.tokenizer.pad({"input_ids": [feature["inputs_length"] for feature in features]},
|
221 |
+
return_tensors=return_tensors)['input_ids']
|
222 |
+
word_tgt_lengths = self.tokenizer.pad({"input_ids": [feature["outputs_length"] for feature in features]},
|
223 |
+
return_tensors=return_tensors)['input_ids']
|
224 |
+
|
225 |
+
features = {
|
226 |
+
"input_ids": features_inputs["input_ids"],
|
227 |
+
"spoken_label": spoken_label,
|
228 |
+
"spoken_idx": spoken_idx,
|
229 |
+
"word_src_lengths": word_src_lengths,
|
230 |
+
"word_tgt_lengths": word_tgt_lengths,
|
231 |
+
"attention_mask": features_inputs["attention_mask"],
|
232 |
+
"bias_input_ids": bias_phrases_inputs["input_ids"],
|
233 |
+
"bias_attention_mask": bias_phrases_inputs["attention_mask"],
|
234 |
+
"labels": outputs,
|
235 |
+
"labels_bias": outputs_bias
|
236 |
+
}
|
237 |
+
|
238 |
+
# print("Make batch {}s".format(time.time() - start_time))
|
239 |
+
# start_time = time.time()
|
240 |
+
|
241 |
+
# prepare decoder_input_ids
|
242 |
+
if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"):
|
243 |
+
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
|
244 |
+
features["decoder_input_ids"] = decoder_input_ids
|
245 |
+
|
246 |
+
return features
|
247 |
+
|
248 |
+
|
249 |
+
# data init
|
250 |
+
def init_data(train_corpus_path='./data-bin/raw/train_raw.txt',
|
251 |
+
test_corpus_path='./data-bin/raw/valid_raw.txt'):
|
252 |
+
dataset_oov = datasets.load_dataset('text', data_files={"train": train_corpus_path,
|
253 |
+
"test": test_corpus_path})
|
254 |
+
|
255 |
+
print(dataset_oov)
|
256 |
+
return dataset_oov
|
257 |
+
|
258 |
+
|
259 |
+
def preprocess_function(batch):
|
260 |
+
|
261 |
+
global tokenizer
|
262 |
+
if tokenizer is None:
|
263 |
+
tokenizer = model_handling.init_tokenizer()
|
264 |
+
|
265 |
+
features = []
|
266 |
+
for src_words, tgt_words in zip(batch["src"], batch["tgt"]):
|
267 |
+
src_ids, pad_ids, src_lengths, tgt_ids, tgt_lengths = [], [], [], [], []
|
268 |
+
src_spoken_label = [] # 0: "O", 1: "B", 2: "I"
|
269 |
+
|
270 |
+
src_spoken_idx = []
|
271 |
+
tgt_spoken_ids = []
|
272 |
+
|
273 |
+
for idx, (src, tgt) in enumerate(zip(src_words, tgt_words)):
|
274 |
+
is_remain = False
|
275 |
+
if src == tgt:
|
276 |
+
is_remain = True
|
277 |
+
|
278 |
+
src_tokenized = tokenizer(src)
|
279 |
+
if len(src_tokenized['input_ids']) < 3:
|
280 |
+
continue
|
281 |
+
# hardcode fix tokenizer email
|
282 |
+
if validators.email(tgt):
|
283 |
+
tgt_tokenized = tokenizer(tgt.replace('@', ' @'))
|
284 |
+
else:
|
285 |
+
tgt_tokenized = tokenizer(tgt)
|
286 |
+
if len(tgt_tokenized['input_ids']) < 3:
|
287 |
+
continue
|
288 |
+
src_ids.extend(src_tokenized["input_ids"][1:-1])
|
289 |
+
if is_remain:
|
290 |
+
src_spoken_label.extend([0 if random.random() < 0.5 else -100 for _ in range(len(src_tokenized["input_ids"][1:-1]))])
|
291 |
+
if random.random() < 0.1:
|
292 |
+
# Random pick normal word for spoken norm
|
293 |
+
src_spoken_idx.append(idx)
|
294 |
+
tgt_spoken_ids.append(tgt_tokenized["input_ids"][1:-1])
|
295 |
+
else:
|
296 |
+
src_spoken_label.extend([1] + [2] * (len(src_tokenized["input_ids"][1:-1]) - 1))
|
297 |
+
src_spoken_idx.append(idx)
|
298 |
+
tgt_spoken_ids.append(tgt_tokenized["input_ids"][1:-1])
|
299 |
+
|
300 |
+
pad_ids.extend(src_tokenized["attention_mask"][1:-1])
|
301 |
+
src_lengths.append(len(src_tokenized["input_ids"]) - 2)
|
302 |
+
tgt_ids.extend(tgt_tokenized["input_ids"][1:-1])
|
303 |
+
tgt_lengths.append(len(tgt_tokenized["input_ids"]) - 2)
|
304 |
+
if len(src_ids) > 70 or len(tgt_ids) > 70:
|
305 |
+
# print("Ignore sample")
|
306 |
+
break
|
307 |
+
|
308 |
+
if len(src_ids) < 1 or len(tgt_ids) < 1:
|
309 |
+
continue
|
310 |
+
# else:
|
311 |
+
# print("ignore")
|
312 |
+
|
313 |
+
features.append({
|
314 |
+
"input_ids": src_ids,
|
315 |
+
"attention_mask": pad_ids,
|
316 |
+
"spoken_label": src_spoken_label,
|
317 |
+
"inputs_length": src_lengths,
|
318 |
+
"outputs": tgt_ids,
|
319 |
+
"outputs_length": tgt_lengths,
|
320 |
+
"src_spoken_idx": src_spoken_idx,
|
321 |
+
"tgt_spoken_ids": tgt_spoken_ids
|
322 |
+
})
|
323 |
+
|
324 |
+
return features
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == "__main__":
|
328 |
+
split_datasets = init_data()
|
329 |
+
|
330 |
+
model, model_tokenizer = model_handling.init_model()
|
331 |
+
data_collator = DataCollatorForNormSeq2Seq(model_tokenizer, model=model)
|
332 |
+
|
333 |
+
# start = time.time()
|
334 |
+
batch = data_collator([split_datasets["train"][i] for i in [random.randint(0, 900) for _ in range(0, 12)]])
|
335 |
+
print(batch)
|
336 |
+
# print("{}s".format(time.time() - start))
|
infer.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
import torch
|
4 |
+
import model_handling
|
5 |
+
from data_handling import DataCollatorForNormSeq2Seq
|
6 |
+
from model_handling import EncoderDecoderSpokenNorm
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import data_handling
|
10 |
+
from transformers.generation_logits_process import LogitsProcessorList
|
11 |
+
from transformers.generation_stopping_criteria import StoppingCriteriaList
|
12 |
+
from transformers.generation_beam_search import BeamSearchScorer
|
13 |
+
from dataclasses import dataclass
|
14 |
+
from transformers.file_utils import ModelOutput
|
15 |
+
import utils
|
16 |
+
|
17 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "4"
|
18 |
+
|
19 |
+
use_gpu = False
|
20 |
+
if use_gpu:
|
21 |
+
if not torch.cuda.is_available():
|
22 |
+
use_gpu = False
|
23 |
+
tokenizer = model_handling.init_tokenizer()
|
24 |
+
model = EncoderDecoderSpokenNorm.from_pretrained('nguyenvulebinh/spoken-norm-taggen-v2').eval()
|
25 |
+
data_collator = DataCollatorForNormSeq2Seq(tokenizer)
|
26 |
+
if use_gpu:
|
27 |
+
model = model.cuda()
|
28 |
+
|
29 |
+
|
30 |
+
def make_batch_input(text_input_list):
|
31 |
+
batch_src_ids, batch_src_lengths = [], []
|
32 |
+
for text_input in text_input_list:
|
33 |
+
src_ids, src_lengths = [], []
|
34 |
+
for src in text_input.split():
|
35 |
+
src_tokenized = tokenizer(src)
|
36 |
+
ids = src_tokenized["input_ids"][1:-1]
|
37 |
+
src_ids.extend(ids)
|
38 |
+
src_lengths.append(len(ids))
|
39 |
+
src_ids = torch.tensor([0] + src_ids + [2])
|
40 |
+
src_lengths = torch.tensor([1] + src_lengths + [1]) + 1
|
41 |
+
batch_src_ids.append(src_ids)
|
42 |
+
batch_src_lengths.append(src_lengths)
|
43 |
+
assert sum(src_lengths - 1) == len(src_ids), "{} vs {}".format(sum(src_lengths), len(src_ids))
|
44 |
+
input_tokenized = tokenizer.pad({"input_ids": batch_src_ids}, padding=True)
|
45 |
+
input_word_length = tokenizer.pad({"input_ids": batch_src_lengths}, padding=True)["input_ids"] - 1
|
46 |
+
return input_tokenized['input_ids'], input_tokenized['attention_mask'], input_word_length
|
47 |
+
|
48 |
+
|
49 |
+
def make_batch_bias_list(bias_list):
|
50 |
+
if len(bias_list) > 0:
|
51 |
+
bias = data_collator.encode_list_string(bias_list)
|
52 |
+
bias_input_ids = bias['input_ids']
|
53 |
+
bias_attention_mask = bias['attention_mask']
|
54 |
+
else:
|
55 |
+
bias_input_ids = None
|
56 |
+
bias_attention_mask = None
|
57 |
+
|
58 |
+
return bias_input_ids, bias_attention_mask
|
59 |
+
|
60 |
+
|
61 |
+
def build_spoken_pronounce_mapping(bias_list):
|
62 |
+
list_pronounce = []
|
63 |
+
mapping = dict({})
|
64 |
+
for item in bias_list:
|
65 |
+
pronounces = item.split(' | ')[1:]
|
66 |
+
pronounces = [tokenizer(item)['input_ids'][1:-1] for item in pronounces]
|
67 |
+
list_pronounce.extend(pronounces)
|
68 |
+
subword_ids = list(set([item for sublist in list_pronounce for item in sublist]))
|
69 |
+
mapping = {item: [] for item in subword_ids}
|
70 |
+
for item in list_pronounce:
|
71 |
+
for wid in subword_ids:
|
72 |
+
if wid in item:
|
73 |
+
mapping[wid].append(item)
|
74 |
+
return mapping
|
75 |
+
|
76 |
+
def find_pivot(seq, subseq):
|
77 |
+
n = len(seq)
|
78 |
+
m = len(subseq)
|
79 |
+
result = []
|
80 |
+
for i in range(n - m + 1):
|
81 |
+
if seq[i] == subseq[0] and seq[i:i + m] == subseq:
|
82 |
+
result.append(i)
|
83 |
+
return result
|
84 |
+
|
85 |
+
def revise_spoken_tagging(list_tags, list_words, pronounce_mapping):
|
86 |
+
if len(pronounce_mapping) == 0:
|
87 |
+
return list_tags
|
88 |
+
result = []
|
89 |
+
for tags_tensor, sen in zip(list_tags, list_words):
|
90 |
+
tags = tags_tensor.detach().numpy().tolist()
|
91 |
+
sen = sen.detach().numpy().tolist()
|
92 |
+
candidate_pronounce = dict({})
|
93 |
+
for idx in range(len(tags)):
|
94 |
+
if tags[idx] != 0 and sen[idx] in pronounce_mapping:
|
95 |
+
for pronounce in pronounce_mapping[sen[idx]]:
|
96 |
+
pronounce_word = str(pronounce)
|
97 |
+
start_find_idx = max(0, idx - len(pronounce))
|
98 |
+
end_find_idx = idx + len(pronounce)
|
99 |
+
find_idx = find_pivot(sen[start_find_idx: end_find_idx], pronounce)
|
100 |
+
if len(find_idx) > 0:
|
101 |
+
find_idx = [item + start_find_idx for item in find_idx]
|
102 |
+
for map_idx in find_idx:
|
103 |
+
if candidate_pronounce.get(map_idx, None) is None:
|
104 |
+
candidate_pronounce[map_idx] = len(pronounce)
|
105 |
+
else:
|
106 |
+
candidate_pronounce[map_idx] = max(candidate_pronounce[map_idx], len(pronounce))
|
107 |
+
for idx, len_word in candidate_pronounce.items():
|
108 |
+
tags_tensor[idx] = 1
|
109 |
+
for i in range(1, len_word):
|
110 |
+
tags_tensor[idx + i] = 2
|
111 |
+
result.append(tags_tensor)
|
112 |
+
return result
|
113 |
+
|
114 |
+
|
115 |
+
def make_spoken_feature(input_features, text_input_list, pronounce_mapping=dict({})):
|
116 |
+
features = {
|
117 |
+
"input_ids": input_features[0],
|
118 |
+
"word_src_lengths": input_features[2],
|
119 |
+
"attention_mask": input_features[1],
|
120 |
+
# "bias_input_ids": bias_features[0],
|
121 |
+
# "bias_attention_mask": bias_features[1],
|
122 |
+
"bias_input_ids": None,
|
123 |
+
"bias_attention_mask": None,
|
124 |
+
}
|
125 |
+
if use_gpu:
|
126 |
+
for key in features.keys():
|
127 |
+
if features[key] is not None:
|
128 |
+
features[key] = features[key].cuda()
|
129 |
+
|
130 |
+
encoder_output = model.get_encoder()(**features)
|
131 |
+
spoken_tagging_output = torch.argmax(encoder_output[0].spoken_tagging_output, dim=-1)
|
132 |
+
spoken_tagging_output = revise_spoken_tagging(spoken_tagging_output, features['input_ids'], pronounce_mapping)
|
133 |
+
|
134 |
+
# print(spoken_tagging_output)
|
135 |
+
# print(features['input_ids'])
|
136 |
+
word_src_lengths = features['word_src_lengths']
|
137 |
+
encoder_features = encoder_output[0][0]
|
138 |
+
list_spoken_features = []
|
139 |
+
list_pre_norm = []
|
140 |
+
for tagging_sample, sample_word_length, text_input_features, sample_text in zip(spoken_tagging_output, word_src_lengths, encoder_features, text_input_list):
|
141 |
+
spoken_feature_idx = []
|
142 |
+
sample_words = ['<s>'] + sample_text.split() + ['</s>']
|
143 |
+
norm_words = []
|
144 |
+
spoken_phrase = []
|
145 |
+
spoken_features = []
|
146 |
+
if tagging_sample.sum() == 0:
|
147 |
+
list_pre_norm.append(sample_words)
|
148 |
+
continue
|
149 |
+
for idx, word_length in enumerate(sample_word_length):
|
150 |
+
if word_length > 0:
|
151 |
+
start = sample_word_length[:idx].sum()
|
152 |
+
end = start + word_length
|
153 |
+
if tagging_sample[start: end].sum() > 0 and sample_words[idx] not in ['<s>', '</s>']:
|
154 |
+
# Word has start tag
|
155 |
+
if (tagging_sample[start: end] == 1).sum():
|
156 |
+
if len(spoken_phrase) > 0:
|
157 |
+
norm_words.append('<mask>[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase)))
|
158 |
+
spoken_phrase = []
|
159 |
+
list_spoken_features.append(torch.cat(spoken_features))
|
160 |
+
spoken_features = []
|
161 |
+
spoken_phrase.append(sample_words[idx])
|
162 |
+
spoken_features.append(text_input_features[start: end])
|
163 |
+
else:
|
164 |
+
if len(spoken_phrase) > 0:
|
165 |
+
norm_words.append('<mask>[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase)))
|
166 |
+
spoken_phrase = []
|
167 |
+
list_spoken_features.append(torch.cat(spoken_features))
|
168 |
+
spoken_features = []
|
169 |
+
norm_words.append(sample_words[idx])
|
170 |
+
if len(spoken_phrase) > 0:
|
171 |
+
norm_words.append('<mask>[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase)))
|
172 |
+
spoken_phrase = []
|
173 |
+
list_spoken_features.append(torch.cat(spoken_features))
|
174 |
+
spoken_features = []
|
175 |
+
list_pre_norm.append(norm_words)
|
176 |
+
|
177 |
+
|
178 |
+
list_features_mask = []
|
179 |
+
if len(list_spoken_features) > 0:
|
180 |
+
feature_pad = torch.zeros_like(list_spoken_features[0][:1, :])
|
181 |
+
max_length = max([len(item) for item in list_spoken_features])
|
182 |
+
for i in range(len(list_spoken_features)):
|
183 |
+
spoken_length = len(list_spoken_features[i])
|
184 |
+
remain_length = max_length - spoken_length
|
185 |
+
device = list_spoken_features[i].device
|
186 |
+
list_spoken_features[i] = torch.cat([list_spoken_features[i],
|
187 |
+
feature_pad.expand(remain_length, feature_pad.size(-1))]).unsqueeze(0)
|
188 |
+
list_features_mask.append(torch.cat([torch.ones(spoken_length, device=device, dtype=torch.int64),
|
189 |
+
torch.zeros(remain_length, device=device, dtype=torch.int64)]).unsqueeze(0))
|
190 |
+
if len(list_spoken_features) > 0:
|
191 |
+
list_spoken_features = torch.cat(list_spoken_features)
|
192 |
+
list_features_mask = torch.cat(list_features_mask)
|
193 |
+
|
194 |
+
return list_spoken_features, list_features_mask, list_pre_norm
|
195 |
+
|
196 |
+
|
197 |
+
def make_bias_feature(bias_raw_features):
|
198 |
+
features = {
|
199 |
+
"bias_input_ids": bias_raw_features[0],
|
200 |
+
"bias_attention_mask": bias_raw_features[1]
|
201 |
+
}
|
202 |
+
if use_gpu:
|
203 |
+
for key in features.keys():
|
204 |
+
if features[key] is not None:
|
205 |
+
features[key] = features[key].cuda()
|
206 |
+
return model.forward_bias(**features)
|
207 |
+
|
208 |
+
|
209 |
+
def decode_plain_output(decoder_output):
|
210 |
+
plain_output = [item.split()[1:] for item in tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=False)]
|
211 |
+
scores = torch.stack(list(decoder_output['scores'])).transpose(1, 0)
|
212 |
+
logit_output = torch.gather(scores, -1, decoder_output['sequences'][:, 1:].unsqueeze(-1)).squeeze(-1)
|
213 |
+
special_tokens = list(tokenizer.special_tokens_map.values())
|
214 |
+
generated_output = []
|
215 |
+
generated_scores = []
|
216 |
+
# filter special tokens
|
217 |
+
for out_text, out_score in zip(plain_output, logit_output):
|
218 |
+
temp_str, tmp_score = [], []
|
219 |
+
for piece, score in zip(out_text, out_score):
|
220 |
+
if piece not in special_tokens:
|
221 |
+
temp_str.append(piece)
|
222 |
+
tmp_score.append(score)
|
223 |
+
if len(temp_str) > 0:
|
224 |
+
generated_output.append(' '.join(temp_str).replace('▁', '|').replace(' ', '').replace('|', ' ').strip())
|
225 |
+
generated_scores.append((sum(tmp_score)/len(tmp_score)).cpu().detach().numpy().tolist())
|
226 |
+
else:
|
227 |
+
generated_output.append("")
|
228 |
+
generated_scores.append(0)
|
229 |
+
return generated_output, generated_scores
|
230 |
+
|
231 |
+
|
232 |
+
def generate_spoken_norm(list_spoken_features, list_features_mask, bias_features):
|
233 |
+
@dataclass
|
234 |
+
class EncoderOutputs(ModelOutput):
|
235 |
+
last_hidden_state: torch.FloatTensor = None
|
236 |
+
hidden_states: torch.FloatTensor = None
|
237 |
+
attentions: torch.FloatTensor = None
|
238 |
+
|
239 |
+
batch_size = list_spoken_features.size(0)
|
240 |
+
max_length = 50
|
241 |
+
device = list_spoken_features.device
|
242 |
+
decoder_input_ids = torch.zeros((batch_size, 1), device=device, dtype=torch.int64)
|
243 |
+
stopping_criteria = model._get_stopping_criteria(max_length=max_length, max_time=None,
|
244 |
+
stopping_criteria=StoppingCriteriaList())
|
245 |
+
model_kwargs = {
|
246 |
+
"encoder_outputs": EncoderOutputs(last_hidden_state=list_spoken_features),
|
247 |
+
"encoder_bias_outputs": bias_features,
|
248 |
+
"attention_mask": list_features_mask
|
249 |
+
}
|
250 |
+
decoder_output = model.greedy_search(
|
251 |
+
decoder_input_ids,
|
252 |
+
logits_processor=LogitsProcessorList(),
|
253 |
+
stopping_criteria=stopping_criteria,
|
254 |
+
pad_token_id=tokenizer.pad_token_id,
|
255 |
+
eos_token_id=tokenizer.eos_token_id,
|
256 |
+
output_scores=True,
|
257 |
+
return_dict_in_generate=True,
|
258 |
+
**model_kwargs,
|
259 |
+
)
|
260 |
+
plain_output, plain_score = decode_plain_output(decoder_output)
|
261 |
+
# plain_output = tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=True)
|
262 |
+
# # print(decoder_output)
|
263 |
+
# plain_output = [word.replace('▁', '|').replace(' ', '').replace('|', ' ').strip() for word in plain_output]
|
264 |
+
return plain_output, plain_score
|
265 |
+
|
266 |
+
|
267 |
+
def generate_beam_spoken_norm(list_spoken_features, list_features_mask, bias_features, num_beams=3):
|
268 |
+
@dataclass
|
269 |
+
class EncoderOutputs(ModelOutput):
|
270 |
+
last_hidden_state: torch.FloatTensor = None
|
271 |
+
|
272 |
+
batch_size = list_spoken_features.size(0)
|
273 |
+
max_length = 50
|
274 |
+
num_return_sequences = 1
|
275 |
+
device = list_spoken_features.device
|
276 |
+
decoder_input_ids = torch.zeros((batch_size, 1), device=device, dtype=torch.int64)
|
277 |
+
stopping_criteria = model._get_stopping_criteria(max_length=max_length, max_time=None,
|
278 |
+
stopping_criteria=StoppingCriteriaList())
|
279 |
+
model_kwargs = {
|
280 |
+
"encoder_outputs": EncoderOutputs(last_hidden_state=list_spoken_features),
|
281 |
+
"encoder_bias_outputs": bias_features,
|
282 |
+
"attention_mask": list_features_mask
|
283 |
+
}
|
284 |
+
beam_scorer = BeamSearchScorer(
|
285 |
+
batch_size=batch_size,
|
286 |
+
num_beams=num_beams,
|
287 |
+
device=device,
|
288 |
+
do_early_stopping=True,
|
289 |
+
num_beam_hyps_to_keep=num_return_sequences,
|
290 |
+
)
|
291 |
+
decoder_input_ids, model_kwargs = model._expand_inputs_for_generation(
|
292 |
+
decoder_input_ids, expand_size=num_beams, is_encoder_decoder=True, **model_kwargs
|
293 |
+
)
|
294 |
+
|
295 |
+
decoder_output = model.beam_search(
|
296 |
+
decoder_input_ids,
|
297 |
+
beam_scorer,
|
298 |
+
logits_processor=LogitsProcessorList(),
|
299 |
+
stopping_criteria=stopping_criteria,
|
300 |
+
pad_token_id=tokenizer.pad_token_id,
|
301 |
+
eos_token_id=tokenizer.eos_token_id,
|
302 |
+
output_scores=None,
|
303 |
+
return_dict_in_generate=True,
|
304 |
+
**model_kwargs,
|
305 |
+
)
|
306 |
+
|
307 |
+
plain_output = tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=True)
|
308 |
+
plain_output = [word.replace('▁', '|').replace(' ', '').replace('|', ' ').strip() for word in plain_output]
|
309 |
+
return plain_output, None
|
310 |
+
|
311 |
+
|
312 |
+
def reformat_normed_term(list_pre_norm, spoken_norm_output, spoken_norm_output_score=None, threshold=None, debug=False):
|
313 |
+
output = []
|
314 |
+
for pre_norm in list_pre_norm:
|
315 |
+
normed_words = []
|
316 |
+
# words = pre_norm.split()
|
317 |
+
for w in pre_norm:
|
318 |
+
if w.startswith('<mask>'):
|
319 |
+
term = w[7:].split('](')
|
320 |
+
# print(w)
|
321 |
+
# print(term)
|
322 |
+
term_idx = int(term[0])
|
323 |
+
norm_val = spoken_norm_output[term_idx]
|
324 |
+
norm_val_score = None if (spoken_norm_output_score is None or threshold is None) else spoken_norm_output_score[term_idx]
|
325 |
+
pre_norm_val = term[1][:-1]
|
326 |
+
if debug:
|
327 |
+
if norm_val_score is not None:
|
328 |
+
normed_words.append("({})({:.2f})[{}]".format(norm_val, norm_val_score, pre_norm_val))
|
329 |
+
else:
|
330 |
+
normed_words.append("({})[{}]".format(norm_val, pre_norm_val))
|
331 |
+
else:
|
332 |
+
if threshold is not None and norm_val_score is not None:
|
333 |
+
if norm_val_score > threshold:
|
334 |
+
normed_words.append(norm_val)
|
335 |
+
else:
|
336 |
+
normed_words.append(pre_norm_val)
|
337 |
+
else:
|
338 |
+
normed_words.append(norm_val)
|
339 |
+
else:
|
340 |
+
normed_words.append(w)
|
341 |
+
output.append(" ".join(normed_words))
|
342 |
+
return output
|
343 |
+
|
344 |
+
|
345 |
+
def infer(text_input_list, bias_list):
|
346 |
+
# extract bias feature
|
347 |
+
bias_raw_features = make_batch_bias_list(bias_list)
|
348 |
+
bias_features = make_bias_feature(bias_raw_features)
|
349 |
+
pronounce_mapping = build_spoken_pronounce_mapping(bias_list)
|
350 |
+
|
351 |
+
# Chunk split input and create feature
|
352 |
+
text_input_chunk_list = [utils.split_chunk_input(item, chunk_size=60, overlap=20) for item in text_input_list]
|
353 |
+
num_chunks = [len(i) for i in text_input_chunk_list]
|
354 |
+
flatten_list = [y for x in text_input_chunk_list for y in x]
|
355 |
+
input_raw_features = make_batch_input(flatten_list)
|
356 |
+
|
357 |
+
# Extract norm term and spoken feature
|
358 |
+
list_spoken_features, list_features_mask, list_pre_norm = make_spoken_feature(input_raw_features, flatten_list, pronounce_mapping)
|
359 |
+
|
360 |
+
# Merge overlap chunks
|
361 |
+
list_pre_norm_by_input = []
|
362 |
+
for idx, input_num in enumerate(num_chunks):
|
363 |
+
start = sum(num_chunks[:idx])
|
364 |
+
end = start + num_chunks[idx]
|
365 |
+
list_pre_norm_by_input.append(list_pre_norm[start:end])
|
366 |
+
text_input_list_pre_norm = [utils.merge_chunk_pre_norm(list_chunks, overlap=20, debug=False) for list_chunks in list_pre_norm_by_input]
|
367 |
+
|
368 |
+
if len(list_spoken_features) > 0:
|
369 |
+
spoken_norm_output, spoken_norm_score = generate_spoken_norm(list_spoken_features, list_features_mask, bias_features)
|
370 |
+
else:
|
371 |
+
spoken_norm_output, spoken_norm_score = [], None
|
372 |
+
|
373 |
+
return reformat_normed_term(text_input_list_pre_norm, spoken_norm_output, spoken_norm_score, threshold=15, debug=False)
|
374 |
+
|
model_config_handling.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import copy
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers import BertConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
# from model_handling import DecoderSpokenNorm
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
class DecoderSpokenNormConfig(BertConfig):
|
29 |
+
# model_type = "decoder-spoken-norm"
|
30 |
+
|
31 |
+
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
|
32 |
+
"""Constructs RobertaConfig."""
|
33 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
34 |
+
self.num_hidden_layers=2
|
35 |
+
# self.hidden_layers_from_pretrained = list(range(self.num_hidden_layers))
|
36 |
+
# self.hidden_layers_from_pretrained = [0, 3]
|
37 |
+
|
38 |
+
# if len(self.hidden_layers_from_pretrained) < self.num_hidden_layers:
|
39 |
+
# self.num_hidden_layers = len(self.hidden_layers_from_pretrained)
|
40 |
+
|
41 |
+
|
42 |
+
class EncoderDecoderSpokenNormConfig(PretrainedConfig):
|
43 |
+
# model_type = "encoder-decoder-spoken-norm"
|
44 |
+
is_composition = True
|
45 |
+
|
46 |
+
def __init__(self, **kwargs):
|
47 |
+
super().__init__(**kwargs)
|
48 |
+
assert (
|
49 |
+
"encoder" in kwargs and "decoder" in kwargs
|
50 |
+
), "Config has to be initialized with encoder and decoder config"
|
51 |
+
encoder_config = kwargs.pop("encoder")
|
52 |
+
encoder_model_type = encoder_config.pop("model_type")
|
53 |
+
decoder_config = kwargs.pop("decoder")
|
54 |
+
decoder_model_type = decoder_config.pop("model_type")
|
55 |
+
|
56 |
+
from transformers.models.auto.configuration_auto import AutoConfig
|
57 |
+
|
58 |
+
self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
|
59 |
+
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
|
60 |
+
self.is_encoder_decoder = True
|
61 |
+
|
62 |
+
@classmethod
|
63 |
+
def from_encoder_decoder_configs(
|
64 |
+
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
|
65 |
+
) -> PretrainedConfig:
|
66 |
+
r"""
|
67 |
+
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model
|
68 |
+
configuration and decoder model configuration.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
:class:`EncoderDecoderConfig`: An instance of a configuration object
|
72 |
+
"""
|
73 |
+
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
74 |
+
decoder_config.is_decoder = True
|
75 |
+
decoder_config.add_cross_attention = True
|
76 |
+
|
77 |
+
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
|
78 |
+
|
79 |
+
def to_dict(self):
|
80 |
+
"""
|
81 |
+
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
85 |
+
"""
|
86 |
+
output = copy.deepcopy(self.__dict__)
|
87 |
+
output["encoder"] = self.encoder.to_dict()
|
88 |
+
output["decoder"] = self.decoder.to_dict()
|
89 |
+
output["model_type"] = self.__class__.model_type
|
90 |
+
return output
|
model_handling.py
ADDED
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.file_utils import cached_path, hf_bucket_url
|
2 |
+
from importlib.machinery import SourceFileLoader
|
3 |
+
import os
|
4 |
+
from transformers import EncoderDecoderModel, AutoConfig, AutoModel, EncoderDecoderConfig, RobertaForCausalLM, \
|
5 |
+
RobertaModel
|
6 |
+
from transformers.modeling_utils import PreTrainedModel, logging
|
7 |
+
import torch
|
8 |
+
from torch.nn import CrossEntropyLoss, Parameter
|
9 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput, CausalLMOutputWithCrossAttentions, \
|
10 |
+
ModelOutput
|
11 |
+
from attentions import ScaledDotProductAttention, MultiHeadAttention
|
12 |
+
from collections import namedtuple
|
13 |
+
from typing import Dict, Any, Optional, Tuple
|
14 |
+
from dataclasses import dataclass
|
15 |
+
import random
|
16 |
+
from model_config_handling import EncoderDecoderSpokenNormConfig, DecoderSpokenNormConfig, PretrainedConfig
|
17 |
+
|
18 |
+
cache_dir = './cache'
|
19 |
+
model_name = 'nguyenvulebinh/envibert'
|
20 |
+
|
21 |
+
if not os.path.exists(cache_dir):
|
22 |
+
os.makedirs(cache_dir)
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class SpokenNormOutput(ModelOutput):
|
28 |
+
loss: Optional[torch.FloatTensor] = None
|
29 |
+
logits: torch.FloatTensor = None
|
30 |
+
logits_spoken_tagging: torch.FloatTensor = None
|
31 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
32 |
+
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
33 |
+
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
34 |
+
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
35 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
36 |
+
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
37 |
+
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def collect_spoken_phrases_features(encoder_hidden_states, word_src_lengths, spoken_label):
|
43 |
+
list_features = []
|
44 |
+
list_features_mask = []
|
45 |
+
max_length = word_src_lengths.max()
|
46 |
+
feature_pad = torch.zeros_like(encoder_hidden_states[0, :1, :])
|
47 |
+
for hidden_state, word_length, list_idx in zip(encoder_hidden_states, word_src_lengths, spoken_label):
|
48 |
+
for idx in list_idx:
|
49 |
+
if idx > 0:
|
50 |
+
start = sum(word_length[:idx])
|
51 |
+
end = start + word_length[idx]
|
52 |
+
remain_length = max_length - word_length[idx]
|
53 |
+
list_features_mask.append(torch.cat([torch.ones_like(spoken_label[0, 0]).expand(word_length[idx]),
|
54 |
+
torch.zeros_like(
|
55 |
+
spoken_label[0, 0].expand(remain_length))]).unsqueeze(0))
|
56 |
+
spoken_phrases_feature = hidden_state[start: end]
|
57 |
+
|
58 |
+
list_features.append(torch.cat([spoken_phrases_feature,
|
59 |
+
feature_pad.expand(remain_length, feature_pad.size(-1))]).unsqueeze(0))
|
60 |
+
return torch.cat(list_features), torch.cat(list_features_mask)
|
61 |
+
|
62 |
+
|
63 |
+
def collect_spoken_phrases_labels(decoder_input_ids, labels, labels_bias, word_tgt_lengths, spoken_idx):
|
64 |
+
list_decoder_input_ids = []
|
65 |
+
list_labels = []
|
66 |
+
list_labels_bias = []
|
67 |
+
max_length = word_tgt_lengths.max()
|
68 |
+
init_decoder_ids = torch.tensor([0], device=labels.device, dtype=labels.dtype)
|
69 |
+
pad_decoder_ids = torch.tensor([1], device=labels.device, dtype=labels.dtype)
|
70 |
+
eos_decoder_ids = torch.tensor([2], device=labels.device, dtype=labels.dtype)
|
71 |
+
none_labels_bias = torch.tensor([0], device=labels.device, dtype=labels.dtype)
|
72 |
+
ignore_labels_bias = torch.tensor([-100], device=labels.device, dtype=labels.dtype)
|
73 |
+
|
74 |
+
for decoder_inputs, decoder_label, decoder_label_bias, word_length, list_idx in zip(decoder_input_ids,
|
75 |
+
labels, labels_bias,
|
76 |
+
word_tgt_lengths, spoken_idx):
|
77 |
+
for idx in list_idx:
|
78 |
+
if idx > 0:
|
79 |
+
start = sum(word_length[:idx - 1])
|
80 |
+
end = start + word_length[idx - 1]
|
81 |
+
remain_length = max_length - word_length[idx - 1]
|
82 |
+
remain_decoder_input_ids = max_length - len(decoder_inputs[start + 1:end + 1])
|
83 |
+
list_decoder_input_ids.append(torch.cat([init_decoder_ids,
|
84 |
+
decoder_inputs[start + 1:end + 1],
|
85 |
+
pad_decoder_ids.expand(remain_decoder_input_ids)]).unsqueeze(0))
|
86 |
+
list_labels.append(torch.cat([decoder_label[start:end],
|
87 |
+
eos_decoder_ids,
|
88 |
+
ignore_labels_bias.expand(remain_length)]).unsqueeze(0))
|
89 |
+
list_labels_bias.append(torch.cat([decoder_label_bias[start:end],
|
90 |
+
none_labels_bias,
|
91 |
+
ignore_labels_bias.expand(remain_length)]).unsqueeze(0))
|
92 |
+
|
93 |
+
decoder_input_ids = torch.cat(list_decoder_input_ids)
|
94 |
+
labels = torch.cat(list_labels)
|
95 |
+
labels_bias = torch.cat(list_labels_bias)
|
96 |
+
|
97 |
+
return decoder_input_ids, labels, labels_bias
|
98 |
+
|
99 |
+
|
100 |
+
class EncoderDecoderSpokenNorm(EncoderDecoderModel):
|
101 |
+
config_class = EncoderDecoderSpokenNormConfig
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
config: Optional[PretrainedConfig] = None,
|
106 |
+
encoder: Optional[PreTrainedModel] = None,
|
107 |
+
decoder: Optional[PreTrainedModel] = None,
|
108 |
+
):
|
109 |
+
if config is None and (encoder is None or decoder is None):
|
110 |
+
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
|
111 |
+
if config is None:
|
112 |
+
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
113 |
+
else:
|
114 |
+
if not isinstance(config, self.config_class):
|
115 |
+
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
|
116 |
+
|
117 |
+
if config.decoder.cross_attention_hidden_size is not None:
|
118 |
+
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
119 |
+
raise ValueError(
|
120 |
+
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
121 |
+
"it has to be equal to the encoder's `hidden_size`. "
|
122 |
+
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
123 |
+
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
124 |
+
)
|
125 |
+
|
126 |
+
# initialize with config
|
127 |
+
super().__init__(config)
|
128 |
+
|
129 |
+
if encoder is None:
|
130 |
+
from transformers.models.auto.modeling_auto import AutoModel
|
131 |
+
|
132 |
+
encoder = AutoModel.from_config(config.encoder)
|
133 |
+
|
134 |
+
if decoder is None:
|
135 |
+
# from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
136 |
+
|
137 |
+
decoder = DecoderSpokenNorm._from_config(config.decoder)
|
138 |
+
|
139 |
+
self.encoder = encoder
|
140 |
+
self.decoder = decoder
|
141 |
+
|
142 |
+
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
|
143 |
+
logger.warning(
|
144 |
+
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
|
145 |
+
)
|
146 |
+
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
|
147 |
+
logger.warning(
|
148 |
+
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
|
149 |
+
)
|
150 |
+
|
151 |
+
# make sure that the individual model's config refers to the shared config
|
152 |
+
# so that the updates to the config will be synced
|
153 |
+
self.encoder.config = self.config.encoder
|
154 |
+
self.decoder.config = self.config.decoder
|
155 |
+
|
156 |
+
# encoder outputs might need to be projected to different dimension for decoder
|
157 |
+
if (
|
158 |
+
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
159 |
+
and self.decoder.config.cross_attention_hidden_size is None
|
160 |
+
):
|
161 |
+
self.enc_to_dec_proj = torch.nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
|
162 |
+
|
163 |
+
if self.encoder.get_output_embeddings() is not None:
|
164 |
+
raise ValueError(
|
165 |
+
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
166 |
+
)
|
167 |
+
|
168 |
+
# spoken tagging
|
169 |
+
self.dropout = torch.nn.Dropout(0.3)
|
170 |
+
# 0: "O", 1: "B", 2: "I"
|
171 |
+
self.spoken_tagging_classifier = torch.nn.Linear(config.encoder.hidden_size, 3)
|
172 |
+
|
173 |
+
# tie encoder, decoder weights if config set accordingly
|
174 |
+
self.tie_weights()
|
175 |
+
|
176 |
+
@classmethod
|
177 |
+
def from_encoder_decoder_pretrained(
|
178 |
+
cls,
|
179 |
+
encoder_pretrained_model_name_or_path: str = None,
|
180 |
+
decoder_pretrained_model_name_or_path: str = None,
|
181 |
+
*model_args,
|
182 |
+
**kwargs
|
183 |
+
) -> PreTrainedModel:
|
184 |
+
|
185 |
+
kwargs_encoder = {
|
186 |
+
argument[len("encoder_"):]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
187 |
+
}
|
188 |
+
|
189 |
+
kwargs_decoder = {
|
190 |
+
argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
191 |
+
}
|
192 |
+
|
193 |
+
# remove encoder, decoder kwargs from kwargs
|
194 |
+
for key in kwargs_encoder.keys():
|
195 |
+
del kwargs["encoder_" + key]
|
196 |
+
for key in kwargs_decoder.keys():
|
197 |
+
del kwargs["decoder_" + key]
|
198 |
+
|
199 |
+
# Load and initialize the encoder and decoder
|
200 |
+
# The distinction between encoder and decoder at the model level is made
|
201 |
+
# by the value of the flag `is_decoder` that we need to set correctly.
|
202 |
+
encoder = kwargs_encoder.pop("model", None)
|
203 |
+
if encoder is None:
|
204 |
+
if encoder_pretrained_model_name_or_path is None:
|
205 |
+
raise ValueError(
|
206 |
+
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
207 |
+
"to be defined."
|
208 |
+
)
|
209 |
+
|
210 |
+
if "config" not in kwargs_encoder:
|
211 |
+
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
212 |
+
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
213 |
+
logger.info(
|
214 |
+
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
215 |
+
"from a decoder model. Cross-attention and casual mask are disabled."
|
216 |
+
)
|
217 |
+
encoder_config.is_decoder = False
|
218 |
+
encoder_config.add_cross_attention = False
|
219 |
+
|
220 |
+
kwargs_encoder["config"] = encoder_config
|
221 |
+
|
222 |
+
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args,
|
223 |
+
**kwargs_encoder)
|
224 |
+
|
225 |
+
decoder = kwargs_decoder.pop("model", None)
|
226 |
+
if decoder is None:
|
227 |
+
if decoder_pretrained_model_name_or_path is None:
|
228 |
+
raise ValueError(
|
229 |
+
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
230 |
+
"to be defined."
|
231 |
+
)
|
232 |
+
|
233 |
+
if "config" not in kwargs_decoder:
|
234 |
+
decoder_config = DecoderSpokenNormConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
235 |
+
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
236 |
+
logger.info(
|
237 |
+
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
238 |
+
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
239 |
+
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
240 |
+
"cross attention layers."
|
241 |
+
)
|
242 |
+
decoder_config.is_decoder = True
|
243 |
+
decoder_config.add_cross_attention = True
|
244 |
+
|
245 |
+
kwargs_decoder["config"] = decoder_config
|
246 |
+
|
247 |
+
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
248 |
+
logger.warning(
|
249 |
+
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
250 |
+
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
251 |
+
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
252 |
+
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
253 |
+
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
254 |
+
)
|
255 |
+
|
256 |
+
decoder = DecoderSpokenNorm.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
257 |
+
|
258 |
+
# instantiate config with corresponding kwargs
|
259 |
+
config = EncoderDecoderSpokenNormConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
260 |
+
return cls(encoder=encoder, decoder=decoder, config=config)
|
261 |
+
|
262 |
+
def get_encoder(self):
|
263 |
+
def forward(input_ids=None,
|
264 |
+
attention_mask=None,
|
265 |
+
bias_input_ids=None,
|
266 |
+
bias_attention_mask=None,
|
267 |
+
return_dict=True,
|
268 |
+
output_attentions=False,
|
269 |
+
output_hidden_states=False,
|
270 |
+
word_src_lengths=None,
|
271 |
+
spoken_idx=None,
|
272 |
+
**kwargs_encoder):
|
273 |
+
encoder_outputs = self.encoder(
|
274 |
+
input_ids=input_ids,
|
275 |
+
attention_mask=attention_mask,
|
276 |
+
inputs_embeds=None,
|
277 |
+
output_attentions=output_attentions,
|
278 |
+
output_hidden_states=output_hidden_states,
|
279 |
+
return_dict=return_dict,
|
280 |
+
**kwargs_encoder,
|
281 |
+
)
|
282 |
+
encoder_outputs.word_src_lengths = word_src_lengths
|
283 |
+
encoder_outputs.spoken_tagging_output = self.spoken_tagging_classifier(self.dropout(encoder_outputs[0]))
|
284 |
+
if spoken_idx is not None:
|
285 |
+
encoder_outputs.spoken_idx = spoken_idx
|
286 |
+
else:
|
287 |
+
pass
|
288 |
+
|
289 |
+
encoder_bias_outputs = self.forward_bias(bias_input_ids,
|
290 |
+
bias_attention_mask,
|
291 |
+
output_attentions=output_attentions,
|
292 |
+
return_dict=return_dict,
|
293 |
+
output_hidden_states=output_hidden_states,
|
294 |
+
**kwargs_encoder)
|
295 |
+
# d = {
|
296 |
+
# "encoder_bias_outputs": None,
|
297 |
+
# "bias_attention_mask": None,
|
298 |
+
# "last_hidden_state": None,
|
299 |
+
# "pooler_output": None
|
300 |
+
#
|
301 |
+
# }
|
302 |
+
# encoder_bias_outputs = namedtuple('Struct', d.keys())(*d.values())
|
303 |
+
# if bias_input_ids is not None:
|
304 |
+
# encoder_bias_outputs = self.encoder(
|
305 |
+
# input_ids=bias_input_ids,
|
306 |
+
# attention_mask=bias_attention_mask,
|
307 |
+
# inputs_embeds=None,
|
308 |
+
# output_attentions=output_attentions,
|
309 |
+
# output_hidden_states=output_hidden_states,
|
310 |
+
# return_dict=return_dict,
|
311 |
+
# **kwargs_encoder,
|
312 |
+
# )
|
313 |
+
# encoder_bias_outputs.bias_attention_mask = bias_attention_mask
|
314 |
+
return encoder_outputs, encoder_bias_outputs
|
315 |
+
|
316 |
+
return forward
|
317 |
+
|
318 |
+
def forward_bias(self,
|
319 |
+
bias_input_ids,
|
320 |
+
bias_attention_mask,
|
321 |
+
output_attentions=False,
|
322 |
+
return_dict=True,
|
323 |
+
output_hidden_states=False,
|
324 |
+
**kwargs_encoder):
|
325 |
+
d = {
|
326 |
+
"encoder_bias_outputs": None,
|
327 |
+
"bias_attention_mask": None,
|
328 |
+
"last_hidden_state": None,
|
329 |
+
"pooler_output": None
|
330 |
+
|
331 |
+
}
|
332 |
+
encoder_bias_outputs = namedtuple('Struct', d.keys())(*d.values())
|
333 |
+
if bias_input_ids is not None:
|
334 |
+
encoder_bias_outputs = self.encoder(
|
335 |
+
input_ids=bias_input_ids,
|
336 |
+
attention_mask=bias_attention_mask,
|
337 |
+
inputs_embeds=None,
|
338 |
+
output_attentions=output_attentions,
|
339 |
+
output_hidden_states=output_hidden_states,
|
340 |
+
return_dict=return_dict,
|
341 |
+
**kwargs_encoder,
|
342 |
+
)
|
343 |
+
encoder_bias_outputs.bias_attention_mask = bias_attention_mask
|
344 |
+
return encoder_bias_outputs
|
345 |
+
|
346 |
+
def _prepare_encoder_decoder_kwargs_for_generation(
|
347 |
+
self, input_ids: torch.LongTensor, model_kwargs, model_input_name
|
348 |
+
) -> Dict[str, Any]:
|
349 |
+
if "encoder_outputs" not in model_kwargs:
|
350 |
+
# retrieve encoder hidden states
|
351 |
+
encoder = self.get_encoder()
|
352 |
+
encoder_kwargs = {
|
353 |
+
argument: value
|
354 |
+
for argument, value in model_kwargs.items()
|
355 |
+
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
|
356 |
+
}
|
357 |
+
encoder_outputs, encoder_bias_outputs = encoder(input_ids, return_dict=True, **encoder_kwargs)
|
358 |
+
model_kwargs["encoder_outputs"]: ModelOutput = encoder_outputs
|
359 |
+
model_kwargs["encoder_bias_outputs"]: ModelOutput = encoder_bias_outputs
|
360 |
+
|
361 |
+
return model_kwargs
|
362 |
+
|
363 |
+
def _prepare_decoder_input_ids_for_generation(
|
364 |
+
self,
|
365 |
+
batch_size: int,
|
366 |
+
decoder_start_token_id: int = None,
|
367 |
+
bos_token_id: int = None,
|
368 |
+
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
369 |
+
) -> torch.LongTensor:
|
370 |
+
|
371 |
+
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
372 |
+
return model_kwargs.pop("decoder_input_ids")
|
373 |
+
else:
|
374 |
+
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
375 |
+
num_spoken_phrases = (model_kwargs['encoder_outputs'].spoken_idx >= 0).view(-1).sum()
|
376 |
+
return torch.ones((num_spoken_phrases, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
|
377 |
+
|
378 |
+
def prepare_inputs_for_generation(
|
379 |
+
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
380 |
+
):
|
381 |
+
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
382 |
+
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
383 |
+
input_dict = {
|
384 |
+
"attention_mask": attention_mask,
|
385 |
+
"decoder_attention_mask": decoder_attention_mask,
|
386 |
+
"decoder_input_ids": decoder_inputs["input_ids"],
|
387 |
+
"encoder_outputs": encoder_outputs,
|
388 |
+
"encoder_bias_outputs": kwargs["encoder_bias_outputs"],
|
389 |
+
"past_key_values": decoder_inputs["past_key_values"],
|
390 |
+
"use_cache": use_cache,
|
391 |
+
}
|
392 |
+
return input_dict
|
393 |
+
|
394 |
+
def forward(
|
395 |
+
self,
|
396 |
+
input_ids=None,
|
397 |
+
attention_mask=None,
|
398 |
+
decoder_input_ids=None,
|
399 |
+
bias_input_ids=None,
|
400 |
+
bias_attention_mask=None,
|
401 |
+
labels_bias=None,
|
402 |
+
decoder_attention_mask=None,
|
403 |
+
encoder_outputs=None,
|
404 |
+
encoder_bias_outputs=None,
|
405 |
+
past_key_values=None,
|
406 |
+
inputs_embeds=None,
|
407 |
+
decoder_inputs_embeds=None,
|
408 |
+
labels=None,
|
409 |
+
use_cache=None,
|
410 |
+
spoken_label=None,
|
411 |
+
word_src_lengths=None,
|
412 |
+
word_tgt_lengths=None,
|
413 |
+
spoken_idx=None,
|
414 |
+
output_attentions=None,
|
415 |
+
output_hidden_states=None,
|
416 |
+
return_dict=None,
|
417 |
+
inputs_length=None,
|
418 |
+
outputs=None,
|
419 |
+
outputs_length=None,
|
420 |
+
text=None,
|
421 |
+
**kwargs,
|
422 |
+
):
|
423 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
424 |
+
|
425 |
+
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
426 |
+
|
427 |
+
kwargs_decoder = {
|
428 |
+
argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
429 |
+
}
|
430 |
+
spoken_tagging_output = None
|
431 |
+
if encoder_outputs is None:
|
432 |
+
encoder_outputs = self.encoder(
|
433 |
+
input_ids=input_ids,
|
434 |
+
attention_mask=attention_mask,
|
435 |
+
inputs_embeds=inputs_embeds,
|
436 |
+
output_attentions=output_attentions,
|
437 |
+
output_hidden_states=output_hidden_states,
|
438 |
+
return_dict=return_dict,
|
439 |
+
**kwargs_encoder,
|
440 |
+
)
|
441 |
+
spoken_tagging_output = self.spoken_tagging_classifier(self.dropout(encoder_outputs[0]))
|
442 |
+
# else:
|
443 |
+
# word_src_lengths = encoder_outputs.word_src_lengths
|
444 |
+
# spoken_tagging_output = encoder_outputs.spoken_tagging_output
|
445 |
+
|
446 |
+
if encoder_bias_outputs is None:
|
447 |
+
encoder_bias_outputs = self.encoder(
|
448 |
+
input_ids=bias_input_ids,
|
449 |
+
attention_mask=bias_attention_mask,
|
450 |
+
inputs_embeds=inputs_embeds,
|
451 |
+
output_attentions=output_attentions,
|
452 |
+
output_hidden_states=output_hidden_states,
|
453 |
+
return_dict=return_dict,
|
454 |
+
**kwargs_encoder,
|
455 |
+
)
|
456 |
+
encoder_bias_outputs.bias_attention_mask = bias_attention_mask
|
457 |
+
|
458 |
+
encoder_hidden_states = encoder_outputs[0]
|
459 |
+
|
460 |
+
# if spoken_idx is None:
|
461 |
+
# # extract spoken_idx from spoken_tagging_output
|
462 |
+
# spoken_idx = None
|
463 |
+
|
464 |
+
# encoder_hidden_states, attention_mask = collect_spoken_phrases_features(encoder_hidden_states,
|
465 |
+
# word_src_lengths,
|
466 |
+
# spoken_idx)
|
467 |
+
# if labels is not None:
|
468 |
+
# decoder_input_ids, labels, labels_bias = collect_spoken_phrases_labels(decoder_input_ids,
|
469 |
+
# labels, labels_bias,
|
470 |
+
# word_tgt_lengths,
|
471 |
+
# spoken_idx)
|
472 |
+
|
473 |
+
if spoken_idx is not None:
|
474 |
+
encoder_hidden_states, attention_mask = collect_spoken_phrases_features(encoder_hidden_states,
|
475 |
+
word_src_lengths,
|
476 |
+
spoken_idx)
|
477 |
+
|
478 |
+
decoder_input_ids, labels, labels_bias = collect_spoken_phrases_labels(decoder_input_ids,
|
479 |
+
labels, labels_bias,
|
480 |
+
word_tgt_lengths,
|
481 |
+
spoken_idx)
|
482 |
+
|
483 |
+
|
484 |
+
# optionally project encoder_hidden_states
|
485 |
+
if (
|
486 |
+
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
487 |
+
and self.decoder.config.cross_attention_hidden_size is None
|
488 |
+
):
|
489 |
+
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
490 |
+
|
491 |
+
# Decode
|
492 |
+
decoder_outputs = self.decoder(
|
493 |
+
input_ids=decoder_input_ids,
|
494 |
+
attention_mask=decoder_attention_mask,
|
495 |
+
encoder_hidden_states=encoder_hidden_states,
|
496 |
+
encoder_bias_pooling=encoder_bias_outputs.pooler_output,
|
497 |
+
# encoder_bias_hidden_states=encoder_bias_outputs[0],
|
498 |
+
encoder_bias_hidden_states=encoder_bias_outputs.last_hidden_state,
|
499 |
+
bias_attention_mask=encoder_bias_outputs.bias_attention_mask,
|
500 |
+
encoder_attention_mask=attention_mask,
|
501 |
+
inputs_embeds=decoder_inputs_embeds,
|
502 |
+
output_attentions=output_attentions,
|
503 |
+
output_hidden_states=output_hidden_states,
|
504 |
+
use_cache=use_cache,
|
505 |
+
past_key_values=past_key_values,
|
506 |
+
return_dict=return_dict,
|
507 |
+
labels_bias=labels_bias,
|
508 |
+
**kwargs_decoder,
|
509 |
+
)
|
510 |
+
|
511 |
+
# Compute loss independent from decoder (as some shift the logits inside them)
|
512 |
+
loss = None
|
513 |
+
if labels is not None:
|
514 |
+
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
|
515 |
+
loss_fct = CrossEntropyLoss()
|
516 |
+
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
|
517 |
+
loss = loss + decoder_outputs.loss
|
518 |
+
|
519 |
+
if spoken_label is not None:
|
520 |
+
loss_fct = CrossEntropyLoss()
|
521 |
+
spoken_tagging_loss = loss_fct(spoken_tagging_output.reshape(-1, 3), spoken_label.view(-1))
|
522 |
+
loss = loss + spoken_tagging_loss
|
523 |
+
|
524 |
+
if not return_dict:
|
525 |
+
if loss is not None:
|
526 |
+
return (loss,) + decoder_outputs + encoder_outputs
|
527 |
+
else:
|
528 |
+
return decoder_outputs + encoder_outputs
|
529 |
+
|
530 |
+
return SpokenNormOutput(
|
531 |
+
loss=loss,
|
532 |
+
logits=decoder_outputs.logits,
|
533 |
+
logits_spoken_tagging=spoken_tagging_output,
|
534 |
+
past_key_values=decoder_outputs.past_key_values,
|
535 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
536 |
+
decoder_attentions=decoder_outputs.attentions,
|
537 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
538 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
539 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
540 |
+
encoder_attentions=encoder_outputs.attentions,
|
541 |
+
)
|
542 |
+
|
543 |
+
|
544 |
+
class DecoderSpokenNorm(RobertaForCausalLM):
|
545 |
+
config_class = DecoderSpokenNormConfig
|
546 |
+
|
547 |
+
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
|
548 |
+
def __init__(self, config):
|
549 |
+
super().__init__(config)
|
550 |
+
self.dense_query_copy = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
551 |
+
self.mem_no_entry = Parameter(torch.randn(config.hidden_size).unsqueeze(0))
|
552 |
+
self.bias_attention_layer = MultiHeadAttention(config.hidden_size)
|
553 |
+
self.copy_attention_layer = MultiHeadAttention(config.hidden_size)
|
554 |
+
|
555 |
+
def forward_bias_attention(self, query, values, values_mask):
|
556 |
+
"""
|
557 |
+
:param query: batch * output_steps * hidden_state
|
558 |
+
:param values: batch * output_steps * max_bias_steps * hidden_state
|
559 |
+
:param values_mask: batch * output_steps * max_bias_steps
|
560 |
+
:return: batch * output_steps * hidden_state
|
561 |
+
"""
|
562 |
+
batch, output_steps, hidden_state = query.size()
|
563 |
+
_, _, max_bias_steps, _ = values.size()
|
564 |
+
|
565 |
+
query = query.view(batch * output_steps, 1, hidden_state)
|
566 |
+
values = values.view(-1, max_bias_steps, hidden_state)
|
567 |
+
values_mask = 1 - values_mask.view(-1, max_bias_steps)
|
568 |
+
result_attention, attention_score = self.bias_attention_layer(query=query,
|
569 |
+
key=values,
|
570 |
+
value=values,
|
571 |
+
mask=values_mask.bool())
|
572 |
+
result_attention = result_attention.squeeze(1).view(batch, output_steps, hidden_state)
|
573 |
+
return result_attention
|
574 |
+
|
575 |
+
def forward_copy_attention(self, query, values, values_mask):
|
576 |
+
"""
|
577 |
+
:param query: batch * output_steps * hidden_state
|
578 |
+
:param values: batch * max_encoder_steps * hidden_state
|
579 |
+
:param values_mask: batch * output_steps * max_encoder_steps
|
580 |
+
:return: batch * output_steps * hidden_state
|
581 |
+
"""
|
582 |
+
dot_attn_score = torch.bmm(query, values.transpose(2, 1))
|
583 |
+
attn_mask = (1 - values_mask.clone().unsqueeze(1)).bool()
|
584 |
+
dot_attn_score.masked_fill_(attn_mask, -float('inf'))
|
585 |
+
dot_attn_score = torch.softmax(dot_attn_score, dim=-1)
|
586 |
+
result_attention = torch.bmm(dot_attn_score, values)
|
587 |
+
return result_attention
|
588 |
+
|
589 |
+
def forward(
|
590 |
+
self,
|
591 |
+
input_ids=None,
|
592 |
+
attention_mask=None,
|
593 |
+
token_type_ids=None,
|
594 |
+
position_ids=None,
|
595 |
+
head_mask=None,
|
596 |
+
encoder_bias_pooling=None,
|
597 |
+
encoder_bias_hidden_states=None,
|
598 |
+
bias_attention_mask=None,
|
599 |
+
inputs_embeds=None,
|
600 |
+
encoder_hidden_states=None,
|
601 |
+
encoder_attention_mask=None,
|
602 |
+
labels=None,
|
603 |
+
labels_bias=None,
|
604 |
+
past_key_values=None,
|
605 |
+
use_cache=None,
|
606 |
+
output_attentions=None,
|
607 |
+
output_hidden_states=None,
|
608 |
+
return_dict=None,
|
609 |
+
):
|
610 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
611 |
+
if labels is not None:
|
612 |
+
use_cache = False
|
613 |
+
|
614 |
+
# attention with input encoded
|
615 |
+
outputs = self.roberta(
|
616 |
+
input_ids,
|
617 |
+
attention_mask=attention_mask,
|
618 |
+
token_type_ids=token_type_ids,
|
619 |
+
position_ids=position_ids,
|
620 |
+
head_mask=head_mask,
|
621 |
+
inputs_embeds=inputs_embeds,
|
622 |
+
encoder_hidden_states=encoder_hidden_states,
|
623 |
+
encoder_attention_mask=encoder_attention_mask,
|
624 |
+
past_key_values=past_key_values,
|
625 |
+
use_cache=use_cache,
|
626 |
+
output_attentions=output_attentions,
|
627 |
+
output_hidden_states=output_hidden_states,
|
628 |
+
return_dict=return_dict,
|
629 |
+
)
|
630 |
+
|
631 |
+
# Query for bias
|
632 |
+
sequence_output = outputs[0]
|
633 |
+
bias_indicate_output = None
|
634 |
+
|
635 |
+
# output copy attention
|
636 |
+
query_copy = torch.relu(self.dense_query_copy(sequence_output))
|
637 |
+
sequence_atten_copy_output = self.forward_copy_attention(query_copy,
|
638 |
+
encoder_hidden_states,
|
639 |
+
encoder_attention_mask)
|
640 |
+
|
641 |
+
if encoder_bias_pooling is not None:
|
642 |
+
|
643 |
+
# Make bias features
|
644 |
+
encoder_bias_pooling = torch.cat([self.mem_no_entry, encoder_bias_pooling], dim=0)
|
645 |
+
mem_no_entry_feature = torch.zeros_like(encoder_bias_hidden_states[0]).unsqueeze(0)
|
646 |
+
mem_no_entry_mask = torch.ones_like(bias_attention_mask[0]).unsqueeze(0)
|
647 |
+
encoder_bias_hidden_states = torch.cat([mem_no_entry_feature, encoder_bias_hidden_states], dim=0)
|
648 |
+
bias_attention_mask = torch.cat([mem_no_entry_mask, bias_attention_mask], dim=0)
|
649 |
+
|
650 |
+
# Compute ranking score
|
651 |
+
b, s, h = sequence_output.size()
|
652 |
+
bias_ranking_score = sequence_output.view(b * s, h).mm(encoder_bias_pooling.T)
|
653 |
+
bias_ranking_score = bias_ranking_score.view(b, s, encoder_bias_pooling.size(0))
|
654 |
+
|
655 |
+
# teacher force with bias label
|
656 |
+
if not self.training:
|
657 |
+
bias_indicate_output = torch.argmax(bias_ranking_score, dim=-1)
|
658 |
+
else:
|
659 |
+
if random.random() < 0.5:
|
660 |
+
bias_indicate_output = labels_bias.clone()
|
661 |
+
bias_indicate_output[torch.where(bias_indicate_output < 0)] = 0
|
662 |
+
else:
|
663 |
+
bias_indicate_output = torch.argmax(bias_ranking_score, dim=-1)
|
664 |
+
|
665 |
+
# Bias encoder hidden state
|
666 |
+
_, max_len, _ = encoder_bias_hidden_states.size()
|
667 |
+
bias_encoder_hidden_states = torch.index_select(input=encoder_bias_hidden_states,
|
668 |
+
dim=0,
|
669 |
+
index=bias_indicate_output.view(b * s)).view(b, s, max_len,
|
670 |
+
h)
|
671 |
+
bias_encoder_attention_mask = torch.index_select(input=bias_attention_mask,
|
672 |
+
dim=0,
|
673 |
+
index=bias_indicate_output.view(b * s)).view(b, s, max_len)
|
674 |
+
|
675 |
+
sequence_atten_bias_output = self.forward_bias_attention(sequence_output,
|
676 |
+
bias_encoder_hidden_states,
|
677 |
+
bias_encoder_attention_mask)
|
678 |
+
|
679 |
+
# Find output words
|
680 |
+
prediction_scores = self.lm_head(sequence_output + sequence_atten_bias_output + sequence_atten_copy_output)
|
681 |
+
else:
|
682 |
+
prediction_scores = self.lm_head(sequence_output + sequence_atten_copy_output)
|
683 |
+
|
684 |
+
# run attention with bias
|
685 |
+
|
686 |
+
bias_ranking_loss = None
|
687 |
+
if labels_bias is not None:
|
688 |
+
loss_fct = CrossEntropyLoss()
|
689 |
+
bias_ranking_loss = loss_fct(bias_ranking_score.view(-1, encoder_bias_pooling.size(0)),
|
690 |
+
labels_bias.view(-1))
|
691 |
+
|
692 |
+
if not return_dict:
|
693 |
+
output = (prediction_scores,) + outputs[2:]
|
694 |
+
return ((bias_ranking_loss,) + output) if bias_ranking_loss is not None else output
|
695 |
+
|
696 |
+
result = CausalLMOutputWithCrossAttentions(
|
697 |
+
loss=bias_ranking_loss,
|
698 |
+
logits=prediction_scores,
|
699 |
+
past_key_values=outputs.past_key_values,
|
700 |
+
hidden_states=outputs.hidden_states,
|
701 |
+
attentions=outputs.attentions,
|
702 |
+
cross_attentions=outputs.cross_attentions,
|
703 |
+
)
|
704 |
+
|
705 |
+
result.bias_indicate_output = bias_indicate_output
|
706 |
+
|
707 |
+
return result
|
708 |
+
|
709 |
+
|
710 |
+
def download_tokenizer_files():
|
711 |
+
resources = ['envibert_tokenizer.py', 'dict.txt', 'sentencepiece.bpe.model']
|
712 |
+
for item in resources:
|
713 |
+
if not os.path.exists(os.path.join(cache_dir, item)):
|
714 |
+
tmp_file = hf_bucket_url(model_name, filename=item)
|
715 |
+
tmp_file = cached_path(tmp_file, cache_dir=cache_dir)
|
716 |
+
os.rename(tmp_file, os.path.join(cache_dir, item))
|
717 |
+
|
718 |
+
|
719 |
+
def init_tokenizer():
|
720 |
+
download_tokenizer_files()
|
721 |
+
tokenizer = SourceFileLoader("envibert.tokenizer",
|
722 |
+
os.path.join(cache_dir,
|
723 |
+
'envibert_tokenizer.py')).load_module().RobertaTokenizer(cache_dir)
|
724 |
+
tokenizer.model_input_names = ["input_ids",
|
725 |
+
"attention_mask",
|
726 |
+
"bias_input_ids",
|
727 |
+
"bias_attention_mask",
|
728 |
+
"labels"
|
729 |
+
"labels_bias"]
|
730 |
+
return tokenizer
|
731 |
+
|
732 |
+
|
733 |
+
def init_model():
|
734 |
+
download_tokenizer_files()
|
735 |
+
tokenizer = SourceFileLoader("envibert.tokenizer",
|
736 |
+
os.path.join(cache_dir,
|
737 |
+
'envibert_tokenizer.py')).load_module().RobertaTokenizer(cache_dir)
|
738 |
+
tokenizer.model_input_names = ["input_ids",
|
739 |
+
"attention_mask",
|
740 |
+
"bias_input_ids",
|
741 |
+
"bias_attention_mask",
|
742 |
+
"labels"
|
743 |
+
"labels_bias"]
|
744 |
+
# set encoder decoder tying to True
|
745 |
+
roberta_shared = EncoderDecoderSpokenNorm.from_encoder_decoder_pretrained(model_name,
|
746 |
+
model_name,
|
747 |
+
tie_encoder_decoder=False)
|
748 |
+
|
749 |
+
# set special tokens
|
750 |
+
roberta_shared.config.decoder_start_token_id = tokenizer.bos_token_id
|
751 |
+
roberta_shared.config.eos_token_id = tokenizer.eos_token_id
|
752 |
+
roberta_shared.config.pad_token_id = tokenizer.pad_token_id
|
753 |
+
|
754 |
+
# sensible parameters for beam search
|
755 |
+
# set decoding params
|
756 |
+
roberta_shared.config.max_length = 50
|
757 |
+
roberta_shared.config.early_stopping = True
|
758 |
+
roberta_shared.config.no_repeat_ngram_size = 3
|
759 |
+
roberta_shared.config.length_penalty = 2.0
|
760 |
+
roberta_shared.config.num_beams = 1
|
761 |
+
roberta_shared.config.vocab_size = roberta_shared.config.encoder.vocab_size
|
762 |
+
|
763 |
+
return roberta_shared, tokenizer
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.0
|
2 |
+
sentencepiece==0.1.91
|
3 |
+
transformers==4.16.2
|
4 |
+
datasets==1.17.0
|
5 |
+
regtag
|
6 |
+
validators
|
7 |
+
jiwer
|
8 |
+
gradio
|
utils.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import difflib
|
2 |
+
import regtag
|
3 |
+
import random
|
4 |
+
|
5 |
+
|
6 |
+
def merge_span(words, tags):
|
7 |
+
spans, span_tags = [], []
|
8 |
+
current_tag = 'O'
|
9 |
+
span = []
|
10 |
+
for w, t in zip(words, tags):
|
11 |
+
w = w.strip(":-")
|
12 |
+
if len(w) == 0:
|
13 |
+
continue
|
14 |
+
t_info = t.split('-')
|
15 |
+
if t_info[-1] != current_tag or t_info[0] == 'B':
|
16 |
+
if len(span) > 0:
|
17 |
+
spans.append(' '.join(span))
|
18 |
+
span_tags.append(current_tag)
|
19 |
+
span = [w]
|
20 |
+
current_tag = t_info[-1]
|
21 |
+
else:
|
22 |
+
span.append(w)
|
23 |
+
if len(span) > 0:
|
24 |
+
spans.append(' '.join(span))
|
25 |
+
span_tags.append(current_tag)
|
26 |
+
return spans, span_tags
|
27 |
+
|
28 |
+
|
29 |
+
def make_spoken(text, do_split=True):
|
30 |
+
src, tgt = [], []
|
31 |
+
if do_split:
|
32 |
+
chunk_size = random.choice(list(range(0, 10)) + list(range(10, 35)) * 4)
|
33 |
+
if chunk_size > 0:
|
34 |
+
text = random.choice(split_chunk_input(text, chunk_size))
|
35 |
+
else:
|
36 |
+
text = ''
|
37 |
+
words, word_tags = merge_span(*regtag.tagging(text))
|
38 |
+
for span, t in zip(words, word_tags):
|
39 |
+
if t == 'O':
|
40 |
+
for w in span.split():
|
41 |
+
w = w.strip('/.,?!').lower()
|
42 |
+
if len(w) > 0:
|
43 |
+
src.append(w)
|
44 |
+
tgt.append(w)
|
45 |
+
if random.random() < 0.01:
|
46 |
+
random_value = regtag.augment.get_random_span()
|
47 |
+
tgt.append(random_value[0])
|
48 |
+
src.append(random_value[1].lower())
|
49 |
+
else:
|
50 |
+
random_value = regtag.augment.get_random_span(t, span.lower())
|
51 |
+
tgt.append(random_value[0])
|
52 |
+
src.append(random_value[1].lower())
|
53 |
+
|
54 |
+
if len(src) == 0:
|
55 |
+
tgt, src = regtag.get_random_span()
|
56 |
+
src = [src]
|
57 |
+
tgt = [tgt]
|
58 |
+
|
59 |
+
return src, tgt
|
60 |
+
|
61 |
+
|
62 |
+
def split_chunk_input(raw_text, chunk_size):
|
63 |
+
input_words = raw_text.strip().split()
|
64 |
+
clean_data = [input_words[i:i + chunk_size] for i in range(0, len(input_words), chunk_size)]
|
65 |
+
if len(clean_data) > 1:
|
66 |
+
clean_data = [" ".join(clean_data[i] + clean_data[i + 1]) for i in range(len(clean_data) - 1)]
|
67 |
+
else:
|
68 |
+
clean_data = [" ".join(clean_data[0])]
|
69 |
+
return clean_data
|
70 |
+
|
71 |
+
|
72 |
+
def split_chunk_input(raw_text, chunk_size=40, overlap=10):
|
73 |
+
input_words = raw_text.strip().split()
|
74 |
+
part_per_chunk = chunk_size // overlap
|
75 |
+
clean_data = [input_words[i:i + overlap] for i in range(0, len(input_words), overlap)]
|
76 |
+
if len(clean_data) > 1:
|
77 |
+
merge_data = []
|
78 |
+
for i in range(0, len(clean_data) - 1, part_per_chunk - 1):
|
79 |
+
merge_data.append(' '.join([y for x in clean_data[i:i + part_per_chunk] for y in x]))
|
80 |
+
else:
|
81 |
+
merge_data = [" ".join(clean_data[0])]
|
82 |
+
return merge_data
|
83 |
+
|
84 |
+
|
85 |
+
def merge_two_chunk(chunk_1, chunk_2, overlap, debug=False):
|
86 |
+
def extract_phrase_word(phrase):
|
87 |
+
if phrase.startswith('<mask>'):
|
88 |
+
return phrase[7:].split('](')[1][:-1].split()
|
89 |
+
else:
|
90 |
+
return [phrase]
|
91 |
+
|
92 |
+
def has_tag(phrase):
|
93 |
+
if phrase.startswith('<') and phrase.endswith(')'):
|
94 |
+
return True
|
95 |
+
return False
|
96 |
+
|
97 |
+
def extract_compete_region(list_phrases, is_head):
|
98 |
+
if is_head:
|
99 |
+
list_phrases = list_phrases[::-1]
|
100 |
+
compete = []
|
101 |
+
remain = []
|
102 |
+
handle_count = 0
|
103 |
+
for phrase in list_phrases:
|
104 |
+
phrase_word = extract_phrase_word(phrase)
|
105 |
+
if len(phrase_word) + handle_count <= overlap:
|
106 |
+
compete.append(phrase)
|
107 |
+
handle_count += len(phrase_word)
|
108 |
+
else:
|
109 |
+
if handle_count < overlap:
|
110 |
+
remain_compete_count = overlap - handle_count
|
111 |
+
remain.append(phrase)
|
112 |
+
if not is_head:
|
113 |
+
compete.extend(["<delete>({})".format(item) for item in phrase_word[:remain_compete_count]])
|
114 |
+
else:
|
115 |
+
compete.extend(
|
116 |
+
["<delete>({})".format(item) for item in phrase_word[::-1][:remain_compete_count]])
|
117 |
+
handle_count = overlap
|
118 |
+
else:
|
119 |
+
remain.append(phrase)
|
120 |
+
if is_head:
|
121 |
+
compete = compete[::-1]
|
122 |
+
remain = remain[::-1]
|
123 |
+
return remain, compete
|
124 |
+
|
125 |
+
def is_equal(phrase_1, phrase_2):
|
126 |
+
if phrase_1 == phrase_2:
|
127 |
+
return True
|
128 |
+
if extract_phrase_word(phrase_1) == extract_phrase_word(phrase_2):
|
129 |
+
if phrase_1.startswith('<mask>') and phrase_2.startswith('<mask>'):
|
130 |
+
return True
|
131 |
+
return False
|
132 |
+
|
133 |
+
def merge_compete(list_1, list_2):
|
134 |
+
idx_list_1, idx_list_2, combine_phrases = [], [], []
|
135 |
+
mark_term_complete = []
|
136 |
+
list_raw = [extract_phrase_word(item) for item in list_1]
|
137 |
+
list_raw = [y for x in list_raw for y in x]
|
138 |
+
for idx, phrase in enumerate(list_1):
|
139 |
+
idx_list_1.extend([idx] * len(extract_phrase_word(phrase)))
|
140 |
+
for idx, phrase in enumerate(list_2):
|
141 |
+
idx_list_2.extend([idx] * len(extract_phrase_word(phrase)))
|
142 |
+
# print(idx_list_1, idx_list_2)
|
143 |
+
for idx, (idx_1, idx_2) in enumerate(zip(idx_list_1, idx_list_2)):
|
144 |
+
if list_1[idx_1].startswith('<delete>') or list_2[idx_2].startswith('<delete>'):
|
145 |
+
continue
|
146 |
+
elif is_equal(list_1[idx_1], list_2[idx_2]):
|
147 |
+
# print(list_1[idx_1])
|
148 |
+
if '1_{}'.format(idx_1) not in mark_term_complete and '2_{}'.format(idx_2) not in mark_term_complete:
|
149 |
+
if idx <= overlap//2:
|
150 |
+
combine_phrases.append(list_1[idx_1])
|
151 |
+
mark_term_complete.append('1_{}'.format(idx_1))
|
152 |
+
else:
|
153 |
+
combine_phrases.append(list_2[idx_2])
|
154 |
+
mark_term_complete.append('2_{}'.format(idx_2))
|
155 |
+
else:
|
156 |
+
combine_phrases.append(list_raw[idx])
|
157 |
+
mark_term_complete.extend(['1_{}'.format(idx_1), '2_{}'.format(idx_2)])
|
158 |
+
# print(mark_term_complete)
|
159 |
+
return combine_phrases
|
160 |
+
|
161 |
+
remain_1, compete_1 = extract_compete_region(chunk_1, is_head=True)
|
162 |
+
remain_2, compete_2 = extract_compete_region(chunk_2[1:-1], is_head=False)
|
163 |
+
compromise = merge_compete(compete_1, compete_2)
|
164 |
+
|
165 |
+
if debug:
|
166 |
+
print(remain_1, '\n', compete_1)
|
167 |
+
print('-----------------------')
|
168 |
+
print(compete_2, '\n', remain_2)
|
169 |
+
print('-----------------------')
|
170 |
+
print(compromise, '\n\n')
|
171 |
+
|
172 |
+
return remain_1 + compromise + remain_2
|
173 |
+
|
174 |
+
|
175 |
+
def merge_chunk_pre_norm(list_chunks, overlap, debug=False):
|
176 |
+
if len(list_chunks) == 0:
|
177 |
+
return []
|
178 |
+
if len(list_chunks) == 1:
|
179 |
+
return list_chunks[0][1:-1]
|
180 |
+
current_chunk = list_chunks[0][1:-1]
|
181 |
+
for tmp_chunk in list_chunks[1:]:
|
182 |
+
current_chunk = merge_two_chunk(current_chunk, tmp_chunk, overlap, debug=debug)
|
183 |
+
return current_chunk
|
184 |
+
|
185 |
+
|
186 |
+
def equalize(s1, s2):
|
187 |
+
l1 = s1.split()
|
188 |
+
l2 = s2.split()
|
189 |
+
res1 = []
|
190 |
+
res2 = []
|
191 |
+
combine = []
|
192 |
+
prev = difflib.Match(0, 0, 0)
|
193 |
+
for match in difflib.SequenceMatcher(a=l1, b=l2).get_matching_blocks():
|
194 |
+
if prev.a + prev.size != match.a:
|
195 |
+
for i in range(prev.a + prev.size, match.a):
|
196 |
+
res2 += ['_' * len(l1[i])]
|
197 |
+
res1 += l1[prev.a + prev.size:match.a]
|
198 |
+
|
199 |
+
for i in l1[prev.a + prev.size:match.a]:
|
200 |
+
if len(combine) < len(l1) // 2:
|
201 |
+
print(l1[prev.a + prev.size:match.a])
|
202 |
+
combine.append(i)
|
203 |
+
if prev.b + prev.size != match.b:
|
204 |
+
for i in range(prev.b + prev.size, match.b):
|
205 |
+
res1 += ['_' * len(l2[i])]
|
206 |
+
res2 += l2[prev.b + prev.size:match.b]
|
207 |
+
|
208 |
+
for i in l2[prev.b + prev.size:match.b]:
|
209 |
+
if len(combine) >= len(l2) // 2:
|
210 |
+
print(l2[prev.b + prev.size:match.b])
|
211 |
+
combine.append(i)
|
212 |
+
res1 += l1[match.a:match.a + match.size]
|
213 |
+
res2 += l2[match.b:match.b + match.size]
|
214 |
+
combine += l2[match.b:match.b + match.size]
|
215 |
+
prev = match
|
216 |
+
return ' '.join(res1), ' '.join(res2), combine
|
217 |
+
|
218 |
+
|
219 |
+
def count_overlap(words_1, words_2):
|
220 |
+
# print(words_1, words_2)
|
221 |
+
assert len(words_1) == len(words_2)
|
222 |
+
len_overlap = 0
|
223 |
+
for match in difflib.SequenceMatcher(a=words_1, b=words_2).get_matching_blocks():
|
224 |
+
len_overlap += match.size
|
225 |
+
|
226 |
+
# for w1, w2 in zip(words_1, words_2):
|
227 |
+
# if w1 == w2:
|
228 |
+
# len_overlap += 1
|
229 |
+
return len_overlap
|
230 |
+
|
231 |
+
|
232 |
+
def find_overlap_chunk(txt_1, txt_2):
|
233 |
+
# print(txt_1)
|
234 |
+
# print(txt_2)
|
235 |
+
window_view = 1
|
236 |
+
idx_1 = len(txt_1) - window_view
|
237 |
+
idx_2 = window_view
|
238 |
+
over_lap = 0
|
239 |
+
current_best_idx_1 = len(txt_1)
|
240 |
+
current_best_idx_2 = 0
|
241 |
+
|
242 |
+
while window_view <= len(txt_1) and window_view <= len(txt_2):
|
243 |
+
current_overlap = count_overlap(txt_1[idx_1:], txt_2[:idx_2])
|
244 |
+
print(current_overlap)
|
245 |
+
if over_lap < current_overlap:
|
246 |
+
over_lap = current_overlap
|
247 |
+
current_best_idx_1 = idx_1
|
248 |
+
current_best_idx_2 = idx_2
|
249 |
+
window_view += 1
|
250 |
+
idx_1 = len(txt_1) - window_view
|
251 |
+
idx_2 = window_view
|
252 |
+
# else:
|
253 |
+
# break
|
254 |
+
print('----->', txt_1[current_best_idx_1:], txt_2[:current_best_idx_2])
|
255 |
+
return txt_1[current_best_idx_1:], txt_2[:current_best_idx_2]
|
256 |
+
|
257 |
+
|
258 |
+
def concat_chunks(list_chunks):
|
259 |
+
concat_string = list_chunks[0].split()
|
260 |
+
for i in range(1, len(list_chunks)):
|
261 |
+
remain_string = list_chunks[i].split()
|
262 |
+
s1, s2 = find_overlap_chunk(concat_string, remain_string)
|
263 |
+
s1 = ' '.join(s1)
|
264 |
+
s2 = ' '.join(s2)
|
265 |
+
_, _, overlap_merged = equalize(s1, s2)
|
266 |
+
merge_len = len(s1.split())
|
267 |
+
|
268 |
+
concat_string = concat_string[:len(concat_string) - merge_len] + overlap_merged + remain_string[merge_len:]
|
269 |
+
|
270 |
+
concat_string = ' '.join(concat_string)
|
271 |
+
return concat_string
|