Bishan commited on
Commit
e4c40fa
·
1 Parent(s): c0d8b29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -24,11 +24,11 @@ def resampler(input_file_path, output_file_path):
24
  subprocess.call(command, shell=True)
25
 
26
 
27
- def parse_transcription_with_lm(logits):
28
- result = processor_with_LM.batch_decode(logits.cpu().numpy())
29
- text = result.text
30
- transcription = text[0].replace('<s>','')
31
- return transcription
32
 
33
  def parse_transcription(logits):
34
  predicted_ids = torch.argmax(logits, dim=-1)
@@ -41,7 +41,8 @@ def parse(wav_file, applyLM):
41
  logits = model(**input_values).logits
42
 
43
  if applyLM:
44
- return parse_transcription_with_lm(logits)
 
45
  else:
46
  return parse_transcription(logits)
47
 
@@ -64,7 +65,7 @@ model_id = "anuragshas/wav2vec2-large-xlsr-53-odia"
64
 
65
 
66
  processor = Wav2Vec2Processor.from_pretrained(model_id)
67
- processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id)
68
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
69
 
70
 
 
24
  subprocess.call(command, shell=True)
25
 
26
 
27
+ # def parse_transcription_with_lm(logits):
28
+ # result = processor_with_LM.batch_decode(logits.cpu().numpy())
29
+ # text = result.text
30
+ # transcription = text[0].replace('<s>','')
31
+ # return transcription
32
 
33
  def parse_transcription(logits):
34
  predicted_ids = torch.argmax(logits, dim=-1)
 
41
  logits = model(**input_values).logits
42
 
43
  if applyLM:
44
+ # return parse_transcription_with_lm(logits)
45
+ return "done"
46
  else:
47
  return parse_transcription(logits)
48
 
 
65
 
66
 
67
  processor = Wav2Vec2Processor.from_pretrained(model_id)
68
+ # processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id)
69
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
70
 
71