5roop commited on
Commit
8428e77
·
verified ·
1 Parent(s): f1462b3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +226 -1
README.md CHANGED
@@ -14,7 +14,7 @@ base_model:
14
  # Wav2Vec2Bert Audio frame classifier for prosodic unit detection
15
 
16
  This model predicts prosodic units on speech.
17
- For each 20ms frame the model predicts a vector like `[0,1]` or `[1,0]`, indicating whether there is a prosodic unit in
18
  this frame or not.
19
 
20
 
@@ -40,12 +40,237 @@ This is the model card of a 🤗 transformers model that has been pushed on the
40
 
41
  ## Uses
42
 
 
43
 
 
 
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ## Bias, Risks, and Limitations
46
 
47
  ## Training Details
48
 
 
 
 
 
 
 
 
 
49
  ## Evaluation
50
 
51
 
 
14
  # Wav2Vec2Bert Audio frame classifier for prosodic unit detection
15
 
16
  This model predicts prosodic units on speech.
17
+ For each 20ms frame the model predicts 1 or 0, indicating whether there is a prosodic unit in
18
  this frame or not.
19
 
20
 
 
40
 
41
  ## Uses
42
 
43
+ ### Simple use (short files)
44
 
45
+ For shorter audios that fit on your GPU the classifier can be used directly.
46
+ ```python
47
+ import numpy as np
48
 
49
+ from datasets import Audio, Dataset
50
+ from transformers import AutoFeatureExtractor, Wav2Vec2BertForAudioFrameClassification
51
+ import torch
52
+ import numpy as np
53
+
54
+ if torch.cuda.is_available():
55
+ device = torch.device("cuda")
56
+ else:
57
+ device = torch.device("cpu")
58
+
59
+ model_name = "5roop/Wav2Vec2BertProsodicUnitsFrameClassifier"
60
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
61
+ model = Wav2Vec2BertForAudioFrameClassification.from_pretrained(model_name).to(device)
62
+ f = "data/Rog-Art-N-G6007-P600702_181.070_211.070.wav"
63
+
64
+
65
+ def frames_to_intervals(frames: list) -> list[tuple]:
66
+ from itertools import pairwise
67
+ import pandas as pd
68
+
69
+ results = []
70
+ ndf = pd.DataFrame(
71
+ data={
72
+ "time_s": [0.020 * i for i in range(len(frames))],
73
+ "frames": frames,
74
+ }
75
+ )
76
+ ndf = ndf.dropna()
77
+ indices_of_change = ndf.frames.diff()[ndf.frames.diff() != 0].index.values
78
+ for si, ei in pairwise(indices_of_change):
79
+ if ndf.loc[si : ei - 1, "frames"].mode()[0] == 0:
80
+ pass
81
+ else:
82
+ results.append(
83
+ (round(ndf.loc[si, "time_s"], 3), round(ndf.loc[ei - 1, "time_s"], 3))
84
+ )
85
+ return results
86
+
87
+
88
+ def evaluator(chunks):
89
+ sampling_rate = chunks["audio"][0]["sampling_rate"]
90
+ with torch.no_grad():
91
+ inputs = feature_extractor(
92
+ [i["array"] for i in chunks["audio"]],
93
+ return_tensors="pt",
94
+ sampling_rate=sampling_rate,
95
+ ).to(device)
96
+ logits = model(**inputs).logits
97
+ y_pred_raw = np.array(logits.cpu())
98
+ y_pred = y_pred_raw.argmax(axis=-1)
99
+ prosodic_units = [frames_to_intervals(i) for i in y_pred]
100
+ return {
101
+ "y_pred": y_pred,
102
+ "y_pred_logits": y_pred_raw,
103
+ "prosodic_units": prosodic_units,
104
+ }
105
+
106
+
107
+ ds = Dataset.from_dict({"audio": [f, f]}).cast_column("audio", Audio(16000, mono=True))
108
+ ds = ds.map(evaluator, batched=True, batch_size=2)
109
+ print(ds["y_pred"][0])
110
+ # Outputs: [0, 0, 1, 1, 1, 1, 1, ...]
111
+ print(ds["y_pred_logits"][0])
112
+ # Outputs:
113
+ # [[ 0.89419061, -0.77746612],
114
+ # [ 0.44213724, -0.34862748],
115
+ # [-0.08605709, 0.13012762],
116
+ # ....
117
+ print(ds["prosodic_units"][0])
118
+ # Outputs: [[0.04, 2.4], [3.52, 6.6], ....
119
+ ```
120
+
121
+
122
+ ### Inference on longer files
123
+ If the file is too big for straight-forward inference, some chunking needs to be performed in order to process it.
124
+ We know that for starts and ends of chunks the probability of false negatives increases, so it is best to process the file
125
+ with some overlap between chunks or split it on silence. We illustrate the former approach here:
126
+ ```python
127
+ import numpy as np
128
+
129
+ from datasets import Audio, Dataset
130
+ from transformers import AutoFeatureExtractor, Wav2Vec2BertForAudioFrameClassification
131
+ import torch
132
+ import numpy as np
133
+
134
+ if torch.cuda.is_available():
135
+ device = torch.device("cuda")
136
+ else:
137
+ device = torch.device("cpu")
138
+
139
+ model_name = "5roop/Wav2Vec2BertProsodicUnitsFrameClassifier"
140
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
141
+ model = Wav2Vec2BertForAudioFrameClassification.from_pretrained(model_name).to(device)
142
+ f = "ROG/ROG-Art/WAV/Rog-Art-N-G5025-P600022.wav"
143
+
144
+ OVERLAP_S = 10
145
+ CHUNK_LENGTH_S = 30
146
+ SAMPLING_RATE = 16_000
147
+ OVERLAP_SAMPLES = OVERLAP_S * SAMPLING_RATE
148
+ CHUNK_LENGTH_SAMPLES = CHUNK_LENGTH_S * SAMPLING_RATE
149
+
150
+
151
+ def frames_to_intervals(frames: list) -> list[tuple]:
152
+ from itertools import pairwise
153
+ import pandas as pd
154
+
155
+ results = []
156
+ ndf = pd.DataFrame(
157
+ data={
158
+ "time_s": [0.020 * i for i in range(len(frames))],
159
+ "frames": frames,
160
+ }
161
+ )
162
+ ndf = ndf.dropna()
163
+ indices_of_change = ndf.frames.diff()[ndf.frames.diff() != 0].index.values
164
+ for si, ei in pairwise(indices_of_change):
165
+ if ndf.loc[si : ei - 1, "frames"].mode()[0] == 0:
166
+ pass
167
+ else:
168
+ results.append(
169
+ (round(ndf.loc[si, "time_s"], 3), round(ndf.loc[ei - 1, "time_s"], 3))
170
+ )
171
+ return results
172
+
173
+
174
+ def merge_events(events: list[list[float]], centroids):
175
+ flattened_events = []
176
+ flattened_centroids = []
177
+ for batch_idx, batch in enumerate(events):
178
+ for event in batch:
179
+ flattened_events.append(event)
180
+ flattened_centroids.append(centroids[batch_idx])
181
+ flattened_events.sort(key=lambda x: x[0])
182
+
183
+ # Merged list to store final intervals
184
+ merged = []
185
+
186
+ for event, centroid in zip(flattened_events, centroids):
187
+ if not merged:
188
+ # If merged is empty, simply add the first event
189
+ merged.append((event, centroid))
190
+ else:
191
+ last_event, last_centroid = merged[-1]
192
+ # Check for overlap
193
+ if (last_event[0] < event[1]) and (last_event[1] > event[0]):
194
+ # Calculate the midpoint of the intervals
195
+ last_event_midpoint = (last_event[0] + last_event[1]) / 2
196
+ current_event_midpoint = (event[0] + event[1]) / 2
197
+
198
+ # Choose the event whose centroid is closer to its midpoint
199
+ if abs(last_centroid - last_event_midpoint) <= abs(
200
+ centroid - current_event_midpoint
201
+ ):
202
+ continue
203
+ else:
204
+ merged[-1] = (event, centroid)
205
+ else:
206
+ merged.append((event, centroid))
207
+
208
+ final_intervals = [event for event, _ in merged]
209
+ return final_intervals
210
+
211
+
212
+ def evaluator(chunks):
213
+ with torch.no_grad():
214
+ samples = []
215
+ for array, start, end in zip(chunks["audio"], chunks["start"], chunks["end"]):
216
+ samples.append(array["array"][start:end])
217
+ inputs = feature_extractor(
218
+ samples,
219
+ return_tensors="pt",
220
+ sampling_rate=SAMPLING_RATE,
221
+ ).to(device)
222
+ logits = model(**inputs).logits
223
+ y_pred_raw = np.array(logits.cpu())
224
+ y_pred = y_pred_raw.argmax(axis=-1)
225
+ prosodic_units = [
226
+ np.array(frames_to_intervals(i)) + start / SAMPLING_RATE
227
+ for i, start in zip(y_pred, chunks["start"])
228
+ ]
229
+ return {
230
+ "y_pred": y_pred,
231
+ "y_pred_logits": y_pred_raw,
232
+ "prosodic_units": prosodic_units,
233
+ }
234
+
235
+
236
+ audio_duration_samples = (
237
+ Audio(SAMPLING_RATE, mono=True)
238
+ .decode_example({"path": f, "bytes": None})["array"]
239
+ .shape[0]
240
+ )
241
+ chunk_starts = np.arange(
242
+ 0, audio_duration_samples, CHUNK_LENGTH_SAMPLES - OVERLAP_SAMPLES
243
+ )
244
+ chunk_ends = chunk_starts + CHUNK_LENGTH_SAMPLES
245
+
246
+ ds = Dataset.from_dict(
247
+ {
248
+ "audio": [f for i in chunk_starts],
249
+ "start": chunk_starts,
250
+ "end": chunk_ends,
251
+ "chunk_centroid_s": (chunk_starts + chunk_ends) / 2 / SAMPLING_RATE,
252
+ }
253
+ ).cast_column("audio", Audio(SAMPLING_RATE, mono=True))
254
+
255
+ ds = ds.map(evaluator, batched=True, batch_size=10)
256
+
257
+
258
+ final_intervals = merge_events(ds["prosodic_units"], ds["chunk_centroid_s"])
259
+ print(final_intervals)
260
+ # Outputs: [[3.14, 4.96], [5.6, 8.4], [8.62, 9.32], [10.12, 10.7], [11.72, 13.1],....
261
+ ```
262
  ## Bias, Risks, and Limitations
263
 
264
  ## Training Details
265
 
266
+ |hyperparameter|value|
267
+ |---|---|
268
+ |learning rate|3e-5|
269
+ |batch size|1|
270
+ |gradient accumulation steps|16|
271
+ |num train epochs|20|
272
+ |weight decay|0.01|
273
+
274
  ## Evaluation
275
 
276