File size: 7,830 Bytes
9b1a8f5
 
 
 
 
 
7d6bada
444d38e
5ebef38
 
444d38e
 
 
 
7d6bada
444d38e
4594c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ebef38
9b1a8f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c76da8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b1a8f5
 
 
0e550b3
 
 
9b1a8f5
 
0e550b3
 
 
 
 
 
9b1a8f5
 
c76da8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82669c
9b1a8f5
 
 
 
fb2a474
 
 
9b1a8f5
 
 
 
 
c76da8f
 
6ebc8e0
9b1a8f5
787c403
9b1a8f5
 
 
 
c82669c
14cf1ab
9b1a8f5
 
c82669c
 
9b1a8f5
1caaac4
6ebc8e0
 
1caaac4
 
6ebc8e0
1caaac4
 
6ebc8e0
14cf1ab
 
 
1caaac4
9b1a8f5
c94ca07
9b1a8f5
1caaac4
787c403
c94ca07
787c403
 
 
 
 
 
c76da8f
 
 
 
 
 
 
 
9b1a8f5
787c403
c76da8f
787c403
9b1a8f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82669c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b1a8f5
c82669c
 
 
9b1a8f5
 
 
 
c82669c
9b1a8f5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import gradio as gr
import subprocess
import os 
import shutil
import tempfile

"""
# Set the PATH and LD_LIBRARY_PATH for CUDA 12.3
cuda_bin_path = "/usr/local/cuda/bin"
cuda_lib_path = "/usr/local/cuda/lib64"

# Update the environment variables
os.environ['PATH'] = f"{cuda_bin_path}:{os.environ.get('PATH', '')}"
os.environ['LD_LIBRARY_PATH'] = f"{cuda_lib_path}:{os.environ.get('LD_LIBRARY_PATH', '')}"
"""

# Install required package
def install_flash_attn():
    try:
        print("Installing flash-attn...")
        subprocess.run(
            ["pip", "install", "flash-attn", "--no-build-isolation"], 
            check=True
        )
        print("flash-attn installed successfully!")
    except subprocess.CalledProcessError as e:
        print(f"Failed to install flash-attn: {e}")
        exit(1)

# Install flash-attn
install_flash_attn()

from huggingface_hub import snapshot_download 

# Create xcodec_mini_infer folder
folder_path = './inference/xcodec_mini_infer'

# Create the folder if it doesn't exist
if not os.path.exists(folder_path):
    os.mkdir(folder_path)
    print(f"Folder created at: {folder_path}")
else:
    print(f"Folder already exists at: {folder_path}")

snapshot_download(
    repo_id = "m-a-p/xcodec_mini_infer",
    local_dir = "./inference/xcodec_mini_infer"
)

# Change to the "inference" directory
inference_dir = "./inference"
try:
    os.chdir(inference_dir)
    print(f"Changed working directory to: {os.getcwd()}")
except FileNotFoundError:
    print(f"Directory not found: {inference_dir}")
    exit(1)

def empty_output_folder(output_dir):
    # List all files in the output directory
    files = os.listdir(output_dir)
    
    # Iterate over the files and remove them
    for file in files:
        file_path = os.path.join(output_dir, file)
        try:
            if os.path.isdir(file_path):
                # If it's a directory, remove it recursively
                shutil.rmtree(file_path)
            else:
                # If it's a file, delete it
                os.remove(file_path)
        except Exception as e:
            print(f"Error deleting file {file_path}: {e}")

# Function to create a temporary file with string content
def create_temp_file(content, prefix, suffix=".txt"):
    temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
    # Ensure content ends with newline and normalize line endings
    content = content.strip() + "\n\n"  # Add extra newline at end
    content = content.replace("\r\n", "\n").replace("\r", "\n")
    temp_file.write(content)
    temp_file.close()
    
    # Debug: Print file contents
    print(f"\nContent written to {prefix}{suffix}:")
    print(content)
    print("---")
    
    return temp_file.name

def get_last_mp3_file(output_dir):
    # List all files in the output directory
    files = os.listdir(output_dir)
    
    # Filter only .mp3 files
    mp3_files = [file for file in files if file.endswith('.mp3')]
    
    if not mp3_files:
        print("No .mp3 files found in the output folder.")
        return None
    
    # Get the full path for the mp3 files
    mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
    
    # Sort the files based on the modification time (most recent first)
    mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
    
    # Return the most recent .mp3 file
    return mp3_files_with_path[0]

def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
    # Create temporary files
    genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
    lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")

    print(f"Genre TXT path: {genre_txt_path}")
    print(f"Lyrics TXT path: {lyrics_txt_path}")

    # Ensure the output folder exists
    output_dir = "./output"
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output folder ensured at: {output_dir}")

    empty_output_folder(output_dir)
 
    # Command and arguments with optimized settings
    command = [
        "python", "infer.py",
        "--stage1_model", "m-a-p/YuE-s1-7B-anneal-en-cot",
        "--stage2_model", "m-a-p/YuE-s2-1B-general",
        "--genre_txt", f"{genre_txt_path}",
        "--lyrics_txt", f"{lyrics_txt_path}",
        "--run_n_segments", f"{num_segments}",
        "--stage2_batch_size", "8",  # Increased from 4 to 8
        "--output_dir", f"{output_dir}",
        "--cuda_idx", "0",
        "--max_new_tokens", f"{max_new_tokens}",
        "--disable_offload_model"
    ]

    # Set up environment variables for CUDA with optimized settings
    env = os.environ.copy()
    env.update({
        "CUDA_VISIBLE_DEVICES": "0",
        "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512",
        "CUDA_HOME": "/usr/local/cuda",
        "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
        "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
        "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512,garbage_collection_threshold:0.8",  # Added garbage collection threshold
        "TORCH_DISTRIBUTED_DEBUG": "DETAIL",  # Added for better debugging
        "CUDA_LAUNCH_BLOCKING": "0"  # Ensure asynchronous CUDA operations
    })
    
    # Execute the command
    try:
        subprocess.run(command, check=True, env=env)
        print("Command executed successfully!")
        
        # Check and print the contents of the output folder
        output_files = os.listdir(output_dir)
        if output_files:
            print("Output folder contents:")
            for file in output_files:
                print(f"- {file}")

            last_mp3 = get_last_mp3_file(output_dir)

            if last_mp3:
                print("Last .mp3 file:", last_mp3)
                return last_mp3
            else:
                return None
        else:
            print("Output folder is empty.")
            return None
    except subprocess.CalledProcessError as e:
        print(f"Error occurred: {e}")
        return None
    finally:
        # Clean up temporary files
        os.remove(genre_txt_path)
        os.remove(lyrics_txt_path)
        print("Temporary files deleted.")

# Gradio 

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# YuE")
        with gr.Row():
            with gr.Column():
                genre_txt = gr.Textbox(label="Genre")
                lyrics_txt = gr.Textbox(label="Lyrics")
                gr.Examples(
                    examples = [
                        [
                            "female blues airy vocal bright vocal piano sad romantic guitar jazz",
                            """
                            [chorus]
                            Don't let this moment fade, hold me close tonight
                            With you here beside me, everything's alright
                            Can't imagine life alone, don't want to let you go
                            Stay with me forever, let our love just flow
                            """
                        ]
                    ], 
                    inputs = [genre_txt, lyrics_txt]
                )
            with gr.Column():
                num_segments = gr.Number(label="Number of Song Segments", info="number of paragraphs", value=1, interactive=False)
                max_new_tokens = gr.Slider(label="Max New Tokens / Duration", info="1000 token = 10 seconds", minimum=500, maximum="24000", step=500, value=1500, interactive=False)
                submit_btn = gr.Button("Submit")
                music_out = gr.Audio(label="Audio Result")
    
    submit_btn.click(
        fn = infer, 
        inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
        outputs = [music_out]
    )
demo.queue().launch(show_api=False, show_error=True)