Higobeatz commited on
Commit
e142cef
·
verified ·
1 Parent(s): f6fb58d

Delete dreamvoice/.ipynb_checkpoints/api-checkpoint.py

Browse files
dreamvoice/.ipynb_checkpoints/api-checkpoint.py DELETED
@@ -1,295 +0,0 @@
1
- import os
2
- import requests
3
- import yaml
4
- import torch
5
- import librosa
6
- import numpy as np
7
- import soundfile as sf
8
- from pathlib import Path
9
- from transformers import T5Tokenizer, T5EncoderModel
10
- from tqdm import tqdm
11
- from .src.vc_wrapper import ReDiffVC, DreamVC
12
- from .src.plugin_wrapper import DreamVG
13
- from .src.modules.speaker_encoder.encoder import inference as spk_encoder
14
- from .src.modules.BigVGAN.inference import load_model as load_vocoder
15
- from .src.feats.contentvec_hf import get_content_model, get_content
16
-
17
-
18
- class DreamVoice:
19
- def __init__(self, config='dreamvc.yaml', mode='plugin', device='cuda', chunk_size=16):
20
- # Initial setup
21
- script_dir = Path(__file__).resolve().parent
22
- config_path = script_dir / config
23
-
24
- # Load configuration file
25
- with open(config_path, 'r') as fp:
26
- self.config = yaml.safe_load(fp)
27
-
28
- self.script_dir = script_dir
29
-
30
- # Ensure all checkpoints are downloaded
31
- self._ensure_checkpoints_exist()
32
-
33
- # Initialize attributes
34
- self.device = device
35
- self.sr = self.config['sample_rate']
36
-
37
- # Load vocoder
38
- vocoder_path = script_dir / self.config['vocoder_path']
39
- self.hifigan, _ = load_vocoder(vocoder_path, device)
40
- self.hifigan.eval()
41
-
42
- # Load content model
43
- self.content_model = get_content_model().to(device)
44
-
45
- # Load tokenizer and text encoder
46
- lm_path = self.config['lm_path']
47
- self.tokenizer = T5Tokenizer.from_pretrained(lm_path)
48
- self.text_encoder = T5EncoderModel.from_pretrained(lm_path).to(device).eval()
49
-
50
- # Set mode
51
- self.mode = mode
52
- if mode == 'plugin':
53
- self._init_plugin_mode()
54
- elif mode == 'end2end':
55
- self._init_end2end_mode()
56
- else:
57
- raise NotImplementedError("Select mode from 'plugin' and 'end2end'")
58
-
59
- # chunk inputs to 10s clips
60
- self.chunk_size = chunk_size * 50
61
-
62
- def _ensure_checkpoints_exist(self):
63
- checkpoints = [
64
- ('vocoder_path', self.config.get('vocoder_url')),
65
- ('vocoder_config_path', self.config.get('vocoder_config_url')),
66
- ('speaker_path', self.config.get('speaker_url')),
67
- ('dreamvc.ckpt_path', self.config.get('dreamvc', {}).get('ckpt_url')),
68
- ('rediffvc.ckpt_path', self.config.get('rediffvc', {}).get('ckpt_url')),
69
- ('dreamvg.ckpt_path', self.config.get('dreamvg', {}).get('ckpt_url'))
70
- ]
71
-
72
- for path_key, url in checkpoints:
73
- local_path = self._get_local_path(path_key)
74
- if not local_path.exists() and url:
75
- print(f"Downloading {path_key} from {url}")
76
- self._download_file(url, local_path)
77
-
78
- def _get_local_path(self, path_key):
79
- keys = path_key.split('.')
80
- local_path = self.config
81
- for key in keys:
82
- local_path = local_path.get(key, {})
83
- return self.script_dir / local_path
84
-
85
- def _download_file(self, url, local_path):
86
- try:
87
- # Attempt to send a GET request to the URL
88
- response = requests.get(url, stream=True)
89
- response.raise_for_status() # Ensure we raise an exception for HTTP errors
90
- except requests.exceptions.RequestException as e:
91
- # Log the error for debugging purposes
92
- print(f"Error encountered: {e}")
93
-
94
- # Development mode: prompt user for Hugging Face API key
95
- user_input = input("Private checkpoint, please request authorization and enter your Hugging Face API key.")
96
- self.hf_key = user_input if user_input else None
97
-
98
- # Set headers if an API key is provided
99
- headers = {'Authorization': f'Bearer {self.hf_key}'} if self.hf_key else {}
100
-
101
- try:
102
- # Attempt to send a GET request with headers in development mode
103
- response = requests.get(url, stream=True, headers=headers)
104
- response.raise_for_status() # Ensure we raise an exception for HTTP errors
105
- except requests.exceptions.RequestException as e:
106
- # Log the error for debugging purposes
107
- print(f"Error encountered in dev mode: {e}")
108
- response = None # Handle response accordingly in your code
109
-
110
- local_path.parent.mkdir(parents=True, exist_ok=True)
111
-
112
- total_size = int(response.headers.get('content-length', 0))
113
- block_size = 8192
114
- t = tqdm(total=total_size, unit='iB', unit_scale=True)
115
-
116
- with open(local_path, 'wb') as f:
117
- for chunk in response.iter_content(chunk_size=block_size):
118
- t.update(len(chunk))
119
- f.write(chunk)
120
- t.close()
121
-
122
- def _init_plugin_mode(self):
123
- # Initialize ReDiffVC
124
- self.dreamvc = ReDiffVC(
125
- config_path=self.script_dir / self.config['rediffvc']['config_path'],
126
- ckpt_path=self.script_dir / self.config['rediffvc']['ckpt_path'],
127
- device=self.device
128
- )
129
-
130
- # Initialize DreamVG
131
- self.dreamvg = DreamVG(
132
- config_path=self.script_dir / self.config['dreamvg']['config_path'],
133
- ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'],
134
- device=self.device
135
- )
136
-
137
- # Load speaker encoder
138
- spk_encoder.load_model(self.script_dir / self.config['speaker_path'], self.device)
139
- self.spk_encoder = spk_encoder
140
- self.spk_embed_cache = None
141
-
142
- def _init_end2end_mode(self):
143
- # Initialize DreamVC
144
- self.dreamvc = DreamVC(
145
- config_path=self.script_dir / self.config['dreamvc']['config_path'],
146
- ckpt_path=self.script_dir / self.config['dreamvc']['ckpt_path'],
147
- device=self.device
148
- )
149
-
150
- def _load_content(self, audio_path):
151
- content_audio, _ = librosa.load(audio_path, sr=16000)
152
- # Calculate the required length to make it a multiple of 16*160
153
- target_length = ((len(content_audio) + 16*160 - 1) // (16*160)) * (16*160)
154
- # Pad with zeros if necessary
155
- if len(content_audio) < target_length:
156
- content_audio = np.pad(content_audio, (0, target_length - len(content_audio)), mode='constant')
157
- content_audio = torch.tensor(content_audio).unsqueeze(0).to(self.device)
158
- content_clip = get_content(self.content_model, content_audio)
159
- return content_clip
160
-
161
- def load_spk_embed(self, emb_path):
162
- self.spk_embed_cache = torch.load(emb_path, map_location=self.device)
163
-
164
- def save_spk_embed(self, emb_path):
165
- assert self.spk_embed_cache is not None
166
- torch.save(self.spk_embed_cache.cpu(), emb_path)
167
-
168
- def save_audio(self, output_path, audio, sr):
169
- sf.write(output_path, audio, samplerate=sr)
170
-
171
- @torch.no_grad()
172
- def genvc(self, content_audio, prompt,
173
- prompt_guidance_scale=3, prompt_guidance_rescale=0.0,
174
- prompt_ddim_steps=100, prompt_eta=1, prompt_random_seed=None,
175
- vc_guidance_scale=3, vc_guidance_rescale=0.7,
176
- vc_ddim_steps=50, vc_eta=1, vc_random_seed=None,
177
- ):
178
-
179
- content_clip = self._load_content(content_audio)
180
-
181
- text_batch = self.tokenizer(prompt, max_length=32,
182
- padding='max_length', truncation=True, return_tensors="pt")
183
- text, text_mask = text_batch.input_ids.to(self.device), \
184
- text_batch.attention_mask.to(self.device)
185
- text = self.text_encoder(input_ids=text, attention_mask=text_mask)[0]
186
-
187
- if self.mode == 'plugin':
188
- spk_embed = self.dreamvg.inference([text, text_mask],
189
- guidance_scale=prompt_guidance_scale,
190
- guidance_rescale=prompt_guidance_rescale,
191
- ddim_steps=prompt_ddim_steps, eta=prompt_eta,
192
- random_seed=prompt_random_seed)
193
-
194
- B, L, D = content_clip.shape
195
- gen_audio_chunks = []
196
- num_chunks = (L + self.chunk_size - 1) // self.chunk_size
197
- for i in range(num_chunks):
198
- start_idx = i * self.chunk_size
199
- end_idx = min((i + 1) * self.chunk_size, L)
200
- content_clip_chunk = content_clip[:, start_idx:end_idx, :]
201
-
202
- gen_audio_chunk = self.dreamvc.inference(
203
- spk_embed, content_clip_chunk, None,
204
- guidance_scale=vc_guidance_scale,
205
- guidance_rescale=vc_guidance_rescale,
206
- ddim_steps=vc_ddim_steps,
207
- eta=vc_eta,
208
- random_seed=vc_random_seed)
209
-
210
- gen_audio_chunks.append(gen_audio_chunk)
211
-
212
- gen_audio = torch.cat(gen_audio_chunks, dim=-1)
213
-
214
- self.spk_embed_cache = spk_embed
215
-
216
- elif self.mode == 'end2end':
217
- B, L, D = content_clip.shape
218
- gen_audio_chunks = []
219
- num_chunks = (L + self.chunk_size - 1) // self.chunk_size
220
-
221
- for i in range(num_chunks):
222
- start_idx = i * self.chunk_size
223
- end_idx = min((i + 1) * self.chunk_size, L)
224
- content_clip_chunk = content_clip[:, start_idx:end_idx, :]
225
-
226
- gen_audio_chunk = self.dreamvc.inference([text, text_mask], content_clip,
227
- guidance_scale=prompt_guidance_scale,
228
- guidance_rescale=prompt_guidance_rescale,
229
- ddim_steps=prompt_ddim_steps,
230
- eta=prompt_eta, random_seed=prompt_random_seed)
231
- gen_audio_chunks.append(gen_audio_chunk)
232
-
233
- gen_audio = torch.cat(gen_audio_chunks, dim=-1)
234
-
235
- else:
236
- raise NotImplementedError("Select mode from 'plugin' and 'end2end'")
237
-
238
- gen_audio = self.hifigan(gen_audio.squeeze(1))
239
- gen_audio = gen_audio.cpu().numpy().squeeze(0).squeeze(0)
240
-
241
- return gen_audio, self.sr
242
-
243
- @torch.no_grad()
244
- def simplevc(self, content_audio, speaker_audio=None, use_spk_cache=False,
245
- vc_guidance_scale=3, vc_guidance_rescale=0.7,
246
- vc_ddim_steps=50, vc_eta=1, vc_random_seed=None,
247
- ):
248
-
249
- assert self.mode == 'plugin'
250
- if speaker_audio is not None:
251
- speaker_audio, _ = librosa.load(speaker_audio, sr=16000)
252
- speaker_audio = torch.tensor(speaker_audio).unsqueeze(0).to(self.device)
253
- spk_embed = spk_encoder.embed_utterance_batch(speaker_audio)
254
- self.spk_embed_cache = spk_embed
255
- elif use_spk_cache:
256
- assert self.spk_embed_cache is not None
257
- spk_embed = self.spk_embed_cache
258
- else:
259
- raise NotImplementedError
260
-
261
- content_clip = self._load_content(content_audio)
262
-
263
- B, L, D = content_clip.shape
264
- gen_audio_chunks = []
265
- num_chunks = (L + self.chunk_size - 1) // self.chunk_size
266
- for i in range(num_chunks):
267
- start_idx = i * self.chunk_size
268
- end_idx = min((i + 1) * self.chunk_size, L)
269
- content_clip_chunk = content_clip[:, start_idx:end_idx, :]
270
-
271
- gen_audio_chunk = self.dreamvc.inference(
272
- spk_embed, content_clip_chunk, None,
273
- guidance_scale=vc_guidance_scale,
274
- guidance_rescale=vc_guidance_rescale,
275
- ddim_steps=vc_ddim_steps,
276
- eta=vc_eta,
277
- random_seed=vc_random_seed)
278
-
279
- gen_audio_chunks.append(gen_audio_chunk)
280
-
281
- gen_audio = torch.cat(gen_audio_chunks, dim=-1)
282
-
283
- gen_audio = self.hifigan(gen_audio.squeeze(1))
284
- gen_audio = gen_audio.cpu().numpy().squeeze(0).squeeze(0)
285
-
286
- return gen_audio, self.sr
287
-
288
-
289
- if __name__ == '__main__':
290
- dreamvoice = DreamVoice(config='dreamvc.yaml', mode='plugin', device='cuda')
291
- content_audio = 'test.wav'
292
- speaker_audio = 'speaker.wav'
293
- prompt = 'young female voice, sounds young and cute'
294
- gen_audio, sr = dreamvoice.genvc('test.wav', prompt)
295
- dreamvoice.save_audio('debug.wav', gen_audio, sr)