ginipick commited on
Commit
6d09855
ยท
verified ยท
1 Parent(s): 13bfd1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -199
app.py CHANGED
@@ -20,6 +20,31 @@ logging.basicConfig(
20
  ]
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def analyze_lyrics(lyrics, repeat_chorus=2):
24
  lines = [line.strip() for line in lyrics.split('\n') if line.strip()]
25
 
@@ -36,84 +61,64 @@ def analyze_lyrics(lyrics, repeat_chorus=2):
36
  'chorus': [],
37
  'bridge': []
38
  }
39
-
40
- # ๋งˆ์ง€๋ง‰ ์„น์…˜์„ ์ถ”์ ํ•˜๊ธฐ ์œ„ํ•œ ๋ณ€์ˆ˜
41
- last_section_start = 0
42
 
43
- for i, line in enumerate(lines):
44
  lower_line = line.lower()
45
  if '[verse]' in lower_line:
46
- if current_section: # ์ด์ „ ์„น์…˜์˜ ๋ผ์ธ๋“ค ์ €์žฅ
47
- section_lines[current_section].extend(lines[last_section_start:i])
48
  current_section = 'verse'
49
  sections['verse'] += 1
50
- last_section_start = i + 1
51
  elif '[chorus]' in lower_line:
52
- if current_section:
53
- section_lines[current_section].extend(lines[last_section_start:i])
54
  current_section = 'chorus'
55
  sections['chorus'] += 1
56
- last_section_start = i + 1
57
  elif '[bridge]' in lower_line:
58
- if current_section:
59
- section_lines[current_section].extend(lines[last_section_start:i])
60
  current_section = 'bridge'
61
  sections['bridge'] += 1
62
- last_section_start = i + 1
63
 
64
- # ๋งˆ์ง€๋ง‰ ์„น์…˜์˜ ๋ผ์ธ๋“ค ์ถ”๊ฐ€
65
- if current_section:
66
- section_lines[current_section].extend(lines[last_section_start:])
67
 
68
- # ์ฝ”๋Ÿฌ์Šค ๋ฐ˜๋ณต ์ฒ˜๋ฆฌ
69
  if sections['chorus'] == 1 and repeat_chorus > 1:
70
  chorus_block = section_lines['chorus'][:]
71
  for _ in range(repeat_chorus - 1):
72
  section_lines['chorus'].extend(chorus_block)
73
 
74
- # ์ „์ฒด ๋ผ์ธ ์ˆ˜ ์žฌ๊ณ„์‚ฐ
75
  new_total_lines = sum(len(section_lines[sec]) for sec in section_lines)
76
 
77
  return sections, (sections['verse'] + sections['chorus'] + sections['bridge']), new_total_lines, section_lines
78
 
79
-
80
  def calculate_generation_params(lyrics):
81
  sections, total_sections, total_lines, section_lines = analyze_lyrics(lyrics)
82
 
83
- # ๊ธฐ๋ณธ ์‹œ๊ฐ„ ๊ณ„์‚ฐ (์ดˆ ๋‹จ์œ„)
84
  time_per_line = {
85
- 'verse': 4, # verse๋Š” ํ•œ ์ค„๋‹น 4์ดˆ
86
- 'chorus': 6, # chorus๋Š” ํ•œ ์ค„๋‹น 6์ดˆ
87
- 'bridge': 5 # bridge๋Š” ํ•œ ์ค„๋‹น 5์ดˆ
88
  }
89
 
90
- # ๊ฐ ์„น์…˜๋ณ„ ์˜ˆ์ƒ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
91
  section_durations = {}
92
  for section_type in ['verse', 'chorus', 'bridge']:
93
- # ๊ฐ ์„น์…˜์˜ ๋ผ์ธ ์ˆ˜์— ํ•ด๋‹น ์„น์…˜์˜ ์‹œ๊ฐ„์„ ๊ณฑํ•จ
94
  if isinstance(section_lines[section_type], list):
95
  section_durations[section_type] = len(section_lines[section_type]) * time_per_line[section_type]
96
  else:
97
  section_durations[section_type] = section_lines[section_type] * time_per_line[section_type]
98
 
99
- # ์ „์ฒด ์‹œ๊ฐ„ ๊ณ„์‚ฐ
100
  total_duration = sum(duration for duration in section_durations.values())
101
- total_duration = max(60, total_duration) # ์ตœ์†Œ 60์ดˆ
102
 
103
- # ํ† ํฐ ๊ณ„์‚ฐ
104
- base_tokens = 3000 # ๊ธฐ๋ณธ ํ† ํฐ ์ˆ˜
105
- tokens_per_line = 200 # ์ค„๋‹น ํ† ํฐ ์ˆ˜
106
 
107
  total_tokens = base_tokens + (total_lines * tokens_per_line)
108
 
109
- # ์„น์…˜ ๊ธฐ๋ฐ˜ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ๊ณ„์‚ฐ
110
  if sections['chorus'] > 0:
111
- num_segments = 3 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ 3๊ฐœ ์„ธ๊ทธ๋จผํŠธ
112
  else:
113
- num_segments = 2 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ 2๊ฐœ ์„ธ๊ทธ๋จผํŠธ
114
 
115
- # ํ† ํฐ ์ˆ˜ ์ œํ•œ
116
- max_tokens = min(8000, total_tokens) # ์ตœ๋Œ€ 8000 ํ† ํฐ์œผ๋กœ ์ œํ•œ
117
 
118
  return {
119
  'max_tokens': max_tokens,
@@ -125,43 +130,15 @@ def calculate_generation_params(lyrics):
125
  'has_chorus': sections['chorus'] > 0
126
  }
127
 
128
- def get_audio_duration(file_path):
129
- try:
130
- import librosa
131
- duration = librosa.get_duration(path=file_path)
132
- return duration
133
- except Exception as e:
134
- logging.error(f"Failed to get audio duration: {e}")
135
- return None
136
-
137
- # ์–ธ์–ด ๊ฐ์ง€ ๋ฐ ๋ชจ๋ธ ์„ ํƒ ํ•จ์ˆ˜
138
  def detect_and_select_model(text):
139
- if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text): # ํ•œ๊ธ€
140
  return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
141
- elif re.search(r'[\u4e00-\u9fff]', text): # ์ค‘๊ตญ์–ด
142
  return "m-a-p/YuE-s1-7B-anneal-zh-cot"
143
- elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text): # ์ผ๋ณธ์–ด
144
  return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
145
- else: # ์˜์–ด/๊ธฐํƒ€
146
- return "m-a-p/YuE-s1-7B-anneal-en-cot"
147
-
148
-
149
-
150
- # GPU ์„ค์ • ์ตœ์ ํ™”
151
- def optimize_gpu_settings():
152
- if torch.cuda.is_available():
153
- torch.backends.cuda.matmul.allow_tf32 = True
154
- torch.backends.cudnn.benchmark = True
155
- torch.backends.cudnn.deterministic = False
156
- torch.backends.cudnn.enabled = True
157
-
158
- torch.cuda.empty_cache()
159
- torch.cuda.set_device(0)
160
-
161
- logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
162
- logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
163
  else:
164
- logging.warning("GPU not available!")
165
 
166
  def install_flash_attn():
167
  try:
@@ -183,17 +160,13 @@ def install_flash_attn():
183
  except ImportError:
184
  logging.info("Installing flash-attn...")
185
 
186
- try:
187
- subprocess.run(
188
- ["pip", "install", "flash-attn", "--no-build-isolation"],
189
- check=True,
190
- capture_output=True
191
- )
192
- logging.info("flash-attn installed successfully!")
193
- return True
194
- except subprocess.CalledProcessError:
195
- logging.warning("Failed to install flash-attn via pip, skipping...")
196
- return False
197
 
198
  except Exception as e:
199
  logging.warning(f"Failed to install flash-attn: {e}")
@@ -201,19 +174,27 @@ def install_flash_attn():
201
 
202
  def initialize_system():
203
  optimize_gpu_settings()
204
- has_flash_attn = install_flash_attn()
205
 
206
- from huggingface_hub import snapshot_download
207
-
208
- folder_path = './inference/xcodec_mini_infer'
209
- os.makedirs(folder_path, exist_ok=True)
210
- logging.info(f"Created folder at: {folder_path}")
211
-
212
- snapshot_download(
213
- repo_id="m-a-p/xcodec_mini_infer",
214
- local_dir="./inference/xcodec_mini_infer",
215
- resume_download=True
216
- )
 
 
 
 
 
 
 
 
 
217
 
218
  try:
219
  os.chdir("./inference")
@@ -222,7 +203,7 @@ def initialize_system():
222
  logging.error(f"Directory error: {e}")
223
  raise
224
 
225
- @lru_cache(maxsize=50)
226
  def get_cached_file_path(content_hash, prefix):
227
  return create_temp_file(content_hash, prefix)
228
 
@@ -254,84 +235,46 @@ def get_last_mp3_file(output_dir):
254
  mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
255
  return mp3_files_with_path[0]
256
 
257
- def optimize_model_selection(lyrics, genre):
258
- model_path = detect_and_select_model(lyrics)
259
- params = calculate_generation_params(lyrics)
260
-
261
- # ์ฝ”๋Ÿฌ์Šค ์กด์žฌ ์—ฌ๋ถ€์— ๋”ฐ๋ฅธ ์„ค์ • ์กฐ์ •
262
- has_chorus = params['sections']['chorus'] > 0
263
-
264
- # ํ† ํฐ ์ˆ˜ ๊ณ„์‚ฐ
265
- tokens_per_segment = params['max_tokens'] // params['num_segments']
266
-
267
- model_config = {
268
- "m-a-p/YuE-s1-7B-anneal-en-cot": {
269
- "max_tokens": params['max_tokens'],
270
- "temperature": 0.8,
271
- "batch_size": 8,
272
- "num_segments": params['num_segments'],
273
- "estimated_duration": params['estimated_duration']
274
- },
275
- "m-a-p/YuE-s1-7B-anneal-jp-kr-cot": {
276
- "max_tokens": params['max_tokens'],
277
- "temperature": 0.7,
278
- "batch_size": 8,
279
- "num_segments": params['num_segments'],
280
- "estimated_duration": params['estimated_duration']
281
- },
282
- "m-a-p/YuE-s1-7B-anneal-zh-cot": {
283
- "max_tokens": params['max_tokens'],
284
- "temperature": 0.7,
285
- "batch_size": 8,
286
- "num_segments": params['num_segments'],
287
- "estimated_duration": params['estimated_duration']
288
- }
289
- }
290
-
291
- # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ํ† ํฐ ์ˆ˜ ์ฆ๊ฐ€
292
- if has_chorus:
293
- for config in model_config.values():
294
- config['max_tokens'] = int(config['max_tokens'] * 1.5) # 50% ๋” ๋งŽ์€ ํ† ํฐ ํ• ๋‹น
295
-
296
- return model_path, model_config[model_path], params
297
 
298
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
299
  genre_txt_path = None
300
  lyrics_txt_path = None
301
 
302
  try:
303
- # ๋ชจ๋ธ ์„ ํƒ ๋ฐ ์„ค์ •
304
  model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content)
305
  logging.info(f"Selected model: {model_path}")
306
  logging.info(f"Lyrics analysis: {params}")
307
 
308
- # ์ฝ”๋Ÿฌ์Šค ์„น์…˜ ํ™•์ธ ๋ฐ ๋กœ๊น…
309
  has_chorus = params['sections']['chorus'] > 0
310
  estimated_duration = params.get('estimated_duration', 90)
311
-
312
- # ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ์กฐ์ •
313
  if has_chorus:
314
- actual_num_segments = min(4, actual_num_segments + 1) # ์„ธ๊ทธ๋จผํŠธ ํ•˜๋‚˜ ์ถ”๊ฐ€
315
- actual_max_tokens = min(8000, int(config['max_tokens'] * 1.3)) # 30% ์ฆ๊ฐ€
316
- else:
317
- actual_num_segments = min(3, actual_num_segments + 1)
318
  actual_max_tokens = min(8000, int(config['max_tokens'] * 1.2))
 
 
 
 
319
 
320
-
321
-
322
-
323
  logging.info(f"Estimated duration: {estimated_duration} seconds")
324
  logging.info(f"Has chorus sections: {has_chorus}")
325
  logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
326
 
327
- # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
328
  genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
329
  lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
330
 
331
  output_dir = "./output"
332
  os.makedirs(output_dir, exist_ok=True)
333
  empty_output_folder(output_dir)
334
- # ๊ธฐ๋ณธ ๋ช…๋ น์–ด ๊ตฌ์„ฑ
335
  command = [
336
  "python", "infer.py",
337
  "--stage1_model", model_path,
@@ -339,19 +282,15 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
339
  "--genre_txt", genre_txt_path,
340
  "--lyrics_txt", lyrics_txt_path,
341
  "--run_n_segments", str(actual_num_segments),
342
- "--stage2_batch_size", "4", # ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ ๊ฐ์†Œ
343
  "--output_dir", output_dir,
344
  "--cuda_idx", "0",
345
- "--max_new_tokens", str(actual_max_tokens)
 
 
 
346
  ]
347
 
348
- # GPU ์„ค์ •
349
- if torch.cuda.is_available():
350
- command.append("--disable_offload_model")
351
- # GPU ์„ค์ •
352
-
353
-
354
- # CUDA ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
355
  env = os.environ.copy()
356
  if torch.cuda.is_available():
357
  env.update({
@@ -359,17 +298,11 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
359
  "CUDA_HOME": "/usr/local/cuda",
360
  "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
361
  "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
362
- "PYTORCH_CUDA_ALLOC_CONF": f"max_split_size_mb:512"
 
 
363
  })
364
 
365
- # transformers ์บ์‹œ ๋งˆ์ด๊ทธ๋ ˆ์ด์…˜ ์ฒ˜๋ฆฌ
366
- try:
367
- from transformers.utils import move_cache
368
- move_cache()
369
- except Exception as e:
370
- logging.warning(f"Cache migration warning (non-critical): {e}")
371
-
372
- # ๋ช…๋ น ์‹คํ–‰
373
  process = subprocess.run(
374
  command,
375
  env=env,
@@ -378,7 +311,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
378
  text=True
379
  )
380
 
381
- # ์‹คํ–‰ ๊ฒฐ๊ณผ ๋กœ๊น…
382
  logging.info(f"Command output: {process.stdout}")
383
  if process.stderr:
384
  logging.error(f"Command error: {process.stderr}")
@@ -388,7 +320,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
388
  logging.error(f"Command: {' '.join(command)}")
389
  raise RuntimeError(f"Inference failed: {process.stderr}")
390
 
391
- # ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
392
  last_mp3 = get_last_mp3_file(output_dir)
393
  if last_mp3:
394
  try:
@@ -398,7 +329,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
398
  logging.info(f"Audio duration: {duration:.2f} seconds")
399
  logging.info(f"Expected duration: {estimated_duration} seconds")
400
 
401
- # ์ƒ์„ฑ๋œ ์Œ์•…์ด ๋„ˆ๋ฌด ์งง์€ ๊ฒฝ์šฐ ๊ฒฝ๊ณ 
402
  if duration < estimated_duration * 0.8:
403
  logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s")
404
  except Exception as e:
@@ -412,27 +342,55 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
412
  logging.error(f"Inference error: {e}")
413
  raise
414
  finally:
415
- # ์ž„์‹œ ํŒŒ์ผ ์ •๋ฆฌ
416
- if genre_txt_path and os.path.exists(genre_txt_path):
417
- try:
418
- os.remove(genre_txt_path)
419
- logging.debug(f"Removed temporary file: {genre_txt_path}")
420
- except Exception as e:
421
- logging.warning(f"Failed to remove temporary file {genre_txt_path}: {e}")
422
-
423
- if lyrics_txt_path and os.path.exists(lyrics_txt_path):
424
- try:
425
- os.remove(lyrics_txt_path)
426
- logging.debug(f"Removed temporary file: {lyrics_txt_path}")
427
- except Exception as e:
428
- logging.warning(f"Failed to remove temporary file {lyrics_txt_path}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  def main():
431
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค
432
  with gr.Blocks() as demo:
433
  with gr.Column():
434
  gr.Markdown("# Open SUNO: Full-Song Generation (Multi-Language Support)")
435
-
436
 
437
  with gr.Row():
438
  with gr.Column():
@@ -469,10 +427,8 @@ def main():
469
  submit_btn = gr.Button("Generate Music", variant="primary")
470
  music_out = gr.Audio(label="Generated Audio")
471
 
472
- # ๋‹ค๊ตญ์–ด ์˜ˆ์ œ
473
  gr.Examples(
474
  examples=[
475
- # ์˜์–ด ์˜ˆ์ œ
476
  [
477
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
478
  """[verse]
@@ -497,36 +453,27 @@ Guiding me back homeward, making my heart rejoice
497
  Don't let this moment fade, hold me close tonight
498
  With you here beside me, everything's alright
499
  Can't imagine life alone, don't want to let you go
500
- Stay with me forever, let our love just flow
501
- """
502
  ],
503
- # ํ•œ๊ตญ์–ด ์˜ˆ์ œ
504
  [
505
  "K-pop bright energetic synth dance electronic",
506
  """[verse]
507
  ์–ธ์  ๊ฐ€ ๋งˆ์ฃผํ•œ ๋ˆˆ๋น› ์†์—์„œ
508
- ์šฐ๋ฆฐ ์„œ๋กœ๋ฅผ ์•Œ์•„๋ณด์•˜์ง€
509
 
510
  [chorus]
511
  ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋‚ด๊ฒŒ ๋งํ•ด์ค˜
512
- ๋„ˆ์˜ ์ง„์‹ฌ์„ ์ˆจ๊ธฐ์ง€ ๋ง์•„ ์ค˜
513
 
514
  [verse]
515
  ์–ด๋‘์šด ๋ฐค์„ ์ง€๋‚  ๋•Œ๋งˆ๋‹ค
516
- ๋„ˆ์˜ ๋ชฉ์†Œ๋ฆฌ๋ฅผ ๋– ์˜ฌ๋ ค
517
 
518
  [chorus]
519
  ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋‚ด๊ฒŒ ๋งํ•ด์ค˜
520
- ๋„ˆ์˜ ์ง„์‹ฌ์„ ์ˆจ๊ธฐ์ง€ ๋ง์•„ ์ค˜
521
-
522
-
523
- """
524
  ]
525
  ],
526
  inputs=[genre_txt, lyrics_txt]
527
  )
528
 
529
- # ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™”
530
  initialize_system()
531
 
532
  def update_info(lyrics):
@@ -540,9 +487,6 @@ Stay with me forever, let our love just flow
540
  f"Verses: {sections['verse']}, Chorus: {sections['chorus']} (Expected full length including chorus)"
541
  )
542
 
543
-
544
-
545
- # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
546
  lyrics_txt.change(
547
  fn=update_info,
548
  inputs=[lyrics_txt],
@@ -565,5 +509,8 @@ if __name__ == "__main__":
565
  share=True,
566
  show_api=True,
567
  show_error=True,
568
- max_threads=2
569
- )
 
 
 
 
20
  ]
21
  )
22
 
23
+ def optimize_gpu_settings():
24
+ if torch.cuda.is_available():
25
+ # GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์ตœ์ ํ™”
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.benchmark = True
28
+ torch.backends.cudnn.enabled = True
29
+ torch.backends.cudnn.deterministic = False
30
+
31
+ # L40S์— ์ตœ์ ํ™”๋œ ๋ฉ”๋ชจ๋ฆฌ ์„ค์ •
32
+ torch.cuda.empty_cache()
33
+ torch.cuda.set_device(0)
34
+
35
+ # CUDA ์ŠคํŠธ๋ฆผ ์ตœ์ ํ™”
36
+ torch.cuda.Stream(0)
37
+
38
+ # ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ์ตœ์ ํ™”
39
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
40
+
41
+ logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
42
+ logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
43
+
44
+ # L40S ํŠนํ™” ์„ค์ •
45
+ if 'L40S' in torch.cuda.get_device_name(0):
46
+ torch.cuda.set_per_process_memory_fraction(0.95)
47
+
48
  def analyze_lyrics(lyrics, repeat_chorus=2):
49
  lines = [line.strip() for line in lyrics.split('\n') if line.strip()]
50
 
 
61
  'chorus': [],
62
  'bridge': []
63
  }
 
 
 
64
 
65
+ for line in lines:
66
  lower_line = line.lower()
67
  if '[verse]' in lower_line:
 
 
68
  current_section = 'verse'
69
  sections['verse'] += 1
70
+ continue
71
  elif '[chorus]' in lower_line:
 
 
72
  current_section = 'chorus'
73
  sections['chorus'] += 1
74
+ continue
75
  elif '[bridge]' in lower_line:
 
 
76
  current_section = 'bridge'
77
  sections['bridge'] += 1
78
+ continue
79
 
80
+ if current_section:
81
+ section_lines[current_section].append(line)
 
82
 
 
83
  if sections['chorus'] == 1 and repeat_chorus > 1:
84
  chorus_block = section_lines['chorus'][:]
85
  for _ in range(repeat_chorus - 1):
86
  section_lines['chorus'].extend(chorus_block)
87
 
 
88
  new_total_lines = sum(len(section_lines[sec]) for sec in section_lines)
89
 
90
  return sections, (sections['verse'] + sections['chorus'] + sections['bridge']), new_total_lines, section_lines
91
 
 
92
  def calculate_generation_params(lyrics):
93
  sections, total_sections, total_lines, section_lines = analyze_lyrics(lyrics)
94
 
 
95
  time_per_line = {
96
+ 'verse': 4,
97
+ 'chorus': 6,
98
+ 'bridge': 5
99
  }
100
 
 
101
  section_durations = {}
102
  for section_type in ['verse', 'chorus', 'bridge']:
 
103
  if isinstance(section_lines[section_type], list):
104
  section_durations[section_type] = len(section_lines[section_type]) * time_per_line[section_type]
105
  else:
106
  section_durations[section_type] = section_lines[section_type] * time_per_line[section_type]
107
 
 
108
  total_duration = sum(duration for duration in section_durations.values())
109
+ total_duration = max(60, total_duration)
110
 
111
+ base_tokens = 3000
112
+ tokens_per_line = 200
 
113
 
114
  total_tokens = base_tokens + (total_lines * tokens_per_line)
115
 
 
116
  if sections['chorus'] > 0:
117
+ num_segments = 3
118
  else:
119
+ num_segments = 2
120
 
121
+ max_tokens = min(8000, total_tokens)
 
122
 
123
  return {
124
  'max_tokens': max_tokens,
 
130
  'has_chorus': sections['chorus'] > 0
131
  }
132
 
 
 
 
 
 
 
 
 
 
 
133
  def detect_and_select_model(text):
134
+ if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text):
135
  return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
136
+ elif re.search(r'[\u4e00-\u9fff]', text):
137
  return "m-a-p/YuE-s1-7B-anneal-zh-cot"
138
+ elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text):
139
  return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  else:
141
+ return "m-a-p/YuE-s1-7B-anneal-en-cot"
142
 
143
  def install_flash_attn():
144
  try:
 
160
  except ImportError:
161
  logging.info("Installing flash-attn...")
162
 
163
+ subprocess.run(
164
+ ["pip", "install", "flash-attn", "--no-build-isolation"],
165
+ check=True,
166
+ capture_output=True
167
+ )
168
+ logging.info("flash-attn installed successfully!")
169
+ return True
 
 
 
 
170
 
171
  except Exception as e:
172
  logging.warning(f"Failed to install flash-attn: {e}")
 
174
 
175
  def initialize_system():
176
  optimize_gpu_settings()
 
177
 
178
+ with ThreadPoolExecutor(max_workers=4) as executor:
179
+ futures = []
180
+
181
+ futures.append(executor.submit(install_flash_attn))
182
+
183
+ from huggingface_hub import snapshot_download
184
+
185
+ folder_path = './inference/xcodec_mini_infer'
186
+ os.makedirs(folder_path, exist_ok=True)
187
+ logging.info(f"Created folder at: {folder_path}")
188
+
189
+ futures.append(executor.submit(
190
+ snapshot_download,
191
+ repo_id="m-a-p/xcodec_mini_infer",
192
+ local_dir="./inference/xcodec_mini_infer",
193
+ resume_download=True
194
+ ))
195
+
196
+ for future in futures:
197
+ future.result()
198
 
199
  try:
200
  os.chdir("./inference")
 
203
  logging.error(f"Directory error: {e}")
204
  raise
205
 
206
+ @lru_cache(maxsize=100)
207
  def get_cached_file_path(content_hash, prefix):
208
  return create_temp_file(content_hash, prefix)
209
 
 
235
  mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
236
  return mp3_files_with_path[0]
237
 
238
+ def get_audio_duration(file_path):
239
+ try:
240
+ import librosa
241
+ duration = librosa.get_duration(path=file_path)
242
+ return duration
243
+ except Exception as e:
244
+ logging.error(f"Failed to get audio duration: {e}")
245
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
248
  genre_txt_path = None
249
  lyrics_txt_path = None
250
 
251
  try:
 
252
  model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content)
253
  logging.info(f"Selected model: {model_path}")
254
  logging.info(f"Lyrics analysis: {params}")
255
 
 
256
  has_chorus = params['sections']['chorus'] > 0
257
  estimated_duration = params.get('estimated_duration', 90)
258
+
259
+ # ์„ธ๊ทธ๋จผํŠธ ๋ฐ ํ† ํฐ ์ˆ˜ ์„ค์ •
260
  if has_chorus:
 
 
 
 
261
  actual_max_tokens = min(8000, int(config['max_tokens'] * 1.2))
262
+ actual_num_segments = min(4, params['num_segments'] + 1)
263
+ else:
264
+ actual_max_tokens = config['max_tokens']
265
+ actual_num_segments = params['num_segments']
266
 
 
 
 
267
  logging.info(f"Estimated duration: {estimated_duration} seconds")
268
  logging.info(f"Has chorus sections: {has_chorus}")
269
  logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
270
 
 
271
  genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
272
  lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
273
 
274
  output_dir = "./output"
275
  os.makedirs(output_dir, exist_ok=True)
276
  empty_output_folder(output_dir)
277
+
278
  command = [
279
  "python", "infer.py",
280
  "--stage1_model", model_path,
 
282
  "--genre_txt", genre_txt_path,
283
  "--lyrics_txt", lyrics_txt_path,
284
  "--run_n_segments", str(actual_num_segments),
285
+ "--stage2_batch_size", "16",
286
  "--output_dir", output_dir,
287
  "--cuda_idx", "0",
288
+ "--max_new_tokens", str(actual_max_tokens),
289
+ "--use_flash_attention", "True",
290
+ "--use_bettertransformer", "True",
291
+ "--use_compile", "True"
292
  ]
293
 
 
 
 
 
 
 
 
294
  env = os.environ.copy()
295
  if torch.cuda.is_available():
296
  env.update({
 
298
  "CUDA_HOME": "/usr/local/cuda",
299
  "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
300
  "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
301
+ "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512",
302
+ "CUDA_LAUNCH_BLOCKING": "0",
303
+ "TORCH_DISTRIBUTED_DEBUG": "DETAIL"
304
  })
305
 
 
 
 
 
 
 
 
 
306
  process = subprocess.run(
307
  command,
308
  env=env,
 
311
  text=True
312
  )
313
 
 
314
  logging.info(f"Command output: {process.stdout}")
315
  if process.stderr:
316
  logging.error(f"Command error: {process.stderr}")
 
320
  logging.error(f"Command: {' '.join(command)}")
321
  raise RuntimeError(f"Inference failed: {process.stderr}")
322
 
 
323
  last_mp3 = get_last_mp3_file(output_dir)
324
  if last_mp3:
325
  try:
 
329
  logging.info(f"Audio duration: {duration:.2f} seconds")
330
  logging.info(f"Expected duration: {estimated_duration} seconds")
331
 
 
332
  if duration < estimated_duration * 0.8:
333
  logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s")
334
  except Exception as e:
 
342
  logging.error(f"Inference error: {e}")
343
  raise
344
  finally:
345
+ for path in [genre_txt_path, lyrics_txt_path]:
346
+ if path and os.path.exists(path):
347
+ try:
348
+ os.remove(path)
349
+ logging.debug(f"Removed temporary file: {path}")
350
+ except Exception as e:
351
+ logging.warning(f"Failed to remove temporary file {path}: {e}")
352
+
353
+ def optimize_model_selection(lyrics, genre):
354
+ model_path = detect_and_select_model(lyrics)
355
+ params = calculate_generation_params(lyrics)
356
+
357
+ has_chorus = params['sections']['chorus'] > 0
358
+ tokens_per_segment = params['max_tokens'] // params['num_segments']
359
+
360
+ model_config = {
361
+ "m-a-p/YuE-s1-7B-anneal-en-cot": {
362
+ "max_tokens": params['max_tokens'],
363
+ "temperature": 0.8,
364
+ "batch_size": 16,
365
+ "num_segments": params['num_segments'],
366
+ "estimated_duration": params['estimated_duration']
367
+ },
368
+ "m-a-p/YuE-s1-7B-anneal-jp-kr-cot": {
369
+ "max_tokens": params['max_tokens'],
370
+ "temperature": 0.7,
371
+ "batch_size": 16,
372
+ "num_segments": params['num_segments'],
373
+ "estimated_duration": params['estimated_duration']
374
+ },
375
+ "m-a-p/YuE-s1-7B-anneal-zh-cot": {
376
+ "max_tokens": params['max_tokens'],
377
+ "temperature": 0.7,
378
+ "batch_size": 16,
379
+ "num_segments": params['num_segments'],
380
+ "estimated_duration": params['estimated_duration']
381
+ }
382
+ }
383
+
384
+ if has_chorus:
385
+ for config in model_config.values():
386
+ config['max_tokens'] = int(config['max_tokens'] * 1.5)
387
+
388
+ return model_path, model_config[model_path], params
389
 
390
  def main():
 
391
  with gr.Blocks() as demo:
392
  with gr.Column():
393
  gr.Markdown("# Open SUNO: Full-Song Generation (Multi-Language Support)")
 
394
 
395
  with gr.Row():
396
  with gr.Column():
 
427
  submit_btn = gr.Button("Generate Music", variant="primary")
428
  music_out = gr.Audio(label="Generated Audio")
429
 
 
430
  gr.Examples(
431
  examples=[
 
432
  [
433
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
434
  """[verse]
 
453
  Don't let this moment fade, hold me close tonight
454
  With you here beside me, everything's alright
455
  Can't imagine life alone, don't want to let you go
456
+ Stay with me forever, let our love just flow"""
 
457
  ],
 
458
  [
459
  "K-pop bright energetic synth dance electronic",
460
  """[verse]
461
  ์–ธ์  ๊ฐ€ ๋งˆ์ฃผํ•œ ๋ˆˆ๋น› ์†์—์„œ
 
462
 
463
  [chorus]
464
  ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋‚ด๊ฒŒ ๋งํ•ด์ค˜
 
465
 
466
  [verse]
467
  ์–ด๋‘์šด ๋ฐค์„ ์ง€๋‚  ๋•Œ๋งˆ๋‹ค
 
468
 
469
  [chorus]
470
  ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋‚ด๊ฒŒ ๋งํ•ด์ค˜
471
+ """
 
 
 
472
  ]
473
  ],
474
  inputs=[genre_txt, lyrics_txt]
475
  )
476
 
 
477
  initialize_system()
478
 
479
  def update_info(lyrics):
 
487
  f"Verses: {sections['verse']}, Chorus: {sections['chorus']} (Expected full length including chorus)"
488
  )
489
 
 
 
 
490
  lyrics_txt.change(
491
  fn=update_info,
492
  inputs=[lyrics_txt],
 
509
  share=True,
510
  show_api=True,
511
  show_error=True,
512
+ max_threads=8,
513
+ enable_queue=True,
514
+ cache_examples=True,
515
+ analytics_enabled=False
516
+ )