seawolf2357 commited on
Commit
4bf30b7
ยท
verified ยท
1 Parent(s): 0889c6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -30,17 +30,13 @@ SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
30
  ##############################################################################
31
  def extract_keywords(text: str, top_k: int = 5) -> str:
32
  """
33
- 1) ํ•œ๊ธ€, ์˜์–ด, ์ˆซ์ž, ๊ณต๋ฐฑ๋งŒ ๋‚จ๊ธฐ๋„๋ก ์ •๊ทœ์‹ ๋ณ€๊ฒฝ
34
  2) ๊ณต๋ฐฑ ๊ธฐ์ค€ ํ† ํฐ ๋ถ„๋ฆฌ
35
  3) ์ตœ๋Œ€ top_k๊ฐœ๋งŒ
36
  """
37
- # ํ•œ๊ธ€(๊ฐ€-ํžฃ)+์˜์–ด๋Œ€์†Œ๋ฌธ์ž+์ˆซ์ž+๊ณต๋ฐฑ๋งŒ ๋ณด์กด
38
  text = re.sub(r"[^a-zA-Z0-9๊ฐ€-ํžฃ\s]", "", text)
39
- # ํ† ํฐ ๋ถ„๋ฆฌ
40
  tokens = text.split()
41
- # ์ตœ๋Œ€ top_k๊ฐœ ์ถ”์ถœ
42
  key_tokens = tokens[:top_k]
43
- # ๋‹ค์‹œ ํ•ฉ์นจ
44
  return " ".join(key_tokens)
45
 
46
  ##############################################################################
@@ -74,7 +70,6 @@ def do_web_search(query: str) -> str:
74
 
75
  summary_lines = []
76
  for idx, item in enumerate(organic[:20], start=1):
77
- # item ์ „์ฒด๋ฅผ JSON ๋ฌธ์ž์—ด๋กœ
78
  item_json = json.dumps(item, ensure_ascii=False, indent=2)
79
  summary_lines.append(f"Result {idx}:\n{item_json}\n")
80
 
@@ -89,6 +84,7 @@ def do_web_search(query: str) -> str:
89
  ##############################################################################
90
  MAX_CONTENT_CHARS = 4000
91
  model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
 
92
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
93
  model = Gemma3ForConditionalGeneration.from_pretrained(
94
  model_id,
@@ -390,47 +386,36 @@ def run(
390
  return
391
 
392
  try:
393
- # (1) system ๋ฉ”์‹œ์ง€๋ฅผ ํ•˜๋‚˜๋กœ ํ•ฉ์น˜๊ธฐ ์œ„ํ•ด, ๋ฏธ๋ฆฌ buffer
394
  combined_system_msg = ""
395
 
396
- # ์‚ฌ์šฉ์ž๊ฐ€ system_prompt๋ฅผ ์ž…๋ ฅํ–ˆ๋‹ค๋ฉด
397
  if system_prompt.strip():
398
  combined_system_msg += f"[System Prompt]\n{system_prompt.strip()}\n\n"
399
 
400
- # (2) ์›น ๊ฒ€์ƒ‰ ์ฒดํฌ ์‹œ, ํ‚ค์›Œ๋“œ ์ถ”์ถœ
401
  if use_web_search:
402
  user_text = message["text"]
403
  ws_query = extract_keywords(user_text, top_k=5)
404
- # ๋งŒ์•ฝ ์ถ”์ถœ ํ‚ค์›Œ๋“œ๊ฐ€ ๋น„์–ด์žˆ์œผ๋ฉด ๊ฒ€์ƒ‰์„ ๊ฑด๋„ˆ๋œ€
405
  if ws_query.strip():
406
  logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
407
  ws_result = do_web_search(ws_query)
408
- # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€ ๋์— ํ•ฉ์นจ
409
  combined_system_msg += f"[Search top-20 Full Items Based on user prompt]\n{ws_result}\n\n"
410
  else:
411
- # ์ถ”์ถœ๋œ ํ‚ค์›Œ๋“œ๊ฐ€ ์—†์œผ๋ฉด ๊ตณ์ด ๊ฒ€์ƒ‰ ์‹œ๋„ ์•ˆ ํ•จ
412
  combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
413
 
414
- # (3) system ๋ฉ”์‹œ์ง€๊ฐ€ ์ตœ์ข…์ ์œผ๋กœ ๋น„์–ด ์žˆ์ง€ ์•Š๋‹ค๋ฉด
415
  messages = []
416
  if combined_system_msg.strip():
417
- # system ์—ญํ•  ๋ฉ”์‹œ์ง€ ํ•˜๋‚˜ ์ƒ์„ฑ
418
  messages.append({
419
  "role": "system",
420
  "content": [{"type": "text", "text": combined_system_msg.strip()}],
421
  })
422
 
423
- # (4) ์ด์ „ ๋Œ€ํ™”์ด๋ ฅ
424
  messages.extend(process_history(history))
425
 
426
- # (5) ์ƒˆ ์œ ์ € ๋ฉ”์‹œ์ง€
427
  user_content = process_new_user_message(message)
428
  for item in user_content:
429
  if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
430
  item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
431
  messages.append({"role": "user", "content": user_content})
432
 
433
- # (6) LLM ์ž…๋ ฅ ์ƒ์„ฑ
434
  inputs = processor.apply_chat_template(
435
  messages,
436
  add_generation_prompt=True,
@@ -446,7 +431,7 @@ def run(
446
  max_new_tokens=max_new_tokens,
447
  )
448
 
449
- t = Thread(target=model.generate, kwargs=gen_kwargs)
450
  t.start()
451
 
452
  output = ""
@@ -459,6 +444,22 @@ def run(
459
  yield f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
460
 
461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  ##############################################################################
463
  # ์˜ˆ์‹œ๋“ค (ํ•œ๊ธ€ํ™”)
464
  ##############################################################################
@@ -658,7 +659,7 @@ with gr.Blocks(css=css, title="Vidraft-Gemma-3-27B") as demo:
658
  minimum=100,
659
  maximum=8000,
660
  step=50,
661
- value=2000,
662
  )
663
 
664
  gr.Markdown("<br><br>")
@@ -698,12 +699,12 @@ with gr.Blocks(css=css, title="Vidraft-Gemma-3-27B") as demo:
698
  gr.Markdown("### Example Inputs (click to load)")
699
  gr.Examples(
700
  examples=examples,
701
- inputs=[], # ์—ฐ๊ฒฐํ•  inputs๊ฐ€ ์—†์œผ๋ฏ€๋กœ ๋นˆ ๋ฆฌ์ŠคํŠธ
702
  cache_examples=False
703
  )
704
 
705
  if __name__ == "__main__":
706
- # 615์ค„ + filler๋กœ 715์ค„ ๋งž์ถ”๋ ค๋ฉด ์•„๋ž˜ ์ฃผ์„ ์ถ”๊ฐ€
707
- demo.launch(share=True)
708
-
709
 
 
30
  ##############################################################################
31
  def extract_keywords(text: str, top_k: int = 5) -> str:
32
  """
33
+ 1) ํ•œ๊ธ€(๊ฐ€-ํžฃ), ์˜์–ด(a-zA-Z), ์ˆซ์ž(0-9), ๊ณต๋ฐฑ๋งŒ ๋‚จ๊น€
34
  2) ๊ณต๋ฐฑ ๊ธฐ์ค€ ํ† ํฐ ๋ถ„๋ฆฌ
35
  3) ์ตœ๋Œ€ top_k๊ฐœ๋งŒ
36
  """
 
37
  text = re.sub(r"[^a-zA-Z0-9๊ฐ€-ํžฃ\s]", "", text)
 
38
  tokens = text.split()
 
39
  key_tokens = tokens[:top_k]
 
40
  return " ".join(key_tokens)
41
 
42
  ##############################################################################
 
70
 
71
  summary_lines = []
72
  for idx, item in enumerate(organic[:20], start=1):
 
73
  item_json = json.dumps(item, ensure_ascii=False, indent=2)
74
  summary_lines.append(f"Result {idx}:\n{item_json}\n")
75
 
 
84
  ##############################################################################
85
  MAX_CONTENT_CHARS = 4000
86
  model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
87
+
88
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
89
  model = Gemma3ForConditionalGeneration.from_pretrained(
90
  model_id,
 
386
  return
387
 
388
  try:
 
389
  combined_system_msg = ""
390
 
 
391
  if system_prompt.strip():
392
  combined_system_msg += f"[System Prompt]\n{system_prompt.strip()}\n\n"
393
 
 
394
  if use_web_search:
395
  user_text = message["text"]
396
  ws_query = extract_keywords(user_text, top_k=5)
 
397
  if ws_query.strip():
398
  logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
399
  ws_result = do_web_search(ws_query)
 
400
  combined_system_msg += f"[Search top-20 Full Items Based on user prompt]\n{ws_result}\n\n"
401
  else:
 
402
  combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
403
 
 
404
  messages = []
405
  if combined_system_msg.strip():
 
406
  messages.append({
407
  "role": "system",
408
  "content": [{"type": "text", "text": combined_system_msg.strip()}],
409
  })
410
 
 
411
  messages.extend(process_history(history))
412
 
 
413
  user_content = process_new_user_message(message)
414
  for item in user_content:
415
  if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
416
  item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
417
  messages.append({"role": "user", "content": user_content})
418
 
 
419
  inputs = processor.apply_chat_template(
420
  messages,
421
  add_generation_prompt=True,
 
431
  max_new_tokens=max_new_tokens,
432
  )
433
 
434
+ t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
435
  t.start()
436
 
437
  output = ""
 
444
  yield f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
445
 
446
 
447
+ ##############################################################################
448
+ # [์ถ”๊ฐ€] ๋ณ„๋„ ํ•จ์ˆ˜์—์„œ model.generate(...)๋ฅผ ํ˜ธ์ถœ, OOM ์บ์น˜
449
+ ##############################################################################
450
+ def _model_gen_with_oom_catch(**kwargs):
451
+ """
452
+ ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ OutOfMemoryError๋ฅผ ์žก์•„์ฃผ๊ธฐ ์œ„ํ•ด
453
+ """
454
+ try:
455
+ model.generate(**kwargs)
456
+ except torch.cuda.OutOfMemoryError:
457
+ raise RuntimeError(
458
+ "[OutOfMemoryError] GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค. "
459
+ "Max New Tokens์„ ์ค„์ด๊ฑฐ๋‚˜, ํ”„๋กฌํ”„ํŠธ ๊ธธ์ด๋ฅผ ์ค„์—ฌ์ฃผ์„ธ์š”."
460
+ )
461
+
462
+
463
  ##############################################################################
464
  # ์˜ˆ์‹œ๋“ค (ํ•œ๊ธ€ํ™”)
465
  ##############################################################################
 
659
  minimum=100,
660
  maximum=8000,
661
  step=50,
662
+ value=512, # GPU ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ ์œ„ํ•ด ๊ธฐ๋ณธ๊ฐ’ ์•ฝ๊ฐ„ ์ถ•์†Œ
663
  )
664
 
665
  gr.Markdown("<br><br>")
 
699
  gr.Markdown("### Example Inputs (click to load)")
700
  gr.Examples(
701
  examples=examples,
702
+ inputs=[],
703
  cache_examples=False
704
  )
705
 
706
  if __name__ == "__main__":
707
+ # share=True ์‹œ HF Spaces์—์„œ ๊ฒฝ๊ณ  ๋ฐœ์ƒ - ๋กœ์ปฌ์—์„œ๋งŒ ๋™์ž‘
708
+ # demo.launch(share=True)
709
+ demo.launch()
710