MusiLingo-short-v1 / README.md
nicolaus625's picture
Update README.md
76cd386 verified
|
raw
history blame
4.81 kB
metadata
license: cc-by-4.0
language:
  - en
tags:
  - music
  - art

MusiLingo-short-v1

This repo contains the code for the following paper. MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response

You can refer to more information at the GitHub repo

You can use the MusicInstruct (MI) dataset for the following demo

This checkpoint is developped on the MI-short.

Inference Code

from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList



class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5,
        repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000):
    audio = samples["audio"].cuda()
    audio_embeds, atts_audio = self.encode_audio(audio)
    if 'instruction_input' in samples:  # instruction dataset
        #print('Instruction Batch')
        instruction_prompt = []
        for instruction in samples['instruction_input']:
            prompt = '<Audio><AudioHere></Audio> ' + instruction
            instruction_prompt.append(self.prompt_template.format(prompt))
        audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
    self.llama_tokenizer.padding_side = "right"
    batch_size = audio_embeds.shape[0]
    bos = torch.ones([batch_size, 1],
                    dtype=torch.long,
                    device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id
    bos_embeds = self.llama_model.model.embed_tokens(bos)
    atts_bos = atts_audio[:, :1]
    inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
    attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
    outputs = self.llama_model.generate(
        inputs_embeds=inputs_embeds,
        max_new_tokens=max_new_tokens,
        stopping_criteria=stopping,
        num_beams=num_beams,
        do_sample=True,
        min_length=min_length,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        temperature=temperature,
    )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
        output_token = output_token[1:]
    if output_token[0] == 1:  # if there is a start token <s> at the beginning. remove it
        output_token = output_token[1:]
    output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    return output_text

processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
ds = CMIDataset(processor, 'path/to/MI_dataset', 'test', question_type='short')
dl = DataLoader(
                ds,
                batch_size=1,
                num_workers=0,
                pin_memory=True,
                shuffle=False,
                drop_last=True,
                collate_fn=ds.collater
                )

stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
                                torch.tensor([2277, 29937]).cuda()])])

from transformers import AutoModel
model_short = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1")

for idx, sample in tqdm(enumerate(dl)):
    ans = answer(Musilingo_short.model, sample, stopping, length_penalty=100, temperature=0.1)
    txt = sample['text_input'][0]
    print(txt)
    print(and)

Citing This Work

If you find the work useful for your research, please consider citing it using the following BibTeX entry:

@inproceedings{deng2024musilingo,
  title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
  author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
  booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
  year={2024},
  organization={Association for Computational Linguistics}
}