|
from audiocraft.models import MusicGen |
|
import streamlit as st |
|
import os |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
import base64 |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = MusicGen.get_pretrained("facebook/musicgen-small") |
|
return model |
|
|
|
def generate_music_tensors(description, duration: int): |
|
print("Description:", description) |
|
print("Duration:", duration) |
|
model = load_model() |
|
|
|
model.set_generation_params( |
|
use_sampling=True, |
|
top_k=250, |
|
duration=duration |
|
) |
|
|
|
output = model.generate( |
|
descriptions=[description], |
|
progress=True, |
|
return_tokens=True |
|
) |
|
return output[0] |
|
|
|
def save_audio(samples: torch.Tensor): |
|
sample_rate = 32000 |
|
save_path = "audio_output/" |
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
assert samples.dim() == 2 or samples.dim() == 3 |
|
samples = samples.detach().cpu() |
|
|
|
if samples.dim() == 2: |
|
samples = samples[None, ...] |
|
for idx, audio in enumerate(samples): |
|
audio_path = os.path.join(save_path, f"audio_{idx}.wav") |
|
torchaudio.save(audio_path, audio, sample_rate) |
|
return os.path.join(save_path, "audio_0.wav") |
|
|
|
def get_binary_file_downloader_html(bin_file, file_label='File'): |
|
with open(bin_file, 'rb') as f: |
|
data = f.read() |
|
bin_str = base64.b64encode(data).decode() |
|
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{file_label}">Download {file_label} from here</a>' |
|
return href |
|
|
|
st.set_page_config( |
|
page_icon=":musical_note:", |
|
page_title="Music Gen" |
|
) |
|
|
|
def main(): |
|
st.title("Your Music") |
|
|
|
with st.expander("See Explanation"): |
|
st.write("App is developed by using Meta's Audiocraft Music Gen model. Write your text and we will generate audio") |
|
|
|
text_area = st.text_area("Enter description") |
|
time_slider = st.slider("Select time duration(s)", 2, 5, 20) |
|
|
|
if text_area and time_slider: |
|
st.json( |
|
{ |
|
"Description": text_area, |
|
"Selected duration": time_slider |
|
} |
|
) |
|
st.subheader("Generated Music") |
|
music_tensors = generate_music_tensors(text_area, time_slider) |
|
audio_file_path = save_audio(music_tensors) |
|
audio_file = open(audio_file_path, 'rb') |
|
audio_bytes = audio_file.read() |
|
st.audio(audio_bytes) |
|
st.markdown(get_binary_file_downloader_html(audio_file_path, 'Audio'), unsafe_allow_html=True) |
|
|
|
if __name__ == "__main__": |
|
main() |