Update app.py
Browse files
app.py
CHANGED
@@ -131,46 +131,7 @@ def detect_and_select_model(text):
|
|
131 |
else: # ์์ด/๊ธฐํ
|
132 |
return "m-a-p/YuE-s1-7B-anneal-en-cot"
|
133 |
|
134 |
-
|
135 |
-
model_path = detect_and_select_model(lyrics)
|
136 |
-
params = calculate_generation_params(lyrics)
|
137 |
-
|
138 |
-
# ์ฝ๋ฌ์ค ์กด์ฌ ์ฌ๋ถ์ ๋ฐ๋ฅธ ์ค์ ์กฐ์
|
139 |
-
has_chorus = params['sections']['chorus'] > 0
|
140 |
-
|
141 |
-
model_config = {
|
142 |
-
"m-a-p/YuE-s1-7B-anneal-en-cot": {
|
143 |
-
"max_tokens": params['max_tokens'],
|
144 |
-
"temperature": 0.8,
|
145 |
-
"batch_size": 8,
|
146 |
-
"num_segments": params['num_segments'],
|
147 |
-
"tokens_per_segment": params['tokens_per_segment'],
|
148 |
-
"estimated_duration": params['estimated_duration']
|
149 |
-
},
|
150 |
-
"m-a-p/YuE-s1-7B-anneal-jp-kr-cot": {
|
151 |
-
"max_tokens": params['max_tokens'],
|
152 |
-
"temperature": 0.7,
|
153 |
-
"batch_size": 8,
|
154 |
-
"num_segments": params['num_segments'],
|
155 |
-
"tokens_per_segment": params['tokens_per_segment'],
|
156 |
-
"estimated_duration": params['estimated_duration']
|
157 |
-
},
|
158 |
-
"m-a-p/YuE-s1-7B-anneal-zh-cot": {
|
159 |
-
"max_tokens": params['max_tokens'],
|
160 |
-
"temperature": 0.7,
|
161 |
-
"batch_size": 8,
|
162 |
-
"num_segments": params['num_segments'],
|
163 |
-
"tokens_per_segment": params['tokens_per_segment'],
|
164 |
-
"estimated_duration": params['estimated_duration']
|
165 |
-
}
|
166 |
-
}
|
167 |
-
|
168 |
-
# ์ฝ๋ฌ์ค๊ฐ ์๋ ๊ฒฝ์ฐ ํ ํฐ ์ ์ฆ๊ฐ
|
169 |
-
if has_chorus:
|
170 |
-
for config in model_config.values():
|
171 |
-
config['max_tokens'] = int(config['max_tokens'] * 1.5) # 50% ๋ ๋ง์ ํ ํฐ ํ ๋น
|
172 |
-
|
173 |
-
return model_path, model_config[model_path], params
|
174 |
|
175 |
# GPU ์ค์ ์ต์ ํ
|
176 |
def optimize_gpu_settings():
|
@@ -279,7 +240,51 @@ def get_last_mp3_file(output_dir):
|
|
279 |
mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
|
280 |
return mp3_files_with_path[0]
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
|
|
|
|
|
|
|
283 |
try:
|
284 |
# ๋ชจ๋ธ ์ ํ ๋ฐ ์ค์
|
285 |
model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content)
|
@@ -287,23 +292,20 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
|
|
287 |
logging.info(f"Lyrics analysis: {params}")
|
288 |
|
289 |
# ์ฝ๋ฌ์ค ์น์
ํ์ธ ๋ฐ ๋ก๊น
|
290 |
-
has_chorus = params['
|
291 |
estimated_duration = params.get('estimated_duration', 90)
|
292 |
|
293 |
# ํ ํฐ ์์ ์ธ๊ทธ๋จผํธ ์ ์กฐ์
|
294 |
if has_chorus:
|
295 |
-
actual_max_tokens = int(
|
296 |
-
actual_num_segments = max(3,
|
297 |
-
tokens_per_segment = actual_max_tokens // actual_num_segments
|
298 |
else:
|
299 |
-
actual_max_tokens =
|
300 |
-
actual_num_segments =
|
301 |
-
tokens_per_segment = actual_max_tokens // actual_num_segments
|
302 |
|
303 |
logging.info(f"Estimated duration: {estimated_duration} seconds")
|
304 |
logging.info(f"Has chorus sections: {has_chorus}")
|
305 |
logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
|
306 |
-
logging.info(f"Tokens per segment: {tokens_per_segment}")
|
307 |
|
308 |
# ์์ ํ์ผ ์์ฑ
|
309 |
genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
|
@@ -330,8 +332,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
|
|
330 |
# GPU ์ค์
|
331 |
if torch.cuda.is_available():
|
332 |
command.extend([
|
333 |
-
"--disable_offload_model"
|
334 |
-
"--use_bf16" # ๋ ๋น ๋ฅธ ์ฒ๋ฆฌ๋ฅผ ์ํ BF16 ์ฌ์ฉ
|
335 |
])
|
336 |
|
337 |
# CUDA ํ๊ฒฝ ๋ณ์ ์ค์
|
@@ -396,12 +397,19 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
|
|
396 |
raise
|
397 |
finally:
|
398 |
# ์์ ํ์ผ ์ ๋ฆฌ
|
399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
try:
|
401 |
-
os.remove(
|
402 |
-
logging.debug(f"Removed temporary file: {
|
403 |
except Exception as e:
|
404 |
-
logging.warning(f"Failed to remove temporary file {
|
405 |
|
406 |
def main():
|
407 |
# Gradio ์ธํฐํ์ด์ค
|
|
|
131 |
else: # ์์ด/๊ธฐํ
|
132 |
return "m-a-p/YuE-s1-7B-anneal-en-cot"
|
133 |
|
134 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# GPU ์ค์ ์ต์ ํ
|
137 |
def optimize_gpu_settings():
|
|
|
240 |
mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
|
241 |
return mp3_files_with_path[0]
|
242 |
|
243 |
+
def optimize_model_selection(lyrics, genre):
|
244 |
+
model_path = detect_and_select_model(lyrics)
|
245 |
+
params = calculate_generation_params(lyrics)
|
246 |
+
|
247 |
+
# ์ฝ๋ฌ์ค ์กด์ฌ ์ฌ๋ถ์ ๋ฐ๋ฅธ ์ค์ ์กฐ์
|
248 |
+
has_chorus = params['sections']['chorus'] > 0
|
249 |
+
|
250 |
+
# ํ ํฐ ์ ๊ณ์ฐ
|
251 |
+
tokens_per_segment = params['max_tokens'] // params['num_segments']
|
252 |
+
|
253 |
+
model_config = {
|
254 |
+
"m-a-p/YuE-s1-7B-anneal-en-cot": {
|
255 |
+
"max_tokens": params['max_tokens'],
|
256 |
+
"temperature": 0.8,
|
257 |
+
"batch_size": 8,
|
258 |
+
"num_segments": params['num_segments'],
|
259 |
+
"estimated_duration": params['estimated_duration']
|
260 |
+
},
|
261 |
+
"m-a-p/YuE-s1-7B-anneal-jp-kr-cot": {
|
262 |
+
"max_tokens": params['max_tokens'],
|
263 |
+
"temperature": 0.7,
|
264 |
+
"batch_size": 8,
|
265 |
+
"num_segments": params['num_segments'],
|
266 |
+
"estimated_duration": params['estimated_duration']
|
267 |
+
},
|
268 |
+
"m-a-p/YuE-s1-7B-anneal-zh-cot": {
|
269 |
+
"max_tokens": params['max_tokens'],
|
270 |
+
"temperature": 0.7,
|
271 |
+
"batch_size": 8,
|
272 |
+
"num_segments": params['num_segments'],
|
273 |
+
"estimated_duration": params['estimated_duration']
|
274 |
+
}
|
275 |
+
}
|
276 |
+
|
277 |
+
# ์ฝ๋ฌ์ค๊ฐ ์๋ ๊ฒฝ์ฐ ํ ํฐ ์ ์ฆ๊ฐ
|
278 |
+
if has_chorus:
|
279 |
+
for config in model_config.values():
|
280 |
+
config['max_tokens'] = int(config['max_tokens'] * 1.5) # 50% ๋ ๋ง์ ํ ํฐ ํ ๋น
|
281 |
+
|
282 |
+
return model_path, model_config[model_path], params
|
283 |
+
|
284 |
def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
|
285 |
+
genre_txt_path = None
|
286 |
+
lyrics_txt_path = None
|
287 |
+
|
288 |
try:
|
289 |
# ๋ชจ๋ธ ์ ํ ๋ฐ ์ค์
|
290 |
model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content)
|
|
|
292 |
logging.info(f"Lyrics analysis: {params}")
|
293 |
|
294 |
# ์ฝ๋ฌ์ค ์น์
ํ์ธ ๋ฐ ๋ก๊น
|
295 |
+
has_chorus = params['sections']['chorus'] > 0
|
296 |
estimated_duration = params.get('estimated_duration', 90)
|
297 |
|
298 |
# ํ ํฐ ์์ ์ธ๊ทธ๋จผํธ ์ ์กฐ์
|
299 |
if has_chorus:
|
300 |
+
actual_max_tokens = int(config['max_tokens'] * 1.5) # 50% ๋ ๋ง์ ํ ํฐ
|
301 |
+
actual_num_segments = max(3, config['num_segments']) # ์ต์ 3๊ฐ ์ธ๊ทธ๋จผํธ
|
|
|
302 |
else:
|
303 |
+
actual_max_tokens = config['max_tokens']
|
304 |
+
actual_num_segments = config['num_segments']
|
|
|
305 |
|
306 |
logging.info(f"Estimated duration: {estimated_duration} seconds")
|
307 |
logging.info(f"Has chorus sections: {has_chorus}")
|
308 |
logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
|
|
|
309 |
|
310 |
# ์์ ํ์ผ ์์ฑ
|
311 |
genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
|
|
|
332 |
# GPU ์ค์
|
333 |
if torch.cuda.is_available():
|
334 |
command.extend([
|
335 |
+
"--disable_offload_model"
|
|
|
336 |
])
|
337 |
|
338 |
# CUDA ํ๊ฒฝ ๋ณ์ ์ค์
|
|
|
397 |
raise
|
398 |
finally:
|
399 |
# ์์ ํ์ผ ์ ๋ฆฌ
|
400 |
+
if genre_txt_path and os.path.exists(genre_txt_path):
|
401 |
+
try:
|
402 |
+
os.remove(genre_txt_path)
|
403 |
+
logging.debug(f"Removed temporary file: {genre_txt_path}")
|
404 |
+
except Exception as e:
|
405 |
+
logging.warning(f"Failed to remove temporary file {genre_txt_path}: {e}")
|
406 |
+
|
407 |
+
if lyrics_txt_path and os.path.exists(lyrics_txt_path):
|
408 |
try:
|
409 |
+
os.remove(lyrics_txt_path)
|
410 |
+
logging.debug(f"Removed temporary file: {lyrics_txt_path}")
|
411 |
except Exception as e:
|
412 |
+
logging.warning(f"Failed to remove temporary file {lyrics_txt_path}: {e}")
|
413 |
|
414 |
def main():
|
415 |
# Gradio ์ธํฐํ์ด์ค
|