seawolf2357 commited on
Commit
1670280
ยท
verified ยท
1 Parent(s): 75b15f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -74
app.py CHANGED
@@ -14,30 +14,26 @@ from loguru import logger
14
  from PIL import Image
15
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
 
17
- # [CSV/TXT ๋ถ„์„์šฉ]
18
  import pandas as pd
19
 
20
- ##################################################
21
- # ์ „์ฒด ์ „๋ฌธ์„ ๋„˜๊ธฐ๋˜, ๋„ˆ๋ฌด ํด ๊ฒฝ์šฐ ์ž˜๋ผ๋‚ด๊ธฐ ์œ„ํ•œ ์ƒ์ˆ˜
22
- ##################################################
23
- MAX_CONTENT_CHARS = 8000 # ์˜ˆ: 8000์ž ์ดˆ๊ณผ ์‹œ ์ž˜๋ผ๋ƒ„
24
 
25
  model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
26
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
27
  model = Gemma3ForConditionalGeneration.from_pretrained(
28
- model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
 
 
 
29
  )
30
 
31
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
32
 
33
 
34
- ##################################################
35
- # CSV/TXT ์ „๋ฌธ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
36
- ##################################################
37
  def analyze_csv_file(path: str) -> str:
38
  """
39
- CSV ํŒŒ์ผ ์ „์ฒด๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฆฌํ„ด.
40
- ๋„ˆ๋ฌด ๊ธธ๋ฉด MAX_CONTENT_CHARS๊นŒ์ง€๋งŒ ์ž˜๋ผ๋ƒ„.
41
  """
42
  try:
43
  df = pd.read_csv(path)
@@ -45,37 +41,26 @@ def analyze_csv_file(path: str) -> str:
45
  if len(df_str) > MAX_CONTENT_CHARS:
46
  df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
47
 
48
- return (
49
- f"**[CSV File: {os.path.basename(path)}]**\n\n"
50
- f"{df_str}"
51
- )
52
  except Exception as e:
53
  return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
54
 
55
 
56
  def analyze_txt_file(path: str) -> str:
57
  """
58
- TXT ํŒŒ์ผ ์ „์ฒด ๋‚ด์šฉ์„ ์ฝ์–ด์„œ ๋ชจ๋ธ์— ๋„˜๊น€.
59
- ๋„ˆ๋ฌด ๊ธธ๋ฉด MAX_CONTENT_CHARS๊นŒ์ง€๋งŒ ์ž˜๋ผ๋ƒ„.
60
  """
61
  try:
62
  with open(path, "r", encoding="utf-8") as f:
63
  text = f.read()
64
-
65
  if len(text) > MAX_CONTENT_CHARS:
66
  text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
67
 
68
- return (
69
- f"**[TXT File: {os.path.basename(path)}]**\n\n"
70
- f"{text}"
71
- )
72
  except Exception as e:
73
  return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
74
 
75
 
76
- ##################################################
77
- # ๊ธฐ์กด ๋ฏธ๋””์–ด ํŒŒ์ผ ๊ฒ€์‚ฌ ๋กœ์ง (์ด๋ฏธ์ง€/๋น„๋””์˜ค)
78
- ##################################################
79
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
80
  image_count = 0
81
  video_count = 0
@@ -105,14 +90,13 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
105
  - ๋น„๋””์˜ค 1๊ฐœ ์ดˆ๊ณผ ๋ถˆ๊ฐ€
106
  - ๋น„๋””์˜ค/์ด๋ฏธ์ง€ ํ˜ผํ•ฉ ๋ถˆ๊ฐ€
107
  - ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜ MAX_NUM_IMAGES ์ดˆ๊ณผ ๋ถˆ๊ฐ€
108
- - <image> ํƒœ๊ทธ๊ฐ€ ์žˆ์œผ๋ฉด ํƒœ๊ทธ ์ˆ˜์™€ ์ด๋ฏธ์ง€ ์ˆ˜ ์ผ์น˜
109
- CSV, TXT, PDF ๋“ฑ์€ ์—ฌ๊ธฐ์„œ ์ œํ•œํ•˜์ง€ ์•Š์Œ.
110
  """
111
  media_files = []
112
  for f in message["files"]:
113
- # mp4๋‚˜ ๋Œ€ํ‘œ ์ด๋ฏธ์ง€ ํ™•์žฅ์ž๋งŒ ๊ฒ€์‚ฌ
114
- # (ํŒŒ์ผ๋ช…์— .png / .jpg / .gif / .webp ๋“ฑ ์žˆ์„ ๋•Œ)
115
- if f.endswith(".mp4") or re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE):
116
  media_files.append(f)
117
 
118
  new_image_count, new_video_count = count_files_in_new_message(media_files)
@@ -140,9 +124,6 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
140
  return True
141
 
142
 
143
- ##################################################
144
- # ๋น„๋””์˜ค ์ฒ˜๋ฆฌ
145
- ##################################################
146
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
147
  vidcap = cv2.VideoCapture(video_path)
148
  fps = vidcap.get(cv2.CAP_PROP_FPS)
@@ -177,9 +158,6 @@ def process_video(video_path: str) -> list[dict]:
177
  return content
178
 
179
 
180
- ##################################################
181
- # interleaved <image> ํƒœ๊ทธ ์ฒ˜๋ฆฌ
182
- ##################################################
183
  def process_interleaved_images(message: dict) -> list[dict]:
184
  logger.debug(f"{message['files']=}")
185
  parts = re.split(r"(<image>)", message["text"])
@@ -188,7 +166,6 @@ def process_interleaved_images(message: dict) -> list[dict]:
188
  content = []
189
  image_index = 0
190
  for part in parts:
191
- logger.debug(f"{part=}")
192
  if part == "<image>":
193
  content.append({"type": "image", "url": message["files"][image_index]})
194
  logger.debug(f"file: {message['files'][image_index]}")
@@ -201,16 +178,7 @@ def process_interleaved_images(message: dict) -> list[dict]:
201
  return content
202
 
203
 
204
- ##################################################
205
- # CSV, TXT ํŒŒ์ผ๋„ ์ „๋ฌธ์„ LLM์— ๋„˜๊ธฐ๋„๋ก
206
- ##################################################
207
  def process_new_user_message(message: dict) -> list[dict]:
208
- """
209
- - mp4 -> ๋น„๋””์˜ค ์ฒ˜๋ฆฌ
210
- - ์ด๋ฏธ์ง€ -> interleaved or multiple
211
- - CSV -> ์ „์ฒด df.to_string() (๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„)
212
- - TXT -> ์ „์ฒด text (๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„)
213
- """
214
  if not message["files"]:
215
  return [{"type": "text", "text": message["text"]}]
216
 
@@ -220,7 +188,7 @@ def process_new_user_message(message: dict) -> list[dict]:
220
  csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
221
  txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
222
 
223
- # ์‚ฌ์šฉ์ž ํ…์ŠคํŠธ
224
  content_list = [{"type": "text", "text": message["text"]}]
225
 
226
  # CSV ์ „๋ฌธ
@@ -233,7 +201,7 @@ def process_new_user_message(message: dict) -> list[dict]:
233
  txt_analysis = analyze_txt_file(txt_path)
234
  content_list.append({"type": "text", "text": txt_analysis})
235
 
236
- # ๋น„๋””์˜ค
237
  if video_files:
238
  content_list += process_video(video_files[0])
239
  return content_list
@@ -242,7 +210,7 @@ def process_new_user_message(message: dict) -> list[dict]:
242
  if "<image>" in message["text"]:
243
  return process_interleaved_images(message)
244
 
245
- # ์ผ๋ฐ˜ ์ด๋ฏธ์ง€(์—ฌ๋Ÿฌ ์žฅ)
246
  if image_files:
247
  for img_path in image_files:
248
  content_list.append({"type": "image", "url": img_path})
@@ -250,9 +218,6 @@ def process_new_user_message(message: dict) -> list[dict]:
250
  return content_list
251
 
252
 
253
- ##################################################
254
- # history -> LLM ๋ฉ”์‹œ์ง€ ๋ณ€ํ™˜
255
- ##################################################
256
  def process_history(history: list[dict]) -> list[dict]:
257
  messages = []
258
  current_user_content: list[dict] = []
@@ -271,9 +236,6 @@ def process_history(history: list[dict]) -> list[dict]:
271
  return messages
272
 
273
 
274
- ##################################################
275
- # ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
276
- ##################################################
277
  @spaces.GPU(duration=120)
278
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
279
  if not validate_media_constraints(message, history):
@@ -309,9 +271,6 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
309
  yield output
310
 
311
 
312
- ##################################################
313
- # ์˜ˆ์‹œ ๋ชฉ๋ก (๊ธฐ์กด)
314
- ##################################################
315
  examples = [
316
  [
317
  {
@@ -435,16 +394,16 @@ examples = [
435
  ]
436
 
437
 
438
- ##################################################
439
- # Gradio ChatInterface
440
- ##################################################
441
  demo = gr.ChatInterface(
442
  fn=run,
443
  type="messages",
444
  chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
445
- # ์—ฌ๊ธฐ์„œ WEBP๋ฅผ ํฌํ•จํ•œ ๋ชจ๋“  ์ด๋ฏธ์ง€, mp4, csv, txt, pdf ํ—ˆ์šฉ
446
  textbox=gr.MultimodalTextbox(
447
- file_types=["image/*", ".mp4", ".csv", ".txt", ".pdf"],
 
 
 
448
  file_count="multiple",
449
  autofocus=True
450
  ),
@@ -452,18 +411,9 @@ demo = gr.ChatInterface(
452
  additional_inputs=[
453
  gr.Textbox(
454
  label="System Prompt",
455
- value=(
456
- "You are a deeply thoughtful AI. Consider problems thoroughly and derive "
457
- "correct solutions through systematic reasoning. Please answer in korean."
458
- )
459
- ),
460
- gr.Slider(
461
- label="Max New Tokens",
462
- minimum=100,
463
- maximum=8000,
464
- step=50,
465
- value=2000
466
  ),
 
467
  ],
468
  stop_btn=False,
469
  title="Gemma 3 27B IT",
 
14
  from PIL import Image
15
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
 
17
+ # CSV/TXT ๋ถ„์„
18
  import pandas as pd
19
 
20
+ MAX_CONTENT_CHARS = 8000 # ํŒŒ์ผ์—์„œ ์ฝ์€ ๋‚ด์šฉ์ด ๋„ˆ๋ฌด ๊ธธ ๊ฒฝ์šฐ ์ด ์ •๋„์—์„œ ์ž˜๋ผ๋ƒ„
 
 
 
21
 
22
  model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
23
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
24
  model = Gemma3ForConditionalGeneration.from_pretrained(
25
+ model_id,
26
+ device_map="auto",
27
+ torch_dtype=torch.bfloat16,
28
+ attn_implementation="eager"
29
  )
30
 
31
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
32
 
33
 
 
 
 
34
  def analyze_csv_file(path: str) -> str:
35
  """
36
+ CSV ํŒŒ์ผ์„ ์ฝ์–ด ๋ฌธ์ž์—ดํ™”. ๋„ˆ๋ฌด ํฌ๋ฉด ์ผ๋ถ€๋งŒ ์ž˜๋ผ๋ƒ„.
 
37
  """
38
  try:
39
  df = pd.read_csv(path)
 
41
  if len(df_str) > MAX_CONTENT_CHARS:
42
  df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
43
 
44
+ return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}"
 
 
 
45
  except Exception as e:
46
  return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
47
 
48
 
49
  def analyze_txt_file(path: str) -> str:
50
  """
51
+ TXT ํŒŒ์ผ ์ „๋ฌธ ์ฝ์–ด๋“ค์ด๋˜, ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„.
 
52
  """
53
  try:
54
  with open(path, "r", encoding="utf-8") as f:
55
  text = f.read()
 
56
  if len(text) > MAX_CONTENT_CHARS:
57
  text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
58
 
59
+ return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}"
 
 
 
60
  except Exception as e:
61
  return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
62
 
63
 
 
 
 
64
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
65
  image_count = 0
66
  video_count = 0
 
90
  - ๋น„๋””์˜ค 1๊ฐœ ์ดˆ๊ณผ ๋ถˆ๊ฐ€
91
  - ๋น„๋””์˜ค/์ด๋ฏธ์ง€ ํ˜ผํ•ฉ ๋ถˆ๊ฐ€
92
  - ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜ MAX_NUM_IMAGES ์ดˆ๊ณผ ๋ถˆ๊ฐ€
93
+ - <image> ํƒœ๊ทธ๊ฐ€ ์žˆ์œผ๋ฉด ํƒœ๊ทธ ์ˆ˜์™€ ์‹ค์ œ ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜ ์ผ์น˜
94
+ - CSV, TXT, PDF ๋“ฑ์€ ์—ฌ๊ธฐ์„œ ์ œํ•œํ•˜์ง€ ์•Š์Œ.
95
  """
96
  media_files = []
97
  for f in message["files"]:
98
+ # ์ด๋ฏธ์ง€(์—ฌ๋Ÿฌ ํ™•์žฅ์ž)๋‚˜ mp4๋งŒ ์ฒดํฌ
99
+ if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
 
100
  media_files.append(f)
101
 
102
  new_image_count, new_video_count = count_files_in_new_message(media_files)
 
124
  return True
125
 
126
 
 
 
 
127
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
128
  vidcap = cv2.VideoCapture(video_path)
129
  fps = vidcap.get(cv2.CAP_PROP_FPS)
 
158
  return content
159
 
160
 
 
 
 
161
  def process_interleaved_images(message: dict) -> list[dict]:
162
  logger.debug(f"{message['files']=}")
163
  parts = re.split(r"(<image>)", message["text"])
 
166
  content = []
167
  image_index = 0
168
  for part in parts:
 
169
  if part == "<image>":
170
  content.append({"type": "image", "url": message["files"][image_index]})
171
  logger.debug(f"file: {message['files'][image_index]}")
 
178
  return content
179
 
180
 
 
 
 
181
  def process_new_user_message(message: dict) -> list[dict]:
 
 
 
 
 
 
182
  if not message["files"]:
183
  return [{"type": "text", "text": message["text"]}]
184
 
 
188
  csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
189
  txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
190
 
191
+ # ์‚ฌ์šฉ์ž ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ๋จผ์ €
192
  content_list = [{"type": "text", "text": message["text"]}]
193
 
194
  # CSV ์ „๋ฌธ
 
201
  txt_analysis = analyze_txt_file(txt_path)
202
  content_list.append({"type": "text", "text": txt_analysis})
203
 
204
+ # ๋™์˜์ƒ ์ฒ˜๋ฆฌ
205
  if video_files:
206
  content_list += process_video(video_files[0])
207
  return content_list
 
210
  if "<image>" in message["text"]:
211
  return process_interleaved_images(message)
212
 
213
+ # ์ผ๋ฐ˜ ์ด๋ฏธ์ง€๋“ค
214
  if image_files:
215
  for img_path in image_files:
216
  content_list.append({"type": "image", "url": img_path})
 
218
  return content_list
219
 
220
 
 
 
 
221
  def process_history(history: list[dict]) -> list[dict]:
222
  messages = []
223
  current_user_content: list[dict] = []
 
236
  return messages
237
 
238
 
 
 
 
239
  @spaces.GPU(duration=120)
240
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
241
  if not validate_media_constraints(message, history):
 
271
  yield output
272
 
273
 
 
 
 
274
  examples = [
275
  [
276
  {
 
394
  ]
395
 
396
 
 
 
 
397
  demo = gr.ChatInterface(
398
  fn=run,
399
  type="messages",
400
  chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
401
+ # .webp, .png, .jpg, .jpeg, .gif, .mp4, .csv, .txt, .pdf ๋ชจ๋‘ ํ—ˆ์šฉ
402
  textbox=gr.MultimodalTextbox(
403
+ file_types=[
404
+ ".webp", ".png", ".jpg", ".jpeg", ".gif",
405
+ ".mp4", ".csv", ".txt", ".pdf"
406
+ ],
407
  file_count="multiple",
408
  autofocus=True
409
  ),
 
411
  additional_inputs=[
412
  gr.Textbox(
413
  label="System Prompt",
414
+ value="You are a deeply thoughtful AI. Consider problems thoroughly and derive correct solutions through systematic reasoning. Please answer in korean."
 
 
 
 
 
 
 
 
 
 
415
  ),
416
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=8000, step=50, value=2000),
417
  ],
418
  stop_btn=False,
419
  title="Gemma 3 27B IT",