Update app.py
Browse files
app.py
CHANGED
@@ -24,11 +24,11 @@ def resampler(input_file_path, output_file_path):
|
|
24 |
subprocess.call(command, shell=True)
|
25 |
|
26 |
|
27 |
-
def parse_transcription_with_lm(logits):
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
def parse_transcription(logits):
|
34 |
predicted_ids = torch.argmax(logits, dim=-1)
|
@@ -41,7 +41,8 @@ def parse(wav_file, applyLM):
|
|
41 |
logits = model(**input_values).logits
|
42 |
|
43 |
if applyLM:
|
44 |
-
return parse_transcription_with_lm(logits)
|
|
|
45 |
else:
|
46 |
return parse_transcription(logits)
|
47 |
|
@@ -64,7 +65,7 @@ model_id = "anuragshas/wav2vec2-large-xlsr-53-odia"
|
|
64 |
|
65 |
|
66 |
processor = Wav2Vec2Processor.from_pretrained(model_id)
|
67 |
-
processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id)
|
68 |
model = Wav2Vec2ForCTC.from_pretrained(model_id)
|
69 |
|
70 |
|
|
|
24 |
subprocess.call(command, shell=True)
|
25 |
|
26 |
|
27 |
+
# def parse_transcription_with_lm(logits):
|
28 |
+
# result = processor_with_LM.batch_decode(logits.cpu().numpy())
|
29 |
+
# text = result.text
|
30 |
+
# transcription = text[0].replace('<s>','')
|
31 |
+
# return transcription
|
32 |
|
33 |
def parse_transcription(logits):
|
34 |
predicted_ids = torch.argmax(logits, dim=-1)
|
|
|
41 |
logits = model(**input_values).logits
|
42 |
|
43 |
if applyLM:
|
44 |
+
# return parse_transcription_with_lm(logits)
|
45 |
+
return "done"
|
46 |
else:
|
47 |
return parse_transcription(logits)
|
48 |
|
|
|
65 |
|
66 |
|
67 |
processor = Wav2Vec2Processor.from_pretrained(model_id)
|
68 |
+
# processor_with_LM = Wav2Vec2ProcessorWithLM.from_pretrained(model_id)
|
69 |
model = Wav2Vec2ForCTC.from_pretrained(model_id)
|
70 |
|
71 |
|