Steveeeeeeen HF staff commited on
Commit
182306d
·
verified ·
1 Parent(s): 7be3eaa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ import torch
4
+ import soundfile as sf
5
+ from xcodec2.modeling_xcodec2 import XCodec2Model
6
+ import torchaudio
7
+ import gradio as gr
8
+ import tempfile
9
+
10
+ import os
11
+ api_key = os.getenv("HF_TOKEN")
12
+
13
+ from huggingface_hub import login
14
+ login(token=api_key)
15
+
16
+ llasa_3b ='Steveeeeeeen/Llasagna-v0.1'
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
19
+
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ llasa_3b,
22
+ trust_remote_code=True,
23
+ device_map='cuda',
24
+ )
25
+
26
+ model_path = "srinivasbilla/xcodec2"
27
+
28
+ Codec_model = XCodec2Model.from_pretrained(model_path)
29
+ Codec_model.eval().cuda()
30
+
31
+ whisper_turbo_pipe = pipeline(
32
+ "automatic-speech-recognition",
33
+ model="openai/whisper-large-v3-turbo",
34
+ torch_dtype=torch.float16,
35
+ device='cuda',
36
+ )
37
+
38
+ SPEAKERS = {
39
+ "Male 1": {
40
+ "path": "speakers/female_1.mp3",
41
+ "transcript": "e lo stesso alessi che andò ad aprire non riconobbe antoni il quale tornava con la sporta sotto il braccio tanto era mutato coperto di polvere e con la barba lungacome fu entrato e si fu messo a sedere in un cantuccio non osavano quasi fargli festa.",
42
+ "description": "Una voce femminile.",
43
+ },
44
+ }
45
+
46
+ def preview_speaker(display_name):
47
+ """Returns the audio and transcript for preview"""
48
+ speaker_name = speaker_display_dict[display_name]
49
+ if speaker_name in SPEAKERS:
50
+ waveform, sample_rate = torchaudio.load(SPEAKERS[speaker_name]["path"])
51
+ return (sample_rate, waveform[0].numpy()), SPEAKERS[speaker_name]["transcript"]
52
+ return None, ""
53
+
54
+
55
+ def ids_to_speech_tokens(speech_ids):
56
+
57
+ speech_tokens_str = []
58
+ for speech_id in speech_ids:
59
+ speech_tokens_str.append(f"<|s_{speech_id}|>")
60
+ return speech_tokens_str
61
+
62
+ def extract_speech_ids(speech_tokens_str):
63
+
64
+ speech_ids = []
65
+ for token_str in speech_tokens_str:
66
+ if token_str.startswith('<|s_') and token_str.endswith('|>'):
67
+ num_str = token_str[4:-2]
68
+
69
+ num = int(num_str)
70
+ speech_ids.append(num)
71
+ else:
72
+ print(f"Unexpected token: {token_str}")
73
+ return speech_ids
74
+
75
+ @spaces.GPU(duration=60)
76
+ def infer(sample_audio_path, target_text, progress=gr.Progress()):
77
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
78
+ progress(0, 'Loading and trimming audio...')
79
+ waveform, sample_rate = torchaudio.load(sample_audio_path)
80
+ if len(waveform[0])/sample_rate > 15:
81
+ gr.Warning("Trimming audio to first 15secs.")
82
+ waveform = waveform[:, :sample_rate*15]
83
+
84
+ # Check if the audio is stereo (i.e., has more than one channel)
85
+ if waveform.size(0) > 1:
86
+ # Convert stereo to mono by averaging the channels
87
+ waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
88
+ else:
89
+ # If already mono, just use the original waveform
90
+ waveform_mono = waveform
91
+
92
+ prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
93
+ prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip()
94
+ progress(0.5, 'Transcribed! Generating speech...')
95
+
96
+ if len(target_text) == 0:
97
+ return None
98
+ elif len(target_text) > 300:
99
+ gr.Warning("Text is too long. Please keep it under 300 characters.")
100
+ target_text = target_text[:300]
101
+
102
+ input_text = prompt_text + ' ' + target_text
103
+
104
+ #TTS start!
105
+ with torch.no_grad():
106
+ # Encode the prompt wav
107
+ vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
108
+
109
+ vq_code_prompt = vq_code_prompt[0,0,:]
110
+ # Convert int 12345 to token <|s_12345|>
111
+ speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
112
+
113
+ formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
114
+
115
+ # Tokenize the text and the speech prefix
116
+ chat = [
117
+ {"role": "user", "content": "Convert the text to speech:" + formatted_text},
118
+ {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
119
+ ]
120
+
121
+ input_ids = tokenizer.apply_chat_template(
122
+ chat,
123
+ tokenize=True,
124
+ return_tensors='pt',
125
+ continue_final_message=True
126
+ )
127
+ input_ids = input_ids.to('cuda')
128
+ speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
129
+
130
+ # Generate the speech autoregressively
131
+ outputs = model.generate(
132
+ input_ids,
133
+ max_length=2048, # We trained our model with a max length of 2048
134
+ eos_token_id= speech_end_id ,
135
+ do_sample=True,
136
+ top_p=1,
137
+ temperature=0.8
138
+ )
139
+ # Extract the speech tokens
140
+ generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
141
+
142
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
143
+
144
+ # Convert token <|s_23456|> to int 23456
145
+ speech_tokens = extract_speech_ids(speech_tokens)
146
+
147
+ speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
148
+
149
+ # Decode the speech tokens to speech waveform
150
+ gen_wav = Codec_model.decode_code(speech_tokens)
151
+
152
+ # if only need the generated part
153
+ gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
154
+
155
+ progress(1, 'Synthesized!')
156
+
157
+ return (16000, gen_wav[0, 0, :].cpu().numpy())
158
+
159
+ with gr.Blocks() as app_tts:
160
+ gr.Markdown("# Zero Shot Voice Clone TTS")
161
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
162
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
163
+
164
+ generate_btn = gr.Button("Synthesize", variant="primary")
165
+
166
+ audio_output = gr.Audio(label="Synthesized Audio")
167
+
168
+ generate_btn.click(
169
+ infer,
170
+ inputs=[
171
+ ref_audio_input,
172
+ gen_text_input,
173
+ ],
174
+ outputs=[audio_output],
175
+ )
176
+
177
+ with gr.Blocks() as app_credits:
178
+ gr.Markdown("""
179
+ # Credits
180
+
181
+ * [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training)
182
+ * [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
183
+ """)
184
+
185
+ with gr.Blocks() as app:
186
+ gr.HTML("<img src='https://huggingface.co/datasets/Steveeeeeeen/random_images/blob/main/llasagna.png' alt='Llasagna' style='width: 100%; height: auto;'>", elem_id="banner")
187
+ gr.Markdown(
188
+ """
189
+ # Llasagna 1b TTS
190
+
191
+ This is a local web UI for Llasagna 1b Zero Shot Voice Cloning and TTS model.
192
+
193
+ The checkpoints support English and Chinese.
194
+
195
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
196
+ """
197
+ )
198
+ gr.TabbedInterface([app_tts], ["TTS"])
199
+
200
+
201
+ app.launch(ssr_mode=False, share=True)