Update custom_interface.py
Browse files- custom_interface.py +7 -18
custom_interface.py
CHANGED
@@ -11,25 +11,14 @@ class ASR(Pretrained):
|
|
11 |
wavs = wavs.to(device)
|
12 |
wav_lens = wav_lens.to(device)
|
13 |
|
14 |
-
# Forward
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, wav_lens)
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
# predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
|
24 |
-
predicted_words = []
|
25 |
-
for prediction in predictions:
|
26 |
-
prediction = [token for token in prediction if token != 0]
|
27 |
-
predicted_words.append(self.hparams.tokenizer.decode_ids(prediction).split(" "))
|
28 |
-
prediction = []
|
29 |
-
for sent in predicted_words:
|
30 |
-
sent = self.filter_repetitions(sent, 3)
|
31 |
-
prediction.append(sent)
|
32 |
-
predicted_words = prediction
|
33 |
return predicted_words
|
34 |
|
35 |
def filter_repetitions(self, seq, max_repetition_length):
|
|
|
11 |
wavs = wavs.to(device)
|
12 |
wav_lens = wav_lens.to(device)
|
13 |
|
14 |
+
# Forward encoder + decoder
|
15 |
+
tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
|
16 |
+
tokens = tokens.to(device)
|
17 |
+
enc_out, logits, _ = self.mods.whisper(wavs, tokens)
|
18 |
+
log_probs = self.hparams.log_softmax(logits)
|
|
|
19 |
|
20 |
+
hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
|
21 |
+
predicted_words = [self.mods.whisper.tokenizer.decode(token, skip_special_tokens=True).strip() for token in hyps]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
return predicted_words
|
23 |
|
24 |
def filter_repetitions(self, seq, max_repetition_length):
|