Spaces:
Sleeping
Sleeping
| import torch | |
| import librosa | |
| import random | |
| import json | |
| from muq import MuQMuLan | |
| from mutagen.mp3 import MP3 | |
| import os | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from diffrhythm.model import DiT, CFM | |
| def prepare_model(device): | |
| # prepare cfm model | |
| dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-base", filename="cfm_model.pt") | |
| dit_config_path = "./diffrhythm/config/diffrhythm-1b.json" | |
| with open(dit_config_path) as f: | |
| model_config = json.load(f) | |
| dit_model_cls = DiT | |
| cfm = CFM( | |
| transformer=dit_model_cls(**model_config["model"], use_style_prompt=True), | |
| num_channels=model_config["model"]['mel_dim'], | |
| use_style_prompt=True | |
| ) | |
| cfm = cfm.to(device) | |
| cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False) | |
| # prepare tokenizer | |
| tokenizer = CNENTokenizer() | |
| # prepare muq | |
| muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large") | |
| muq = muq.to(device).eval() | |
| # prepare vae | |
| vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt") | |
| vae = torch.jit.load(vae_ckpt_path).to(device) | |
| return cfm, tokenizer, muq, vae | |
| # for song edit, will be added in the future | |
| def get_reference_latent(device, max_frames): | |
| return torch.zeros(1, max_frames, 64).to(device) | |
| def get_negative_style_prompt(device): | |
| file_path = "./prompt/negative_prompt.npy" | |
| vocal_stlye = np.load(file_path) | |
| vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512] | |
| vocal_stlye = vocal_stlye.half() | |
| return vocal_stlye | |
| def get_style_prompt(model, wav_path): | |
| mulan = model | |
| ext = os.path.splitext(wav_path)[-1].lower() | |
| if ext == '.mp3': | |
| meta = MP3(wav_path) | |
| audio_len = meta.info.length | |
| src_sr = meta.info.sample_rate | |
| elif ext == '.wav': | |
| audio, sr = librosa.load(wav_path, sr=None) | |
| audio_len = librosa.get_duration(y=audio, sr=sr) | |
| src_sr = sr | |
| else: | |
| raise ValueError("Unsupported file format: {}".format(ext)) | |
| assert(audio_len >= 10) | |
| mid_time = audio_len // 2 | |
| start_time = mid_time - 5 | |
| wav, sr = librosa.load(wav_path, sr=None, offset=start_time, duration=10) | |
| resampled_wav = librosa.resample(wav, orig_sr=src_sr, target_sr=24000) | |
| resampled_wav = torch.tensor(resampled_wav).unsqueeze(0).to(model.device) | |
| with torch.no_grad(): | |
| audio_emb = mulan(wavs = resampled_wav) # [1, 512] | |
| audio_emb = audio_emb | |
| audio_emb = audio_emb.half() | |
| return audio_emb | |
| def parse_lyrics(lyrics: str): | |
| lyrics_with_time = [] | |
| lyrics = lyrics.strip() | |
| for line in lyrics.split('\n'): | |
| try: | |
| time, lyric = line[1:9], line[10:] | |
| lyric = lyric.strip() | |
| mins, secs = time.split(':') | |
| secs = int(mins) * 60 + float(secs) | |
| lyrics_with_time.append((secs, lyric)) | |
| except: | |
| continue | |
| return lyrics_with_time | |
| class CNENTokenizer(): | |
| def __init__(self): | |
| with open('./diffrhythm/g2p/g2p/vocab.json', 'r') as file: | |
| self.phone2id:dict = json.load(file)['vocab'] | |
| self.id2phone = {v:k for (k, v) in self.phone2id.items()} | |
| # from f5_tts.g2p.g2p_generation import chn_eng_g2p | |
| from diffrhythm.g2p.g2p_generation import chn_eng_g2p | |
| self.tokenizer = chn_eng_g2p | |
| def encode(self, text): | |
| phone, token = self.tokenizer(text) | |
| token = [x+1 for x in token] | |
| return token | |
| def decode(self, token): | |
| return "|".join([self.id2phone[x-1] for x in token]) | |
| def get_lrc_token(text, tokenizer, device): | |
| max_frames = 2048 | |
| lyrics_shift = 0 | |
| sampling_rate = 44100 | |
| downsample_rate = 2048 | |
| max_secs = max_frames / (sampling_rate / downsample_rate) | |
| pad_token_id = 0 | |
| comma_token_id = 1 | |
| period_token_id = 2 | |
| lrc_with_time = parse_lyrics(text) | |
| modified_lrc_with_time = [] | |
| for i in range(len(lrc_with_time)): | |
| time, line = lrc_with_time[i] | |
| line_token = tokenizer.encode(line) | |
| modified_lrc_with_time.append((time, line_token)) | |
| lrc_with_time = modified_lrc_with_time | |
| lrc_with_time = [(time_start, line) for (time_start, line) in lrc_with_time if time_start < max_secs] | |
| lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time | |
| normalized_start_time = 0. | |
| lrc = torch.zeros((max_frames,), dtype=torch.long) | |
| tokens_count = 0 | |
| last_end_pos = 0 | |
| for time_start, line in lrc_with_time: | |
| tokens = [token if token != period_token_id else comma_token_id for token in line] + [period_token_id] | |
| tokens = torch.tensor(tokens, dtype=torch.long) | |
| num_tokens = tokens.shape[0] | |
| gt_frame_start = int(time_start * sampling_rate / downsample_rate) | |
| frame_shift = random.randint(int(lyrics_shift), int(lyrics_shift)) | |
| frame_start = max(gt_frame_start - frame_shift, last_end_pos) | |
| frame_len = min(num_tokens, max_frames - frame_start) | |
| #print(gt_frame_start, frame_shift, frame_start, frame_len, tokens_count, last_end_pos, full_pos_emb.shape) | |
| lrc[frame_start:frame_start + frame_len] = tokens[:frame_len] | |
| tokens_count += num_tokens | |
| last_end_pos = frame_start + frame_len | |
| lrc_emb = lrc.unsqueeze(0).to(device) | |
| normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device) | |
| normalized_start_time = normalized_start_time.half() | |
| return lrc_emb, normalized_start_time | |
| def load_checkpoint(model, ckpt_path, device, use_ema=True): | |
| if device == "cuda": | |
| model = model.half() | |
| ckpt_type = ckpt_path.split(".")[-1] | |
| if ckpt_type == "safetensors": | |
| from safetensors.torch import load_file | |
| checkpoint = load_file(ckpt_path) | |
| else: | |
| checkpoint = torch.load(ckpt_path, weights_only=True) | |
| if use_ema: | |
| if ckpt_type == "safetensors": | |
| checkpoint = {"ema_model_state_dict": checkpoint} | |
| checkpoint["model_state_dict"] = { | |
| k.replace("ema_model.", ""): v | |
| for k, v in checkpoint["ema_model_state_dict"].items() | |
| if k not in ["initted", "step"] | |
| } | |
| model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
| else: | |
| if ckpt_type == "safetensors": | |
| checkpoint = {"model_state_dict": checkpoint} | |
| model.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
| return model.to(device) |