Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		seungheondoh
		
	commited on
		
		
					Commit 
							
							·
						
						e48ca55
	
1
								Parent(s):
							
							7ccf3fd
								
add model
Browse files- app.py +81 -4
- model/bart.py +151 -0
- model/modules.py +95 -0
- utils/audio_utils.py +247 -0
    	
        app.py
    CHANGED
    
    | @@ -1,7 +1,84 @@ | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 2 |  | 
| 3 | 
            -
             | 
| 4 | 
            -
                 | 
|  | |
|  | |
| 5 |  | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
             
            import gradio as gr
         | 
| 4 | 
            +
            from timeit import default_timer as timer
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import pandas as pd
         | 
| 8 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 9 | 
            +
            from model.bart import BartCaptionModel
         | 
| 10 | 
            +
            from utils.audio_utils import load_audio, STR_CH_FIRST
         | 
| 11 |  | 
| 12 | 
            +
            if os.path.isfile("transfer.pth") == False:
         | 
| 13 | 
            +
                torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth', 'transfer.pth')
         | 
| 14 | 
            +
                torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/electronic.mp3', 'electronic.mp3')
         | 
| 15 | 
            +
                torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/orchestra.wav', 'orchestra.wav')
         | 
| 16 |  | 
| 17 | 
            +
            device = "cuda:0" if torch.cuda.is_available() else "cpu"
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            example_list = ['electronic.mp3', 'orchestra.wav']
         | 
| 20 | 
            +
            model = BartCaptionModel(max_length = 128)
         | 
| 21 | 
            +
            pretrained_object = torch.load('./transfer.pth', map_location='cpu')
         | 
| 22 | 
            +
            state_dict = pretrained_object['state_dict']
         | 
| 23 | 
            +
            model.load_state_dict(state_dict)
         | 
| 24 | 
            +
            torch.cuda.set_device(device)
         | 
| 25 | 
            +
            model = model.cuda(device)
         | 
| 26 | 
            +
            model.eval()
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def get_audio(audio_path, duration=10, target_sr=16000):
         | 
| 29 | 
            +
                n_samples = int(duration * target_sr)
         | 
| 30 | 
            +
                audio, sr = load_audio(
         | 
| 31 | 
            +
                    path= audio_path,
         | 
| 32 | 
            +
                    ch_format= STR_CH_FIRST,
         | 
| 33 | 
            +
                    sample_rate= target_sr,
         | 
| 34 | 
            +
                    downmix_to_mono= True,
         | 
| 35 | 
            +
                )
         | 
| 36 | 
            +
                if len(audio.shape) == 2:
         | 
| 37 | 
            +
                    audio = audio.mean(0, False)  # to mono
         | 
| 38 | 
            +
                input_size = int(n_samples)
         | 
| 39 | 
            +
                if audio.shape[-1] < input_size:  # pad sequence
         | 
| 40 | 
            +
                    pad = np.zeros(input_size)
         | 
| 41 | 
            +
                    pad[: audio.shape[-1]] = audio
         | 
| 42 | 
            +
                    audio = pad
         | 
| 43 | 
            +
                ceil = int(audio.shape[-1] // n_samples)
         | 
| 44 | 
            +
                audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
         | 
| 45 | 
            +
                return audio
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            def captioning(audio_path):
         | 
| 48 | 
            +
                audio_tensor = get_audio(audio_path = audio_path)
         | 
| 49 | 
            +
                if device is not None:
         | 
| 50 | 
            +
                    audio_tensor = audio_tensor.to(device)
         | 
| 51 | 
            +
                with torch.no_grad():
         | 
| 52 | 
            +
                    output = model.generate(
         | 
| 53 | 
            +
                        samples=audio_tensor,
         | 
| 54 | 
            +
                        num_beams=5,
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                inference = ""
         | 
| 57 | 
            +
                number_of_chunks = range(audio_tensor.shape[0])
         | 
| 58 | 
            +
                for chunk, text in zip(number_of_chunks, output):
         | 
| 59 | 
            +
                    time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
         | 
| 60 | 
            +
                    inference += f"{time}\n{text} \n \n"
         | 
| 61 | 
            +
                return inference
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            title = "Interactive demo: Music Captioning 🤖🎵"
         | 
| 64 | 
            +
            description = """
         | 
| 65 | 
            +
            <p style='text-align: center'> LP-MusicCaps: LLM-Based Pseudo Music Captioning</p> 
         | 
| 66 | 
            +
            <p style='text-align: center'> SeungHeon Doh, Keunwoo Choi, Jongpil Lee, Juhan Nam, ISMIR 2023</p> 
         | 
| 67 | 
            +
            <p style='text-align: center'> <a href='#' target='_blank'>ArXiv</a> | <a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>Github</a> | <a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>LP-MusicCaps-Dataset</a> </p>
         | 
| 68 | 
            +
            <p style='text-align: center'> To use it, simply upload your audio and click 'submit', or click one of the examples to load them. Read more at the links below. </p>
         | 
| 69 | 
            +
            """
         | 
| 70 | 
            +
            article = "<p style='text-align: center'><a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>LP-MusicCaps Github</a> | <a href='#' target='_blank'>LP-MusicCaps Paper</a></p>"
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            demo = gr.Interface(fn=captioning,
         | 
| 74 | 
            +
                                inputs=gr.Audio(type="filepath"),
         | 
| 75 | 
            +
                                outputs=[
         | 
| 76 | 
            +
                                    gr.Textbox(label="Caption generated by LP-MusicCaps Transfer Model"),
         | 
| 77 | 
            +
                                    ],
         | 
| 78 | 
            +
                                examples=example_list,
         | 
| 79 | 
            +
                                title=title,
         | 
| 80 | 
            +
                                description=description,
         | 
| 81 | 
            +
                                article=article, 
         | 
| 82 | 
            +
                                cache_examples=False
         | 
| 83 | 
            +
                                )
         | 
| 84 | 
            +
            demo.launch()
         | 
    	
        model/bart.py
    ADDED
    
    | @@ -0,0 +1,151 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from .modules import AudioEncoder
         | 
| 6 | 
            +
            from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            class BartCaptionModel(nn.Module):
         | 
| 9 | 
            +
                def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
         | 
| 10 | 
            +
                    super(BartCaptionModel, self).__init__()
         | 
| 11 | 
            +
                    # non-finetunning case
         | 
| 12 | 
            +
                    bart_config = BartConfig.from_pretrained(bart_type)
         | 
| 13 | 
            +
                    self.tokenizer = BartTokenizer.from_pretrained(bart_type)
         | 
| 14 | 
            +
                    self.bart = BartForConditionalGeneration(bart_config)
         | 
| 15 | 
            +
                    
         | 
| 16 | 
            +
                    self.n_sample = sr * duration
         | 
| 17 | 
            +
                    self.hop_length = int(0.01 * sr) # hard coding hop_size
         | 
| 18 | 
            +
                    self.n_frames = int(self.n_sample // self.hop_length)
         | 
| 19 | 
            +
                    self.num_of_stride_conv = num_of_conv - 1
         | 
| 20 | 
            +
                    self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
         | 
| 21 | 
            +
                    self.audio_encoder = AudioEncoder(
         | 
| 22 | 
            +
                        n_mels = n_mels, # hard coding n_mel
         | 
| 23 | 
            +
                        n_ctx = self.n_ctx, 
         | 
| 24 | 
            +
                        audio_dim = audio_dim, 
         | 
| 25 | 
            +
                        text_dim = self.bart.config.hidden_size,
         | 
| 26 | 
            +
                        num_of_stride_conv = self.num_of_stride_conv
         | 
| 27 | 
            +
                    )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    self.max_length = max_length
         | 
| 30 | 
            +
                    self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                @property
         | 
| 33 | 
            +
                def device(self):
         | 
| 34 | 
            +
                    return list(self.parameters())[0].device
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
         | 
| 37 | 
            +
                    """
         | 
| 38 | 
            +
                    Shift input ids one token to the right.ls
         | 
| 39 | 
            +
                    """
         | 
| 40 | 
            +
                    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
         | 
| 41 | 
            +
                    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
         | 
| 42 | 
            +
                    shifted_input_ids[:, 0] = decoder_start_token_id
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    if pad_token_id is None:
         | 
| 45 | 
            +
                        raise ValueError("self.model.config.pad_token_id has to be defined.")
         | 
| 46 | 
            +
                    # replace possible -100 values in labels by `pad_token_id`
         | 
| 47 | 
            +
                    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
         | 
| 48 | 
            +
                    return shifted_input_ids
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def forward_encoder(self, audio):
         | 
| 51 | 
            +
                    audio_embs = self.audio_encoder(audio)
         | 
| 52 | 
            +
                    encoder_outputs = self.bart.model.encoder(
         | 
| 53 | 
            +
                        input_ids=None,
         | 
| 54 | 
            +
                        inputs_embeds=audio_embs,
         | 
| 55 | 
            +
                        return_dict=True
         | 
| 56 | 
            +
                    )["last_hidden_state"]
         | 
| 57 | 
            +
                    return encoder_outputs, audio_embs
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward_decoder(self, text, encoder_outputs):
         | 
| 60 | 
            +
                    text = self.tokenizer(text,
         | 
| 61 | 
            +
                                          padding='longest',
         | 
| 62 | 
            +
                                          truncation=True,
         | 
| 63 | 
            +
                                          max_length=self.max_length,
         | 
| 64 | 
            +
                                          return_tensors="pt")
         | 
| 65 | 
            +
                    input_ids = text["input_ids"].to(self.device)
         | 
| 66 | 
            +
                    attention_mask = text["attention_mask"].to(self.device)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    decoder_targets = input_ids.masked_fill(
         | 
| 69 | 
            +
                        input_ids == self.tokenizer.pad_token_id, -100
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    decoder_input_ids = self.shift_tokens_right(
         | 
| 73 | 
            +
                        decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    decoder_outputs = self.bart(
         | 
| 77 | 
            +
                        input_ids=None,
         | 
| 78 | 
            +
                        attention_mask=None,
         | 
| 79 | 
            +
                        decoder_input_ids=decoder_input_ids,
         | 
| 80 | 
            +
                        decoder_attention_mask=attention_mask,
         | 
| 81 | 
            +
                        inputs_embeds=None,
         | 
| 82 | 
            +
                        labels=None,
         | 
| 83 | 
            +
                        encoder_outputs=(encoder_outputs,),
         | 
| 84 | 
            +
                        return_dict=True
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    lm_logits = decoder_outputs["logits"]
         | 
| 87 | 
            +
                    loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
         | 
| 88 | 
            +
                    return loss
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def forward(self, audio, text):
         | 
| 91 | 
            +
                    encoder_outputs, _ = self.forward_encoder(audio)
         | 
| 92 | 
            +
                    loss = self.forward_decoder(text, encoder_outputs)
         | 
| 93 | 
            +
                    return loss
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def generate(self,
         | 
| 96 | 
            +
                             samples,
         | 
| 97 | 
            +
                             use_nucleus_sampling=False,
         | 
| 98 | 
            +
                             num_beams=5,
         | 
| 99 | 
            +
                             max_length=128,
         | 
| 100 | 
            +
                             min_length=2,
         | 
| 101 | 
            +
                             top_p=0.9,
         | 
| 102 | 
            +
                             repetition_penalty=1.0,
         | 
| 103 | 
            +
                             ):
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    # self.bart.force_bos_token_to_be_generated = True
         | 
| 106 | 
            +
                    audio_embs = self.audio_encoder(samples)
         | 
| 107 | 
            +
                    encoder_outputs = self.bart.model.encoder(
         | 
| 108 | 
            +
                        input_ids=None,
         | 
| 109 | 
            +
                        attention_mask=None,
         | 
| 110 | 
            +
                        head_mask=None,
         | 
| 111 | 
            +
                        inputs_embeds=audio_embs,
         | 
| 112 | 
            +
                        output_attentions=None,
         | 
| 113 | 
            +
                        output_hidden_states=None,
         | 
| 114 | 
            +
                        return_dict=True)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
         | 
| 117 | 
            +
                    input_ids[:, 0] = self.bart.config.decoder_start_token_id
         | 
| 118 | 
            +
                    decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
         | 
| 119 | 
            +
                    if use_nucleus_sampling:
         | 
| 120 | 
            +
                        outputs = self.bart.generate(
         | 
| 121 | 
            +
                            input_ids=None,
         | 
| 122 | 
            +
                            attention_mask=None,
         | 
| 123 | 
            +
                            decoder_input_ids=input_ids,
         | 
| 124 | 
            +
                            decoder_attention_mask=decoder_attention_mask,
         | 
| 125 | 
            +
                            encoder_outputs=encoder_outputs,
         | 
| 126 | 
            +
                            max_length=max_length,
         | 
| 127 | 
            +
                            min_length=min_length,
         | 
| 128 | 
            +
                            do_sample=True,
         | 
| 129 | 
            +
                            top_p=top_p,
         | 
| 130 | 
            +
                            num_return_sequences=1,
         | 
| 131 | 
            +
                            repetition_penalty=1.1)
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        outputs = self.bart.generate(input_ids=None,
         | 
| 134 | 
            +
                                                        attention_mask=None,
         | 
| 135 | 
            +
                                                        decoder_input_ids=input_ids,
         | 
| 136 | 
            +
                                                        decoder_attention_mask=decoder_attention_mask,
         | 
| 137 | 
            +
                                                        encoder_outputs=encoder_outputs,
         | 
| 138 | 
            +
                                                        head_mask=None,
         | 
| 139 | 
            +
                                                        decoder_head_mask=None,
         | 
| 140 | 
            +
                                                        inputs_embeds=None,
         | 
| 141 | 
            +
                                                        decoder_inputs_embeds=None,
         | 
| 142 | 
            +
                                                        use_cache=None,
         | 
| 143 | 
            +
                                                        output_attentions=None,
         | 
| 144 | 
            +
                                                        output_hidden_states=None,
         | 
| 145 | 
            +
                                                        max_length=max_length,
         | 
| 146 | 
            +
                                                        min_length=min_length,
         | 
| 147 | 
            +
                                                        num_beams=num_beams,
         | 
| 148 | 
            +
                                                        repetition_penalty=repetition_penalty)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
         | 
| 151 | 
            +
                    return captions
         | 
    	
        model/modules.py
    ADDED
    
    | @@ -0,0 +1,95 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torchaudio
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 8 | 
            +
            from torch import Tensor, nn
         | 
| 9 | 
            +
            from typing import Dict, Iterable, Optional
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # hard-coded audio hyperparameters
         | 
| 12 | 
            +
            SAMPLE_RATE = 16000
         | 
| 13 | 
            +
            N_FFT = 1024
         | 
| 14 | 
            +
            N_MELS = 128
         | 
| 15 | 
            +
            HOP_LENGTH = int(0.01 * SAMPLE_RATE)
         | 
| 16 | 
            +
            DURATION = 10
         | 
| 17 | 
            +
            N_SAMPLES = int(DURATION * SAMPLE_RATE) 
         | 
| 18 | 
            +
            N_FRAMES = N_SAMPLES // HOP_LENGTH + 1 
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            def sinusoids(length, channels, max_timescale=10000):
         | 
| 21 | 
            +
                """Returns sinusoids for positional embedding"""
         | 
| 22 | 
            +
                log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
         | 
| 23 | 
            +
                inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
         | 
| 24 | 
            +
                scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
         | 
| 25 | 
            +
                return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            class MelEncoder(nn.Module):
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                time-frequency represntation
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
                def __init__(self, 
         | 
| 32 | 
            +
                            sample_rate= 16000,
         | 
| 33 | 
            +
                            f_min=0,
         | 
| 34 | 
            +
                            f_max=8000,
         | 
| 35 | 
            +
                            n_fft=1024,
         | 
| 36 | 
            +
                            win_length=1024,
         | 
| 37 | 
            +
                            hop_length = int(0.01 * 16000),
         | 
| 38 | 
            +
                            n_mels = 128,
         | 
| 39 | 
            +
                            power = None,
         | 
| 40 | 
            +
                            pad= 0,
         | 
| 41 | 
            +
                            normalized= False,
         | 
| 42 | 
            +
                            center= True,
         | 
| 43 | 
            +
                            pad_mode= "reflect"
         | 
| 44 | 
            +
                            ):
         | 
| 45 | 
            +
                    super(MelEncoder, self).__init__()
         | 
| 46 | 
            +
                    self.window = torch.hann_window(win_length)
         | 
| 47 | 
            +
                    self.spec_fn = torchaudio.transforms.Spectrogram(
         | 
| 48 | 
            +
                        n_fft = n_fft,
         | 
| 49 | 
            +
                        win_length = win_length,
         | 
| 50 | 
            +
                        hop_length = hop_length,
         | 
| 51 | 
            +
                        power = power
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
                    self.mel_scale = torchaudio.transforms.MelScale(
         | 
| 54 | 
            +
                        n_mels, 
         | 
| 55 | 
            +
                        sample_rate,
         | 
| 56 | 
            +
                        f_min,
         | 
| 57 | 
            +
                        f_max,
         | 
| 58 | 
            +
                        n_fft // 2 + 1)
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def forward(self, wav):
         | 
| 63 | 
            +
                    spec = self.spec_fn(wav)
         | 
| 64 | 
            +
                    power_spec = spec.real.abs().pow(2)
         | 
| 65 | 
            +
                    mel_spec = self.mel_scale(power_spec)
         | 
| 66 | 
            +
                    mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin))
         | 
| 67 | 
            +
                    return mel_spec
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            class AudioEncoder(nn.Module):
         | 
| 70 | 
            +
                def __init__(
         | 
| 71 | 
            +
                    self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int,
         | 
| 72 | 
            +
                ):
         | 
| 73 | 
            +
                    super().__init__()
         | 
| 74 | 
            +
                    self.mel_encoder = MelEncoder(n_mels=n_mels)
         | 
| 75 | 
            +
                    self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1)
         | 
| 76 | 
            +
                    self.conv_stack = nn.ModuleList([])
         | 
| 77 | 
            +
                    for _ in range(num_of_stride_conv):
         | 
| 78 | 
            +
                        self.conv_stack.append(
         | 
| 79 | 
            +
                            nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1)
         | 
| 80 | 
            +
                        )
         | 
| 81 | 
            +
                    # self.proj = nn.Linear(audio_dim, text_dim, bias=False)
         | 
| 82 | 
            +
                    self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim))
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def forward(self, x: Tensor):
         | 
| 85 | 
            +
                    """
         | 
| 86 | 
            +
                    x : torch.Tensor, shape = (batch_size, waveform)
         | 
| 87 | 
            +
                        single channel wavform
         | 
| 88 | 
            +
                    """
         | 
| 89 | 
            +
                    x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx)
         | 
| 90 | 
            +
                    x = F.gelu(self.conv1(x))
         | 
| 91 | 
            +
                    for conv in self.conv_stack:
         | 
| 92 | 
            +
                        x = F.gelu(conv(x))
         | 
| 93 | 
            +
                    x = x.permute(0, 2, 1)
         | 
| 94 | 
            +
                    x = (x + self.positional_embedding).to(x.dtype)
         | 
| 95 | 
            +
                    return x
         | 
    	
        utils/audio_utils.py
    ADDED
    
    | @@ -0,0 +1,247 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            STR_CLIP_ID = 'clip_id'
         | 
| 2 | 
            +
            STR_AUDIO_SIGNAL = 'audio_signal'
         | 
| 3 | 
            +
            STR_TARGET_VECTOR = 'target_vector'
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            STR_CH_FIRST = 'channels_first'
         | 
| 7 | 
            +
            STR_CH_LAST = 'channels_last'
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import io
         | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
            import tqdm
         | 
| 12 | 
            +
            import logging
         | 
| 13 | 
            +
            import subprocess
         | 
| 14 | 
            +
            from typing import Tuple
         | 
| 15 | 
            +
            from pathlib import Path
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # import librosa
         | 
| 18 | 
            +
            import numpy as np
         | 
| 19 | 
            +
            import soundfile as sf
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import itertools
         | 
| 22 | 
            +
            from numpy.fft import irfft
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]:
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                Decoding, downmixing, and downsampling by librosa.
         | 
| 27 | 
            +
                Returns a channel-first audio signal.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Args:
         | 
| 30 | 
            +
                    path:
         | 
| 31 | 
            +
                    sample_rate:
         | 
| 32 | 
            +
                    downmix_to_mono:
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                Returns:
         | 
| 35 | 
            +
                    (audio signal, sample rate)
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def _decode_resample_by_ffmpeg(filename, sr):
         | 
| 39 | 
            +
                    """decode, downmix, and resample audio file"""
         | 
| 40 | 
            +
                    channel_cmd = '-ac 1 ' if downmix_to_mono else ''  # downmixing option
         | 
| 41 | 
            +
                    resampling_cmd = f'-ar {str(sr)}' if sr else ''  # downsampling option
         | 
| 42 | 
            +
                    cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -"
         | 
| 43 | 
            +
                    p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
         | 
| 44 | 
            +
                    out, err = p.communicate()
         | 
| 45 | 
            +
                    return out
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate)))
         | 
| 48 | 
            +
                return src.T, sr
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def _resample_load_librosa(path: str, sample_rate: int, downmix_to_mono: bool, **kwargs) -> Tuple[np.ndarray, int]:
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                Decoding, downmixing, and downsampling by librosa.
         | 
| 54 | 
            +
                Returns a channel-first audio signal.
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                src, sr = librosa.load(path, sr=sample_rate, mono=downmix_to_mono, **kwargs)
         | 
| 57 | 
            +
                return src, sr
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def load_audio(
         | 
| 61 | 
            +
                path: str or Path,
         | 
| 62 | 
            +
                ch_format: str,
         | 
| 63 | 
            +
                sample_rate: int = None,
         | 
| 64 | 
            +
                downmix_to_mono: bool = False,
         | 
| 65 | 
            +
                resample_by: str = 'ffmpeg',
         | 
| 66 | 
            +
                **kwargs,
         | 
| 67 | 
            +
            ) -> Tuple[np.ndarray, int]:
         | 
| 68 | 
            +
                """A wrapper of librosa.load that:
         | 
| 69 | 
            +
                    - forces the returned audio to be 2-dim,
         | 
| 70 | 
            +
                    - defaults to sr=None, and
         | 
| 71 | 
            +
                    - defaults to downmix_to_mono=False.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                The audio decoding is done by `audioread` or `soundfile` package and ultimately, often by ffmpeg.
         | 
| 74 | 
            +
                The resampling is done by `librosa`'s child package `resampy`.
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                Args:
         | 
| 77 | 
            +
                    path: audio file path
         | 
| 78 | 
            +
                    ch_format: one of 'channels_first' or 'channels_last'
         | 
| 79 | 
            +
                    sample_rate: target sampling rate. if None, use the rate of the audio file
         | 
| 80 | 
            +
                    downmix_to_mono:
         | 
| 81 | 
            +
                    resample_by (str): 'librosa' or 'ffmpeg'. it decides backend for audio decoding and resampling.
         | 
| 82 | 
            +
                    **kwargs: keyword args for librosa.load - offset, duration, dtype, res_type.
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                Returns:
         | 
| 85 | 
            +
                    (audio, sr) tuple
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
                if ch_format not in (STR_CH_FIRST, STR_CH_LAST):
         | 
| 88 | 
            +
                    raise ValueError(f'ch_format is wrong here -> {ch_format}')
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                if os.stat(path).st_size > 8000:
         | 
| 91 | 
            +
                    if resample_by == 'librosa':
         | 
| 92 | 
            +
                        src, sr = _resample_load_librosa(path, sample_rate, downmix_to_mono, **kwargs)
         | 
| 93 | 
            +
                    elif resample_by == 'ffmpeg':
         | 
| 94 | 
            +
                        src, sr = _resample_load_ffmpeg(path, sample_rate, downmix_to_mono)
         | 
| 95 | 
            +
                    else:
         | 
| 96 | 
            +
                        raise NotImplementedError(f'resample_by: "{resample_by}" is not supposred yet')
         | 
| 97 | 
            +
                else:
         | 
| 98 | 
            +
                    raise ValueError('Given audio is too short!')
         | 
| 99 | 
            +
                return src, sr
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                # if src.ndim == 1:
         | 
| 102 | 
            +
                #     src = np.expand_dims(src, axis=0)
         | 
| 103 | 
            +
                # # now always 2d and channels_first
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                # if ch_format == STR_CH_FIRST:
         | 
| 106 | 
            +
                #     return src, sr
         | 
| 107 | 
            +
                # else:
         | 
| 108 | 
            +
                #     return src.T, sr
         | 
| 109 | 
            +
             | 
| 110 | 
            +
            def ms(x):
         | 
| 111 | 
            +
                """Mean value of signal `x` squared.
         | 
| 112 | 
            +
                :param x: Dynamic quantity.
         | 
| 113 | 
            +
                :returns: Mean squared of `x`.
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                return (np.abs(x)**2.0).mean()
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            def normalize(y, x=None):
         | 
| 118 | 
            +
                """normalize power in y to a (standard normal) white noise signal.
         | 
| 119 | 
            +
                Optionally normalize to power in signal `x`.
         | 
| 120 | 
            +
                #The mean power of a Gaussian with :math:`\\mu=0` and :math:`\\sigma=1` is 1.
         | 
| 121 | 
            +
                """
         | 
| 122 | 
            +
                if x is not None:
         | 
| 123 | 
            +
                    x = ms(x)
         | 
| 124 | 
            +
                else:
         | 
| 125 | 
            +
                    x = 1.0
         | 
| 126 | 
            +
                return y * np.sqrt(x / ms(y))
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            def noise(N, color='white', state=None):
         | 
| 129 | 
            +
                """Noise generator.
         | 
| 130 | 
            +
                :param N: Amount of samples.
         | 
| 131 | 
            +
                :param color: Color of noise.
         | 
| 132 | 
            +
                :param state: State of PRNG.
         | 
| 133 | 
            +
                :type state: :class:`np.random.RandomState`
         | 
| 134 | 
            +
                """
         | 
| 135 | 
            +
                try:
         | 
| 136 | 
            +
                    return _noise_generators[color](N, state)
         | 
| 137 | 
            +
                except KeyError:
         | 
| 138 | 
            +
                    raise ValueError("Incorrect color.")
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            def white(N, state=None):
         | 
| 141 | 
            +
                """
         | 
| 142 | 
            +
                White noise.
         | 
| 143 | 
            +
                :param N: Amount of samples.
         | 
| 144 | 
            +
                :param state: State of PRNG.
         | 
| 145 | 
            +
                :type state: :class:`np.random.RandomState`
         | 
| 146 | 
            +
                White noise has a constant power density. It's narrowband spectrum is therefore flat.
         | 
| 147 | 
            +
                The power in white noise will increase by a factor of two for each octave band,
         | 
| 148 | 
            +
                and therefore increases with 3 dB per octave.
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                state = np.random.RandomState() if state is None else state
         | 
| 151 | 
            +
                return state.randn(N)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            def pink(N, state=None):
         | 
| 154 | 
            +
                """
         | 
| 155 | 
            +
                Pink noise.
         | 
| 156 | 
            +
                :param N: Amount of samples.
         | 
| 157 | 
            +
                :param state: State of PRNG.
         | 
| 158 | 
            +
                :type state: :class:`np.random.RandomState`
         | 
| 159 | 
            +
                Pink noise has equal power in bands that are proportionally wide.
         | 
| 160 | 
            +
                Power density decreases with 3 dB per octave.
         | 
| 161 | 
            +
                """
         | 
| 162 | 
            +
                state = np.random.RandomState() if state is None else state
         | 
| 163 | 
            +
                uneven = N % 2
         | 
| 164 | 
            +
                X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
         | 
| 165 | 
            +
                S = np.sqrt(np.arange(len(X)) + 1.)  # +1 to avoid divide by zero
         | 
| 166 | 
            +
                y = (irfft(X / S)).real
         | 
| 167 | 
            +
                if uneven:
         | 
| 168 | 
            +
                    y = y[:-1]
         | 
| 169 | 
            +
                return normalize(y)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            def blue(N, state=None):
         | 
| 172 | 
            +
                """
         | 
| 173 | 
            +
                Blue noise.
         | 
| 174 | 
            +
                :param N: Amount of samples.
         | 
| 175 | 
            +
                :param state: State of PRNG.
         | 
| 176 | 
            +
                :type state: :class:`np.random.RandomState`
         | 
| 177 | 
            +
                Power increases with 6 dB per octave.
         | 
| 178 | 
            +
                Power density increases with 3 dB per octave.
         | 
| 179 | 
            +
                """
         | 
| 180 | 
            +
                state = np.random.RandomState() if state is None else state
         | 
| 181 | 
            +
                uneven = N % 2
         | 
| 182 | 
            +
                X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
         | 
| 183 | 
            +
                S = np.sqrt(np.arange(len(X)))  # Filter
         | 
| 184 | 
            +
                y = (irfft(X * S)).real
         | 
| 185 | 
            +
                if uneven:
         | 
| 186 | 
            +
                    y = y[:-1]
         | 
| 187 | 
            +
                return normalize(y)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            def brown(N, state=None):
         | 
| 190 | 
            +
                """
         | 
| 191 | 
            +
                Violet noise.
         | 
| 192 | 
            +
                :param N: Amount of samples.
         | 
| 193 | 
            +
                :param state: State of PRNG.
         | 
| 194 | 
            +
                :type state: :class:`np.random.RandomState`
         | 
| 195 | 
            +
                Power decreases with -3 dB per octave.
         | 
| 196 | 
            +
                Power density decreases with 6 dB per octave.
         | 
| 197 | 
            +
                """
         | 
| 198 | 
            +
                state = np.random.RandomState() if state is None else state
         | 
| 199 | 
            +
                uneven = N % 2
         | 
| 200 | 
            +
                X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
         | 
| 201 | 
            +
                S = (np.arange(len(X)) + 1)  # Filter
         | 
| 202 | 
            +
                y = (irfft(X / S)).real
         | 
| 203 | 
            +
                if uneven:
         | 
| 204 | 
            +
                    y = y[:-1]
         | 
| 205 | 
            +
                return normalize(y)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
            def violet(N, state=None):
         | 
| 208 | 
            +
                """
         | 
| 209 | 
            +
                Violet noise. Power increases with 6 dB per octave.
         | 
| 210 | 
            +
                :param N: Amount of samples.
         | 
| 211 | 
            +
                :param state: State of PRNG.
         | 
| 212 | 
            +
                :type state: :class:`np.random.RandomState`
         | 
| 213 | 
            +
                Power increases with +9 dB per octave.
         | 
| 214 | 
            +
                Power density increases with +6 dB per octave.
         | 
| 215 | 
            +
                """
         | 
| 216 | 
            +
                state = np.random.RandomState() if state is None else state
         | 
| 217 | 
            +
                uneven = N % 2
         | 
| 218 | 
            +
                X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven)
         | 
| 219 | 
            +
                S = (np.arange(len(X)))  # Filter
         | 
| 220 | 
            +
                y = (irfft(X * S)).real
         | 
| 221 | 
            +
                if uneven:
         | 
| 222 | 
            +
                    y = y[:-1]
         | 
| 223 | 
            +
                return normalize(y)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
            _noise_generators = {
         | 
| 226 | 
            +
                'white': white,
         | 
| 227 | 
            +
                'pink': pink,
         | 
| 228 | 
            +
                'blue': blue,
         | 
| 229 | 
            +
                'brown': brown,
         | 
| 230 | 
            +
                'violet': violet,
         | 
| 231 | 
            +
            }
         | 
| 232 | 
            +
             | 
| 233 | 
            +
            def noise_generator(N=44100, color='white', state=None):
         | 
| 234 | 
            +
                """Noise generator.
         | 
| 235 | 
            +
                :param N: Amount of unique samples to generate.
         | 
| 236 | 
            +
                :param color: Color of noise.
         | 
| 237 | 
            +
                Generate `N` amount of unique samples and cycle over these samples.
         | 
| 238 | 
            +
                """
         | 
| 239 | 
            +
                #yield from itertools.cycle(noise(N, color)) # Python 3.3
         | 
| 240 | 
            +
                for sample in itertools.cycle(noise(N, color, state)):
         | 
| 241 | 
            +
                    yield sample
         | 
| 242 | 
            +
             | 
| 243 | 
            +
            def heaviside(N):
         | 
| 244 | 
            +
                """Heaviside.
         | 
| 245 | 
            +
                Returns the value 0 for `x < 0`, 1 for `x > 0`, and 1/2 for `x = 0`.
         | 
| 246 | 
            +
                """
         | 
| 247 | 
            +
                return 0.5 * (np.sign(N) + 1)
         | 
