annapurnapadmaprema-ji commited on
Commit
fe49032
·
verified ·
1 Parent(s): 8d8198a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -33
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from audiocraft.models import MusicGen
2
  import streamlit as st
3
- import os
4
  import torch
5
  import torchaudio
6
  import numpy as np
@@ -8,13 +8,13 @@ import base64
8
 
9
  @st.cache_resource
10
  def load_model():
11
- model=MusicGen.get_pretrained("facebook/musicgen-small")
12
  return model
13
 
14
- def generate_music_tensors(description,duration:int):
15
- print("Description:",description)
16
- print("Duration:",duration)
17
- model=load_model()
18
 
19
  model.set_generation_params(
20
  use_sampling=True,
@@ -22,31 +22,33 @@ def generate_music_tensors(description,duration:int):
22
  duration=duration
23
  )
24
 
25
- output=model.generate(
26
  descriptions=[description],
27
  progress=True,
28
  return_tokens=True
29
  )
30
  return output[0]
31
 
32
- def save_audio(samples:torch.tensor):
33
- sample_rate=32000
34
- save_path="audio_output/"
 
35
 
36
- assert samples.dim()==2 or samples.dim()==3
37
- samples=samples.detach().cpu()
38
 
39
- if samples.dim()==2:
40
- samples=samples[None,...]
41
- for idx,audio in enumerate(samples):
42
- audio_path=os.path.join(save_path,f"audio_0.wav")
43
- torchaudio.save(audio_path,audio,sample_rate)
 
44
 
45
- def get_binary_file_downloader_html(bin_file,file_label='File'):
46
- with open(bin_file,'rb') as f:
47
- data=f.read()
48
- bin_str=base64.b64encode(data).decode()
49
- href=f'<a href="data:application/octet-stream;base64,{bin_str} download {(bin_file)}">Download {file_label} from here</a>'
50
  return href
51
 
52
  st.set_page_config(
@@ -59,23 +61,24 @@ def main():
59
 
60
  with st.expander("See Explanation"):
61
  st.write("App is developed by using Meta's Audiocraft Music Gen model. Write your text and we will generate audio")
62
- text_area=st.text_area("Enter description")
63
- time_slider=st.slider("Select time duration(s)",2,5,20)
 
64
 
65
  if text_area and time_slider:
66
  st.json(
67
  {
68
- "Description":text_area,
69
- "Selected duration:":time_slider
70
  }
71
  )
72
  st.subheader("Generated Music")
73
- music_tensors=generate_music_tensors(text_area,time_slider)
74
- save_music_file=save_audio(music_tensors)
75
- audio_file_path='audio_output/audio_0.wav'
76
- audio_file=open(audio_file_path,'rb')
77
- audio_bytes=audio_file.read()
78
  st.audio(audio_bytes)
79
- st.markdown(get_binary_file_downloader_html,audio_file_path,'Audio',unsafe_allow_html=True)
80
- if __name__=="__main__":
 
81
  main()
 
1
  from audiocraft.models import MusicGen
2
  import streamlit as st
3
+ import os
4
  import torch
5
  import torchaudio
6
  import numpy as np
 
8
 
9
  @st.cache_resource
10
  def load_model():
11
+ model = MusicGen.get_pretrained("facebook/musicgen-small")
12
  return model
13
 
14
+ def generate_music_tensors(description, duration: int):
15
+ print("Description:", description)
16
+ print("Duration:", duration)
17
+ model = load_model()
18
 
19
  model.set_generation_params(
20
  use_sampling=True,
 
22
  duration=duration
23
  )
24
 
25
+ output = model.generate(
26
  descriptions=[description],
27
  progress=True,
28
  return_tokens=True
29
  )
30
  return output[0]
31
 
32
+ def save_audio(samples: torch.Tensor):
33
+ sample_rate = 32000 # corrected to integer
34
+ save_path = "audio_output/"
35
+ os.makedirs(save_path, exist_ok=True) # ensure directory exists
36
 
37
+ assert samples.dim() == 2 or samples.dim() == 3
38
+ samples = samples.detach().cpu()
39
 
40
+ if samples.dim() == 2:
41
+ samples = samples[None, ...]
42
+ for idx, audio in enumerate(samples):
43
+ audio_path = os.path.join(save_path, f"audio_{idx}.wav")
44
+ torchaudio.save(audio_path, audio, sample_rate)
45
+ return os.path.join(save_path, "audio_0.wav")
46
 
47
+ def get_binary_file_downloader_html(bin_file, file_label='File'):
48
+ with open(bin_file, 'rb') as f:
49
+ data = f.read()
50
+ bin_str = base64.b64encode(data).decode()
51
+ href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{file_label}">Download {file_label} from here</a>'
52
  return href
53
 
54
  st.set_page_config(
 
61
 
62
  with st.expander("See Explanation"):
63
  st.write("App is developed by using Meta's Audiocraft Music Gen model. Write your text and we will generate audio")
64
+
65
+ text_area = st.text_area("Enter description")
66
+ time_slider = st.slider("Select time duration(s)", 2, 5, 20)
67
 
68
  if text_area and time_slider:
69
  st.json(
70
  {
71
+ "Description": text_area,
72
+ "Selected duration": time_slider
73
  }
74
  )
75
  st.subheader("Generated Music")
76
+ music_tensors = generate_music_tensors(text_area, time_slider)
77
+ audio_file_path = save_audio(music_tensors)
78
+ audio_file = open(audio_file_path, 'rb')
79
+ audio_bytes = audio_file.read()
 
80
  st.audio(audio_bytes)
81
+ st.markdown(get_binary_file_downloader_html(audio_file_path, 'Audio'), unsafe_allow_html=True)
82
+
83
+ if __name__ == "__main__":
84
  main()