Flux9665 commited on
Commit
db5766e
·
verified ·
1 Parent(s): 43f2732

Update InferenceInterfaces/ControllableInterface.py

Browse files
InferenceInterfaces/ControllableInterface.py CHANGED
@@ -1,23 +1,31 @@
1
  import os
2
 
3
  import torch
 
4
 
5
  from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
6
  from Modules.ControllabilityGAN.GAN import GanWrapper
7
- from Utility.storage_config import MODELS_DIR
8
 
9
 
10
  class ControllableInterface:
11
 
12
- def __init__(self, gpu_id="cpu", available_artificial_voices=1000):
13
  if gpu_id == "cpu":
14
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
15
- else:
 
 
16
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
17
  os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
 
 
 
 
 
 
18
  self.device = "cuda" if gpu_id != "cpu" else "cpu"
19
- self.model = ToucanTTSInterface(device=self.device, tts_model_path="Meta")
20
- self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device=self.device)
21
  self.generated_speaker_embeds = list()
22
  self.available_artificial_voices = available_artificial_voices
23
  self.current_language = ""
 
1
  import os
2
 
3
  import torch
4
+ from huggingface_hub import hf_hub_download
5
 
6
  from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
7
  from Modules.ControllabilityGAN.GAN import GanWrapper
 
8
 
9
 
10
  class ControllableInterface:
11
 
12
+ def __init__(self, gpu_id="cpu", available_artificial_voices=50, tts_model_path=None, vocoder_model_path=None, embedding_gan_path=None):
13
  if gpu_id == "cpu":
14
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
15
+ elif gpu_id == "cuda":
16
+ pass
17
+ else: # in this case we hopefully got a number.
18
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
19
  os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
20
+ if tts_model_path is None:
21
+ tts_model_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="ToucanTTS.pt")
22
+ if vocoder_model_path is None:
23
+ vocoder_model_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="Vocoder.pt")
24
+ if embedding_gan_path is None:
25
+ embedding_gan_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="embedding_gan.pt")
26
  self.device = "cuda" if gpu_id != "cpu" else "cpu"
27
+ self.model = ToucanTTSInterface(device=self.device, tts_model_path=tts_model_path, vocoder_model_path=vocoder_model_path)
28
+ self.wgan = GanWrapper(embedding_gan_path, num_cached_voices=available_artificial_voices, device=self.device)
29
  self.generated_speaker_embeds = list()
30
  self.available_artificial_voices = available_artificial_voices
31
  self.current_language = ""