xiaoqinfeng commited on
Commit
ae1b88c
·
verified ·
1 Parent(s): 679c1f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -115,12 +115,12 @@ def run_tts(
115
  return save_path
116
 
117
 
118
- def build_ui(model_dir):
119
 
120
  global MODEL
121
 
122
  # Initialize model with proper device handling
123
- device = "cuda" if torch.cuda.is_available() else "cpu"
124
  if MODEL is None:
125
  MODEL = initialize_model(model_dir, device=device)
126
  if device == "cuda":
@@ -164,7 +164,7 @@ def build_ui(model_dir):
164
 
165
  with gr.Blocks() as demo:
166
  # Use HTML for centered title
167
- gr.HTML('<h1 style="text-align: center;">(Official) Spark-TTS by SparkAudio</h1>')
168
  with gr.Tabs():
169
  # Voice Clone Tab
170
  with gr.TabItem("Voice Clone"):
@@ -260,7 +260,12 @@ def parse_arguments():
260
  default=None,
261
  help="Path to the model directory."
262
  )
263
-
 
 
 
 
 
264
  parser.add_argument(
265
  "--server_name",
266
  type=str,
@@ -279,8 +284,11 @@ if __name__ == "__main__":
279
  # Parse command-line arguments
280
  args = parse_arguments()
281
 
282
- # Build the Gradio demo by specifying the model directory
283
- demo = build_ui(model_dir=args.model_dir)
 
 
 
284
 
285
  # Launch Gradio with the specified server name and port
286
  demo.launch(
 
115
  return save_path
116
 
117
 
118
+ def build_ui(model_dir, device=0):
119
 
120
  global MODEL
121
 
122
  # Initialize model with proper device handling
123
+ device = "cuda" if torch.cuda.is_available() and device != "cpu" else "cpu"
124
  if MODEL is None:
125
  MODEL = initialize_model(model_dir, device=device)
126
  if device == "cuda":
 
164
 
165
  with gr.Blocks() as demo:
166
  # Use HTML for centered title
167
+ gr.HTML('<h1 style="text-align: center;">(Unofficial) Spark-TTS by SparkAudio</h1>')
168
  with gr.Tabs():
169
  # Voice Clone Tab
170
  with gr.TabItem("Voice Clone"):
 
260
  default=None,
261
  help="Path to the model directory."
262
  )
263
+ parser.add_argument(
264
+ "--device",
265
+ type=str,
266
+ default="cpu",
267
+ help="Device to use (e.g., 'cpu' or 'cuda:0')."
268
+ )
269
  parser.add_argument(
270
  "--server_name",
271
  type=str,
 
284
  # Parse command-line arguments
285
  args = parse_arguments()
286
 
287
+ # Build the Gradio demo by specifying the model directory and GPU device
288
+ demo = build_ui(
289
+ model_dir=args.model_dir,
290
+ device=args.device
291
+ )
292
 
293
  # Launch Gradio with the specified server name and port
294
  demo.launch(