Porjaz commited on
Commit
b3b62d6
·
verified ·
1 Parent(s): c31791e

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +7 -18
custom_interface.py CHANGED
@@ -11,25 +11,14 @@ class ASR(Pretrained):
11
  wavs = wavs.to(device)
12
  wav_lens = wav_lens.to(device)
13
 
14
- # Forward pass
15
- encoded_outputs = self.mods.encoder_w2v2(wavs.detach())
16
- # append
17
- tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(device)
18
- embedded_tokens = self.mods.embedding(tokens_bos)
19
- decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, wav_lens)
20
 
21
- # Output layer for seq2seq log-probabilities
22
- predictions = self.hparams.test_search(encoded_outputs, wav_lens)[0]
23
- # predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
24
- predicted_words = []
25
- for prediction in predictions:
26
- prediction = [token for token in prediction if token != 0]
27
- predicted_words.append(self.hparams.tokenizer.decode_ids(prediction).split(" "))
28
- prediction = []
29
- for sent in predicted_words:
30
- sent = self.filter_repetitions(sent, 3)
31
- prediction.append(sent)
32
- predicted_words = prediction
33
  return predicted_words
34
 
35
  def filter_repetitions(self, seq, max_repetition_length):
 
11
  wavs = wavs.to(device)
12
  wav_lens = wav_lens.to(device)
13
 
14
+ # Forward encoder + decoder
15
+ tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
16
+ tokens = tokens.to(device)
17
+ enc_out, logits, _ = self.mods.whisper(wavs, tokens)
18
+ log_probs = self.hparams.log_softmax(logits)
 
19
 
20
+ hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
21
+ predicted_words = [self.mods.whisper.tokenizer.decode(token, skip_special_tokens=True).strip() for token in hyps]
 
 
 
 
 
 
 
 
 
 
22
  return predicted_words
23
 
24
  def filter_repetitions(self, seq, max_repetition_length):