Pendrokar commited on
Commit
ba6c0cb
·
1 Parent(s): 76a480c

parler multi & dia test scripts

Browse files
Files changed (2) hide show
  1. test_tts_dia.py +46 -0
  2. test_tts_parler_multi.py +46 -0
test_tts_dia.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from test_overrides import _get_param_examples, _override_params
3
+ from gradio_client import Client, file
4
+
5
+ model = "nari-labs/Dia-1.6B"
6
+ client = Client(model, hf_token=os.getenv('HF_TOKEN'))
7
+ endpoints = client.view_api(all_endpoints=True, print_info=False, return_format='dict')
8
+
9
+ api_name = '/generate_audio'
10
+ fn_index = None
11
+ end_parameters = None
12
+ text = '[2] This is what my voice sounds like. \n[2] \n[2]'
13
+
14
+ end_parameters = _get_param_examples(
15
+ endpoints['named_endpoints'][api_name]['parameters']
16
+ )
17
+ print(end_parameters)
18
+
19
+
20
+ # override some or all default parameters
21
+ space_inputs = _override_params(end_parameters, model)
22
+
23
+ if(type(space_inputs) == dict):
24
+ space_inputs['text_input'] = text
25
+ result = client.predict(
26
+ **space_inputs,
27
+ api_name=api_name
28
+ )
29
+ else:
30
+ space_inputs[0] = text
31
+ result = client.predict(
32
+ *space_inputs,
33
+ api_name=api_name
34
+ )
35
+ # space_inputs = {str(i): value for i, value in enumerate(space_inputs)}
36
+
37
+ print(space_inputs)
38
+ # print(*space_inputs)
39
+ # print(**space_inputs)
40
+
41
+ # result = client.predict(
42
+ # **space_inputs,
43
+ # api_name=api_name,
44
+ # fn_index=fn_index
45
+ # )
46
+ print(result)
test_tts_parler_multi.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from test_overrides import _get_param_examples, _override_params
3
+ from gradio_client import Client, file
4
+
5
+ model = "PHBJT/multi_parler_tts"
6
+ client = Client(model, hf_token=os.getenv('HF_TOKEN'))
7
+ endpoints = client.view_api(all_endpoints=True, print_info=False, return_format='dict')
8
+
9
+ api_name = '/gen_tts'
10
+ fn_index = None
11
+ end_parameters = None
12
+ text = 'This is what my voice sounds like.'
13
+
14
+ end_parameters = _get_param_examples(
15
+ endpoints['named_endpoints'][api_name]['parameters']
16
+ )
17
+ print(end_parameters)
18
+
19
+
20
+ # override some or all default parameters
21
+ space_inputs = _override_params(end_parameters, model)
22
+
23
+ if(type(space_inputs) == dict):
24
+ space_inputs['text'] = text
25
+ result = client.predict(
26
+ **space_inputs,
27
+ api_name=api_name
28
+ )
29
+ else:
30
+ space_inputs[0] = text
31
+ result = client.predict(
32
+ *space_inputs,
33
+ api_name=api_name
34
+ )
35
+ # space_inputs = {str(i): value for i, value in enumerate(space_inputs)}
36
+
37
+ print(space_inputs)
38
+ # print(*space_inputs)
39
+ # print(**space_inputs)
40
+
41
+ # result = client.predict(
42
+ # **space_inputs,
43
+ # api_name=api_name,
44
+ # fn_index=fn_index
45
+ # )
46
+ print(result)