Higobeatz commited on
Commit
b12410b
·
verified ·
1 Parent(s): 81ccb24

Delete dreamvoice/src/.ipynb_checkpoints

Browse files
dreamvoice/src/.ipynb_checkpoints/plugin_wrapper-checkpoint.py DELETED
@@ -1,76 +0,0 @@
1
- import yaml
2
- import torch
3
- from diffusers import DDIMScheduler
4
- from .model.p2e_cross import P2E_Cross
5
- from .utils import scale_shift, scale_shift_re, rescale_noise_cfg
6
-
7
-
8
- class DreamVG(object):
9
- def __init__(self,
10
- config_path='configs/plugin_cross.yaml',
11
- ckpt_path='../ckpts/dreamvc_plugin.pt',
12
- device='cpu'):
13
-
14
- with open(config_path, 'r') as fp:
15
- config = yaml.safe_load(fp)
16
-
17
- self.device = device
18
- self.model = P2E_Cross(config['model']).to(device)
19
- self.model.load_state_dict(torch.load(ckpt_path)['model'])
20
- self.model.eval()
21
-
22
- noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
23
- beta_start=config['scheduler']['beta_start'],
24
- beta_end=config['scheduler']['beta_end'],
25
- rescale_betas_zero_snr=True,
26
- timestep_spacing="trailing",
27
- clip_sample=False,
28
- prediction_type='v_prediction')
29
- self.noise_scheduler = noise_scheduler
30
- self.scale = config['scheduler']['scale']
31
- self.shift = config['scheduler']['shift']
32
- self.spk_shape = config['model']['unet']['in_channels']
33
-
34
- @torch.no_grad()
35
- def inference(self, text,
36
- guidance_scale=5, guidance_rescale=0.7,
37
- ddim_steps=50, eta=1, random_seed=2023,
38
- ):
39
- text, text_mask = text
40
- self.model.eval()
41
-
42
- gen_shape = (1, self.spk_shape)
43
-
44
- if random_seed is not None:
45
- generator = torch.Generator(device=self.device).manual_seed(random_seed)
46
- else:
47
- generator = torch.Generator(device=self.device)
48
- generator.seed()
49
-
50
- self.noise_scheduler.set_timesteps(ddim_steps)
51
-
52
- # init noise
53
- noise = torch.randn(gen_shape, generator=generator, device=self.device)
54
- latents = noise
55
-
56
- for t in self.noise_scheduler.timesteps:
57
- latents = self.noise_scheduler.scale_model_input(latents, t)
58
-
59
- if guidance_scale:
60
- output_text = self.model(latents, t, text, text_mask, train_cfg=False)
61
- output_uncond = self.model(latents, t, text, text_mask, train_cfg=True, cfg_prob=1.0)
62
-
63
- output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
64
- if guidance_rescale > 0.0:
65
- output_pred = rescale_noise_cfg(output_pred, output_text,
66
- guidance_rescale=guidance_rescale)
67
- else:
68
- output_pred = self.model(latents, t, text, text_mask, train_cfg=False)
69
-
70
- latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
71
- eta=eta, generator=generator).prev_sample
72
-
73
- # pred = reverse_minmax_norm_diff(latents, vmin=0.0, vmax=0.5)
74
- pred = scale_shift_re(latents, 1/self.scale, self.shift)
75
- # pred = torch.clip(pred, min=0.0, max=0.5)
76
- return pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dreamvoice/src/.ipynb_checkpoints/vc_wrapper-checkpoint.py DELETED
@@ -1,144 +0,0 @@
1
- import yaml
2
- import torch
3
- from diffusers import DDIMScheduler
4
- from .model.model import DiffVC
5
- from .model.model_cross import DiffVC_Cross
6
- from .utils import scale_shift, scale_shift_re, rescale_noise_cfg
7
-
8
-
9
- class ReDiffVC(object):
10
- def __init__(self,
11
- config_path='configs/diffvc_base.yaml',
12
- ckpt_path='../ckpts/dreamvc_base.pt',
13
- device='cpu'):
14
-
15
- with open(config_path, 'r') as fp:
16
- config = yaml.safe_load(fp)
17
-
18
- self.device = device
19
- self.model = DiffVC(config['model']).to(device)
20
- self.model.load_state_dict(torch.load(ckpt_path)['model'])
21
- self.model.eval()
22
-
23
- noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
24
- beta_start=config['scheduler']['beta_start'],
25
- beta_end=config['scheduler']['beta_end'],
26
- rescale_betas_zero_snr=True,
27
- timestep_spacing="trailing",
28
- clip_sample=False,
29
- prediction_type='v_prediction')
30
- self.noise_scheduler = noise_scheduler
31
- self.scale = config['scheduler']['scale']
32
- self.shift = config['scheduler']['shift']
33
- self.melshape = config['model']['unet']['sample_size'][0]
34
-
35
- @torch.no_grad()
36
- def inference(self,
37
- spk_embed, content_clip, f0_clip=None,
38
- guidance_scale=3, guidance_rescale=0.7,
39
- ddim_steps=50, eta=1, random_seed=2023):
40
-
41
- self.model.eval()
42
- if random_seed is not None:
43
- generator = torch.Generator(device=self.device).manual_seed(random_seed)
44
- else:
45
- generator = torch.Generator(device=self.device)
46
- generator.seed()
47
-
48
- self.noise_scheduler.set_timesteps(ddim_steps)
49
-
50
- # init noise
51
- gen_shape = (1, 1, self.melshape, content_clip.shape[-2])
52
- noise = torch.randn(gen_shape, generator=generator, device=self.device)
53
- latents = noise
54
-
55
- for t in self.noise_scheduler.timesteps:
56
- latents = self.noise_scheduler.scale_model_input(latents, t)
57
-
58
- if guidance_scale:
59
- output_text = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False)
60
- output_uncond = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=True,
61
- speaker_cfg=1.0, pitch_cfg=0.0)
62
-
63
- output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
64
- if guidance_rescale > 0.0:
65
- output_pred = rescale_noise_cfg(output_pred, output_text,
66
- guidance_rescale=guidance_rescale)
67
- else:
68
- output_pred = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False)
69
-
70
- latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
71
- eta=eta, generator=generator).prev_sample
72
-
73
- pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift)
74
- return pred
75
-
76
-
77
- class DreamVC(object):
78
- def __init__(self,
79
- config_path='configs/diffvc_cross.yaml',
80
- ckpt_path='../ckpts/dreamvc_cross.pt',
81
- device='cpu'):
82
-
83
- with open(config_path, 'r') as fp:
84
- config = yaml.safe_load(fp)
85
-
86
- self.device = device
87
- self.model = DiffVC_Cross(config['model']).to(device)
88
- self.model.load_state_dict(torch.load(ckpt_path)['model'])
89
- self.model.eval()
90
-
91
- noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
92
- beta_start=config['scheduler']['beta_start'],
93
- beta_end=config['scheduler']['beta_end'],
94
- rescale_betas_zero_snr=True,
95
- timestep_spacing="trailing",
96
- clip_sample=False,
97
- prediction_type='v_prediction')
98
- self.noise_scheduler = noise_scheduler
99
- self.scale = config['scheduler']['scale']
100
- self.shift = config['scheduler']['shift']
101
- self.melshape = config['model']['unet']['sample_size'][0]
102
-
103
- @torch.no_grad()
104
- def inference(self,
105
- text, content_clip, f0_clip=None,
106
- guidance_scale=3, guidance_rescale=0.7,
107
- ddim_steps=50, eta=1, random_seed=2023):
108
-
109
- text, text_mask = text
110
- self.model.eval()
111
- if random_seed is not None:
112
- generator = torch.Generator(device=self.device).manual_seed(random_seed)
113
- else:
114
- generator = torch.Generator(device=self.device)
115
- generator.seed()
116
-
117
- self.noise_scheduler.set_timesteps(ddim_steps)
118
-
119
- # init noise
120
- gen_shape = (1, 1, self.melshape, content_clip.shape[-2])
121
- noise = torch.randn(gen_shape, generator=generator, device=self.device)
122
- latents = noise
123
-
124
- for t in self.noise_scheduler.timesteps:
125
- latents = self.noise_scheduler.scale_model_input(latents, t)
126
-
127
- if guidance_scale:
128
- output_text = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False)
129
- output_uncond = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=True,
130
- speaker_cfg=1.0, pitch_cfg=0.0)
131
-
132
- output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
133
- if guidance_rescale > 0.0:
134
- output_pred = rescale_noise_cfg(output_pred, output_text,
135
- guidance_rescale=guidance_rescale)
136
- else:
137
- output_pred = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False)
138
-
139
- latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
140
- eta=eta, generator=generator).prev_sample
141
-
142
- pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift)
143
- return pred
144
-