Update custom_interface_app.py
Browse files- custom_interface_app.py +2 -2
custom_interface_app.py
CHANGED
@@ -251,7 +251,7 @@ class ASR(Pretrained):
|
|
251 |
|
252 |
# Pass the segment through the ASR model
|
253 |
inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
|
254 |
-
inputs['input_values'] = inputs['input_values']
|
255 |
outputs = model(**inputs).logits
|
256 |
ids = torch.argmax(outputs, dim=-1)[0]
|
257 |
segment_output = processor.decode(ids)
|
@@ -259,7 +259,7 @@ class ASR(Pretrained):
|
|
259 |
else:
|
260 |
waveform = torch.tensor(waveform).to(device)
|
261 |
inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
|
262 |
-
inputs['input_values'] = inputs['input_values']
|
263 |
outputs = model(**inputs).logits
|
264 |
ids = torch.argmax(outputs, dim=-1)[0]
|
265 |
transcription = processor.decode(ids)
|
|
|
251 |
|
252 |
# Pass the segment through the ASR model
|
253 |
inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
|
254 |
+
inputs['input_values'] = inputs['input_values']
|
255 |
outputs = model(**inputs).logits
|
256 |
ids = torch.argmax(outputs, dim=-1)[0]
|
257 |
segment_output = processor.decode(ids)
|
|
|
259 |
else:
|
260 |
waveform = torch.tensor(waveform).to(device)
|
261 |
inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
|
262 |
+
inputs['input_values'] = inputs['input_values']
|
263 |
outputs = model(**inputs).logits
|
264 |
ids = torch.argmax(outputs, dim=-1)[0]
|
265 |
transcription = processor.decode(ids)
|