anzorq commited on
Commit
0c872e7
·
verified ·
1 Parent(s): bfb5ccb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -7,6 +7,9 @@ from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
7
  model = AutoModelForCTC.from_pretrained("anzorq/w2v-bert-2.0-kbd")
8
  processor = Wav2Vec2BertProcessor.from_pretrained("anzorq/w2v-bert-2.0-kbd")
9
 
 
 
 
10
  @spaces.GPU
11
  def transcribe_speech(audio):
12
  # Load the audio file
@@ -25,7 +28,7 @@ def transcribe_speech(audio):
25
 
26
  # Extract input features
27
  input_features = processor(waveform.unsqueeze(0), sampling_rate=16000).input_features
28
- input_features = torch.from_numpy(input_features).to("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
  # Generate logits using the model
31
  with torch.no_grad():
 
7
  model = AutoModelForCTC.from_pretrained("anzorq/w2v-bert-2.0-kbd")
8
  processor = Wav2Vec2BertProcessor.from_pretrained("anzorq/w2v-bert-2.0-kbd")
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
+
13
  @spaces.GPU
14
  def transcribe_speech(audio):
15
  # Load the audio file
 
28
 
29
  # Extract input features
30
  input_features = processor(waveform.unsqueeze(0), sampling_rate=16000).input_features
31
+ input_features = torch.from_numpy(input_features).to(device)
32
 
33
  # Generate logits using the model
34
  with torch.no_grad():