File size: 2,395 Bytes
bc62b2b 8092c47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from audiocraft.models import MusicGen
import streamlit as st
import os
import torch
import torchaudio
import numpy as np
import base64
@st.cache_resource
def load_model():
model=MusicGen.get_pretrained("facebook/musicgen-small")
return model
def generate_music_tensors(description,duration:int):
print("Description:",description)
print("Duration:",duration)
model=load_model()
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration
)
output=model.generate(
descriptions=[description],
progress=True,
return_tokens=True
)
return output[0]
def save_audio(samples:torch.tensor):
sample_rate=32000,
save_path="audio_output/"
assert samples.dim()==2 or samples.dim()==3
samples=samples.detach().cpu()
if samples.dim()==2:
samples=samples[None,...]
for idx,audio in enumerate(samples):
audio_path=os.path.join(save_path,f"audio_{idx}.wav")
torchaudio.save(audio_path,audio,sample_rate)
def get_binary_file_downloader_html(bin_file,file_label='File'):
with open(bin_file,'rb') as f:
data=f.read()
bin_str=base64.b64encode(data).decode()
href=f'<a href="data:application/octet-stream;base64,{bin_str} download {(bin_file)}">Download {file_label} from here</a>'
return href
st.set_page_config(
page_icon=":musical_note:",
page_title="Music Gen"
)
def main():
st.title("Your Music")
with st.expander("See Explanation"):
st.write("App is developed by using Meta's Audiocraft Music Gen model. Write your text and we will generate audio")
text_area=st.text_area("Enter description")
time_slider=st.slider("Select time duration(s)",2,5,20)
if text_area and time_slider:
st.json(
{
"Description":text_area,
"Selected duration:":time_slider
}
)
st.subheader("Generated Music")
music_tensors=generate_music_tensors(text_area,time_slider)
save_music_file=save_audio(music_tensors)
audio_file_path='audio_output/audio_0.wav'
audio_file=open(audio_file_path,'rb')
audio_bytes=audio_file.read()
st.audio(audio_bytes)
st.markdown(get_binary_file_downloader_html,audio_file_path,'Audio',unsafe_allow_html=True)
if __name__=="__main__":
main() |