from typing import Dict, List, Any from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch import librosa import os class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = AutoProcessor.from_pretrained(path) self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda") self.ref_sample_y, self.ref_sr = librosa.load(os.path.join(path,'accompaniment.mp3'),sr=32000) def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) self.ref_sample_y = self.ref_sample_y[:480213] #audio=self.ref_sample_y, #sampling_rate=self.ref_sr, # preprocess inputs = self.processor( audio=self.ref_sample_y, sampling_rate=self.ref_sr, text=[inputs], padding=True, return_tensors="pt",).to("cuda") # pass inputs with all kwargs in data if parameters is not None: outputs = self.model.generate(**inputs,do_sample=True, guidance_scale=4, max_new_tokens=256) else: outputs = self.model.generate(**inputs,do_sample=True, guidance_scale=4, max_new_tokens=256) # postprocess the prediction prediction = outputs[0].cpu().numpy() return [{"generated_text": prediction}]