seawolf2357 commited on
Commit
1c72d37
ยท
verified ยท
1 Parent(s): a9e7179

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -37
app.py CHANGED
@@ -22,7 +22,7 @@ import PyPDF2
22
  ##################################################
23
  # ์ƒ์ˆ˜ ๋ฐ ๋ชจ๋ธ ๋กœ๋”ฉ
24
  ##################################################
25
- MAX_CONTENT_CHARS = 8000 # ๋„ˆ๋ฌด ํฐ ํŒŒ์ผ ๋‚ด์šฉ์€ ์ด ์ •๋„๊นŒ์ง€๋งŒ ํ‘œ์‹œ
26
  model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
27
 
28
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
@@ -40,7 +40,7 @@ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
40
  # 1) CSV, TXT, PDF ๋ถ„์„ ํ•จ์ˆ˜
41
  ##################################################
42
  def analyze_csv_file(path: str) -> str:
43
- """CSV ํŒŒ์ผ์„ ์ฝ์–ด ๋ฌธ์ž์—ดํ™”. ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ผ๋ถ€๋งŒ ์ถœ๋ ฅ."""
44
  try:
45
  df = pd.read_csv(path)
46
  df_str = df.to_string()
@@ -52,7 +52,7 @@ def analyze_csv_file(path: str) -> str:
52
 
53
 
54
  def analyze_txt_file(path: str) -> str:
55
- """TXT ํŒŒ์ผ ์ „์ฒด๋ฅผ ์ฝ์–ด ๋ฌธ์ž์—ด ๋ฐ˜ํ™˜. ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„."""
56
  try:
57
  with open(path, "r", encoding="utf-8") as f:
58
  text = f.read()
@@ -64,9 +64,9 @@ def analyze_txt_file(path: str) -> str:
64
 
65
 
66
  def pdf_to_markdown(pdf_path: str) -> str:
67
- """PDF -> ํ…์ŠคํŠธ ์ถ”์ถœ -> Markdown ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜. ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž๋ฆ„."""
68
- text_chunks = []
69
  try:
 
70
  with open(pdf_path, "rb") as f:
71
  reader = PyPDF2.PdfReader(f)
72
  for page_num, page in enumerate(reader.pages, start=1):
@@ -85,7 +85,7 @@ def pdf_to_markdown(pdf_path: str) -> str:
85
 
86
 
87
  ##################################################
88
- # 2) ์ด๋ฏธ์ง€/๋น„๋””์˜ค ๊ฐœ์ˆ˜ ์ œํ•œ ๊ฒ€์‚ฌ
89
  ##################################################
90
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
91
  image_count = 0
@@ -102,11 +102,11 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
102
  image_count = 0
103
  video_count = 0
104
  for item in history:
105
- # assistant ๋ฉ”์‹œ์ง€์ด๊ฑฐ๋‚˜ content๊ฐ€ str์ด๋ฉด ์ œ์™ธ
106
  if item["role"] != "user" or isinstance(item["content"], str):
107
  continue
108
- # ์ด๋ฏธ์ง€/๋น„๋””์˜ค ๊ฒฝ๋กœ๋กœ๋งŒ ์นด์šดํŠธ
109
- if item["content"][0].endswith(".mp4"):
 
110
  video_count += 1
111
  else:
112
  image_count += 1
@@ -115,13 +115,11 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
115
 
116
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
117
  """
118
- - ์ด๋ฏธ์ง€/๋น„๋””์˜ค๋งŒ ๋Œ€์ƒ์œผ๋กœ ๊ฐœ์ˆ˜ยทํ˜ผํ•ฉ ์ œํ•œ
119
- - CSV, PDF, TXT ๋“ฑ์€ ๋Œ€์ƒ ์ œ์™ธ
120
- - <image> ํƒœ๊ทธ์™€ ์‹ค์ œ ์ด๋ฏธ์ง€ ์ˆ˜๊ฐ€ ์ผ์น˜ํ•˜๋Š”์ง€ ๋“ฑ
121
  """
122
  media_files = []
123
  for f in message["files"]:
124
- # ์ด๋ฏธ์ง€ ํ™•์žฅ์ž ๋˜๋Š” .mp4
125
  if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
126
  media_files.append(f)
127
 
@@ -146,7 +144,7 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
146
  if video_count == 0 and image_count > MAX_NUM_IMAGES:
147
  gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
148
  return False
149
- # <image> ํƒœ๊ทธ์™€ ์‹ค์ œ ์ด๋ฏธ์ง€ ์ˆ˜๊ฐ€ ์ผ์น˜?
150
  if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
151
  gr.Warning("The number of <image> tags in the text does not match the number of images.")
152
  return False
@@ -158,7 +156,6 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
158
  # 3) ๋น„๋””์˜ค ์ฒ˜๋ฆฌ
159
  ##################################################
160
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
161
- """์˜์ƒ์—์„œ ์ผ์ • ๊ฐ„๊ฒฉ์œผ๋กœ ํ”„๋ ˆ์ž„์„ ์ถ”์ถœ, PIL ์ด๋ฏธ์ง€์™€ timestamp ๋ฐ˜ํ™˜."""
162
  vidcap = cv2.VideoCapture(video_path)
163
  fps = vidcap.get(cv2.CAP_PROP_FPS)
164
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -203,14 +200,13 @@ def process_interleaved_images(message: dict) -> list[dict]:
203
  elif part.strip():
204
  content.append({"type": "text", "text": part.strip()})
205
  else:
206
- # ๊ณต๋ฐฑ๋งŒ ์žˆ๋Š” ๊ฒฝ์šฐ
207
  if isinstance(part, str) and part != "<image>":
208
  content.append({"type": "text", "text": part})
209
  return content
210
 
211
 
212
  ##################################################
213
- # 5) CSV/PDF/TXT๋Š” ํ…์ŠคํŠธ๋กœ๋งŒ, ์ด๋ฏธ์ง€/๋น„๋””์˜ค๋Š” ๊ฒฝ๋กœ๋กœ
214
  ##################################################
215
  def process_new_user_message(message: dict) -> list[dict]:
216
  if not message["files"]:
@@ -223,13 +219,12 @@ def process_new_user_message(message: dict) -> list[dict]:
223
  txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
224
  pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
225
 
226
- # user ํ…์ŠคํŠธ ๋จผ์ € ์ถ”๊ฐ€
227
  content_list = [{"type": "text", "text": message["text"]}]
228
 
229
  # CSV
230
  for csv_path in csv_files:
231
  csv_analysis = analyze_csv_file(csv_path)
232
- # ๋ถ„์„ ๋‚ด์šฉ๋งŒ ๋„ฃ์Œ (ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜์ง€ ์•Š์Œ)
233
  content_list.append({"type": "text", "text": csv_analysis})
234
 
235
  # TXT
@@ -249,10 +244,8 @@ def process_new_user_message(message: dict) -> list[dict]:
249
 
250
  # ์ด๋ฏธ์ง€
251
  if "<image>" in message["text"]:
252
- # interleaved
253
  return process_interleaved_images(message)
254
  else:
255
- # ์—ฌ๋Ÿฌ ์žฅ ์ด๋ฏธ์ง€
256
  for img_path in image_files:
257
  content_list.append({"type": "image", "url": img_path})
258
 
@@ -260,45 +253,58 @@ def process_new_user_message(message: dict) -> list[dict]:
260
 
261
 
262
  ##################################################
263
- # 6) history -> LLM ๋ฉ”์‹œ์ง€ ๋ณ€ํ™˜
264
  ##################################################
265
  def process_history(history: list[dict]) -> list[dict]:
 
 
 
 
266
  messages = []
267
- current_user_content: list[dict] = []
268
  for item in history:
269
  if item["role"] == "assistant":
270
  if current_user_content:
271
  messages.append({"role": "user", "content": current_user_content})
272
  current_user_content = []
 
273
  messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
274
  else:
 
275
  content = item["content"]
276
  if isinstance(content, str):
 
277
  current_user_content.append({"type": "text", "text": content})
278
  else:
279
- # ์ด๋ฏธ์ง€ or ๊ธฐํƒ€ ํŒŒ์ผ url
280
- current_user_content.append({"type": "image", "url": content[0]})
 
 
 
 
 
 
281
  return messages
282
 
283
 
284
  ##################################################
285
- # 7) ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
286
  ##################################################
287
  @spaces.GPU(duration=120)
288
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
289
- # a) ์ด๋ฏธ์ง€/๋น„๋””์˜ค ์ œํ•œ ๊ฒ€์‚ฌ
290
  if not validate_media_constraints(message, history):
291
  yield ""
292
  return
293
 
294
- # b) ๋Œ€ํ™” ๊ธฐ๋ก + ์ด๋ฒˆ ๋ฉ”์‹œ์ง€
295
  messages = []
296
  if system_prompt:
297
  messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
298
  messages.extend(process_history(history))
299
  messages.append({"role": "user", "content": process_new_user_message(message)})
300
 
301
- # c) ๋ชจ๋ธ ์ถ”๋ก 
302
  inputs = processor.apply_chat_template(
303
  messages,
304
  add_generation_prompt=True,
@@ -308,11 +314,11 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
308
  ).to(device=model.device, dtype=torch.bfloat16)
309
 
310
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
311
- gen_kwargs = dict(
312
- inputs,
313
- streamer=streamer,
314
- max_new_tokens=max_new_tokens,
315
- )
316
  t = Thread(target=model.generate, kwargs=gen_kwargs)
317
  t.start()
318
 
@@ -322,6 +328,8 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
322
  yield output
323
 
324
 
 
 
325
  ##################################################
326
  # ์˜ˆ์‹œ๋“ค (ํ•œ๊ธ€ํ™” ๋ฒ„์ „)
327
  ##################################################
@@ -457,6 +465,7 @@ examples = [
457
 
458
 
459
 
 
460
  ##################################################
461
  # 9) Gradio ChatInterface
462
  ##################################################
@@ -464,7 +473,7 @@ demo = gr.ChatInterface(
464
  fn=run,
465
  type="messages",
466
  chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
467
- # ์ด๋ฏธ์ง€/๋™์˜์ƒ + CSV/TXT/PDF ํ—ˆ์šฉ (์ด๋ฏธ์ง€: webp ํฌํ•จ)
468
  textbox=gr.MultimodalTextbox(
469
  file_types=[
470
  ".png", ".jpg", ".jpeg", ".gif", ".webp",
@@ -496,8 +505,7 @@ demo = gr.ChatInterface(
496
  delete_cache=(1800, 1800),
497
  )
498
 
 
499
  if __name__ == "__main__":
500
  demo.launch()
501
 
502
-
503
-
 
22
  ##################################################
23
  # ์ƒ์ˆ˜ ๋ฐ ๋ชจ๋ธ ๋กœ๋”ฉ
24
  ##################################################
25
+ MAX_CONTENT_CHARS = 8000 # ํ…์ŠคํŠธ๋กœ ์ „๋‹ฌ ์‹œ ์ตœ๋Œ€ 8000์ž๊นŒ์ง€๋งŒ
26
  model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
27
 
28
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
 
40
  # 1) CSV, TXT, PDF ๋ถ„์„ ํ•จ์ˆ˜
41
  ##################################################
42
  def analyze_csv_file(path: str) -> str:
43
+ """CSV ํŒŒ์ผ -> ๋ฌธ์ž์—ด. ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„."""
44
  try:
45
  df = pd.read_csv(path)
46
  df_str = df.to_string()
 
52
 
53
 
54
  def analyze_txt_file(path: str) -> str:
55
+ """TXT ํŒŒ์ผ -> ์ „์ฒด ๋ฌธ์ž์—ด. ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„."""
56
  try:
57
  with open(path, "r", encoding="utf-8") as f:
58
  text = f.read()
 
64
 
65
 
66
  def pdf_to_markdown(pdf_path: str) -> str:
67
+ """PDF -> ํ…์ŠคํŠธ ์ถ”์ถœ -> Markdown. ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„."""
 
68
  try:
69
+ text_chunks = []
70
  with open(pdf_path, "rb") as f:
71
  reader = PyPDF2.PdfReader(f)
72
  for page_num, page in enumerate(reader.pages, start=1):
 
85
 
86
 
87
  ##################################################
88
+ # 2) ์ด๋ฏธ์ง€/๋น„๋””์˜ค ์ œํ•œ ๊ฒ€์‚ฌ (CSV, PDF, TXT ์ œ์™ธ)
89
  ##################################################
90
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
91
  image_count = 0
 
102
  image_count = 0
103
  video_count = 0
104
  for item in history:
 
105
  if item["role"] != "user" or isinstance(item["content"], str):
106
  continue
107
+ # item["content"]๊ฐ€ ["๊ฒฝ๋กœ"] ํ˜•ํƒœ์ผ ๋•Œ, ํ™•์žฅ์ž๋ฅผ ํ™•์ธ
108
+ file_path = item["content"][0]
109
+ if file_path.endswith(".mp4"):
110
  video_count += 1
111
  else:
112
  image_count += 1
 
115
 
116
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
117
  """
118
+ ์ด๋ฏธ์ง€ & ๋น„๋””์˜ค ์ œํ•œ
 
 
119
  """
120
  media_files = []
121
  for f in message["files"]:
122
+ # ์ด๋ฏธ์ง€/๋น„๋””์˜ค๋งŒ
123
  if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
124
  media_files.append(f)
125
 
 
144
  if video_count == 0 and image_count > MAX_NUM_IMAGES:
145
  gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
146
  return False
147
+ # <image> ํƒœ๊ทธ์™€ ์‹ค์ œ ์ด๋ฏธ์ง€ ์ˆ˜ ์ผ์น˜?
148
  if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
149
  gr.Warning("The number of <image> tags in the text does not match the number of images.")
150
  return False
 
156
  # 3) ๋น„๋””์˜ค ์ฒ˜๋ฆฌ
157
  ##################################################
158
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
 
159
  vidcap = cv2.VideoCapture(video_path)
160
  fps = vidcap.get(cv2.CAP_PROP_FPS)
161
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
 
200
  elif part.strip():
201
  content.append({"type": "text", "text": part.strip()})
202
  else:
 
203
  if isinstance(part, str) and part != "<image>":
204
  content.append({"type": "text", "text": part})
205
  return content
206
 
207
 
208
  ##################################################
209
+ # 5) CSV/PDF/TXT๋Š” ํ…์ŠคํŠธ ๋ณ€ํ™˜๋งŒ, ์ด๋ฏธ์ง€/๋น„๋””์˜ค๋Š” ๊ฒฝ๋กœ๋กœ
210
  ##################################################
211
  def process_new_user_message(message: dict) -> list[dict]:
212
  if not message["files"]:
 
219
  txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
220
  pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
221
 
222
+ # user ํ…์ŠคํŠธ ์ถ”๊ฐ€
223
  content_list = [{"type": "text", "text": message["text"]}]
224
 
225
  # CSV
226
  for csv_path in csv_files:
227
  csv_analysis = analyze_csv_file(csv_path)
 
228
  content_list.append({"type": "text", "text": csv_analysis})
229
 
230
  # TXT
 
244
 
245
  # ์ด๋ฏธ์ง€
246
  if "<image>" in message["text"]:
 
247
  return process_interleaved_images(message)
248
  else:
 
249
  for img_path in image_files:
250
  content_list.append({"type": "image", "url": img_path})
251
 
 
253
 
254
 
255
  ##################################################
256
+ # 6) ํžˆ์Šคํ† ๋ฆฌ -> LLM ๋ฉ”์‹œ์ง€ ๋ณ€ํ™˜
257
  ##################################################
258
  def process_history(history: list[dict]) -> list[dict]:
259
+ """
260
+ ์—ฌ๊ธฐ์„œ, ์ด๋ฏธ์ง€/๋น„๋””์˜ค ์™ธ์˜ ํŒŒ์ผ(.csv, .pdf, .txt) ๊ฒฝ๋กœ๋Š”
261
+ ๋ชจ๋ธ๋กœ ์ „๋‹ฌ๋˜์ง€ ์•Š๋„๋ก ์ œ๊ฑฐ (or ๋ฌด์‹œ)
262
+ """
263
  messages = []
264
+ current_user_content = []
265
  for item in history:
266
  if item["role"] == "assistant":
267
  if current_user_content:
268
  messages.append({"role": "user", "content": current_user_content})
269
  current_user_content = []
270
+ # assistant -> ๊ทธ๋ƒฅ ํ…์ŠคํŠธ๋กœ
271
  messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
272
  else:
273
+ # user
274
  content = item["content"]
275
  if isinstance(content, str):
276
+ # ๋‹จ์ˆœ ํ…์ŠคํŠธ
277
  current_user_content.append({"type": "text", "text": content})
278
  else:
279
+ # ๋ณดํ†ต [ํŒŒ์ผ๊ฒฝ๋กœ] ํ˜•ํƒœ
280
+ file_path = content[0]
281
+ # ๋งŒ์•ฝ ์ด๋ฏธ์ง€๋‚˜ mp4๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด -> ๋ฌด์‹œ
282
+ if re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE) or file_path.endswith(".mp4"):
283
+ current_user_content.append({"type": "image", "url": file_path})
284
+ else:
285
+ # csv, pdf, txt ๋“ฑ์€ ์ œ๊ฑฐ
286
+ pass
287
  return messages
288
 
289
 
290
  ##################################################
291
+ # 7) ๋ฉ”์ธ ์ถ”๋ก 
292
  ##################################################
293
  @spaces.GPU(duration=120)
294
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
295
+ # a) ๋ฏธ๋””์–ด ์ œํ•œ ๊ฒ€์‚ฌ
296
  if not validate_media_constraints(message, history):
297
  yield ""
298
  return
299
 
300
+ # b) ๊ธฐ์กด ํžˆ์Šคํ† ๋ฆฌ -> LLM ๋ฉ”์‹œ์ง€
301
  messages = []
302
  if system_prompt:
303
  messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
304
  messages.extend(process_history(history))
305
  messages.append({"role": "user", "content": process_new_user_message(message)})
306
 
307
+ # c) ๋ชจ๋ธ ํ˜ธ์ถœ
308
  inputs = processor.apply_chat_template(
309
  messages,
310
  add_generation_prompt=True,
 
314
  ).to(device=model.device, dtype=torch.bfloat16)
315
 
316
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
317
+ gen_kwargs = {
318
+ "inputs": inputs,
319
+ "streamer": streamer,
320
+ "max_new_tokens": max_new_tokens,
321
+ }
322
  t = Thread(target=model.generate, kwargs=gen_kwargs)
323
  t.start()
324
 
 
328
  yield output
329
 
330
 
331
+
332
+
333
  ##################################################
334
  # ์˜ˆ์‹œ๋“ค (ํ•œ๊ธ€ํ™” ๋ฒ„์ „)
335
  ##################################################
 
465
 
466
 
467
 
468
+
469
  ##################################################
470
  # 9) Gradio ChatInterface
471
  ##################################################
 
473
  fn=run,
474
  type="messages",
475
  chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
476
+ # ์ด๋ฏธ์ง€(์—ฌ๋Ÿฌ ํ™•์žฅ์ž), mp4, csv, txt, pdf ํ—ˆ์šฉ
477
  textbox=gr.MultimodalTextbox(
478
  file_types=[
479
  ".png", ".jpg", ".jpeg", ".gif", ".webp",
 
505
  delete_cache=(1800, 1800),
506
  )
507
 
508
+
509
  if __name__ == "__main__":
510
  demo.launch()
511