ginipick commited on
Commit
b32c775
Β·
verified Β·
1 Parent(s): 5ad049a

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -802
app.py DELETED
@@ -1,802 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import os
4
- import re
5
- import tempfile
6
- import gc # garbage collector μΆ”κ°€
7
- from collections.abc import Iterator
8
- from threading import Thread
9
- import json
10
- import requests
11
- import cv2
12
- import base64
13
- import logging
14
- import time
15
- from urllib.parse import quote # URL 인코딩 (ν•„μš” μ‹œ μ‚¬μš©)
16
-
17
- import gradio as gr
18
- import spaces
19
- import torch
20
- from loguru import logger
21
- from PIL import Image
22
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
23
-
24
- # CSV/TXT/PDF 뢄석
25
- import pandas as pd
26
- import PyPDF2
27
-
28
- # =============================================================================
29
- # (μ‹ κ·œ) 이미지 API κ΄€λ ¨ ν•¨μˆ˜λ“€
30
- # =============================================================================
31
- from gradio_client import Client
32
-
33
- API_URL = "http://211.233.58.201:7896"
34
-
35
- logging.basicConfig(
36
- level=logging.DEBUG,
37
- format='%(asctime)s - %(levelname)s - %(message)s'
38
- )
39
-
40
- def test_api_connection() -> str:
41
- """API μ„œλ²„ μ—°κ²° ν…ŒμŠ€νŠΈ"""
42
- try:
43
- client = Client(API_URL)
44
- return "API μ—°κ²° 성곡: 정상 μž‘λ™ 쀑"
45
- except Exception as e:
46
- logging.error(f"API connection test failed: {e}")
47
- return f"API μ—°κ²° μ‹€νŒ¨: {e}"
48
-
49
- def generate_image(prompt: str, width: float, height: float, guidance: float, inference_steps: float, seed: float):
50
- """
51
- 이미지 생성 ν•¨μˆ˜.
52
- μ—¬κΈ°μ„œλŠ” μ„œλ²„κ°€ μ΅œμ’… 이미지λ₯Ό Base64(λ˜λŠ” data:image/...) ν˜•νƒœλ‘œ 직접 λ°˜ν™˜ν•œλ‹€κ³  κ°€μ •ν•©λ‹ˆλ‹€.
53
- /tmp/... κ²½λ‘œλ‚˜ μΆ”κ°€ λ‹€μš΄λ‘œλ“œλ₯Ό μ‹œλ„ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.
54
- """
55
- if not prompt:
56
- return None, "Error: Prompt is required"
57
- try:
58
- logging.info(f"Calling image generation API with prompt: {prompt}")
59
-
60
- client = Client(API_URL)
61
- result = client.predict(
62
- prompt=prompt,
63
- width=int(width),
64
- height=int(height),
65
- guidance=float(guidance),
66
- inference_steps=int(inference_steps),
67
- seed=int(seed),
68
- do_img2img=False,
69
- init_image=None,
70
- image2image_strength=0.8,
71
- resize_img=True,
72
- api_name="/generate_image"
73
- )
74
-
75
- logging.info(
76
- f"Image generation result: {type(result)}, "
77
- f"length: {len(result) if isinstance(result, (list, tuple)) else 'unknown'}"
78
- )
79
-
80
- # κ²°κ³Όκ°€ νŠœν”Œ/리슀트: [이미지_base64 or data_url, seed_info] 둜 κ°€μ •
81
- if isinstance(result, (list, tuple)) and len(result) > 0:
82
- image_data = result[0] # 첫 번째 μš”μ†Œκ°€ 이미지 데이터 (Base64 or data:image/... λ“±)
83
- seed_info = result[1] if len(result) > 1 else "Unknown seed"
84
- return image_data, seed_info
85
- else:
86
- # λ‹€λ₯Έ ν˜•νƒœλ‘œ λ°˜ν™˜λœ 경우
87
- return result, "Unknown seed"
88
-
89
- except Exception as e:
90
- logging.error(f"Image generation failed: {str(e)}")
91
- return None, f"Error: {str(e)}"
92
-
93
- # Base64 νŒ¨λ”© μˆ˜μ • ν•¨μˆ˜ (ν•„μš”ν•˜λ‹€λ©΄ μ‚¬μš©)
94
- def fix_base64_padding(data):
95
- """Base64 λ¬Έμžμ—΄μ˜ νŒ¨λ”©μ„ μˆ˜μ •ν•©λ‹ˆλ‹€."""
96
- if isinstance(data, bytes):
97
- data = data.decode('utf-8')
98
-
99
- if "base64," in data:
100
- data = data.split("base64,", 1)[1]
101
-
102
- missing_padding = len(data) % 4
103
- if missing_padding:
104
- data += '=' * (4 - missing_padding)
105
-
106
- return data
107
-
108
- # =============================================================================
109
- # λ©”λͺ¨λ¦¬ 정리 ν•¨μˆ˜
110
- # =============================================================================
111
- def clear_cuda_cache():
112
- """CUDA μΊμ‹œλ₯Ό λͺ…μ‹œμ μœΌλ‘œ λΉ„μ›λ‹ˆλ‹€."""
113
- if torch.cuda.is_available():
114
- torch.cuda.empty_cache()
115
- gc.collect()
116
-
117
- # =============================================================================
118
- # SerpHouse κ΄€λ ¨ ν•¨μˆ˜
119
- # =============================================================================
120
- SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
121
-
122
- def extract_keywords(text: str, top_k: int = 5) -> str:
123
- """λ‹¨μˆœ ν‚€μ›Œλ“œ μΆ”μΆœ: ν•œκΈ€, μ˜μ–΄, 숫자, 곡백만 남김"""
124
- text = re.sub(r"[^a-zA-Z0-9κ°€-힣\s]", "", text)
125
- tokens = text.split()
126
- return " ".join(tokens[:top_k])
127
-
128
- def do_web_search(query: str) -> str:
129
- """
130
- SerpHouse LIVE API ν˜ΈμΆœν•˜μ—¬ 검색 κ²°κ³Ό λ§ˆν¬λ‹€μš΄ λ°˜ν™˜
131
- (ν•„μš”ν•˜λ‹€λ©΄ μˆ˜μ • or μ‚­μ œ κ°€λŠ₯)
132
- """
133
- try:
134
- url = "https://api.serphouse.com/serp/live"
135
- params = {
136
- "q": query,
137
- "domain": "google.com",
138
- "serp_type": "web",
139
- "device": "desktop",
140
- "lang": "en",
141
- "num": "20"
142
- }
143
- headers = {"Authorization": f"Bearer {SERPHOUSE_API_KEY}"}
144
- logger.info(f"SerpHouse API 호좜 쀑... 검색어: {query}")
145
- response = requests.get(url, headers=headers, params=params, timeout=60)
146
- response.raise_for_status()
147
- data = response.json()
148
- results = data.get("results", {})
149
- organic = None
150
- if isinstance(results, dict) and "organic" in results:
151
- organic = results["organic"]
152
- elif isinstance(results, dict) and "results" in results:
153
- if isinstance(results["results"], dict) and "organic" in results["results"]:
154
- organic = results["results"]["organic"]
155
- elif "organic" in data:
156
- organic = data["organic"]
157
- if not organic:
158
- logger.warning("μ‘λ‹΅μ—μ„œ organic κ²°κ³Όλ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
159
- return "No web search results found or unexpected API response structure."
160
- max_results = min(20, len(organic))
161
- limited_organic = organic[:max_results]
162
- summary_lines = []
163
- for idx, item in enumerate(limited_organic, start=1):
164
- title = item.get("title", "No title")
165
- link = item.get("link", "#")
166
- snippet = item.get("snippet", "No description")
167
- displayed_link = item.get("displayed_link", link)
168
- summary_lines.append(
169
- f"### Result {idx}: {title}\n\n"
170
- f"{snippet}\n\n"
171
- f"**좜처**: [{displayed_link}]({link})\n\n"
172
- f"---\n"
173
- )
174
- instructions = """
175
- # μ›Ή 검색 κ²°κ³Ό
176
- μ•„λž˜λŠ” 검색 κ²°κ³Όμž…λ‹ˆλ‹€. μ§ˆλ¬Έμ— λ‹΅λ³€ν•  λ•Œ 이 정보λ₯Ό ν™œμš©ν•˜μ„Έμš”:
177
- 1. μ—¬λŸ¬ 좜처 λ‚΄μš©μ„ μ’…ν•©ν•˜μ—¬ λ‹΅λ³€.
178
- 2. 좜처 인용 μ‹œ "[좜처 제λͺ©](링크)" λ§ˆν¬λ‹€μš΄ ν˜•μ‹ μ‚¬μš©.
179
- 3. λ‹΅λ³€ λ§ˆμ§€λ§‰μ— 'μ°Έκ³  자료:' μ„Ήμ…˜μ— μ‚¬μš©ν•œ μ£Όμš” 좜처λ₯Ό λ‚˜μ—΄.
180
- """
181
- return instructions + "\n".join(summary_lines)
182
- except Exception as e:
183
- logger.error(f"Web search failed: {e}")
184
- return f"Web search failed: {str(e)}"
185
-
186
- # =============================================================================
187
- # λͺ¨λΈ 및 ν”„λ‘œμ„Έμ„œ λ‘œλ”©
188
- # =============================================================================
189
- MAX_CONTENT_CHARS = 2000
190
- MAX_INPUT_LENGTH = 2096
191
-
192
- model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
193
- processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
194
- model = Gemma3ForConditionalGeneration.from_pretrained(
195
- model_id,
196
- device_map="auto",
197
- torch_dtype=torch.bfloat16,
198
- attn_implementation="eager"
199
- )
200
-
201
- MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
202
-
203
- # =============================================================================
204
- # CSV, TXT, PDF 뢄석 ν•¨μˆ˜
205
- # =============================================================================
206
- def analyze_csv_file(path: str) -> str:
207
- try:
208
- df = pd.read_csv(path)
209
- if df.shape[0] > 50 or df.shape[1] > 10:
210
- df = df.iloc[:50, :10]
211
- df_str = df.to_string()
212
- if len(df_str) > MAX_CONTENT_CHARS:
213
- df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
214
- return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}"
215
- except Exception as e:
216
- return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
217
-
218
- def analyze_txt_file(path: str) -> str:
219
- try:
220
- with open(path, "r", encoding="utf-8") as f:
221
- text = f.read()
222
- if len(text) > MAX_CONTENT_CHARS:
223
- text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
224
- return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}"
225
- except Exception as e:
226
- return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
227
-
228
- def pdf_to_markdown(pdf_path: str) -> str:
229
- text_chunks = []
230
- try:
231
- with open(pdf_path, "rb") as f:
232
- reader = PyPDF2.PdfReader(f)
233
- max_pages = min(5, len(reader.pages))
234
- for page_num in range(max_pages):
235
- page_text = reader.pages[page_num].extract_text() or ""
236
- page_text = page_text.strip()
237
- if page_text:
238
- if len(page_text) > MAX_CONTENT_CHARS // max_pages:
239
- page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)"
240
- text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n")
241
- if len(reader.pages) > max_pages:
242
- text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...")
243
- except Exception as e:
244
- return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}"
245
- full_text = "\n".join(text_chunks)
246
- if len(full_text) > MAX_CONTENT_CHARS:
247
- full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
248
- return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
249
-
250
- # =============================================================================
251
- # 이미지/λΉ„λ””μ˜€ 파일 μ œν•œ 검사
252
- # =============================================================================
253
- def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
254
- image_count = 0
255
- video_count = 0
256
- for path in paths:
257
- if path.endswith(".mp4"):
258
- video_count += 1
259
- elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE):
260
- image_count += 1
261
- return image_count, video_count
262
-
263
- def count_files_in_history(history: list[dict]) -> tuple[int, int]:
264
- image_count = 0
265
- video_count = 0
266
- for item in history:
267
- if item["role"] != "user" or isinstance(item["content"], str):
268
- continue
269
- if isinstance(item["content"], list) and len(item["content"]) > 0:
270
- file_path = item["content"][0]
271
- if isinstance(file_path, str):
272
- if file_path.endswith(".mp4"):
273
- video_count += 1
274
- elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE):
275
- image_count += 1
276
- return image_count, video_count
277
-
278
- def validate_media_constraints(message: dict, history: list[dict]) -> bool:
279
- """이미지/λΉ„λ””μ˜€ μ—…λ‘œλ“œ μ œν•œ 검사."""
280
- media_files = [f for f in message["files"]
281
- if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4")]
282
- new_image_count, new_video_count = count_files_in_new_message(media_files)
283
- history_image_count, history_video_count = count_files_in_history(history)
284
-
285
- image_count = history_image_count + new_image_count
286
- video_count = history_video_count + new_video_count
287
-
288
- if video_count > 1:
289
- gr.Warning("Only one video is supported.")
290
- return False
291
- if video_count == 1:
292
- if image_count > 0:
293
- gr.Warning("Mixing images and videos is not allowed.")
294
- return False
295
- if "<image>" in message["text"]:
296
- gr.Warning("Using <image> tags with video files is not supported.")
297
- return False
298
- if video_count == 0 and image_count > MAX_NUM_IMAGES:
299
- gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
300
- return False
301
- if "<image>" in message["text"]:
302
- image_files = [f for f in message["files"]
303
- if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
304
- image_tag_count = message["text"].count("<image>")
305
- if image_tag_count != len(image_files):
306
- gr.Warning("The number of <image> tags in the text does not match the number of image files.")
307
- return False
308
- return True
309
-
310
- # =============================================================================
311
- # λΉ„λ””μ˜€ 처리 ν•¨μˆ˜
312
- # =============================================================================
313
- def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
314
- vidcap = cv2.VideoCapture(video_path)
315
- fps = vidcap.get(cv2.CAP_PROP_FPS)
316
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
317
- frame_interval = max(int(fps), int(total_frames / 10))
318
- frames = []
319
- for i in range(0, total_frames, frame_interval):
320
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
321
- success, image = vidcap.read()
322
- if success:
323
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
324
- image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5)
325
- pil_image = Image.fromarray(image)
326
- timestamp = round(i / fps, 2)
327
- frames.append((pil_image, timestamp))
328
- if len(frames) >= 5:
329
- break
330
- vidcap.release()
331
- return frames
332
-
333
- def process_video(video_path: str) -> tuple[list[dict], list[str]]:
334
- content = []
335
- temp_files = []
336
- frames = downsample_video(video_path)
337
- for pil_image, timestamp in frames:
338
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
339
- pil_image.save(temp_file.name)
340
- temp_files.append(temp_file.name)
341
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
342
- content.append({"type": "image", "url": temp_file.name})
343
- return content, temp_files
344
-
345
- # =============================================================================
346
- # interleaved <image> 처리 ν•¨μˆ˜ (<image> νƒœκ·Έμ™€ 이미지 μ—…λ‘œλ“œ ν˜Όν•© 지원)
347
- # =============================================================================
348
- def process_interleaved_images(message: dict) -> list[dict]:
349
- parts = re.split(r"(<image>)", message["text"])
350
- content = []
351
- image_files = [f for f in message["files"]
352
- if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
353
- image_index = 0
354
- for part in parts:
355
- if part == "<image>" and image_index < len(image_files):
356
- content.append({"type": "image", "url": image_files[image_index]})
357
- image_index += 1
358
- elif part.strip():
359
- content.append({"type": "text", "text": part.strip()})
360
- else:
361
- if isinstance(part, str) and part != "<image>":
362
- content.append({"type": "text", "text": part})
363
- return content
364
-
365
- # =============================================================================
366
- # 파일 처리 -> content 생성
367
- # =============================================================================
368
- def is_image_file(file_path: str) -> bool:
369
- return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
370
-
371
- def is_video_file(file_path: str) -> bool:
372
- return file_path.endswith(".mp4")
373
-
374
- def is_document_file(file_path: str) -> bool:
375
- return file_path.lower().endswith(".pdf") or file_path.lower().endswith(".csv") or file_path.lower().endswith(".txt")
376
-
377
- def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
378
- """μ‚¬μš©μžκ°€ μƒˆλ‘œ μž…λ ₯ν•œ λ©”μ‹œμ§€ + μ—…λ‘œλ“œ νŒŒμΌλ“€μ„ ν•˜λ‚˜μ˜ content(list)둜 λ³€ν™˜."""
379
- temp_files = []
380
- if not message["files"]:
381
- return [{"type": "text", "text": message["text"]}], temp_files
382
-
383
- video_files = [f for f in message["files"] if is_video_file(f)]
384
- image_files = [f for f in message["files"] if is_image_file(f)]
385
- csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
386
- txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
387
- pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
388
-
389
- content_list = [{"type": "text", "text": message["text"]}]
390
-
391
- # λ¬Έμ„œλ“€
392
- for csv_path in csv_files:
393
- content_list.append({"type": "text", "text": analyze_csv_file(csv_path)})
394
- for txt_path in txt_files:
395
- content_list.append({"type": "text", "text": analyze_txt_file(txt_path)})
396
- for pdf_path in pdf_files:
397
- content_list.append({"type": "text", "text": pdf_to_markdown(pdf_path)})
398
-
399
- # λΉ„λ””μ˜€ 처리
400
- if video_files:
401
- video_content, video_temp_files = process_video(video_files[0])
402
- content_list += video_content
403
- temp_files.extend(video_temp_files)
404
- return content_list, temp_files
405
-
406
- # 이미지 처리
407
- if "<image>" in message["text"] and image_files:
408
- interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
409
- if content_list and content_list[0]["type"] == "text":
410
- content_list = content_list[1:]
411
- return interleaved_content + content_list, temp_files
412
- else:
413
- for img_path in image_files:
414
- content_list.append({"type": "image", "url": img_path})
415
-
416
- return content_list, temp_files
417
-
418
- # =============================================================================
419
- # history -> LLM λ©”μ‹œμ§€ λ³€ν™˜
420
- # =============================================================================
421
- def process_history(history: list[dict]) -> list[dict]:
422
- """
423
- κΈ°μ‘΄ λŒ€ν™” 기둝을 LLM에 맞게 λ³€ν™˜.
424
- - user -> {"role":"user","content":[{type,text},...]}
425
- - assistant -> {"role":"assistant","content":[{type:"text",text},...]}
426
- """
427
- messages = []
428
- current_user_content = []
429
- for item in history:
430
- if item["role"] == "assistant":
431
- # μ‚¬μš©μž content λˆ„μ λΆ„μ΄ 있으면 ν•œλ²ˆμ— user둜 μΆ”κ°€
432
- if current_user_content:
433
- messages.append({"role": "user", "content": current_user_content})
434
- current_user_content = []
435
- # assistant λ°”λ‘œ μΆ”κ°€
436
- messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
437
- else:
438
- content = item["content"]
439
- if isinstance(content, str):
440
- current_user_content.append({"type": "text", "text": content})
441
- elif isinstance(content, list) and len(content) > 0:
442
- file_path = content[0]
443
- if is_image_file(file_path):
444
- current_user_content.append({"type": "image", "url": file_path})
445
- else:
446
- current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"})
447
- if current_user_content:
448
- messages.append({"role": "user", "content": current_user_content})
449
- return messages
450
-
451
- # =============================================================================
452
- # λͺ¨λΈ 생성 ν•¨μˆ˜ (OOM 캐치)
453
- # =============================================================================
454
- def _model_gen_with_oom_catch(**kwargs):
455
- try:
456
- model.generate(**kwargs)
457
- except torch.cuda.OutOfMemoryError:
458
- raise RuntimeError("[OutOfMemoryError] GPU λ©”λͺ¨λ¦¬κ°€ λΆ€μ‘±ν•©λ‹ˆλ‹€.")
459
- finally:
460
- clear_cuda_cache()
461
-
462
- # =============================================================================
463
- # 메인 μΆ”λ‘  ν•¨μˆ˜
464
- # =============================================================================
465
- @spaces.GPU(duration=120)
466
- def run(
467
- message: dict,
468
- history: list[dict],
469
- system_prompt: str = "",
470
- max_new_tokens: int = 512,
471
- use_web_search: bool = False,
472
- web_search_query: str = "",
473
- age_group: str = "20λŒ€",
474
- mbti_personality: str = "INTP",
475
- sexual_openness: int = 2,
476
- image_gen: bool = False
477
- ) -> Iterator[str]:
478
- """
479
- LLM μΆ”λ‘  ν•¨μˆ˜.
480
- - 이미지 생성 μ‹œ, μ„œλ²„κ°€ Base64(λ˜λŠ” data:image/... ν˜•νƒœ)λ₯Ό 직접 λ°˜ν™˜ν•œλ‹€κ³  κ°€μ •.
481
- - /tmp/... νŒŒμΌμ— λŒ€ν•œ μž¬λ‹€μš΄λ‘œλ“œλ₯Ό μ‹œλ„ν•˜μ§€ μ•ŠμŒ (403 Forbidden 문제 νšŒν”Ό).
482
- """
483
- if not validate_media_constraints(message, history):
484
- yield ""
485
- return
486
-
487
- temp_files = []
488
- try:
489
- # 1) μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ + 페λ₯΄μ†Œλ‚˜ 정보
490
- persona = (
491
- f"{system_prompt.strip()}\n\n"
492
- f"Gender: Female\n"
493
- f"Age Group: {age_group}\n"
494
- f"MBTI Persona: {mbti_personality}\n"
495
- f"Sexual Openness (1~5): {sexual_openness}\n"
496
- )
497
- combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
498
-
499
- # 2) μ›Ή 검색 (μ˜΅μ…˜)
500
- if use_web_search:
501
- user_text = message["text"]
502
- ws_query = extract_keywords(user_text)
503
- if ws_query.strip():
504
- logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
505
- ws_result = do_web_search(ws_query)
506
- combined_system_msg += f"[Search top-20 Full Items]\n{ws_result}\n\n"
507
- combined_system_msg += (
508
- "[μ°Έκ³ : μœ„ 검색결과 linkλ₯Ό 좜처둜 μΈμš©ν•˜μ—¬ λ‹΅λ³€]\n"
509
- "[μ€‘μš” μ§€μ‹œμ‚¬ν•­]\n"
510
- "1. 검색 κ²°κ³Όμ—μ„œ 찾은 μ •λ³΄μ˜ 좜처λ₯Ό λ°˜λ“œμ‹œ 인용.\n"
511
- "2. '[좜처 제λͺ©](링크)' ν˜•μ‹μœΌλ‘œ 링크.\n"
512
- "3. λ‹΅λ³€ λ§ˆμ§€λ§‰μ— 'μ°Έκ³  자료:' μ„Ήμ…˜.\n"
513
- )
514
- else:
515
- combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
516
-
517
- # 3) κΈ°μ‘΄ history + μƒˆ user λ©”μ‹œμ§€
518
- messages = []
519
- if combined_system_msg.strip():
520
- messages.append({"role": "system", "content": [{"type": "text", "text": combined_system_msg.strip()}]})
521
- messages.extend(process_history(history))
522
-
523
- user_content, user_temp_files = process_new_user_message(message)
524
- temp_files.extend(user_temp_files)
525
-
526
- for item in user_content:
527
- if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
528
- item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
529
-
530
- messages.append({"role": "user", "content": user_content})
531
-
532
- # 4) ν† ν¬λ‚˜μ΄μ§•
533
- inputs = processor.apply_chat_template(
534
- messages,
535
- add_generation_prompt=True,
536
- tokenize=True,
537
- return_dict=True,
538
- return_tensors="pt",
539
- ).to(device=model.device, dtype=torch.bfloat16)
540
- if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH:
541
- inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
542
- if 'attention_mask' in inputs:
543
- inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
544
-
545
- streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
546
- gen_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
547
-
548
- t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
549
- t.start()
550
-
551
- # 슀트리밍 좜λ ₯
552
- output_so_far = ""
553
- for new_text in streamer:
554
- output_so_far += new_text
555
- yield output_so_far
556
-
557
- # 5) 이미지 생성 (Base64)
558
- if image_gen:
559
- last_user_text = message["text"].strip()
560
- if not last_user_text:
561
- yield output_so_far + "\n\n(이미지 생성 μ‹€νŒ¨: Empty user prompt)"
562
- else:
563
- try:
564
- width, height = 512, 512
565
- guidance, steps, seed = 7.5, 30, 42
566
-
567
- logger.info(f"Generating image with prompt: {last_user_text}")
568
-
569
- # API ν˜ΈμΆœν•΄μ„œ (base64) 이미지 생성
570
- image_result, seed_info = generate_image(
571
- prompt=last_user_text,
572
- width=width,
573
- height=height,
574
- guidance=guidance,
575
- inference_steps=steps,
576
- seed=seed
577
- )
578
-
579
- logger.info(f"Received image data type: {type(image_result)}")
580
-
581
- # Base64 or data:image/... 처리
582
- if image_result:
583
- if isinstance(image_result, str):
584
- # 이미 data:image/둜 μ‹œμž‘ν•˜λ©΄ κ·ΈλŒ€λ‘œ μ‚¬μš©
585
- if image_result.startswith("data:image/"):
586
- final_md = f"\n\n**[μƒμ„±λœ 이미지]**\n\n![μƒμ„±λœ 이미지]({image_result})"
587
- yield output_so_far + final_md
588
- else:
589
- # 순수 base64둜 νŒλ‹¨(단, 일반 URLμ΄λ‚˜ '/tmp/...'이면 처리 λΆˆκ°€)
590
- if len(image_result) > 100 and "/" not in image_result:
591
- # base64
592
- image_data = "data:image/webp;base64," + image_result
593
- final_md = f"\n\n**[μƒμ„±λœ 이미지]**\n\n![μƒμ„±λœ 이미지]({image_data})"
594
- yield output_so_far + final_md
595
- else:
596
- # κ·Έ μ™Έ (ex. http://..., /tmp/...) -> 403 문제 λ°œμƒν•˜λ―€λ‘œ ν‘œμ‹œ μ•ˆ 함
597
- yield output_so_far + "\n\n(이미지 생성 κ²°κ³Όκ°€ base64 ν˜•μ‹μ΄ μ•„λ‹™λ‹ˆλ‹€)"
598
- else:
599
- yield output_so_far + "\n\n(이미지 생성 κ²°κ³Όκ°€ λ¬Έμžμ—΄μ΄ μ•„λ‹˜)"
600
- else:
601
- yield output_so_far + f"\n\n(이미지 생성 μ‹€νŒ¨: {seed_info})"
602
-
603
- except Exception as e:
604
- logger.error(f"Image generation error: {e}")
605
- yield output_so_far + f"\n\n(이미지 생성 쀑 였λ₯˜ λ°œμƒ: {e})"
606
-
607
- except Exception as e:
608
- logger.error(f"Error in run: {str(e)}")
609
- yield f"μ£„μ†‘ν•©λ‹ˆλ‹€. 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}"
610
- finally:
611
- for tmp in temp_files:
612
- try:
613
- if os.path.exists(tmp):
614
- os.unlink(tmp)
615
- logger.info(f"Deleted temp file: {tmp}")
616
- except Exception as ee:
617
- logger.warning(f"Failed to delete temp file {tmp}: {ee}")
618
- try:
619
- del inputs, streamer
620
- except Exception:
621
- pass
622
- clear_cuda_cache()
623
-
624
- # =============================================================================
625
- # μ˜ˆμ‹œλ“€
626
- # =============================================================================
627
- examples = [
628
- [
629
- {
630
- "text": "Compare the contents of the two PDF files.",
631
- "files": [
632
- "assets/additional-examples/before.pdf",
633
- "assets/additional-examples/after.pdf",
634
- ],
635
- }
636
- ],
637
- [
638
- {
639
- "text": "Summarize and analyze the contents of the CSV file.",
640
- "files": ["assets/additional-examples/sample-csv.csv"],
641
- }
642
- ],
643
- # ... λ‚˜λ¨Έμ§€ μ˜ˆμ‹œ ν•„μš”ν•˜λ‹€λ©΄ μΆ”κ°€ ...
644
- ]
645
-
646
- # =============================================================================
647
- # Gradio UI (Blocks) ꡬ성
648
- # =============================================================================
649
-
650
- css = """
651
- .gradio-container {
652
- background: rgba(255, 255, 255, 0.7);
653
- padding: 30px 40px;
654
- margin: 20px auto;
655
- width: 100% !important;
656
- max-width: none !important;
657
- }
658
- """
659
- title_html = """
660
- <h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> πŸ’˜ HeartSync : Love Dating AI πŸ’˜ </h1>
661
- <p align="center" style="font-size:1.1em; color:#555;">
662
- βœ… FLUX Image Generation βœ… Reasoning & Uncensored βœ… Multimodal & VLM βœ… Deep-Research & RAG <br>
663
- </p>
664
- """
665
-
666
- with gr.Blocks(css=css, title="HeartSync") as demo:
667
- gr.Markdown(title_html)
668
-
669
- # 별도 가러리 μ˜ˆμ‹œ (ν•„μš” μ‹œ μ‚¬μš©)
670
- generated_images = gr.Gallery(
671
- label="μƒμ„±λœ 이미지",
672
- show_label=True,
673
- visible=False,
674
- elem_id="generated_images",
675
- columns=2,
676
- height="auto",
677
- object_fit="contain"
678
- )
679
-
680
- with gr.Row():
681
- web_search_checkbox = gr.Checkbox(label="Deep Research", value=False)
682
- image_gen_checkbox = gr.Checkbox(label="Image Gen", value=False)
683
-
684
- base_system_prompt_box = gr.Textbox(
685
- lines=3,
686
- value="You are a deep thinking AI...\n페λ₯΄μ†Œλ‚˜: 당신은 λ‹¬μ½€ν•˜κ³ ...",
687
- label="κΈ°λ³Έ μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ",
688
- visible=False
689
- )
690
- with gr.Row():
691
- age_group_dropdown = gr.Dropdown(
692
- label="μ—°λ ΉλŒ€ 선택 (κΈ°λ³Έ 20λŒ€)",
693
- choices=["10λŒ€", "20λŒ€", "30~40λŒ€", "50~60λŒ€", "70λŒ€ 이상"],
694
- value="20λŒ€",
695
- interactive=True
696
- )
697
- mbti_choices = [
698
- "INTJ (μš©μ˜μ£Όλ„ν•œ μ „λž΅κ°€)",
699
- "INTP (논리적인 사색가)",
700
- "ENTJ (λŒ€λ‹΄ν•œ ν†΅μ†”μž)",
701
- "ENTP (뜨거운 λ…ΌμŸκ°€)",
702
- "INFJ (μ„ μ˜μ˜ 옹호자)",
703
- "INFP (열정적인 μ€‘μž¬μž)",
704
- "ENFJ (μ •μ˜λ‘œμš΄ μ‚¬νšŒμš΄λ™κ°€)",
705
- "ENFP (μž¬κΈ°λ°œλž„ν•œ ν™œλ™κ°€)",
706
- "ISTJ (μ²­λ ΄κ²°λ°±ν•œ λ…Όλ¦¬μ£Όμ˜μž)",
707
- "ISFJ (μš©κ°ν•œ 수호자)",
708
- "ESTJ (μ—„κ²©ν•œ κ΄€λ¦¬μž)",
709
- "ESFJ (사ꡐ적인 외ꡐ관)",
710
- "ISTP (만λŠ₯ 재주꾼)",
711
- "ISFP (ν˜ΈκΈ°μ‹¬ λ§Žμ€ μ˜ˆμˆ κ°€)",
712
- "ESTP (λͺ¨ν—˜μ„ μ¦κΈ°λŠ” 사업가)",
713
- "ESFP (자유둜운 영혼의 μ—°μ˜ˆμΈ)"
714
- ]
715
- mbti_dropdown = gr.Dropdown(
716
- label="AI 페λ₯΄μ†Œλ‚˜ MBTI (κΈ°λ³Έ INTP)",
717
- choices=mbti_choices,
718
- value="INTP (논리적인 사색가)",
719
- interactive=True
720
- )
721
- sexual_openness_slider = gr.Slider(
722
- minimum=1, maximum=5, step=1, value=2,
723
- label="μ„ΉμŠˆμ–Ό 관심도/κ°œλ°©μ„± (1~5, κΈ°λ³Έ=2)",
724
- interactive=True
725
- )
726
- max_tokens_slider = gr.Slider(
727
- label="Max New Tokens",
728
- minimum=100, maximum=8000, step=50, value=1000,
729
- visible=False
730
- )
731
- web_search_text = gr.Textbox(
732
- lines=1,
733
- label="(Unused) Web Search Query",
734
- placeholder="No direct input needed",
735
- visible=False
736
- )
737
-
738
- def modified_run(
739
- message, history, system_prompt, max_new_tokens,
740
- use_web_search, web_search_query,
741
- age_group, mbti_personality, sexual_openness, image_gen
742
- ):
743
- """
744
- run() ν•¨μˆ˜λ₯Ό ν˜ΈμΆœν•˜μ—¬ ν…μŠ€νŠΈ μŠ€νŠΈλ¦Όμ„ λ°›κ³ ,
745
- ν•„μš” μ‹œ μΆ”κ°€ 처리 ν›„ κ²°κ³Ό λ°˜ν™˜ (가러리 μ—…λ°μ΄νŠΈ λ“±).
746
- """
747
- output_so_far = ""
748
- gallery_update = gr.Gallery(visible=False, value=[])
749
- yield output_so_far, gallery_update
750
-
751
- text_generator = run(
752
- message, history,
753
- system_prompt, max_new_tokens,
754
- use_web_search, web_search_query,
755
- age_group, mbti_personality,
756
- sexual_openness, image_gen
757
- )
758
-
759
- for text_chunk in text_generator:
760
- output_so_far = text_chunk
761
- yield output_so_far, gallery_update
762
-
763
- # λ§Œμ•½ run() λ‚΄λΆ€μ—μ„œ Base64 이미지λ₯Ό 이미 λŒ€ν™”μ°½μ— μ‚½μž…ν–ˆλ‹€λ©΄,
764
- # μ—¬κΈ°μ„œ κ°€λŸ¬λ¦¬μ— λ”°λ‘œ ν‘œμ‹œν•  ν•„μš”λŠ” 없을 μˆ˜λ„ μžˆμŠ΅λ‹ˆλ‹€.
765
- # run() λ‚΄λΆ€μ—μ„œμ˜ image_resultλ₯Ό κ°€μ Έμ˜€λ €λ©΄, run() ν•¨μˆ˜κ°€ ν•΄λ‹Ή 정보λ₯Ό λ°˜ν™˜ν•˜λ„λ‘ μΆ”κ°€ μˆ˜μ •μ΄ ν•„μš”ν•©λ‹ˆλ‹€.
766
-
767
- chat = gr.ChatInterface(
768
- fn=modified_run,
769
- type="messages",
770
- chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
771
- textbox=gr.MultimodalTextbox(
772
- file_types=[".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf"],
773
- file_count="multiple",
774
- autofocus=True
775
- ),
776
- multimodal=True,
777
- additional_inputs=[
778
- base_system_prompt_box,
779
- max_tokens_slider,
780
- web_search_checkbox,
781
- web_search_text,
782
- age_group_dropdown,
783
- mbti_dropdown,
784
- sexual_openness_slider,
785
- image_gen_checkbox,
786
- ],
787
- additional_outputs=[generated_images],
788
- stop_btn=False,
789
- title='<a href="https://discord.gg/openfreeai" target="_blank">https://discord.gg/openfreeai</a>',
790
- examples=examples,
791
- run_examples_on_click=False,
792
- cache_examples=False,
793
- css_paths=None,
794
- delete_cache=(1800, 1800),
795
- )
796
-
797
- with gr.Row(elem_id="examples_row"):
798
- with gr.Column(scale=12, elem_id="examples_container"):
799
- gr.Markdown("### Example Inputs (click to load)")
800
-
801
- if __name__ == "__main__":
802
- demo.launch(share=True)