ginipick commited on
Commit
4fb8e24
ยท
verified ยท
1 Parent(s): 1d7d926

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -26
app.py CHANGED
@@ -77,33 +77,26 @@ def calculate_generation_params(lyrics):
77
  }
78
 
79
  total_duration = sum(section_durations.values())
 
80
 
81
- # ์ตœ์†Œ ์ง€์† ์‹œ๊ฐ„ ๋ณด์žฅ (90์ดˆ)
82
- total_duration = max(90, total_duration)
 
83
 
84
- # ํ† ํฐ ๊ณ„์‚ฐ (1์ดˆ๋‹น ์•ฝ 100ํ† ํฐ์œผ๋กœ ์ฆ๊ฐ€)
85
- tokens_per_second = 100
86
- base_tokens = int(total_duration * tokens_per_second)
87
-
88
- # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ์ถ”๊ฐ€ ํ† ํฐ ํ• ๋‹น
89
- if sections['chorus'] > 0:
90
- chorus_tokens = int(section_durations['chorus'] * tokens_per_second * 1.5)
91
- total_tokens = base_tokens + chorus_tokens
92
- else:
93
- total_tokens = base_tokens
94
 
95
  # ์„น์…˜ ๊ธฐ๋ฐ˜ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ๊ณ„์‚ฐ
96
  if sections['chorus'] > 0:
97
- num_segments = max(3, sections['verse'] + sections['chorus'])
98
  else:
99
- num_segments = max(2, total_sections)
100
 
101
- # ํ† ํฐ ์ˆ˜ ์ œํ•œ (์ตœ์†Œ 8000ํ† ํฐ ๋ณด์žฅ)
102
- max_tokens = min(32000, max(8000, total_tokens))
103
 
104
  return {
105
  'max_tokens': max_tokens,
106
- 'num_segments': min(4, num_segments), # ์ตœ๋Œ€ 4๊ฐœ ์„ธ๊ทธ๋จผํŠธ๋กœ ์ œํ•œ
107
  'sections': sections,
108
  'section_lines': section_lines,
109
  'estimated_duration': total_duration,
@@ -294,14 +287,17 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
294
  # ์ฝ”๋Ÿฌ์Šค ์„น์…˜ ํ™•์ธ ๋ฐ ๋กœ๊น…
295
  has_chorus = params['sections']['chorus'] > 0
296
  estimated_duration = params.get('estimated_duration', 90)
297
-
 
298
  # ํ† ํฐ ์ˆ˜์™€ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ์กฐ์ •
299
  if has_chorus:
300
- actual_max_tokens = int(config['max_tokens'] * 1.5) # 50% ๋” ๋งŽ์€ ํ† ํฐ
301
- actual_num_segments = max(3, config['num_segments']) # ์ตœ์†Œ 3๊ฐœ ์„ธ๊ทธ๋จผํŠธ
302
  else:
303
  actual_max_tokens = config['max_tokens']
304
- actual_num_segments = config['num_segments']
 
 
305
 
306
  logging.info(f"Estimated duration: {estimated_duration} seconds")
307
  logging.info(f"Has chorus sections: {has_chorus}")
@@ -314,7 +310,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
314
  output_dir = "./output"
315
  os.makedirs(output_dir, exist_ok=True)
316
  empty_output_folder(output_dir)
317
-
318
  # ๊ธฐ๋ณธ ๋ช…๋ น์–ด ๊ตฌ์„ฑ
319
  command = [
320
  "python", "infer.py",
@@ -323,7 +318,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
323
  "--genre_txt", genre_txt_path,
324
  "--lyrics_txt", lyrics_txt_path,
325
  "--run_n_segments", str(actual_num_segments),
326
- "--stage2_batch_size", str(config['batch_size']),
327
  "--output_dir", output_dir,
328
  "--cuda_idx", "0",
329
  "--max_new_tokens", str(actual_max_tokens)
@@ -331,9 +326,9 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
331
 
332
  # GPU ์„ค์ •
333
  if torch.cuda.is_available():
334
- command.extend([
335
- "--disable_offload_model"
336
- ])
337
 
338
  # CUDA ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
339
  env = os.environ.copy()
 
77
  }
78
 
79
  total_duration = sum(section_durations.values())
80
+ total_duration = max(60, total_duration) # ์ตœ์†Œ 60์ดˆ
81
 
82
+ # ํ† ํฐ ๊ณ„์‚ฐ (๋” ๋ณด์ˆ˜์ ์ธ ๊ฐ’ ์‚ฌ์šฉ)
83
+ base_tokens = 3000 # ๊ธฐ๋ณธ ํ† ํฐ ์ˆ˜
84
+ tokens_per_line = 200 # ์ค„๋‹น ํ† ํฐ ์ˆ˜
85
 
86
+ total_tokens = base_tokens + (total_lines * tokens_per_line)
 
 
 
 
 
 
 
 
 
87
 
88
  # ์„น์…˜ ๊ธฐ๋ฐ˜ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ๊ณ„์‚ฐ
89
  if sections['chorus'] > 0:
90
+ num_segments = 3 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ 3๊ฐœ ์„ธ๊ทธ๋จผํŠธ
91
  else:
92
+ num_segments = 2 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ 2๊ฐœ ์„ธ๊ทธ๋จผํŠธ
93
 
94
+ # ํ† ํฐ ์ˆ˜ ์ œํ•œ
95
+ max_tokens = min(8000, total_tokens) # ์ตœ๋Œ€ 8000 ํ† ํฐ์œผ๋กœ ์ œํ•œ
96
 
97
  return {
98
  'max_tokens': max_tokens,
99
+ 'num_segments': num_segments,
100
  'sections': sections,
101
  'section_lines': section_lines,
102
  'estimated_duration': total_duration,
 
287
  # ์ฝ”๋Ÿฌ์Šค ์„น์…˜ ํ™•์ธ ๋ฐ ๋กœ๊น…
288
  has_chorus = params['sections']['chorus'] > 0
289
  estimated_duration = params.get('estimated_duration', 90)
290
+
291
+
292
  # ํ† ํฐ ์ˆ˜์™€ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ์กฐ์ •
293
  if has_chorus:
294
+ actual_max_tokens = min(8000, int(config['max_tokens'] * 1.2)) # 20% ์ฆ๊ฐ€, ์ตœ๋Œ€ 8000
295
+ actual_num_segments = 3
296
  else:
297
  actual_max_tokens = config['max_tokens']
298
+ actual_num_segments = 2
299
+
300
+
301
 
302
  logging.info(f"Estimated duration: {estimated_duration} seconds")
303
  logging.info(f"Has chorus sections: {has_chorus}")
 
310
  output_dir = "./output"
311
  os.makedirs(output_dir, exist_ok=True)
312
  empty_output_folder(output_dir)
 
313
  # ๊ธฐ๋ณธ ๋ช…๋ น์–ด ๊ตฌ์„ฑ
314
  command = [
315
  "python", "infer.py",
 
318
  "--genre_txt", genre_txt_path,
319
  "--lyrics_txt", lyrics_txt_path,
320
  "--run_n_segments", str(actual_num_segments),
321
+ "--stage2_batch_size", "4", # ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ ๊ฐ์†Œ
322
  "--output_dir", output_dir,
323
  "--cuda_idx", "0",
324
  "--max_new_tokens", str(actual_max_tokens)
 
326
 
327
  # GPU ์„ค์ •
328
  if torch.cuda.is_available():
329
+ command.append("--disable_offload_model")
330
+ # GPU ์„ค์ •
331
+
332
 
333
  # CUDA ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
334
  env = os.environ.copy()