Porjaz commited on
Commit
9d7c1bc
·
verified ·
1 Parent(s): c095698

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +2 -2
custom_interface_app.py CHANGED
@@ -251,7 +251,7 @@ class ASR(Pretrained):
251
 
252
  # Pass the segment through the ASR model
253
  inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
254
- inputs['input_values'] = inputs['input_values'].to(torch.float16)
255
  outputs = model(**inputs).logits
256
  ids = torch.argmax(outputs, dim=-1)[0]
257
  segment_output = processor.decode(ids)
@@ -259,7 +259,7 @@ class ASR(Pretrained):
259
  else:
260
  waveform = torch.tensor(waveform).to(device)
261
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
262
- inputs['input_values'] = inputs['input_values'].to(torch.float16)
263
  outputs = model(**inputs).logits
264
  ids = torch.argmax(outputs, dim=-1)[0]
265
  transcription = processor.decode(ids)
 
251
 
252
  # Pass the segment through the ASR model
253
  inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
254
+ inputs['input_values'] = inputs['input_values']
255
  outputs = model(**inputs).logits
256
  ids = torch.argmax(outputs, dim=-1)[0]
257
  segment_output = processor.decode(ids)
 
259
  else:
260
  waveform = torch.tensor(waveform).to(device)
261
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
262
+ inputs['input_values'] = inputs['input_values']
263
  outputs = model(**inputs).logits
264
  ids = torch.argmax(outputs, dim=-1)[0]
265
  transcription = processor.decode(ids)