JulienHalgand commited on
Commit
162bd64
·
1 Parent(s): e5007d2

GPU support

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. request.py +1 -0
app.py CHANGED
@@ -7,6 +7,7 @@ import subprocess
7
 
8
  import gradio as gr
9
  import torchaudio
 
10
 
11
  from model_helper import load_model_checkpoint, transcribe
12
  from prepare_media import prepare_media
@@ -51,7 +52,7 @@ MODELS = {
51
  log_file = 'amt/log.txt'
52
 
53
  model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
54
- #model.to("cuda")
55
 
56
  def prepare_media(source_path_or_url: os.PathLike,
57
  source_type: Literal['audio_filepath', 'youtube_url'],
@@ -104,7 +105,7 @@ def prepare_media(source_path_or_url: os.PathLike,
104
  "duration": int(info.num_frames / info.sample_rate),
105
  "encoding": str.lower(info.encoding),
106
  }
107
-
108
  def handle_audio(file_path):
109
  # Guess extension from MIME
110
  mime_type, _ = mimetypes.guess_type(file_path)
 
7
 
8
  import gradio as gr
9
  import torchaudio
10
+ import spaces
11
 
12
  from model_helper import load_model_checkpoint, transcribe
13
  from prepare_media import prepare_media
 
52
  log_file = 'amt/log.txt'
53
 
54
  model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu")
55
+ model.to("cuda")
56
 
57
  def prepare_media(source_path_or_url: os.PathLike,
58
  source_type: Literal['audio_filepath', 'youtube_url'],
 
105
  "duration": int(info.num_frames / info.sample_rate),
106
  "encoding": str.lower(info.encoding),
107
  }
108
+ @spaces.GPU
109
  def handle_audio(file_path):
110
  # Guess extension from MIME
111
  mime_type, _ = mimetypes.guess_type(file_path)
request.py CHANGED
@@ -1,6 +1,7 @@
1
  from gradio_client import Client, handle_file
2
 
3
  client = Client("http://127.0.0.1:7860/")
 
4
  result = client.predict(
5
  file_path=handle_file('/home/julien/Downloads/La Marseillaise_1st_couplet.mp3'),
6
  api_name="/predict"
 
1
  from gradio_client import Client, handle_file
2
 
3
  client = Client("http://127.0.0.1:7860/")
4
+ #client = Client("JulienHalgand/hf-gradio-example")
5
  result = client.predict(
6
  file_path=handle_file('/home/julien/Downloads/La Marseillaise_1st_couplet.mp3'),
7
  api_name="/predict"