Spaces:
Running
on
T4
Running
on
T4
Update InferenceInterfaces/ControllableInterface.py
Browse files
InferenceInterfaces/ControllableInterface.py
CHANGED
@@ -1,23 +1,31 @@
|
|
1 |
import os
|
2 |
|
3 |
import torch
|
|
|
4 |
|
5 |
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
|
6 |
from Modules.ControllabilityGAN.GAN import GanWrapper
|
7 |
-
from Utility.storage_config import MODELS_DIR
|
8 |
|
9 |
|
10 |
class ControllableInterface:
|
11 |
|
12 |
-
def __init__(self, gpu_id="cpu", available_artificial_voices=
|
13 |
if gpu_id == "cpu":
|
14 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
15 |
-
|
|
|
|
|
16 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
17 |
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
self.device = "cuda" if gpu_id != "cpu" else "cpu"
|
19 |
-
self.model = ToucanTTSInterface(device=self.device, tts_model_path=
|
20 |
-
self.wgan = GanWrapper(
|
21 |
self.generated_speaker_embeds = list()
|
22 |
self.available_artificial_voices = available_artificial_voices
|
23 |
self.current_language = ""
|
|
|
1 |
import os
|
2 |
|
3 |
import torch
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
|
6 |
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
|
7 |
from Modules.ControllabilityGAN.GAN import GanWrapper
|
|
|
8 |
|
9 |
|
10 |
class ControllableInterface:
|
11 |
|
12 |
+
def __init__(self, gpu_id="cpu", available_artificial_voices=50, tts_model_path=None, vocoder_model_path=None, embedding_gan_path=None):
|
13 |
if gpu_id == "cpu":
|
14 |
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
15 |
+
elif gpu_id == "cuda":
|
16 |
+
pass
|
17 |
+
else: # in this case we hopefully got a number.
|
18 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
19 |
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
|
20 |
+
if tts_model_path is None:
|
21 |
+
tts_model_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="ToucanTTS.pt")
|
22 |
+
if vocoder_model_path is None:
|
23 |
+
vocoder_model_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="Vocoder.pt")
|
24 |
+
if embedding_gan_path is None:
|
25 |
+
embedding_gan_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="embedding_gan.pt")
|
26 |
self.device = "cuda" if gpu_id != "cpu" else "cpu"
|
27 |
+
self.model = ToucanTTSInterface(device=self.device, tts_model_path=tts_model_path, vocoder_model_path=vocoder_model_path)
|
28 |
+
self.wgan = GanWrapper(embedding_gan_path, num_cached_voices=available_artificial_voices, device=self.device)
|
29 |
self.generated_speaker_embeds = list()
|
30 |
self.available_artificial_voices = available_artificial_voices
|
31 |
self.current_language = ""
|