quazim commited on
Commit
8c46cbe
·
1 Parent(s): 43f4544

added elastic model

Browse files
Files changed (1) hide show
  1. app.py +32 -32
app.py CHANGED
@@ -8,41 +8,41 @@ import subprocess
8
  import sys
9
  import os
10
 
11
- def setup_flash_attention():
12
- """One-time setup for flash-attention with special flags"""
13
- # Check if flash-attn is already installed
14
- try:
15
- import flash_attn
16
- print("flash-attn already installed")
17
- return
18
- except ImportError:
19
- pass
20
 
21
- # Check if we've already tried to install it in this session
22
- if os.path.exists("/tmp/flash_attn_installed"):
23
- return
24
 
25
- try:
26
- print("Installing flash-attn with --no-build-isolation...")
27
- subprocess.run([
28
- sys.executable, "-m", "pip", "install",
29
- "flash-attn==2.7.3", "--no-build-isolation"
30
- ], check=True)
31
 
32
- # Uninstall apex if it exists
33
- subprocess.run([
34
- sys.executable, "-m", "pip", "uninstall", "apex", "-y"
35
- ], check=False) # Don't fail if apex isn't installed
36
 
37
- # Mark as installed
38
- with open("/tmp/flash_attn_installed", "w") as f:
39
- f.write("installed")
40
 
41
- print("flash-attn installation completed")
42
 
43
- except subprocess.CalledProcessError as e:
44
- print(f"Warning: Failed to install flash-attn: {e}")
45
- # Continue anyway - the model might work without it
46
 
47
  # Run setup once when the module is imported
48
  # setup_flash_attention()
@@ -85,8 +85,7 @@ def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0
85
  cache_implementation="paged"
86
  )
87
 
88
- # Convert to numpy array and prepare for output
89
- audio_data = audio_values[0, 0].cpu().numpy()
90
  sample_rate = model.config.sample_rate
91
 
92
  # Normalize audio
@@ -95,7 +94,8 @@ def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0
95
  return sample_rate, audio_data
96
 
97
  except Exception as e:
98
- return None, f"Error generating music: {str(e)}"
 
99
 
100
  # Create Gradio interface
101
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
 
8
  import sys
9
  import os
10
 
11
+ # def setup_flash_attention():
12
+ # """One-time setup for flash-attention with special flags"""
13
+ # # Check if flash-attn is already installed
14
+ # try:
15
+ # import flash_attn
16
+ # print("flash-attn already installed")
17
+ # return
18
+ # except ImportError:
19
+ # pass
20
 
21
+ # # Check if we've already tried to install it in this session
22
+ # if os.path.exists("/tmp/flash_attn_installed"):
23
+ # return
24
 
25
+ # try:
26
+ # print("Installing flash-attn with --no-build-isolation...")
27
+ # subprocess.run([
28
+ # sys.executable, "-m", "pip", "install",
29
+ # "flash-attn==2.7.3", "--no-build-isolation"
30
+ # ], check=True)
31
 
32
+ # # Uninstall apex if it exists
33
+ # subprocess.run([
34
+ # sys.executable, "-m", "pip", "uninstall", "apex", "-y"
35
+ # ], check=False) # Don't fail if apex isn't installed
36
 
37
+ # # Mark as installed
38
+ # with open("/tmp/flash_attn_installed", "w") as f:
39
+ # f.write("installed")
40
 
41
+ # print("flash-attn installation completed")
42
 
43
+ # except subprocess.CalledProcessError as e:
44
+ # print(f"Warning: Failed to install flash-attn: {e}")
45
+ # # Continue anyway - the model might work without it
46
 
47
  # Run setup once when the module is imported
48
  # setup_flash_attention()
 
85
  cache_implementation="paged"
86
  )
87
 
88
+ audio_data = audio_values[0, 0].cpu().numpy().astype(np.float32)
 
89
  sample_rate = model.config.sample_rate
90
 
91
  # Normalize audio
 
94
  return sample_rate, audio_data
95
 
96
  except Exception as e:
97
+ print(f"Error: {str(e)}")
98
+ return None
99
 
100
  # Create Gradio interface
101
  with gr.Blocks(title="MusicGen Large - Music Generation") as demo: