ginipick commited on
Commit
06fb240
·
verified ·
1 Parent(s): 01ee1f1

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -867
app.py DELETED
@@ -1,867 +0,0 @@
1
- import gradio as gr
2
- import subprocess
3
- import os
4
- import shutil
5
- import tempfile
6
- import torch
7
- import logging
8
- import numpy as np
9
- import re
10
- from concurrent.futures import ThreadPoolExecutor
11
- from functools import lru_cache
12
-
13
- from datetime import datetime
14
-
15
- # 로깅 설정
16
- logging.basicConfig(
17
- level=logging.INFO,
18
- format='%(asctime)s - %(levelname)s - %(message)s',
19
- handlers=[
20
- logging.FileHandler('yue_generation.log'),
21
- logging.StreamHandler()
22
- ]
23
- )
24
-
25
- def optimize_gpu_settings():
26
- if torch.cuda.is_available():
27
- # GPU 메모리 관리 최적화
28
- torch.backends.cuda.matmul.allow_tf32 = True
29
- torch.backends.cudnn.benchmark = True
30
- torch.backends.cudnn.enabled = True
31
- torch.backends.cudnn.deterministic = False
32
-
33
- # L40S에 최적화된 메모리 설정
34
- torch.cuda.empty_cache()
35
- torch.cuda.set_device(0)
36
-
37
- # CUDA 스트림 최적화
38
- torch.cuda.Stream(0)
39
-
40
- # 메모리 할당 최적화
41
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
42
-
43
- logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
44
- logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
45
-
46
- # L40S 특화 설정
47
- if 'L40S' in torch.cuda.get_device_name(0):
48
- torch.cuda.set_per_process_memory_fraction(0.95)
49
-
50
- def analyze_lyrics(lyrics, repeat_chorus=2):
51
- lines = [line.strip() for line in lyrics.split('\n') if line.strip()]
52
-
53
- sections = {
54
- 'verse': 0,
55
- 'chorus': 0,
56
- 'bridge': 0,
57
- 'total_lines': len(lines)
58
- }
59
-
60
- current_section = None
61
- section_lines = {
62
- 'verse': [],
63
- 'chorus': [],
64
- 'bridge': []
65
- }
66
- last_section = None
67
-
68
- # 마지막 섹션 태그 찾기
69
- for i, line in enumerate(lines):
70
- if '[verse]' in line.lower() or '[chorus]' in line.lower() or '[bridge]' in line.lower():
71
- last_section = i
72
-
73
- for i, line in enumerate(lines):
74
- lower_line = line.lower()
75
-
76
- # 섹션 태그 처리
77
- if '[verse]' in lower_line:
78
- if current_section: # 이전 섹션의 라인들 저장
79
- section_lines[current_section].extend(lines[last_section_start:i])
80
- current_section = 'verse'
81
- sections['verse'] += 1
82
- last_section_start = i + 1
83
- continue
84
- elif '[chorus]' in lower_line:
85
- if current_section:
86
- section_lines[current_section].extend(lines[last_section_start:i])
87
- current_section = 'chorus'
88
- sections['chorus'] += 1
89
- last_section_start = i + 1
90
- continue
91
- elif '[bridge]' in lower_line:
92
- if current_section:
93
- section_lines[current_section].extend(lines[last_section_start:i])
94
- current_section = 'bridge'
95
- sections['bridge'] += 1
96
- last_section_start = i + 1
97
- continue
98
-
99
- # 마지막 섹션의 라인들 추가
100
- if current_section and last_section_start < len(lines):
101
- section_lines[current_section].extend(lines[last_section_start:])
102
-
103
- # 코러스 반복 처리
104
- if sections['chorus'] > 0 and repeat_chorus > 1:
105
- original_chorus = section_lines['chorus'][:]
106
- for _ in range(repeat_chorus - 1):
107
- section_lines['chorus'].extend(original_chorus)
108
-
109
- # 섹션별 라인 수 확인 로깅
110
- logging.info(f"Section line counts - Verse: {len(section_lines['verse'])}, "
111
- f"Chorus: {len(section_lines['chorus'])}, "
112
- f"Bridge: {len(section_lines['bridge'])}")
113
-
114
- return sections, (sections['verse'] + sections['chorus'] + sections['bridge']), len(lines), section_lines
115
-
116
- def calculate_generation_params(lyrics):
117
- sections, total_sections, total_lines, section_lines = analyze_lyrics(lyrics)
118
-
119
- # 기본 시간 계산 (초 단위)
120
- time_per_line = {
121
- 'verse': 4, # verse는 한 줄당 4초
122
- 'chorus': 6, # chorus는 한 줄당 6초
123
- 'bridge': 5 # bridge는 한 줄당 5초
124
- }
125
-
126
- # 각 섹션별 예상 시간 계산 (마지막 섹션 포함)
127
- section_durations = {}
128
- for section_type in ['verse', 'chorus', 'bridge']:
129
- lines_count = len(section_lines[section_type])
130
- section_durations[section_type] = lines_count * time_per_line[section_type]
131
-
132
- # 전체 시간 계산 (여유 시간 추가)
133
- total_duration = sum(duration for duration in section_durations.values())
134
- total_duration = max(60, int(total_duration * 1.2)) # 20% 여유 시간 추가
135
-
136
- # 토큰 계산 (마지막 섹션을 위한 추가 토큰)
137
- base_tokens = 3000
138
- tokens_per_line = 200
139
- extra_tokens = 1000 # 마지막 섹션을 위한 추가 토큰
140
-
141
- total_tokens = base_tokens + (total_lines * tokens_per_line) + extra_tokens
142
-
143
- # 세그먼트 ��� 계산 (마지막 섹션을 위한 추가 세그먼트)
144
- if sections['chorus'] > 0:
145
- num_segments = 4 # 코러스가 있는 경우 4개 세그먼트
146
- else:
147
- num_segments = 3 # 코러스가 없는 경우 3개 세그먼트
148
-
149
- # 토큰 수 제한 (더 큰 제한)
150
- max_tokens = min(12000, total_tokens) # 최대 토큰 수 증가
151
-
152
- return {
153
- 'max_tokens': max_tokens,
154
- 'num_segments': num_segments,
155
- 'sections': sections,
156
- 'section_lines': section_lines,
157
- 'estimated_duration': total_duration,
158
- 'section_durations': section_durations,
159
- 'has_chorus': sections['chorus'] > 0
160
- }
161
-
162
- def detect_and_select_model(text):
163
- if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text):
164
- return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
165
- elif re.search(r'[\u4e00-\u9fff]', text):
166
- return "m-a-p/YuE-s1-7B-anneal-zh-cot"
167
- elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text):
168
- return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
169
- else:
170
- return "m-a-p/YuE-s1-7B-anneal-en-cot"
171
-
172
- def install_flash_attn():
173
- try:
174
- if not torch.cuda.is_available():
175
- logging.warning("GPU not available, skipping flash-attn installation")
176
- return False
177
-
178
- cuda_version = torch.version.cuda
179
- if cuda_version is None:
180
- logging.warning("CUDA not available, skipping flash-attn installation")
181
- return False
182
-
183
- logging.info(f"Detected CUDA version: {cuda_version}")
184
-
185
- try:
186
- import flash_attn
187
- logging.info("flash-attn already installed")
188
- return True
189
- except ImportError:
190
- logging.info("Installing flash-attn...")
191
-
192
- subprocess.run(
193
- ["pip", "install", "flash-attn", "--no-build-isolation"],
194
- check=True,
195
- capture_output=True
196
- )
197
- logging.info("flash-attn installed successfully!")
198
- return True
199
-
200
- except Exception as e:
201
- logging.warning(f"Failed to install flash-attn: {e}")
202
- return False
203
-
204
-
205
- def download_missing_files():
206
- try:
207
- from huggingface_hub import hf_hub_download, snapshot_download
208
-
209
- # xcodec_mini_infer 파일들 직접 다운로드
210
- repo_id = "hf-internal-testing/xcodec_mini_infer" # 저장소 경로 수정
211
- files_to_download = {
212
- "config.json": "config/config.json",
213
- "vocal_decoder.pth": "checkpoints/vocal_decoder.pth",
214
- "inst_decoder.pth": "checkpoints/inst_decoder.pth"
215
- }
216
-
217
- xcodec_dir = "./xcodec_mini_infer"
218
- os.makedirs(xcodec_dir, exist_ok=True)
219
- os.makedirs(os.path.join(xcodec_dir, "checkpoints"), exist_ok=True)
220
-
221
- for target_name, source_path in files_to_download.items():
222
- try:
223
- downloaded_path = hf_hub_download(
224
- repo_id=repo_id,
225
- filename=source_path,
226
- cache_dir="./models/cache",
227
- force_download=True,
228
- local_files_only=False
229
- )
230
- logging.info(f"Downloaded {source_path} to: {downloaded_path}")
231
-
232
- # 파일을 올바른 위치로 복사
233
- target_path = os.path.join(xcodec_dir, target_name)
234
- shutil.copy2(downloaded_path, target_path)
235
- logging.info(f"Copied to: {target_path}")
236
-
237
- except Exception as e:
238
- logging.error(f"Error downloading {source_path}: {e}")
239
- # 대체 저장소 시도
240
- try:
241
- alt_repo_id = "facebook/musicgen-small"
242
- downloaded_path = hf_hub_download(
243
- repo_id=alt_repo_id,
244
- filename=source_path,
245
- cache_dir="./models/cache",
246
- force_download=True
247
- )
248
- target_path = os.path.join(xcodec_dir, target_name)
249
- shutil.copy2(downloaded_path, target_path)
250
- logging.info(f"Downloaded from alternate source to: {target_path}")
251
- except Exception as alt_e:
252
- logging.error(f"Error with alternate download: {alt_e}")
253
- raise
254
-
255
- # YuE 모델들 다운로드
256
- models = [
257
- "m-a-p/YuE-s1-7B-anneal-jp-kr-cot",
258
- "m-a-p/YuE-s1-7B-anneal-en-cot",
259
- "m-a-p/YuE-s1-7B-anneal-zh-cot",
260
- "m-a-p/YuE-s2-1B-general"
261
- ]
262
-
263
- for model in models:
264
- model_name = model.split('/')[-1]
265
- model_path = snapshot_download(
266
- repo_id=model,
267
- local_dir=f"./models/{model_name}",
268
- cache_dir="./models/cache",
269
- resume_download=True,
270
- force_download=True
271
- )
272
- logging.info(f"Downloaded {model_name} to: {model_path}")
273
-
274
- # 파일 존재 및 크기 확인
275
- for target_name in files_to_download.keys():
276
- file_path = os.path.join(xcodec_dir, target_name)
277
- if not os.path.exists(file_path):
278
- raise FileNotFoundError(f"Failed to download {target_name}")
279
- file_size = os.path.getsize(file_path)
280
- if file_size == 0:
281
- raise FileNotFoundError(f"Downloaded file is empty: {target_name}")
282
- logging.info(f"Verified {target_name}: {file_size} bytes")
283
-
284
- logging.info("All required models downloaded successfully")
285
-
286
- except Exception as e:
287
- logging.error(f"Error downloading models: {e}")
288
- raise
289
-
290
- def check_model_files():
291
- base_dir = os.getcwd()
292
- xcodec_dir = os.path.join(base_dir, "xcodec_mini_infer")
293
-
294
- # 필요한 파일 목록
295
- required_files = {
296
- "config.json": "config.json",
297
- "vocal_decoder.pth": "vocal_decoder.pth",
298
- "inst_decoder.pth": "inst_decoder.pth"
299
- }
300
-
301
- # 파일 존재 여부 확인
302
- missing = False
303
- for file_name in required_files.keys():
304
- file_path = os.path.join(xcodec_dir, file_name)
305
- if not os.path.exists(file_path):
306
- missing = True
307
- logging.warning(f"Missing file: {file_path}")
308
-
309
- if missing:
310
- logging.info("Downloading missing files...")
311
- download_missing_files()
312
-
313
- # 다운로드 후 파일 재확인
314
- for file_name in required_files.keys():
315
- file_path = os.path.join(xcodec_dir, file_name)
316
- if not os.path.exists(file_path):
317
- raise FileNotFoundError(f"Failed to download or locate required file: {file_name}")
318
- else:
319
- file_size = os.path.getsize(file_path)
320
- if file_size == 0:
321
- raise FileNotFoundError(f"Downloaded file is empty: {file_name}")
322
- logging.info(f"Verified {file_name}: {file_size} bytes")
323
-
324
- def initialize_system():
325
- optimize_gpu_settings()
326
-
327
- try:
328
- # 기본 디렉토리 구조 생성
329
- base_dir = os.path.abspath("./inference")
330
- os.makedirs(base_dir, exist_ok=True)
331
- os.makedirs(os.path.join(base_dir, "models"), exist_ok=True)
332
-
333
- # 작업 디렉토리 변경
334
- os.chdir(base_dir)
335
- logging.info(f"Working directory changed to: {os.getcwd()}")
336
-
337
- from huggingface_hub import snapshot_download, hf_hub_download
338
-
339
- # xcodec_mini_infer 파일들 직접 다운로드
340
- xcodec_dir = os.path.join(base_dir, "xcodec_mini_infer")
341
- os.makedirs(xcodec_dir, exist_ok=True)
342
-
343
- # 필수 파일 직접 다운로드
344
- required_files = {
345
- "config.json": "config.json",
346
- "vocal_decoder.pth": "vocal_decoder.pth",
347
- "inst_decoder.pth": "inst_decoder.pth"
348
- }
349
-
350
- for file_name in required_files.keys():
351
- try:
352
- file_path = os.path.join(xcodec_dir, file_name)
353
- if not os.path.exists(file_path):
354
- downloaded_path = hf_hub_download(
355
- repo_id="m-a-p/xcodec_mini_infer",
356
- filename=file_name,
357
- local_dir=xcodec_dir,
358
- force_download=True
359
- )
360
- if downloaded_path != file_path:
361
- shutil.copy2(downloaded_path, file_path)
362
- logging.info(f"Downloaded {file_name} to {file_path}")
363
- except Exception as e:
364
- logging.error(f"Error downloading {file_name}: {e}")
365
- raise
366
-
367
- # YuE 모델들 다운로드
368
- models = [
369
- "m-a-p/YuE-s1-7B-anneal-jp-kr-cot",
370
- "m-a-p/YuE-s1-7B-anneal-en-cot",
371
- "m-a-p/YuE-s1-7B-anneal-zh-cot",
372
- "m-a-p/YuE-s2-1B-general"
373
- ]
374
-
375
- with ThreadPoolExecutor(max_workers=4) as executor:
376
- futures = []
377
-
378
- # Flash Attention 설치
379
- futures.append(executor.submit(install_flash_attn))
380
-
381
- # 모델 다운로드
382
- for model in models:
383
- model_name = model.split('/')[-1]
384
- model_path = os.path.join(base_dir, "models", model_name)
385
- futures.append(executor.submit(
386
- snapshot_download,
387
- repo_id=model,
388
- local_dir=model_path,
389
- force_download=True
390
- ))
391
-
392
- # 모든 작업 완료 대기
393
- for future in futures:
394
- future.result()
395
-
396
- # 파일 존재 확인
397
- for file_name, _ in required_files.items():
398
- file_path = os.path.join(xcodec_dir, file_name)
399
- if not os.path.exists(file_path):
400
- raise FileNotFoundError(f"Required file still missing after download: {file_path}")
401
- else:
402
- file_size = os.path.getsize(file_path)
403
- logging.info(f"Verified {file_name}: {file_size} bytes")
404
-
405
- logging.info("System initialization completed successfully")
406
-
407
- except Exception as e:
408
- logging.error(f"Directory error: {e}")
409
- raise
410
-
411
- @lru_cache(maxsize=100)
412
- def get_cached_file_path(content_hash, prefix):
413
- return create_temp_file(content_hash, prefix)
414
-
415
- def empty_output_folder(output_dir):
416
- try:
417
- shutil.rmtree(output_dir)
418
- os.makedirs(output_dir)
419
- logging.info(f"Output folder cleaned: {output_dir}")
420
- except Exception as e:
421
- logging.error(f"Error cleaning output folder: {e}")
422
- raise
423
-
424
- def create_temp_file(content, prefix, suffix=".txt"):
425
- temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
426
- content = content.strip() + "\n\n"
427
- content = content.replace("\r\n", "\n").replace("\r", "\n")
428
- temp_file.write(content)
429
- temp_file.close()
430
- logging.debug(f"Temporary file created: {temp_file.name}")
431
- return temp_file.name
432
-
433
- def get_last_mp3_file(output_dir):
434
- mp3_files = [f for f in os.listdir(output_dir) if f.endswith('.mp3')]
435
- if not mp3_files:
436
- logging.warning("No MP3 files found")
437
- return None
438
-
439
- mp3_files_with_path = [os.path.join(output_dir, f) for f in mp3_files]
440
- mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
441
- return mp3_files_with_path[0]
442
-
443
- def get_audio_duration(file_path):
444
- try:
445
- import librosa
446
- duration = librosa.get_duration(path=file_path)
447
- return duration
448
- except Exception as e:
449
- logging.error(f"Failed to get audio duration: {e}")
450
- return None
451
-
452
-
453
-
454
- def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
455
- check_model_files() # 필요한 파일 체크 및 다운로드
456
-
457
- genre_txt_path = None
458
- lyrics_txt_path = None
459
-
460
- try:
461
-
462
- model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content)
463
- logging.info(f"Selected model: {model_path}")
464
- logging.info(f"Lyrics analysis: {params}")
465
-
466
- has_chorus = params['sections']['chorus'] > 0
467
- estimated_duration = params.get('estimated_duration', 90)
468
-
469
-
470
- # 세그먼트 및 토큰 수 설정
471
- if has_chorus:
472
- actual_max_tokens = min(12000, int(config['max_tokens'] * 1.3)) # 30% 더 많은 토큰
473
- actual_num_segments = min(5, params['num_segments'] + 2) # 추가 세그먼트
474
- else:
475
- actual_max_tokens = min(10000, int(config['max_tokens'] * 1.2))
476
- actual_num_segments = min(4, params['num_segments'] + 1)
477
-
478
-
479
-
480
- logging.info(f"Estimated duration: {estimated_duration} seconds")
481
- logging.info(f"Has chorus sections: {has_chorus}")
482
- logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
483
-
484
- genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
485
- lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
486
-
487
- output_dir = "./output"
488
- os.makedirs(output_dir, exist_ok=True)
489
- empty_output_folder(output_dir)
490
-
491
- command = [
492
- "python", "infer.py",
493
- "--stage1_model", model_path, # 원래 모델 경로 사용
494
- "--stage2_model", "m-a-p/YuE-s2-1B-general",
495
- "--genre_txt", genre_txt_path,
496
- "--lyrics_txt", lyrics_txt_path,
497
- "--run_n_segments", str(actual_num_segments),
498
- "--stage2_batch_size", "16",
499
- "--output_dir", output_dir,
500
- "--cuda_idx", "0",
501
- "--max_new_tokens", str(actual_max_tokens),
502
- "--disable_offload_model"
503
- ]
504
-
505
-
506
- env = os.environ.copy()
507
- if torch.cuda.is_available():
508
- env.update({
509
- "CUDA_VISIBLE_DEVICES": "0",
510
- "CUDA_HOME": "/usr/local/cuda",
511
- "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
512
- "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
513
- "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512",
514
- "CUDA_LAUNCH_BLOCKING": "0",
515
- "TRANSFORMERS_CACHE": "./models/cache",
516
- "HF_HOME": "./models/cache"
517
- })
518
-
519
-
520
-
521
-
522
- # transformers 캐시 마이그레이션 처리
523
- try:
524
- from transformers.utils import move_cache
525
- move_cache()
526
- except Exception as e:
527
- logging.warning(f"Cache migration warning (non-critical): {e}")
528
-
529
- process = subprocess.run(
530
- command,
531
- env=env,
532
- check=False,
533
- capture_output=True,
534
- text=True
535
- )
536
-
537
- logging.info(f"Command output: {process.stdout}")
538
- if process.stderr:
539
- logging.error(f"Command error: {process.stderr}")
540
-
541
- if process.returncode != 0:
542
- logging.error(f"Command failed with return code: {process.returncode}")
543
- logging.error(f"Command: {' '.join(command)}")
544
- raise RuntimeError(f"Inference failed: {process.stderr}")
545
-
546
- last_mp3 = get_last_mp3_file(output_dir)
547
- if last_mp3:
548
- try:
549
- duration = get_audio_duration(last_mp3)
550
- logging.info(f"Generated audio file: {last_mp3}")
551
- if duration:
552
- logging.info(f"Audio duration: {duration:.2f} seconds")
553
- logging.info(f"Expected duration: {estimated_duration} seconds")
554
-
555
- if duration < estimated_duration * 0.8:
556
- logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s")
557
- except Exception as e:
558
- logging.warning(f"Failed to get audio duration: {e}")
559
- return last_mp3
560
- else:
561
- logging.warning("No output audio file generated")
562
- return None
563
-
564
- except Exception as e:
565
- logging.error(f"Inference error: {e}")
566
- raise
567
- finally:
568
- for path in [genre_txt_path, lyrics_txt_path]:
569
- if path and os.path.exists(path):
570
- try:
571
- os.remove(path)
572
- logging.debug(f"Removed temporary file: {path}")
573
- except Exception as e:
574
- logging.warning(f"Failed to remove temporary file {path}: {e}")
575
-
576
- def optimize_model_selection(lyrics, genre):
577
- model_path = detect_and_select_model(lyrics)
578
- params = calculate_generation_params(lyrics)
579
-
580
- has_chorus = params['sections']['chorus'] > 0
581
- tokens_per_segment = params['max_tokens'] // params['num_segments']
582
-
583
- model_config = {
584
- "m-a-p/YuE-s1-7B-anneal-en-cot": {
585
- "max_tokens": params['max_tokens'],
586
- "temperature": 0.8,
587
- "batch_size": 16,
588
- "num_segments": params['num_segments'],
589
- "estimated_duration": params['estimated_duration']
590
- },
591
- "m-a-p/YuE-s1-7B-anneal-jp-kr-cot": {
592
- "max_tokens": params['max_tokens'],
593
- "temperature": 0.7,
594
- "batch_size": 16,
595
- "num_segments": params['num_segments'],
596
- "estimated_duration": params['estimated_duration']
597
- },
598
- "m-a-p/YuE-s1-7B-anneal-zh-cot": {
599
- "max_tokens": params['max_tokens'],
600
- "temperature": 0.7,
601
- "batch_size": 16,
602
- "num_segments": params['num_segments'],
603
- "estimated_duration": params['estimated_duration']
604
- }
605
- }
606
-
607
- if has_chorus:
608
- for config in model_config.values():
609
- config['max_tokens'] = int(config['max_tokens'] * 1.5)
610
-
611
- return model_path, model_config[model_path], params
612
-
613
- css = """
614
- #main-container {
615
- max-width: 1200px;
616
- margin: auto;
617
- padding: 20px;
618
- }
619
- #header {
620
- text-align: center;
621
- margin-bottom: 30px;
622
- }
623
- #genre-input, #lyrics-input {
624
- border-radius: 8px;
625
- }
626
- #generate-btn {
627
- margin-top: 20px;
628
- min-height: 45px;
629
- }
630
- .label {
631
- font-weight: bold;
632
- }
633
- .example-container {
634
- background: #f8f9fa;
635
- padding: 15px;
636
- border-radius: 8px;
637
- margin: 10px 0;
638
- }
639
- """
640
-
641
- def main():
642
- with gr.Blocks(theme=gr.themes.Soft(
643
- primary_hue="indigo",
644
- secondary_hue="purple",
645
- neutral_hue="slate",
646
- font=["Arial", "sans-serif"]
647
- ), css=css) as demo:
648
- with gr.Column(elem_id="main-container"):
649
- # 헤더 섹션
650
- with gr.Row(elem_id="header"):
651
- gr.Markdown(
652
- """
653
- # 🎵 Open SUNO: Transform Your Lyrics into Complete Songs
654
- ### Create complete songs from your lyrics in multiple languages
655
- """,
656
- elem_id="title"
657
- )
658
-
659
- # 메인 컨텐츠를 탭으로 구성
660
- with gr.Tabs() as tabs:
661
- # 생성 탭
662
- with gr.TabItem("✨ Create Music", id="create"):
663
- with gr.Row():
664
- # 입력 섹션
665
- with gr.Column(scale=1):
666
- genre_txt = gr.Textbox(
667
- label="🎸 Music Genre & Style",
668
- placeholder="e.g., K-pop bright energetic synth dance electronic...",
669
- elem_id="genre-input"
670
- )
671
- lyrics_txt = gr.Textbox(
672
- label="📝 Lyrics",
673
- placeholder="Enter lyrics with section tags: [verse], [chorus], [bridge]...",
674
- lines=10,
675
- elem_id="lyrics-input"
676
- )
677
-
678
- # 정보 표��� 섹션
679
- with gr.Row():
680
- with gr.Column(scale=1):
681
- duration_info = gr.Label(
682
- label="⏱️ Estimated Duration",
683
- elem_id="duration-info"
684
- )
685
- with gr.Column(scale=1):
686
- sections_info = gr.Label(
687
- label="📊 Section Analysis",
688
- elem_id="sections-info"
689
- )
690
-
691
- # 생성 버튼
692
- submit_btn = gr.Button(
693
- "🎼 Generate Music",
694
- variant="primary",
695
- elem_id="generate-btn"
696
- )
697
-
698
- # 출력 섹션
699
- with gr.Column(scale=1):
700
- music_out = gr.Audio(
701
- label="🎵 Generated Music",
702
- elem_id="music-output"
703
- )
704
-
705
- # 진행 상태 표시
706
- progress = gr.Textbox(
707
- label="Generation Status",
708
- interactive=False,
709
- elem_id="progress-status"
710
- )
711
-
712
- # 히스토리 탭
713
- with gr.TabItem("📚 History", id="history"):
714
- with gr.Row():
715
- history_container = gr.HTML("""
716
- <div id="history-container" style="width: 100%; padding: 10px;">
717
- <h3>🎵 Generation History</h3>
718
- <div id="history-list"></div>
719
- </div>
720
- """)
721
-
722
- # 히스토리 상태 저장
723
- history_state = gr.State([])
724
-
725
- # 예제 섹션
726
- with gr.Accordion("📖 Examples", open=False):
727
- gr.Examples(
728
- examples=[
729
- [
730
- "female blues airy vocal bright vocal piano sad romantic guitar jazz",
731
- """[verse]
732
- In the quiet of the evening, shadows start to fall
733
- Whispers of the night wind echo through the hall
734
- Lost within the silence, I hear your gentle voice
735
- Guiding me back homeward, making my heart rejoice
736
-
737
- [chorus]
738
- Don't let this moment fade, hold me close tonight
739
- """
740
- ],
741
- [
742
- "K-pop bright energetic synth dance electronic",
743
- """[verse]
744
- 언젠가 마주한 눈빛 속에서
745
- 어두운 밤을 지날 때마다
746
-
747
- [chorus]
748
- 다시 한 번 내게 말해줘
749
- """
750
- ]
751
- ],
752
- inputs=[genre_txt, lyrics_txt]
753
- )
754
-
755
- # 도움말 및 설명 섹션
756
- with gr.Accordion("ℹ️ Help & Information", open=False):
757
- gr.Markdown(
758
- """
759
- ### 🎵 How to Use
760
- 1. **Enter Genre & Style**: Describe the musical style you want (e.g., "K-pop", "Jazz", "Rock")
761
- 2. **Input Lyrics**: Write your lyrics using section tags:
762
- - Use `[verse]` for verses
763
- - Use `[chorus]` for choruses
764
- - Use `[bridge]` for bridges
765
- 3. **Generate**: Click the Generate button and wait for your music!
766
-
767
- ### 🌏 Supported Languages
768
- - English
769
- - Korean (한국어)
770
- - Japanese (日本語)
771
- - Chinese (中文)
772
-
773
- ### ⚡ Tips
774
- - Be specific with your genre descriptions
775
- - Include emotion and instrument preferences
776
- - Make sure to properly tag your lyrics sections
777
- - For best results, include both verse and chorus sections
778
- """
779
- )
780
-
781
- # 숨겨진 상태 변수들
782
- num_segments = gr.State(value=2) # 기본값 2
783
- max_new_tokens = gr.State(value=4000) # 기본값 4000
784
-
785
- # 시스템 초기화
786
- initialize_system()
787
-
788
- def update_info(lyrics):
789
- if not lyrics:
790
- return "No lyrics entered", "No sections detected"
791
- params = calculate_generation_params(lyrics)
792
- duration = params['estimated_duration']
793
- sections = params['sections']
794
- return (
795
- f"⏱️ Estimated: {duration:.1f} seconds",
796
- f"📊 Verses: {sections['verse']}, Chorus: {sections['chorus']}"
797
- )
798
-
799
-
800
-
801
- def update_history(audio_path, genre, lyrics, history):
802
- if audio_path:
803
- new_entry = {
804
- "audio": audio_path,
805
- "genre": genre,
806
- "lyrics": lyrics,
807
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") # datetime 직접 사용
808
- }
809
- history = [new_entry] + (history or [])
810
-
811
- history_html = "<div class='history-entries'>"
812
- for entry in history:
813
- history_html += f"""
814
- <div class='history-entry' style='margin: 10px 0; padding: 10px; border: 1px solid #ddd; border-radius: 8px;'>
815
- <audio controls src='{entry["audio"]}'></audio>
816
- <div style='margin-top: 5px;'><strong>Genre:</strong> {entry["genre"]}</div>
817
- <div style='margin-top: 5px;'><strong>Lyrics:</strong><pre>{entry["lyrics"]}</pre></div>
818
- <div style='color: #666; font-size: 0.9em;'>{entry["timestamp"]}</div>
819
- </div>
820
- """
821
- history_html += "</div>"
822
-
823
- return history, history_html
824
- return history, ""
825
-
826
- def generate_with_progress(genre, lyrics, segments, tokens, history):
827
- try:
828
- status_text = "🎵 Starting generation..."
829
- result = infer(genre, lyrics, segments, tokens)
830
-
831
- if result:
832
- status_text = "✅ Generation complete!"
833
- new_history, history_html = update_history(result, genre, lyrics, history)
834
- return result, new_history, history_html, status_text
835
- else:
836
- status_text = "❌ Generation failed"
837
- return None, history, "", status_text
838
- except Exception as e:
839
- status_text = f"❌ Error: {str(e)}"
840
- return None, history, "", status_text
841
- # 이벤트 핸들러
842
- lyrics_txt.change(
843
- fn=update_info,
844
- inputs=[lyrics_txt],
845
- outputs=[duration_info, sections_info]
846
- )
847
-
848
- # submit 버튼 클릭 이벤트 추가
849
- submit_btn.click(
850
- fn=generate_with_progress,
851
- inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens, history_state],
852
- outputs=[music_out, history_state, history_container, progress]
853
- )
854
-
855
- return demo
856
-
857
-
858
- if __name__ == "__main__":
859
- demo = main()
860
- demo.queue(max_size=20).launch(
861
- server_name="0.0.0.0",
862
- server_port=7860,
863
- share=True,
864
- show_api=True,
865
- show_error=True,
866
- max_threads=8
867
- )