Update app.py
Browse files
app.py
CHANGED
@@ -115,12 +115,12 @@ def run_tts(
|
|
115 |
return save_path
|
116 |
|
117 |
|
118 |
-
def build_ui(model_dir):
|
119 |
|
120 |
global MODEL
|
121 |
|
122 |
# Initialize model with proper device handling
|
123 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
124 |
if MODEL is None:
|
125 |
MODEL = initialize_model(model_dir, device=device)
|
126 |
if device == "cuda":
|
@@ -164,7 +164,7 @@ def build_ui(model_dir):
|
|
164 |
|
165 |
with gr.Blocks() as demo:
|
166 |
# Use HTML for centered title
|
167 |
-
gr.HTML('<h1 style="text-align: center;">(
|
168 |
with gr.Tabs():
|
169 |
# Voice Clone Tab
|
170 |
with gr.TabItem("Voice Clone"):
|
@@ -260,7 +260,12 @@ def parse_arguments():
|
|
260 |
default=None,
|
261 |
help="Path to the model directory."
|
262 |
)
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
264 |
parser.add_argument(
|
265 |
"--server_name",
|
266 |
type=str,
|
@@ -279,8 +284,11 @@ if __name__ == "__main__":
|
|
279 |
# Parse command-line arguments
|
280 |
args = parse_arguments()
|
281 |
|
282 |
-
# Build the Gradio demo by specifying the model directory
|
283 |
-
demo = build_ui(
|
|
|
|
|
|
|
284 |
|
285 |
# Launch Gradio with the specified server name and port
|
286 |
demo.launch(
|
|
|
115 |
return save_path
|
116 |
|
117 |
|
118 |
+
def build_ui(model_dir, device=0):
|
119 |
|
120 |
global MODEL
|
121 |
|
122 |
# Initialize model with proper device handling
|
123 |
+
device = "cuda" if torch.cuda.is_available() and device != "cpu" else "cpu"
|
124 |
if MODEL is None:
|
125 |
MODEL = initialize_model(model_dir, device=device)
|
126 |
if device == "cuda":
|
|
|
164 |
|
165 |
with gr.Blocks() as demo:
|
166 |
# Use HTML for centered title
|
167 |
+
gr.HTML('<h1 style="text-align: center;">(Unofficial) Spark-TTS by SparkAudio</h1>')
|
168 |
with gr.Tabs():
|
169 |
# Voice Clone Tab
|
170 |
with gr.TabItem("Voice Clone"):
|
|
|
260 |
default=None,
|
261 |
help="Path to the model directory."
|
262 |
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--device",
|
265 |
+
type=str,
|
266 |
+
default="cpu",
|
267 |
+
help="Device to use (e.g., 'cpu' or 'cuda:0')."
|
268 |
+
)
|
269 |
parser.add_argument(
|
270 |
"--server_name",
|
271 |
type=str,
|
|
|
284 |
# Parse command-line arguments
|
285 |
args = parse_arguments()
|
286 |
|
287 |
+
# Build the Gradio demo by specifying the model directory and GPU device
|
288 |
+
demo = build_ui(
|
289 |
+
model_dir=args.model_dir,
|
290 |
+
device=args.device
|
291 |
+
)
|
292 |
|
293 |
# Launch Gradio with the specified server name and port
|
294 |
demo.launch(
|