ginipick commited on
Commit
3469b26
ยท
verified ยท
1 Parent(s): 2cb4fdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -30
app.py CHANGED
@@ -65,40 +65,50 @@ def calculate_generation_params(lyrics):
65
  # ๊ธฐ๋ณธ ์‹œ๊ฐ„ ๊ณ„์‚ฐ (์ดˆ ๋‹จ์œ„)
66
  time_per_line = {
67
  'verse': 4, # verse๋Š” ํ•œ ์ค„๋‹น 4์ดˆ
68
- 'chorus': 6, # chorus๋Š” ํ•œ ์ค„๋‹น 6์ดˆ (๋” ๊ธด ์‹œ๊ฐ„ ํ• ๋‹น)
69
  'bridge': 5 # bridge๋Š” ํ•œ ์ค„๋‹น 5์ดˆ
70
  }
71
 
72
  # ๊ฐ ์„น์…˜๋ณ„ ์˜ˆ์ƒ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
73
- total_duration = 0
74
- for section_type, lines in section_lines.items():
75
- total_duration += lines * time_per_line[section_type]
 
 
 
 
76
 
77
- # ์ตœ์†Œ ์ง€์† ์‹œ๊ฐ„ ๋ณด์žฅ (60์ดˆ)
78
- total_duration = max(60, total_duration)
79
 
80
- # ํ† ํฐ ๊ณ„์‚ฐ (1์ดˆ๋‹น ์•ฝ 50ํ† ํฐ์œผ๋กœ ๊ณ„์‚ฐ)
81
- tokens_per_second = 50
82
- total_tokens = int(total_duration * tokens_per_second)
 
 
 
 
 
 
 
83
 
84
  # ์„น์…˜ ๊ธฐ๋ฐ˜ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ๊ณ„์‚ฐ
85
- if total_duration > 180: # 3๋ถ„ ์ด์ƒ
86
- num_segments = 4
87
- elif total_duration > 120: # 2๋ถ„ ์ด์ƒ
88
- num_segments = 3
89
- else: # 2๋ถ„ ๋ฏธ๋งŒ
90
- num_segments = 2
91
 
92
- # ํ† ํฐ ์ˆ˜ ์ œํ•œ (์ตœ์†Œ 6000ํ† ํฐ ๋ณด์žฅ)
93
- max_tokens = min(32000, max(6000, total_tokens))
94
 
95
  return {
96
  'max_tokens': max_tokens,
97
- 'num_segments': num_segments,
98
  'sections': sections,
99
  'section_lines': section_lines,
100
  'estimated_duration': total_duration,
101
- 'tokens_per_segment': max_tokens // num_segments
 
102
  }
103
 
104
  def get_audio_duration(file_path):
@@ -277,20 +287,23 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
277
  logging.info(f"Lyrics analysis: {params}")
278
 
279
  # ์ฝ”๋Ÿฌ์Šค ์„น์…˜ ํ™•์ธ ๋ฐ ๋กœ๊น…
280
- has_chorus = params['sections']['chorus'] > 0
281
- estimated_duration = params.get('estimated_duration', 60)
282
 
283
- # ํ† ํฐ ์ˆ˜ ์กฐ์ • (์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ๋” ๋งŽ์€ ํ† ํฐ ํ• ๋‹น)
284
  if has_chorus:
285
- actual_max_tokens = int(config['max_tokens'] * 1.5) # 50% ๋” ๋งŽ์€ ํ† ํฐ
286
- actual_num_segments = max(3, config['num_segments']) # ์ตœ์†Œ 3๊ฐœ ์„ธ๊ทธ๋จผํŠธ ๋ณด์žฅ
 
287
  else:
288
- actual_max_tokens = config['max_tokens']
289
- actual_num_segments = config['num_segments']
 
290
 
291
  logging.info(f"Estimated duration: {estimated_duration} seconds")
292
  logging.info(f"Has chorus sections: {has_chorus}")
293
  logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
 
294
 
295
  # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
296
  genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
@@ -314,9 +327,12 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
314
  "--max_new_tokens", str(actual_max_tokens)
315
  ]
316
 
317
- # GPU๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ์—๋งŒ ์ถ”๊ฐ€ ์˜ต์…˜ ์ ์šฉ
318
  if torch.cuda.is_available():
319
- command.append("--disable_offload_model")
 
 
 
320
 
321
  # CUDA ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
322
  env = os.environ.copy()
@@ -326,7 +342,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
326
  "CUDA_HOME": "/usr/local/cuda",
327
  "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
328
  "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
329
- "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512"
330
  })
331
 
332
  # transformers ์บ์‹œ ๋งˆ์ด๊ทธ๋ ˆ์ด์…˜ ์ฒ˜๋ฆฌ
@@ -366,7 +382,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
366
  logging.info(f"Expected duration: {estimated_duration} seconds")
367
 
368
  # ์ƒ์„ฑ๋œ ์Œ์•…์ด ๋„ˆ๋ฌด ์งง์€ ๊ฒฝ์šฐ ๊ฒฝ๊ณ 
369
- if duration < estimated_duration * 0.8: # ์˜ˆ์ƒ ๊ธธ์ด์˜ 80% ๋ฏธ๋งŒ์ธ ๊ฒฝ์šฐ
370
  logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s")
371
  except Exception as e:
372
  logging.warning(f"Failed to get audio duration: {e}")
 
65
  # ๊ธฐ๋ณธ ์‹œ๊ฐ„ ๊ณ„์‚ฐ (์ดˆ ๋‹จ์œ„)
66
  time_per_line = {
67
  'verse': 4, # verse๋Š” ํ•œ ์ค„๋‹น 4์ดˆ
68
+ 'chorus': 6, # chorus๋Š” ํ•œ ์ค„๋‹น 6์ดˆ
69
  'bridge': 5 # bridge๋Š” ํ•œ ์ค„๋‹น 5์ดˆ
70
  }
71
 
72
  # ๊ฐ ์„น์…˜๋ณ„ ์˜ˆ์ƒ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
73
+ section_durations = {
74
+ 'verse': section_lines['verse'] * time_per_line['verse'],
75
+ 'chorus': section_lines['chorus'] * time_per_line['chorus'],
76
+ 'bridge': section_lines['bridge'] * time_per_line['bridge']
77
+ }
78
+
79
+ total_duration = sum(section_durations.values())
80
 
81
+ # ์ตœ์†Œ ์ง€์† ์‹œ๊ฐ„ ๋ณด์žฅ (90์ดˆ)
82
+ total_duration = max(90, total_duration)
83
 
84
+ # ํ† ํฐ ๊ณ„์‚ฐ (1์ดˆ๋‹น ์•ฝ 100ํ† ํฐ์œผ๋กœ ์ฆ๊ฐ€)
85
+ tokens_per_second = 100
86
+ base_tokens = int(total_duration * tokens_per_second)
87
+
88
+ # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ์ถ”๊ฐ€ ํ† ํฐ ํ• ๋‹น
89
+ if sections['chorus'] > 0:
90
+ chorus_tokens = int(section_durations['chorus'] * tokens_per_second * 1.5)
91
+ total_tokens = base_tokens + chorus_tokens
92
+ else:
93
+ total_tokens = base_tokens
94
 
95
  # ์„น์…˜ ๊ธฐ๋ฐ˜ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ๊ณ„์‚ฐ
96
+ if sections['chorus'] > 0:
97
+ num_segments = max(3, sections['verse'] + sections['chorus'])
98
+ else:
99
+ num_segments = max(2, total_sections)
 
 
100
 
101
+ # ํ† ํฐ ์ˆ˜ ์ œํ•œ (์ตœ์†Œ 8000ํ† ํฐ ๋ณด์žฅ)
102
+ max_tokens = min(32000, max(8000, total_tokens))
103
 
104
  return {
105
  'max_tokens': max_tokens,
106
+ 'num_segments': min(4, num_segments), # ์ตœ๋Œ€ 4๊ฐœ ์„ธ๊ทธ๋จผํŠธ๋กœ ์ œํ•œ
107
  'sections': sections,
108
  'section_lines': section_lines,
109
  'estimated_duration': total_duration,
110
+ 'section_durations': section_durations,
111
+ 'has_chorus': sections['chorus'] > 0
112
  }
113
 
114
  def get_audio_duration(file_path):
 
287
  logging.info(f"Lyrics analysis: {params}")
288
 
289
  # ์ฝ”๋Ÿฌ์Šค ์„น์…˜ ํ™•์ธ ๋ฐ ๋กœ๊น…
290
+ has_chorus = params['has_chorus']
291
+ estimated_duration = params.get('estimated_duration', 90)
292
 
293
+ # ํ† ํฐ ์ˆ˜์™€ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ์กฐ์ •
294
  if has_chorus:
295
+ actual_max_tokens = int(params['max_tokens'] * 1.5) # 50% ๋” ๋งŽ์€ ํ† ํฐ
296
+ actual_num_segments = max(3, params['num_segments']) # ์ตœ์†Œ 3๊ฐœ ์„ธ๊ทธ๋จผํŠธ
297
+ tokens_per_segment = actual_max_tokens // actual_num_segments
298
  else:
299
+ actual_max_tokens = params['max_tokens']
300
+ actual_num_segments = params['num_segments']
301
+ tokens_per_segment = actual_max_tokens // actual_num_segments
302
 
303
  logging.info(f"Estimated duration: {estimated_duration} seconds")
304
  logging.info(f"Has chorus sections: {has_chorus}")
305
  logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
306
+ logging.info(f"Tokens per segment: {tokens_per_segment}")
307
 
308
  # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
309
  genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
 
327
  "--max_new_tokens", str(actual_max_tokens)
328
  ]
329
 
330
+ # GPU ์„ค์ •
331
  if torch.cuda.is_available():
332
+ command.extend([
333
+ "--disable_offload_model",
334
+ "--use_bf16" # ๋” ๋น ๋ฅธ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ BF16 ์‚ฌ์šฉ
335
+ ])
336
 
337
  # CUDA ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
338
  env = os.environ.copy()
 
342
  "CUDA_HOME": "/usr/local/cuda",
343
  "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
344
  "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
345
+ "PYTORCH_CUDA_ALLOC_CONF": f"max_split_size_mb:512"
346
  })
347
 
348
  # transformers ์บ์‹œ ๋งˆ์ด๊ทธ๋ ˆ์ด์…˜ ์ฒ˜๋ฆฌ
 
382
  logging.info(f"Expected duration: {estimated_duration} seconds")
383
 
384
  # ์ƒ์„ฑ๋œ ์Œ์•…์ด ๋„ˆ๋ฌด ์งง์€ ๊ฒฝ์šฐ ๊ฒฝ๊ณ 
385
+ if duration < estimated_duration * 0.8:
386
  logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s")
387
  except Exception as e:
388
  logging.warning(f"Failed to get audio duration: {e}")