Pendrokar commited on
Commit
3c4537d
·
1 Parent(s): faa1fb8

new TTS: Index TTS; SoVITS disabled due to mid perfo

Browse files
Files changed (3) hide show
  1. app/models.py +20 -3
  2. test_tts_index.py +50 -0
  3. test_tts_llasa.py +50 -0
app/models.py CHANGED
@@ -73,7 +73,7 @@ AVAILABLE_MODELS = {
73
  # 'Svngoku/maskgct-audio-lab': 'Svngoku/maskgct-audio-lab', # DEMANDS 300 seconds of ZeroGPU!
74
 
75
  # GPT-SoVITS
76
- 'lj1995/GPT-SoVITS-v2': 'lj1995/GPT-SoVITS-v2',
77
 
78
  # OuteTTS 500M
79
  # 'OuteAI/OuteTTS-0.2-500M-Demo': 'OuteAI/OuteTTS-0.2-500M-Demo',
@@ -107,6 +107,9 @@ AVAILABLE_MODELS = {
107
  # Orpheus
108
  'MohamedRashad/Orpheus-TTS' : 'MohamedRashad/Orpheus-TTS',
109
 
 
 
 
110
  # HF TTS w issues
111
  # 'LeeSangHoon/HierSpeech_TTS': 'LeeSangHoon/HierSpeech_TTS', # irresponsive to exclamation marks # 4.29
112
  # 'PolyAI/pheme': '/predict#0', # sleepy HF Space
@@ -505,6 +508,15 @@ HF_SPACES = {
505
  'is_zero_gpu_space': True,
506
  'series': 'Orpheus',
507
  },
 
 
 
 
 
 
 
 
 
508
  }
509
 
510
  # for zero-shot TTS - voice sample used by XTTS (11 seconds)
@@ -806,7 +818,12 @@ OVERRIDE_INPUTS = {
806
  'top_p': 0.95,
807
  'repetition_penalty': 1.1,
808
  'max_new_tokens': 1200,
809
- }
 
 
 
 
 
810
  }
811
 
812
  # minor mods to model from the same space
@@ -871,7 +888,7 @@ closed_source = [
871
  ]
872
 
873
  # top five models in order to always have one of them picked and scrutinized
874
- top_five = ['thunnai/SparkTTS']
875
 
876
  # prioritize low vote models
877
  sql = 'SELECT name FROM model WHERE (upvote + downvote) < 700 ORDER BY (upvote + downvote) ASC'
 
73
  # 'Svngoku/maskgct-audio-lab': 'Svngoku/maskgct-audio-lab', # DEMANDS 300 seconds of ZeroGPU!
74
 
75
  # GPT-SoVITS
76
+ # 'lj1995/GPT-SoVITS-v2': 'lj1995/GPT-SoVITS-v2',
77
 
78
  # OuteTTS 500M
79
  # 'OuteAI/OuteTTS-0.2-500M-Demo': 'OuteAI/OuteTTS-0.2-500M-Demo',
 
107
  # Orpheus
108
  'MohamedRashad/Orpheus-TTS' : 'MohamedRashad/Orpheus-TTS',
109
 
110
+ # Index TTS
111
+ 'IndexTeam/IndexTTS': 'IndexTeam/IndexTTS',
112
+
113
  # HF TTS w issues
114
  # 'LeeSangHoon/HierSpeech_TTS': 'LeeSangHoon/HierSpeech_TTS', # irresponsive to exclamation marks # 4.29
115
  # 'PolyAI/pheme': '/predict#0', # sleepy HF Space
 
508
  'is_zero_gpu_space': True,
509
  'series': 'Orpheus',
510
  },
511
+
512
+ 'IndexTeam/IndexTTS' : {
513
+ 'name': 'Index TTS',
514
+ 'function': '/gen_single',
515
+ 'text_param_index': 'text',
516
+ 'return_audio_index': 0,
517
+ 'is_zero_gpu_space': True,
518
+ 'series': 'Index',
519
+ },
520
  }
521
 
522
  # for zero-shot TTS - voice sample used by XTTS (11 seconds)
 
818
  'top_p': 0.95,
819
  'repetition_penalty': 1.1,
820
  'max_new_tokens': 1200,
821
+ },
822
+
823
+ # Index TTS
824
+ 'IndexTeam/IndexTTS' : {
825
+ 'prompt': DEFAULT_VOICE_SAMPLE, # voice
826
+ },
827
  }
828
 
829
  # minor mods to model from the same space
 
888
  ]
889
 
890
  # top five models in order to always have one of them picked and scrutinized
891
+ top_five = []
892
 
893
  # prioritize low vote models
894
  sql = 'SELECT name FROM model WHERE (upvote + downvote) < 700 ORDER BY (upvote + downvote) ASC'
test_tts_index.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from test_overrides import _get_param_examples, _override_params
3
+ from gradio_client import Client, file
4
+
5
+ model = "IndexTeam/IndexTTS"
6
+ client = Client(model, hf_token=os.getenv('HF_TOKEN'))
7
+ endpoints = client.view_api(all_endpoints=True, print_info=False, return_format='dict')
8
+ # print(endpoints)
9
+
10
+ api_name = '/gen_single'
11
+ fn_index = None
12
+ end_parameters = None
13
+ text = 'This is what my voice sounds like.'
14
+
15
+ end_parameters = _get_param_examples(
16
+ endpoints['named_endpoints'][api_name]['parameters']
17
+ )
18
+ print(end_parameters)
19
+
20
+
21
+ space_inputs = end_parameters
22
+ # override some or all default parameters
23
+ space_inputs = _override_params(end_parameters, model)
24
+
25
+ if(type(space_inputs) == dict):
26
+ space_inputs['text'] = text
27
+ result = client.predict(
28
+ **space_inputs,
29
+ api_name=api_name,
30
+ fn_index=fn_index
31
+ )
32
+ else:
33
+ space_inputs[0] = text
34
+ result = client.predict(
35
+ *space_inputs,
36
+ api_name=api_name,
37
+ fn_index=fn_index
38
+ )
39
+ # space_inputs = {str(i): value for i, value in enumerate(space_inputs)}
40
+
41
+ print(space_inputs)
42
+ # print(*space_inputs)
43
+ # print(**space_inputs)
44
+
45
+ # result = client.predict(
46
+ # **space_inputs,
47
+ # api_name=api_name,
48
+ # fn_index=fn_index
49
+ # )
50
+ print(result)
test_tts_llasa.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from test_overrides import _get_param_examples, _override_params
3
+ from gradio_client import Client, file
4
+
5
+ model = "srinivasbilla/llasa-3b-tts"
6
+ client = Client(model, hf_token=os.getenv('HF_TOKEN'))
7
+ endpoints = client.view_api(all_endpoints=True, print_info=False, return_format='dict')
8
+ # print(endpoints)
9
+
10
+ api_name = '/infer'
11
+ fn_index = None
12
+ end_parameters = None
13
+ text = 'This is what my voice sounds like.'
14
+
15
+ end_parameters = _get_param_examples(
16
+ endpoints['named_endpoints'][api_name]['parameters']
17
+ )
18
+ print(end_parameters)
19
+
20
+
21
+ space_inputs = end_parameters
22
+ # override some or all default parameters
23
+ space_inputs = _override_params(end_parameters, model)
24
+
25
+ if(type(space_inputs) == dict):
26
+ space_inputs['target_text'] = text
27
+ result = client.predict(
28
+ **space_inputs,
29
+ api_name=api_name,
30
+ fn_index=fn_index
31
+ )
32
+ else:
33
+ space_inputs[0] = text
34
+ result = client.predict(
35
+ *space_inputs,
36
+ api_name=api_name,
37
+ fn_index=fn_index
38
+ )
39
+ # space_inputs = {str(i): value for i, value in enumerate(space_inputs)}
40
+
41
+ print(space_inputs)
42
+ # print(*space_inputs)
43
+ # print(**space_inputs)
44
+
45
+ # result = client.predict(
46
+ # **space_inputs,
47
+ # api_name=api_name,
48
+ # fn_index=fn_index
49
+ # )
50
+ print(result)