Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ import PyPDF2
|
|
22 |
##################################################
|
23 |
# ์์ ๋ฐ ๋ชจ๋ธ ๋ก๋ฉ
|
24 |
##################################################
|
25 |
-
MAX_CONTENT_CHARS = 8000 #
|
26 |
model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
|
27 |
|
28 |
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
@@ -40,7 +40,7 @@ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
|
40 |
# 1) CSV, TXT, PDF ๋ถ์ ํจ์
|
41 |
##################################################
|
42 |
def analyze_csv_file(path: str) -> str:
|
43 |
-
"""CSV
|
44 |
try:
|
45 |
df = pd.read_csv(path)
|
46 |
df_str = df.to_string()
|
@@ -52,7 +52,7 @@ def analyze_csv_file(path: str) -> str:
|
|
52 |
|
53 |
|
54 |
def analyze_txt_file(path: str) -> str:
|
55 |
-
"""TXT ํ์ผ
|
56 |
try:
|
57 |
with open(path, "r", encoding="utf-8") as f:
|
58 |
text = f.read()
|
@@ -64,9 +64,9 @@ def analyze_txt_file(path: str) -> str:
|
|
64 |
|
65 |
|
66 |
def pdf_to_markdown(pdf_path: str) -> str:
|
67 |
-
"""PDF -> ํ
์คํธ ์ถ์ถ -> Markdown
|
68 |
-
text_chunks = []
|
69 |
try:
|
|
|
70 |
with open(pdf_path, "rb") as f:
|
71 |
reader = PyPDF2.PdfReader(f)
|
72 |
for page_num, page in enumerate(reader.pages, start=1):
|
@@ -85,7 +85,7 @@ def pdf_to_markdown(pdf_path: str) -> str:
|
|
85 |
|
86 |
|
87 |
##################################################
|
88 |
-
# 2) ์ด๋ฏธ์ง/๋น๋์ค
|
89 |
##################################################
|
90 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
91 |
image_count = 0
|
@@ -102,11 +102,11 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
|
102 |
image_count = 0
|
103 |
video_count = 0
|
104 |
for item in history:
|
105 |
-
# assistant ๋ฉ์์ง์ด๊ฑฐ๋ content๊ฐ str์ด๋ฉด ์ ์ธ
|
106 |
if item["role"] != "user" or isinstance(item["content"], str):
|
107 |
continue
|
108 |
-
#
|
109 |
-
|
|
|
110 |
video_count += 1
|
111 |
else:
|
112 |
image_count += 1
|
@@ -115,13 +115,11 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
|
115 |
|
116 |
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
117 |
"""
|
118 |
-
|
119 |
-
- CSV, PDF, TXT ๋ฑ์ ๋์ ์ ์ธ
|
120 |
-
- <image> ํ๊ทธ์ ์ค์ ์ด๋ฏธ์ง ์๊ฐ ์ผ์นํ๋์ง ๋ฑ
|
121 |
"""
|
122 |
media_files = []
|
123 |
for f in message["files"]:
|
124 |
-
#
|
125 |
if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
|
126 |
media_files.append(f)
|
127 |
|
@@ -146,7 +144,7 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
|
146 |
if video_count == 0 and image_count > MAX_NUM_IMAGES:
|
147 |
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
|
148 |
return False
|
149 |
-
# <image> ํ๊ทธ์ ์ค์ ์ด๋ฏธ์ง
|
150 |
if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
|
151 |
gr.Warning("The number of <image> tags in the text does not match the number of images.")
|
152 |
return False
|
@@ -158,7 +156,6 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
|
158 |
# 3) ๋น๋์ค ์ฒ๋ฆฌ
|
159 |
##################################################
|
160 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
161 |
-
"""์์์์ ์ผ์ ๊ฐ๊ฒฉ์ผ๋ก ํ๋ ์์ ์ถ์ถ, PIL ์ด๋ฏธ์ง์ timestamp ๋ฐํ."""
|
162 |
vidcap = cv2.VideoCapture(video_path)
|
163 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
164 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
@@ -203,14 +200,13 @@ def process_interleaved_images(message: dict) -> list[dict]:
|
|
203 |
elif part.strip():
|
204 |
content.append({"type": "text", "text": part.strip()})
|
205 |
else:
|
206 |
-
# ๊ณต๋ฐฑ๋ง ์๋ ๊ฒฝ์ฐ
|
207 |
if isinstance(part, str) and part != "<image>":
|
208 |
content.append({"type": "text", "text": part})
|
209 |
return content
|
210 |
|
211 |
|
212 |
##################################################
|
213 |
-
# 5) CSV/PDF/TXT๋
|
214 |
##################################################
|
215 |
def process_new_user_message(message: dict) -> list[dict]:
|
216 |
if not message["files"]:
|
@@ -223,13 +219,12 @@ def process_new_user_message(message: dict) -> list[dict]:
|
|
223 |
txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
|
224 |
pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
|
225 |
|
226 |
-
# user ํ
์คํธ
|
227 |
content_list = [{"type": "text", "text": message["text"]}]
|
228 |
|
229 |
# CSV
|
230 |
for csv_path in csv_files:
|
231 |
csv_analysis = analyze_csv_file(csv_path)
|
232 |
-
# ๋ถ์ ๋ด์ฉ๋ง ๋ฃ์ (ํ์ผ ๊ฒฝ๋ก๋ฅผ ํ์คํ ๋ฆฌ์ ์ถ๊ฐํ์ง ์์)
|
233 |
content_list.append({"type": "text", "text": csv_analysis})
|
234 |
|
235 |
# TXT
|
@@ -249,10 +244,8 @@ def process_new_user_message(message: dict) -> list[dict]:
|
|
249 |
|
250 |
# ์ด๋ฏธ์ง
|
251 |
if "<image>" in message["text"]:
|
252 |
-
# interleaved
|
253 |
return process_interleaved_images(message)
|
254 |
else:
|
255 |
-
# ์ฌ๋ฌ ์ฅ ์ด๋ฏธ์ง
|
256 |
for img_path in image_files:
|
257 |
content_list.append({"type": "image", "url": img_path})
|
258 |
|
@@ -260,45 +253,58 @@ def process_new_user_message(message: dict) -> list[dict]:
|
|
260 |
|
261 |
|
262 |
##################################################
|
263 |
-
# 6)
|
264 |
##################################################
|
265 |
def process_history(history: list[dict]) -> list[dict]:
|
|
|
|
|
|
|
|
|
266 |
messages = []
|
267 |
-
current_user_content
|
268 |
for item in history:
|
269 |
if item["role"] == "assistant":
|
270 |
if current_user_content:
|
271 |
messages.append({"role": "user", "content": current_user_content})
|
272 |
current_user_content = []
|
|
|
273 |
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
|
274 |
else:
|
|
|
275 |
content = item["content"]
|
276 |
if isinstance(content, str):
|
|
|
277 |
current_user_content.append({"type": "text", "text": content})
|
278 |
else:
|
279 |
-
#
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
return messages
|
282 |
|
283 |
|
284 |
##################################################
|
285 |
-
# 7) ๋ฉ์ธ ์ถ๋ก
|
286 |
##################################################
|
287 |
@spaces.GPU(duration=120)
|
288 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
289 |
-
# a)
|
290 |
if not validate_media_constraints(message, history):
|
291 |
yield ""
|
292 |
return
|
293 |
|
294 |
-
# b)
|
295 |
messages = []
|
296 |
if system_prompt:
|
297 |
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
|
298 |
messages.extend(process_history(history))
|
299 |
messages.append({"role": "user", "content": process_new_user_message(message)})
|
300 |
|
301 |
-
# c) ๋ชจ๋ธ
|
302 |
inputs = processor.apply_chat_template(
|
303 |
messages,
|
304 |
add_generation_prompt=True,
|
@@ -308,11 +314,11 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
|
|
308 |
).to(device=model.device, dtype=torch.bfloat16)
|
309 |
|
310 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
311 |
-
gen_kwargs =
|
312 |
-
inputs,
|
313 |
-
streamer
|
314 |
-
max_new_tokens
|
315 |
-
|
316 |
t = Thread(target=model.generate, kwargs=gen_kwargs)
|
317 |
t.start()
|
318 |
|
@@ -322,6 +328,8 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
|
|
322 |
yield output
|
323 |
|
324 |
|
|
|
|
|
325 |
##################################################
|
326 |
# ์์๋ค (ํ๊ธํ ๋ฒ์ )
|
327 |
##################################################
|
@@ -457,6 +465,7 @@ examples = [
|
|
457 |
|
458 |
|
459 |
|
|
|
460 |
##################################################
|
461 |
# 9) Gradio ChatInterface
|
462 |
##################################################
|
@@ -464,7 +473,7 @@ demo = gr.ChatInterface(
|
|
464 |
fn=run,
|
465 |
type="messages",
|
466 |
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
|
467 |
-
#
|
468 |
textbox=gr.MultimodalTextbox(
|
469 |
file_types=[
|
470 |
".png", ".jpg", ".jpeg", ".gif", ".webp",
|
@@ -496,8 +505,7 @@ demo = gr.ChatInterface(
|
|
496 |
delete_cache=(1800, 1800),
|
497 |
)
|
498 |
|
|
|
499 |
if __name__ == "__main__":
|
500 |
demo.launch()
|
501 |
|
502 |
-
|
503 |
-
|
|
|
22 |
##################################################
|
23 |
# ์์ ๋ฐ ๋ชจ๋ธ ๋ก๋ฉ
|
24 |
##################################################
|
25 |
+
MAX_CONTENT_CHARS = 8000 # ํ
์คํธ๋ก ์ ๋ฌ ์ ์ต๋ 8000์๊น์ง๋ง
|
26 |
model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
|
27 |
|
28 |
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
|
|
40 |
# 1) CSV, TXT, PDF ๋ถ์ ํจ์
|
41 |
##################################################
|
42 |
def analyze_csv_file(path: str) -> str:
|
43 |
+
"""CSV ํ์ผ -> ๋ฌธ์์ด. ๊ธธ๋ฉด ์๋ผ๋."""
|
44 |
try:
|
45 |
df = pd.read_csv(path)
|
46 |
df_str = df.to_string()
|
|
|
52 |
|
53 |
|
54 |
def analyze_txt_file(path: str) -> str:
|
55 |
+
"""TXT ํ์ผ -> ์ ์ฒด ๋ฌธ์์ด. ๊ธธ๋ฉด ์๋ผ๋."""
|
56 |
try:
|
57 |
with open(path, "r", encoding="utf-8") as f:
|
58 |
text = f.read()
|
|
|
64 |
|
65 |
|
66 |
def pdf_to_markdown(pdf_path: str) -> str:
|
67 |
+
"""PDF -> ํ
์คํธ ์ถ์ถ -> Markdown. ๊ธธ๋ฉด ์๋ผ๋."""
|
|
|
68 |
try:
|
69 |
+
text_chunks = []
|
70 |
with open(pdf_path, "rb") as f:
|
71 |
reader = PyPDF2.PdfReader(f)
|
72 |
for page_num, page in enumerate(reader.pages, start=1):
|
|
|
85 |
|
86 |
|
87 |
##################################################
|
88 |
+
# 2) ์ด๋ฏธ์ง/๋น๋์ค ์ ํ ๊ฒ์ฌ (CSV, PDF, TXT ์ ์ธ)
|
89 |
##################################################
|
90 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
91 |
image_count = 0
|
|
|
102 |
image_count = 0
|
103 |
video_count = 0
|
104 |
for item in history:
|
|
|
105 |
if item["role"] != "user" or isinstance(item["content"], str):
|
106 |
continue
|
107 |
+
# item["content"]๊ฐ ["๊ฒฝ๋ก"] ํํ์ผ ๋, ํ์ฅ์๋ฅผ ํ์ธ
|
108 |
+
file_path = item["content"][0]
|
109 |
+
if file_path.endswith(".mp4"):
|
110 |
video_count += 1
|
111 |
else:
|
112 |
image_count += 1
|
|
|
115 |
|
116 |
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
117 |
"""
|
118 |
+
์ด๋ฏธ์ง & ๋น๋์ค ์ ํ
|
|
|
|
|
119 |
"""
|
120 |
media_files = []
|
121 |
for f in message["files"]:
|
122 |
+
# ์ด๋ฏธ์ง/๋น๋์ค๋ง
|
123 |
if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
|
124 |
media_files.append(f)
|
125 |
|
|
|
144 |
if video_count == 0 and image_count > MAX_NUM_IMAGES:
|
145 |
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
|
146 |
return False
|
147 |
+
# <image> ํ๊ทธ์ ์ค์ ์ด๋ฏธ์ง ์ ์ผ์น?
|
148 |
if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
|
149 |
gr.Warning("The number of <image> tags in the text does not match the number of images.")
|
150 |
return False
|
|
|
156 |
# 3) ๋น๋์ค ์ฒ๋ฆฌ
|
157 |
##################################################
|
158 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
|
|
159 |
vidcap = cv2.VideoCapture(video_path)
|
160 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
161 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
200 |
elif part.strip():
|
201 |
content.append({"type": "text", "text": part.strip()})
|
202 |
else:
|
|
|
203 |
if isinstance(part, str) and part != "<image>":
|
204 |
content.append({"type": "text", "text": part})
|
205 |
return content
|
206 |
|
207 |
|
208 |
##################################################
|
209 |
+
# 5) CSV/PDF/TXT๋ ํ
์คํธ ๋ณํ๋ง, ์ด๋ฏธ์ง/๋น๋์ค๋ ๊ฒฝ๋ก๋ก
|
210 |
##################################################
|
211 |
def process_new_user_message(message: dict) -> list[dict]:
|
212 |
if not message["files"]:
|
|
|
219 |
txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
|
220 |
pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
|
221 |
|
222 |
+
# user ํ
์คํธ ์ถ๊ฐ
|
223 |
content_list = [{"type": "text", "text": message["text"]}]
|
224 |
|
225 |
# CSV
|
226 |
for csv_path in csv_files:
|
227 |
csv_analysis = analyze_csv_file(csv_path)
|
|
|
228 |
content_list.append({"type": "text", "text": csv_analysis})
|
229 |
|
230 |
# TXT
|
|
|
244 |
|
245 |
# ์ด๋ฏธ์ง
|
246 |
if "<image>" in message["text"]:
|
|
|
247 |
return process_interleaved_images(message)
|
248 |
else:
|
|
|
249 |
for img_path in image_files:
|
250 |
content_list.append({"type": "image", "url": img_path})
|
251 |
|
|
|
253 |
|
254 |
|
255 |
##################################################
|
256 |
+
# 6) ํ์คํ ๋ฆฌ -> LLM ๋ฉ์์ง ๋ณํ
|
257 |
##################################################
|
258 |
def process_history(history: list[dict]) -> list[dict]:
|
259 |
+
"""
|
260 |
+
์ฌ๊ธฐ์, ์ด๋ฏธ์ง/๋น๋์ค ์ธ์ ํ์ผ(.csv, .pdf, .txt) ๊ฒฝ๋ก๋
|
261 |
+
๋ชจ๋ธ๋ก ์ ๋ฌ๋์ง ์๋๋ก ์ ๊ฑฐ (or ๋ฌด์)
|
262 |
+
"""
|
263 |
messages = []
|
264 |
+
current_user_content = []
|
265 |
for item in history:
|
266 |
if item["role"] == "assistant":
|
267 |
if current_user_content:
|
268 |
messages.append({"role": "user", "content": current_user_content})
|
269 |
current_user_content = []
|
270 |
+
# assistant -> ๊ทธ๋ฅ ํ
์คํธ๋ก
|
271 |
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
|
272 |
else:
|
273 |
+
# user
|
274 |
content = item["content"]
|
275 |
if isinstance(content, str):
|
276 |
+
# ๋จ์ ํ
์คํธ
|
277 |
current_user_content.append({"type": "text", "text": content})
|
278 |
else:
|
279 |
+
# ๋ณดํต [ํ์ผ๊ฒฝ๋ก] ํํ
|
280 |
+
file_path = content[0]
|
281 |
+
# ๋ง์ฝ ์ด๋ฏธ์ง๋ mp4๊ฐ ์๋๋ผ๋ฉด -> ๋ฌด์
|
282 |
+
if re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE) or file_path.endswith(".mp4"):
|
283 |
+
current_user_content.append({"type": "image", "url": file_path})
|
284 |
+
else:
|
285 |
+
# csv, pdf, txt ๋ฑ์ ์ ๊ฑฐ
|
286 |
+
pass
|
287 |
return messages
|
288 |
|
289 |
|
290 |
##################################################
|
291 |
+
# 7) ๋ฉ์ธ ์ถ๋ก
|
292 |
##################################################
|
293 |
@spaces.GPU(duration=120)
|
294 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
295 |
+
# a) ๋ฏธ๋์ด ์ ํ ๊ฒ์ฌ
|
296 |
if not validate_media_constraints(message, history):
|
297 |
yield ""
|
298 |
return
|
299 |
|
300 |
+
# b) ๊ธฐ์กด ํ์คํ ๋ฆฌ -> LLM ๋ฉ์์ง
|
301 |
messages = []
|
302 |
if system_prompt:
|
303 |
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
|
304 |
messages.extend(process_history(history))
|
305 |
messages.append({"role": "user", "content": process_new_user_message(message)})
|
306 |
|
307 |
+
# c) ๋ชจ๋ธ ํธ์ถ
|
308 |
inputs = processor.apply_chat_template(
|
309 |
messages,
|
310 |
add_generation_prompt=True,
|
|
|
314 |
).to(device=model.device, dtype=torch.bfloat16)
|
315 |
|
316 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
317 |
+
gen_kwargs = {
|
318 |
+
"inputs": inputs,
|
319 |
+
"streamer": streamer,
|
320 |
+
"max_new_tokens": max_new_tokens,
|
321 |
+
}
|
322 |
t = Thread(target=model.generate, kwargs=gen_kwargs)
|
323 |
t.start()
|
324 |
|
|
|
328 |
yield output
|
329 |
|
330 |
|
331 |
+
|
332 |
+
|
333 |
##################################################
|
334 |
# ์์๋ค (ํ๊ธํ ๋ฒ์ )
|
335 |
##################################################
|
|
|
465 |
|
466 |
|
467 |
|
468 |
+
|
469 |
##################################################
|
470 |
# 9) Gradio ChatInterface
|
471 |
##################################################
|
|
|
473 |
fn=run,
|
474 |
type="messages",
|
475 |
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
|
476 |
+
# ์ด๋ฏธ์ง(์ฌ๋ฌ ํ์ฅ์), mp4, csv, txt, pdf ํ์ฉ
|
477 |
textbox=gr.MultimodalTextbox(
|
478 |
file_types=[
|
479 |
".png", ".jpg", ".jpeg", ".gif", ".webp",
|
|
|
505 |
delete_cache=(1800, 1800),
|
506 |
)
|
507 |
|
508 |
+
|
509 |
if __name__ == "__main__":
|
510 |
demo.launch()
|
511 |
|
|
|
|