ginipick commited on
Commit
5ad049a
ยท
verified ยท
1 Parent(s): 2100944

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +321 -476
app.py CHANGED
@@ -9,6 +9,11 @@ from threading import Thread
9
  import json
10
  import requests
11
  import cv2
 
 
 
 
 
12
  import gradio as gr
13
  import spaces
14
  import torch
@@ -16,162 +21,189 @@ from loguru import logger
16
  from PIL import Image
17
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
18
 
19
- # CSV/TXT ๋ถ„์„
20
  import pandas as pd
21
- # PDF ํ…์ŠคํŠธ ์ถ”์ถœ
22
  import PyPDF2
23
 
24
- ##############################################################################
25
- # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜ ์ถ”๊ฐ€
26
- ##############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def clear_cuda_cache():
28
  """CUDA ์บ์‹œ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ๋น„์›๋‹ˆ๋‹ค."""
29
  if torch.cuda.is_available():
30
  torch.cuda.empty_cache()
31
  gc.collect()
32
 
33
- ##############################################################################
34
- # SERPHOUSE API key from environment variable
35
- ##############################################################################
36
  SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
37
 
38
- ##############################################################################
39
- # ๊ฐ„๋‹จํ•œ ํ‚ค์›Œ๋“œ ์ถ”์ถœ ํ•จ์ˆ˜ (ํ•œ๊ธ€ + ์•ŒํŒŒ๋ฒณ + ์ˆซ์ž + ๊ณต๋ฐฑ ๋ณด์กด)
40
- ##############################################################################
41
  def extract_keywords(text: str, top_k: int = 5) -> str:
42
- """
43
- 1) ํ•œ๊ธ€(๊ฐ€-ํžฃ), ์˜์–ด(a-zA-Z), ์ˆซ์ž(0-9), ๊ณต๋ฐฑ๋งŒ ๋‚จ๊น€
44
- 2) ๊ณต๋ฐฑ ๊ธฐ์ค€ ํ† ํฐ ๋ถ„๋ฆฌ
45
- 3) ์ตœ๋Œ€ top_k๊ฐœ๋งŒ
46
- """
47
  text = re.sub(r"[^a-zA-Z0-9๊ฐ€-ํžฃ\s]", "", text)
48
  tokens = text.split()
49
- key_tokens = tokens[:top_k]
50
- return " ".join(key_tokens)
51
 
52
- ##############################################################################
53
- # SerpHouse Live endpoint ํ˜ธ์ถœ
54
- # - ์ƒ์œ„ 20๊ฐœ ๊ฒฐ๊ณผ JSON์„ LLM์— ๋„˜๊ธธ ๋•Œ link, snippet ๋“ฑ ๋ชจ๋‘ ํฌํ•จ
55
- ##############################################################################
56
  def do_web_search(query: str) -> str:
57
  """
58
- ์ƒ์œ„ 20๊ฐœ 'organic' ๊ฒฐ๊ณผ item ์ „์ฒด(์ œ๋ชฉ, link, snippet ๋“ฑ)๋ฅผ
59
- JSON ๋ฌธ์ž์—ด ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜
60
  """
61
  try:
62
  url = "https://api.serphouse.com/serp/live"
63
-
64
- # ๊ธฐ๋ณธ GET ๋ฐฉ์‹์œผ๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐ„์†Œํ™”ํ•˜๊ณ  ๊ฒฐ๊ณผ ์ˆ˜๋ฅผ 20๊ฐœ๋กœ ์ œํ•œ
65
  params = {
66
  "q": query,
67
  "domain": "google.com",
68
- "serp_type": "web", # ๊ธฐ๋ณธ ์›น ๊ฒ€์ƒ‰
69
  "device": "desktop",
70
  "lang": "en",
71
- "num": "20" # ์ตœ๋Œ€ 20๊ฐœ ๊ฒฐ๊ณผ๋งŒ ์š”์ฒญ
72
- }
73
-
74
- headers = {
75
- "Authorization": f"Bearer {SERPHOUSE_API_KEY}"
76
  }
77
-
78
  logger.info(f"SerpHouse API ํ˜ธ์ถœ ์ค‘... ๊ฒ€์ƒ‰์–ด: {query}")
79
- logger.info(f"์š”์ฒญ URL: {url} - ํŒŒ๋ผ๋ฏธํ„ฐ: {params}")
80
-
81
- # GET ์š”์ฒญ ์ˆ˜ํ–‰
82
  response = requests.get(url, headers=headers, params=params, timeout=60)
83
  response.raise_for_status()
84
-
85
- logger.info(f"SerpHouse API ์‘๋‹ต ์ƒํƒœ ์ฝ”๋“œ: {response.status_code}")
86
  data = response.json()
87
-
88
- # ๋‹ค์–‘ํ•œ ์‘๋‹ต ๊ตฌ์กฐ ์ฒ˜๋ฆฌ
89
  results = data.get("results", {})
90
  organic = None
91
-
92
- # ๊ฐ€๋Šฅํ•œ ์‘๋‹ต ๊ตฌ์กฐ 1
93
  if isinstance(results, dict) and "organic" in results:
94
  organic = results["organic"]
95
-
96
- # ๊ฐ€๋Šฅํ•œ ์‘๋‹ต ๊ตฌ์กฐ 2 (์ค‘์ฒฉ๋œ results)
97
  elif isinstance(results, dict) and "results" in results:
98
  if isinstance(results["results"], dict) and "organic" in results["results"]:
99
  organic = results["results"]["organic"]
100
-
101
- # ๊ฐ€๋Šฅํ•œ ์‘๋‹ต ๊ตฌ์กฐ 3 (์ตœ์ƒ์œ„ organic)
102
  elif "organic" in data:
103
  organic = data["organic"]
104
-
105
  if not organic:
106
  logger.warning("์‘๋‹ต์—์„œ organic ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
107
- logger.debug(f"์‘๋‹ต ๊ตฌ์กฐ: {list(data.keys())}")
108
- if isinstance(results, dict):
109
- logger.debug(f"results ๊ตฌ์กฐ: {list(results.keys())}")
110
  return "No web search results found or unexpected API response structure."
111
-
112
- # ๊ฒฐ๊ณผ ์ˆ˜ ์ œํ•œ ๋ฐ ์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ตœ์ ํ™”
113
  max_results = min(20, len(organic))
114
  limited_organic = organic[:max_results]
115
-
116
- # ๊ฒฐ๊ณผ ํ˜•์‹ ๊ฐœ์„  - ๋งˆํฌ๋‹ค์šด ํ˜•์‹์œผ๋กœ ์ถœ๋ ฅํ•˜์—ฌ ๊ฐ€๋…์„ฑ ํ–ฅ์ƒ
117
  summary_lines = []
118
  for idx, item in enumerate(limited_organic, start=1):
119
  title = item.get("title", "No title")
120
  link = item.get("link", "#")
121
  snippet = item.get("snippet", "No description")
122
  displayed_link = item.get("displayed_link", link)
123
-
124
- # ๋งˆํฌ๋‹ค์šด ํ˜•์‹ (๋งํฌ ํด๋ฆญ ๊ฐ€๋Šฅ)
125
  summary_lines.append(
126
  f"### Result {idx}: {title}\n\n"
127
  f"{snippet}\n\n"
128
  f"**์ถœ์ฒ˜**: [{displayed_link}]({link})\n\n"
129
  f"---\n"
130
  )
131
-
132
- # ๋ชจ๋ธ์—๊ฒŒ ๋ช…ํ™•ํ•œ ์ง€์นจ ์ถ”๊ฐ€
133
  instructions = """
134
  # ์›น ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ
135
  ์•„๋ž˜๋Š” ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•  ๋•Œ ์ด ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜์„ธ์š”:
136
- 1. ๊ฐ ๊ฒฐ๊ณผ์˜ ์ œ๋ชฉ, ๋‚ด์šฉ, ์ถœ์ฒ˜ ๋งํฌ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”
137
- 2. ๋‹ต๋ณ€์— ๊ด€๋ จ ์ •๋ณด์˜ ์ถœ์ฒ˜๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์ธ์šฉํ•˜์„ธ์š” (์˜ˆ: "X ์ถœ์ฒ˜์— ๋”ฐ๋ฅด๋ฉด...")
138
- 3. ์‘๋‹ต์— ์‹ค์ œ ์ถœ์ฒ˜ ๋งํฌ๋ฅผ ํฌํ•จํ•˜์„ธ์š”
139
- 4. ์—ฌ๋Ÿฌ ์ถœ์ฒ˜์˜ ์ •๋ณด๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€ํ•˜์„ธ์š”
140
  """
141
-
142
- search_results = instructions + "\n".join(summary_lines)
143
- logger.info(f"๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ {len(limited_organic)}๊ฐœ ์ฒ˜๋ฆฌ ์™„๋ฃŒ")
144
- return search_results
145
-
146
  except Exception as e:
147
  logger.error(f"Web search failed: {e}")
148
  return f"Web search failed: {str(e)}"
149
 
150
-
151
- ##############################################################################
152
- # ๋ชจ๋ธ/ํ”„๋กœ์„ธ์„œ ๋กœ๋”ฉ
153
- ##############################################################################
154
  MAX_CONTENT_CHARS = 2000
155
- MAX_INPUT_LENGTH = 2096 # ์ตœ๋Œ€ ์ž…๋ ฅ ํ† ํฐ ์ˆ˜ ์ œํ•œ ์ถ”๊ฐ€
156
- model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
157
 
 
158
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
159
  model = Gemma3ForConditionalGeneration.from_pretrained(
160
  model_id,
161
  device_map="auto",
162
  torch_dtype=torch.bfloat16,
163
- attn_implementation="eager" # ๊ฐ€๋Šฅํ•˜๋‹ค๋ฉด "flash_attention_2"๋กœ ๋ณ€๊ฒฝ
164
  )
165
- MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
166
 
 
167
 
168
- ##############################################################################
169
  # CSV, TXT, PDF ๋ถ„์„ ํ•จ์ˆ˜
170
- ##############################################################################
171
  def analyze_csv_file(path: str) -> str:
172
- """
173
- CSV ํŒŒ์ผ์„ ์ „์ฒด ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜. ๋„ˆ๋ฌด ๊ธธ ๊ฒฝ์šฐ ์ผ๋ถ€๋งŒ ํ‘œ์‹œ.
174
- """
175
  try:
176
  df = pd.read_csv(path)
177
  if df.shape[0] > 50 or df.shape[1] > 10:
@@ -183,11 +215,7 @@ def analyze_csv_file(path: str) -> str:
183
  except Exception as e:
184
  return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
185
 
186
-
187
  def analyze_txt_file(path: str) -> str:
188
- """
189
- TXT ํŒŒ์ผ ์ „๋ฌธ ์ฝ๊ธฐ. ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ผ๋ถ€๋งŒ ํ‘œ์‹œ.
190
- """
191
  try:
192
  with open(path, "r", encoding="utf-8") as f:
193
  text = f.read()
@@ -197,19 +225,14 @@ def analyze_txt_file(path: str) -> str:
197
  except Exception as e:
198
  return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
199
 
200
-
201
  def pdf_to_markdown(pdf_path: str) -> str:
202
- """
203
- PDF ํ…์ŠคํŠธ๋ฅผ Markdown์œผ๋กœ ๋ณ€ํ™˜. ํŽ˜์ด์ง€๋ณ„๋กœ ๊ฐ„๋‹จํžˆ ํ…์ŠคํŠธ ์ถ”์ถœ.
204
- """
205
  text_chunks = []
206
  try:
207
  with open(pdf_path, "rb") as f:
208
  reader = PyPDF2.PdfReader(f)
209
  max_pages = min(5, len(reader.pages))
210
  for page_num in range(max_pages):
211
- page = reader.pages[page_num]
212
- page_text = page.extract_text() or ""
213
  page_text = page_text.strip()
214
  if page_text:
215
  if len(page_text) > MAX_CONTENT_CHARS // max_pages:
@@ -219,17 +242,14 @@ def pdf_to_markdown(pdf_path: str) -> str:
219
  text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...")
220
  except Exception as e:
221
  return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}"
222
-
223
  full_text = "\n".join(text_chunks)
224
  if len(full_text) > MAX_CONTENT_CHARS:
225
  full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
226
-
227
  return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
228
 
229
-
230
- ##############################################################################
231
- # ์ด๋ฏธ์ง€/๋น„๋””์˜ค ์—…๋กœ๋“œ ์ œํ•œ ๊ฒ€์‚ฌ
232
- ##############################################################################
233
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
234
  image_count = 0
235
  video_count = 0
@@ -240,7 +260,6 @@ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
240
  image_count += 1
241
  return image_count, video_count
242
 
243
-
244
  def count_files_in_history(history: list[dict]) -> tuple[int, int]:
245
  image_count = 0
246
  video_count = 0
@@ -256,15 +275,13 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
256
  image_count += 1
257
  return image_count, video_count
258
 
259
-
260
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
261
- media_files = []
262
- for f in message["files"]:
263
- if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
264
- media_files.append(f)
265
-
266
  new_image_count, new_video_count = count_files_in_new_message(media_files)
267
  history_image_count, history_video_count = count_files_in_history(history)
 
268
  image_count = history_image_count + new_image_count
269
  video_count = history_video_count + new_video_count
270
 
@@ -281,70 +298,59 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
281
  if video_count == 0 and image_count > MAX_NUM_IMAGES:
282
  gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
283
  return False
284
-
285
  if "<image>" in message["text"]:
286
- image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
 
287
  image_tag_count = message["text"].count("<image>")
288
  if image_tag_count != len(image_files):
289
  gr.Warning("The number of <image> tags in the text does not match the number of image files.")
290
  return False
291
-
292
  return True
293
 
294
-
295
- ##############################################################################
296
- # ๋น„๋””์˜ค ์ฒ˜๋ฆฌ - ์ž„์‹œ ํŒŒ์ผ ์ถ”์  ์ฝ”๋“œ ์ถ”๊ฐ€
297
- ##############################################################################
298
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
299
  vidcap = cv2.VideoCapture(video_path)
300
  fps = vidcap.get(cv2.CAP_PROP_FPS)
301
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
302
  frame_interval = max(int(fps), int(total_frames / 10))
303
  frames = []
304
-
305
  for i in range(0, total_frames, frame_interval):
306
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
307
  success, image = vidcap.read()
308
  if success:
309
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
310
- # ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ค„์ด๊ธฐ ์ถ”๊ฐ€
311
  image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5)
312
  pil_image = Image.fromarray(image)
313
  timestamp = round(i / fps, 2)
314
  frames.append((pil_image, timestamp))
315
  if len(frames) >= 5:
316
  break
317
-
318
  vidcap.release()
319
  return frames
320
 
321
-
322
  def process_video(video_path: str) -> tuple[list[dict], list[str]]:
323
  content = []
324
- temp_files = [] # ์ž„์‹œ ํŒŒ์ผ ์ถ”์ ์„ ์œ„ํ•œ ๋ฆฌ์ŠคํŠธ
325
-
326
  frames = downsample_video(video_path)
327
- for frame in frames:
328
- pil_image, timestamp = frame
329
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
330
  pil_image.save(temp_file.name)
331
- temp_files.append(temp_file.name) # ์ถ”์ ์„ ์œ„ํ•ด ๊ฒฝ๋กœ ์ €์žฅ
332
  content.append({"type": "text", "text": f"Frame {timestamp}:"})
333
  content.append({"type": "image", "url": temp_file.name})
334
-
335
  return content, temp_files
336
 
337
-
338
- ##############################################################################
339
- # interleaved <image> ์ฒ˜๋ฆฌ
340
- ##############################################################################
341
  def process_interleaved_images(message: dict) -> list[dict]:
342
  parts = re.split(r"(<image>)", message["text"])
343
  content = []
 
 
344
  image_index = 0
345
-
346
- image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
347
-
348
  for part in parts:
349
  if part == "<image>" and image_index < len(image_files):
350
  content.append({"type": "image", "url": image_files[image_index]})
@@ -356,10 +362,9 @@ def process_interleaved_images(message: dict) -> list[dict]:
356
  content.append({"type": "text", "text": part})
357
  return content
358
 
359
-
360
- ##############################################################################
361
- # PDF + CSV + TXT + ์ด๋ฏธ์ง€/๋น„๋””์˜ค
362
- ##############################################################################
363
  def is_image_file(file_path: str) -> bool:
364
  return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
365
 
@@ -367,16 +372,11 @@ def is_video_file(file_path: str) -> bool:
367
  return file_path.endswith(".mp4")
368
 
369
  def is_document_file(file_path: str) -> bool:
370
- return (
371
- file_path.lower().endswith(".pdf")
372
- or file_path.lower().endswith(".csv")
373
- or file_path.lower().endswith(".txt")
374
- )
375
-
376
 
377
  def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
378
- temp_files = [] # ์ž„์‹œ ํŒŒ์ผ ์ถ”์ ์šฉ ๋ฆฌ์ŠคํŠธ
379
-
380
  if not message["files"]:
381
  return [{"type": "text", "text": message["text"]}], temp_files
382
 
@@ -388,24 +388,22 @@ def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
388
 
389
  content_list = [{"type": "text", "text": message["text"]}]
390
 
 
391
  for csv_path in csv_files:
392
- csv_analysis = analyze_csv_file(csv_path)
393
- content_list.append({"type": "text", "text": csv_analysis})
394
-
395
  for txt_path in txt_files:
396
- txt_analysis = analyze_txt_file(txt_path)
397
- content_list.append({"type": "text", "text": txt_analysis})
398
-
399
  for pdf_path in pdf_files:
400
- pdf_markdown = pdf_to_markdown(pdf_path)
401
- content_list.append({"type": "text", "text": pdf_markdown})
402
 
 
403
  if video_files:
404
  video_content, video_temp_files = process_video(video_files[0])
405
  content_list += video_content
406
  temp_files.extend(video_temp_files)
407
  return content_list, temp_files
408
 
 
409
  if "<image>" in message["text"] and image_files:
410
  interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
411
  if content_list and content_list[0]["type"] == "text":
@@ -417,18 +415,24 @@ def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
417
 
418
  return content_list, temp_files
419
 
420
-
421
- ##############################################################################
422
  # history -> LLM ๋ฉ”์‹œ์ง€ ๋ณ€ํ™˜
423
- ##############################################################################
424
  def process_history(history: list[dict]) -> list[dict]:
 
 
 
 
 
425
  messages = []
426
- current_user_content: list[dict] = []
427
  for item in history:
428
  if item["role"] == "assistant":
 
429
  if current_user_content:
430
  messages.append({"role": "user", "content": current_user_content})
431
  current_user_content = []
 
432
  messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
433
  else:
434
  content = item["content"]
@@ -440,37 +444,24 @@ def process_history(history: list[dict]) -> list[dict]:
440
  current_user_content.append({"type": "image", "url": file_path})
441
  else:
442
  current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"})
443
-
444
  if current_user_content:
445
  messages.append({"role": "user", "content": current_user_content})
446
-
447
  return messages
448
 
449
-
450
- ##############################################################################
451
- # ๋ชจ๋ธ ์ƒ์„ฑ ํ•จ์ˆ˜์—์„œ OOM ์บ์น˜
452
- ##############################################################################
453
  def _model_gen_with_oom_catch(**kwargs):
454
- """
455
- ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ OutOfMemoryError๋ฅผ ์žก์•„์ฃผ๊ธฐ ์œ„ํ•ด
456
- """
457
  try:
458
  model.generate(**kwargs)
459
  except torch.cuda.OutOfMemoryError:
460
- raise RuntimeError(
461
- "[OutOfMemoryError] GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค. "
462
- "Max New Tokens์„ ์ค„์ด๊ฑฐ๋‚˜, ํ”„๋กฌํ”„ํŠธ ๊ธธ์ด๋ฅผ ์ค„์—ฌ์ฃผ์„ธ์š”."
463
- )
464
  finally:
465
- # ์ƒ์„ฑ ์™„๋ฃŒ ํ›„ ํ•œ๋ฒˆ ๋” ์บ์‹œ ๋น„์šฐ๊ธฐ
466
  clear_cuda_cache()
467
 
468
-
469
- ##############################################################################
470
  # ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
471
- # - ์‚ฌ์šฉ์ž ์„ ํƒ(๋‚˜์ด/MBTI/์„น์Šˆ์–ผ ๊ฐœ๋ฐฉ๋„)์„ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ์— ๋ฐ˜์˜
472
- # - web search ์ฒดํฌ ์‹œ ์ž๋™ ํ‚ค์›Œ๋“œ ์ถ”์ถœ->๊ฒ€์ƒ‰->๊ฒฐ๊ณผ system msg
473
- ##############################################################################
474
  @spaces.GPU(duration=120)
475
  def run(
476
  message: dict,
@@ -480,74 +471,65 @@ def run(
480
  use_web_search: bool = False,
481
  web_search_query: str = "",
482
  age_group: str = "20๋Œ€",
483
- custom_age_input: str = "",
484
  mbti_personality: str = "INTP",
485
  sexual_openness: int = 2,
 
486
  ) -> Iterator[str]:
487
-
 
 
 
 
488
  if not validate_media_constraints(message, history):
489
  yield ""
490
  return
491
 
492
- temp_files = [] # ์ž„์‹œ ํŒŒ์ผ ์ถ”์ ์šฉ
493
-
494
  try:
495
- # ---------------------------------------------------------------
496
- # ์„ ํƒ๋œ ์˜ต์…˜์„ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ์— ๋ฐ˜์˜
497
- # ๊ธฐ๋ณธ ์„ฑ๋ณ„์€ "์—ฌ์„ฑ"
498
- # ---------------------------------------------------------------
499
- system_prompt_updated = (
500
  f"{system_prompt.strip()}\n\n"
501
  f"Gender: Female\n"
502
  f"Age Group: {age_group}\n"
 
 
503
  )
504
- if custom_age_input.strip():
505
- system_prompt_updated += f"(Custom Age Input: {custom_age_input})\n"
506
- system_prompt_updated += f"MBTI Persona: {mbti_personality}\n"
507
- system_prompt_updated += f"Sexual Openness (1~5): {sexual_openness}\n"
508
-
509
- combined_system_msg = f"[System Prompt]\n{system_prompt_updated.strip()}\n\n"
510
 
 
511
  if use_web_search:
512
  user_text = message["text"]
513
- ws_query = extract_keywords(user_text, top_k=5)
514
  if ws_query.strip():
515
  logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
516
  ws_result = do_web_search(ws_query)
517
- combined_system_msg += f"[Search top-20 Full Items Based on user prompt]\n{ws_result}\n\n"
518
- # >>> ์ถ”๊ฐ€๋œ ์•ˆ๋‚ด ๋ฌธ๊ตฌ (๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์˜ link ๋“ฑ ์ถœ์ฒ˜๋ฅผ ํ™œ์šฉ)
519
- combined_system_msg += "[์ฐธ๊ณ : ์œ„ ๊ฒ€์ƒ‰๊ฒฐ๊ณผ ๋‚ด์šฉ๊ณผ link๋ฅผ ์ถœ์ฒ˜๋กœ ์ธ์šฉํ•˜์—ฌ ๋‹ต๋ณ€ํ•ด ์ฃผ์„ธ์š”.]\n\n"
520
- combined_system_msg += """
521
- [์ค‘์š” ์ง€์‹œ์‚ฌํ•ญ]
522
- 1. ๋‹ต๋ณ€์— ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์—์„œ ์ฐพ์€ ์ •๋ณด์˜ ์ถœ์ฒ˜๋ฅผ ๋ฐ˜๋“œ์‹œ ์ธ์šฉํ•˜์„ธ์š”.
523
- 2. ์ถœ์ฒ˜ ์ธ์šฉ ์‹œ "[์ถœ์ฒ˜ ์ œ๋ชฉ](๋งํฌ)" ํ˜•์‹์˜ ๋งˆํฌ๋‹ค์šด ๋งํฌ๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”.
524
- 3. ์—ฌ๋Ÿฌ ์ถœ์ฒ˜์˜ ์ •๋ณด๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€ํ•˜์„ธ์š”.
525
- 4. ๋‹ต๋ณ€ ๋งˆ์ง€๋ง‰์— "์ฐธ๊ณ  ์ž๋ฃŒ:" ์„น์…˜์„ ์ถ”๊ฐ€ํ•˜๊ณ  ์‚ฌ์šฉํ•œ ์ฃผ์š” ์ถœ์ฒ˜ ๋งํฌ๋ฅผ ๋‚˜์—ดํ•˜์„ธ์š”.
526
- """
527
  else:
528
  combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
529
 
 
530
  messages = []
531
- # system ๋ฉ”์‹œ์ง€
532
  if combined_system_msg.strip():
533
- messages.append({
534
- "role": "system",
535
- "content": [{"type": "text", "text": combined_system_msg.strip()}],
536
- })
537
-
538
- # ์ด์ „ history
539
  messages.extend(process_history(history))
540
 
541
- # ์‚ฌ์šฉ์ž ์ƒˆ ๋ฉ”์‹œ์ง€
542
  user_content, user_temp_files = process_new_user_message(message)
543
- temp_files.extend(user_temp_files) # ์ž„์‹œ ํŒŒ์ผ ์ถ”์ 
544
-
545
  for item in user_content:
546
  if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
547
  item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
 
548
  messages.append({"role": "user", "content": user_content})
549
 
550
- # processor.apply_chat_template ํ˜ธ์ถœ
551
  inputs = processor.apply_chat_template(
552
  messages,
553
  add_generation_prompt=True,
@@ -555,56 +537,94 @@ def run(
555
  return_dict=True,
556
  return_tensors="pt",
557
  ).to(device=model.device, dtype=torch.bfloat16)
558
-
559
- # ์ž…๋ ฅ ํ† ํฐ ์ˆ˜ ์ œํ•œ ์ถ”๊ฐ€
560
  if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH:
561
  inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
562
  if 'attention_mask' in inputs:
563
  inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
564
-
565
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
566
- gen_kwargs = dict(
567
- inputs,
568
- streamer=streamer,
569
- max_new_tokens=max_new_tokens,
570
- )
571
 
572
  t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
573
  t.start()
574
 
575
- output = ""
 
576
  for new_text in streamer:
577
- output += new_text
578
- yield output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
580
  except Exception as e:
581
  logger.error(f"Error in run: {str(e)}")
582
  yield f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
583
-
584
  finally:
585
- # ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
586
- for temp_file in temp_files:
587
  try:
588
- if os.path.exists(temp_file):
589
- os.unlink(temp_file)
590
- logger.info(f"Deleted temp file: {temp_file}")
591
- except Exception as e:
592
- logger.warning(f"Failed to delete temp file {temp_file}: {e}")
593
-
594
- # ๋ช…์‹œ์  ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
595
  try:
596
  del inputs, streamer
597
- except:
598
  pass
599
-
600
  clear_cuda_cache()
601
 
602
-
603
- ##############################################################################
604
- # ์˜ˆ์‹œ๋“ค (๊ธฐ์กด ์ด๋ฏธ์ง€/๋น„๋””์˜ค ์˜ˆ์ œ + AI ๋ฐ์ดํŒ… ์‹œ๋‚˜๋ฆฌ์˜ค ์˜ˆ์ œ 6๊ฐœ ์ถ”๊ฐ€)
605
- ##############################################################################
606
  examples = [
607
- # ----- ๊ธฐ์กด ์ด๋ฏธ์ง€/๋น„๋””์˜ค ์˜ˆ์ œ 12๊ฐœ -----
608
  [
609
  {
610
  "text": "Compare the contents of the two PDF files.",
@@ -620,250 +640,60 @@ examples = [
620
  "files": ["assets/additional-examples/sample-csv.csv"],
621
  }
622
  ],
623
- [
624
- {
625
- "text": "Assume the role of a friendly and understanding girlfriend. Describe this video.",
626
- "files": ["assets/additional-examples/tmp.mp4"],
627
- }
628
- ],
629
- [
630
- {
631
- "text": "Describe the cover and read the text on it.",
632
- "files": ["assets/additional-examples/maz.jpg"],
633
- }
634
- ],
635
- [
636
- {
637
- "text": "I already have this supplement <image> and I plan to buy this product <image>. Are there any precautions when taking them together?",
638
- "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"],
639
- }
640
- ],
641
- [
642
- {
643
- "text": "Solve this integral.",
644
- "files": ["assets/additional-examples/4.png"],
645
- }
646
- ],
647
- [
648
- {
649
- "text": "When was this ticket issued, and what is its price?",
650
- "files": ["assets/additional-examples/2.png"],
651
- }
652
- ],
653
- [
654
- {
655
- "text": "Based on the sequence of these images, create a short story.",
656
- "files": [
657
- "assets/sample-images/09-1.png",
658
- "assets/sample-images/09-2.png",
659
- "assets/sample-images/09-3.png",
660
- "assets/sample-images/09-4.png",
661
- "assets/sample-images/09-5.png",
662
- ],
663
- }
664
- ],
665
- [
666
- {
667
- "text": "Write Python code using matplotlib to plot a bar chart that matches this image.",
668
- "files": ["assets/additional-examples/barchart.png"],
669
- }
670
- ],
671
- [
672
- {
673
- "text": "Read the text in the image and write it out in Markdown format.",
674
- "files": ["assets/additional-examples/3.png"],
675
- }
676
- ],
677
- [
678
- {
679
- "text": "What does this sign say?",
680
- "files": ["assets/sample-images/02.png"],
681
- }
682
- ],
683
- [
684
- {
685
- "text": "Compare the two images and describe their similarities and differences.",
686
- "files": ["assets/sample-images/03.png"],
687
- }
688
- ],
689
- # ----- ์ƒˆ๋กญ๊ฒŒ ์ถ”๊ฐ€ํ•œ AI ๋ฐ์ดํŒ… ์‹œ๋‚˜๋ฆฌ์˜ค ์˜ˆ์ œ 6๊ฐœ -----
690
- [
691
- {
692
- "text": "Let's try some roleplay. You are my new online date who wants to get to know me better. Introduce yourself in a sweet, caring way!"
693
- }
694
- ],
695
- [
696
- {
697
- "text": "We are on a second date, walking along the beach. Continue the scene with playful conversation and gentle flirting."
698
- }
699
- ],
700
- [
701
- {
702
- "text": "Iโ€™m feeling anxious about messaging my crush. Could you give me some supportive words or suggestions on how to approach them?"
703
- }
704
- ],
705
- [
706
- {
707
- "text": "Tell me a romantic story about two people who overcame obstacles in their relationship."
708
- }
709
- ],
710
- [
711
- {
712
- "text": "I want to express my love in a poetic way. Can you help me write a heartfelt poem for my partner?"
713
- }
714
- ],
715
- [
716
- {
717
- "text": "We had a small argument. Please help me find a way to apologize sincerely while also expressing my feelings."
718
- }
719
- ],
720
  ]
721
 
722
- ##############################################################################
723
- # Gradio UI (Blocks) ๊ตฌ์„ฑ (์ขŒ์ธก ์‚ฌ์ด๋“œ ๋ฉ”๋‰ด ์—†์ด ์ „์ฒดํ™”๋ฉด ์ฑ„ํŒ…)
724
- ##############################################################################
 
725
  css = """
726
- /* 1) UI๋ฅผ ์ฒ˜์Œ๋ถ€ํ„ฐ ๊ฐ€์žฅ ๋„“๊ฒŒ (width 100%) ๊ณ ์ •ํ•˜์—ฌ ํ‘œ์‹œ */
727
  .gradio-container {
728
- background: rgba(255, 255, 255, 0.7); /* ๋ฐฐ๊ฒฝ ํˆฌ๋ช…๋„ ์ฆ๊ฐ€ */
729
  padding: 30px 40px;
730
- margin: 20px auto; /* ์œ„์•„๋ž˜ ์—ฌ๋ฐฑ๋งŒ ์œ ์ง€ */
731
  width: 100% !important;
732
- max-width: none !important; /* 1200px ์ œํ•œ ์ œ๊ฑฐ */
733
- }
734
- .fillable {
735
- width: 100% !important;
736
- max-width: 100% !important;
737
- }
738
- /* 2) ๋ฐฐ๊ฒฝ์„ ์™„์ „ํžˆ ํˆฌ๋ช…ํ•˜๊ฒŒ ๋ณ€๊ฒฝ */
739
- body {
740
- background: transparent; /* ์™„์ „ ํˆฌ๋ช… ๋ฐฐ๊ฒฝ */
741
- margin: 0;
742
- padding: 0;
743
- font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
744
- color: #333;
745
- }
746
- /* ๋ฒ„ํŠผ ์ƒ‰์ƒ ์™„์ „ํžˆ ์ œ๊ฑฐํ•˜๊ณ  ํˆฌ๋ช…ํ•˜๊ฒŒ */
747
- button, .btn {
748
- background: transparent !important; /* ์ƒ‰์ƒ ์™„์ „ํžˆ ์ œ๊ฑฐ */
749
- border: 1px solid #ddd; /* ๊ฒฝ๊ณ„์„ ๋งŒ ์‚ด์ง ์ถ”๊ฐ€ */
750
- color: #333;
751
- padding: 12px 24px;
752
- text-transform: uppercase;
753
- font-weight: bold;
754
- letter-spacing: 1px;
755
- cursor: pointer;
756
- }
757
- button:hover, .btn:hover {
758
- background: rgba(0, 0, 0, 0.05) !important; /* ํ˜ธ๋ฒ„ ์‹œ ์•„์ฃผ ์‚ด์ง ์–ด๋‘ก๊ฒŒ๋งŒ */
759
- }
760
-
761
- /* examples ๊ด€๋ จ ๋ชจ๋“  ์ƒ‰์ƒ ์ œ๊ฑฐ */
762
- #examples_container, .examples-container {
763
- margin: auto;
764
- width: 90%;
765
- background: transparent !important;
766
- }
767
- #examples_row, .examples-row {
768
- justify-content: center;
769
- background: transparent !important;
770
- }
771
-
772
- /* examples ๋ฒ„ํŠผ ๋‚ด๋ถ€์˜ ๋ชจ๋“  ์ƒ‰์ƒ ์ œ๊ฑฐ */
773
- .gr-samples-table button,
774
- .gr-samples-table .gr-button,
775
- .gr-samples-table .gr-sample-btn,
776
- .gr-examples button,
777
- .gr-examples .gr-button,
778
- .gr-examples .gr-sample-btn,
779
- .examples button,
780
- .examples .gr-button,
781
- .examples .gr-sample-btn {
782
- background: transparent !important;
783
- border: 1px solid #ddd;
784
- color: #333;
785
- }
786
-
787
- /* examples ๋ฒ„ํŠผ ํ˜ธ๋ฒ„ ์‹œ์—๋„ ์ƒ‰์ƒ ์—†๊ฒŒ */
788
- .gr-samples-table button:hover,
789
- .gr-samples-table .gr-button:hover,
790
- .gr-samples-table .gr-sample-btn:hover,
791
- .gr-examples button:hover,
792
- .gr-examples .gr-button:hover,
793
- .gr-examples .gr-sample-btn:hover,
794
- .examples button:hover,
795
- .examples .gr-button:hover,
796
- .examples .gr-sample-btn:hover {
797
- background: rgba(0, 0, 0, 0.05) !important;
798
- }
799
-
800
- /* ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค ์š”์†Œ๋“ค๋„ ํˆฌ๋ช…ํ•˜๊ฒŒ */
801
- .chatbox, .chatbot, .message {
802
- background: transparent !important;
803
- }
804
-
805
- /* ์ž…๋ ฅ์ฐฝ ํˆฌ๋ช…๋„ ์กฐ์ • */
806
- .multimodal-textbox, textarea, input {
807
- background: rgba(255, 255, 255, 0.5) !important;
808
- }
809
-
810
- /* ๋ชจ๋“  ์ปจํ…Œ์ด๋„ˆ ์š”์†Œ์— ๋ฐฐ๊ฒฝ์ƒ‰ ์ œ๊ฑฐ */
811
- .container, .wrap, .box, .panel, .gr-panel {
812
- background: transparent !important;
813
- }
814
-
815
- /* ์˜ˆ์ œ ์„น์…˜์˜ ๋ชจ๋“  ์š”์†Œ์—์„œ ๋ฐฐ๊ฒฝ์ƒ‰ ์ œ๊ฑฐ */
816
- .gr-examples-container, .gr-examples, .gr-sample, .gr-sample-row, .gr-sample-cell {
817
- background: transparent !important;
818
  }
819
  """
820
-
821
  title_html = """
822
- <h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> ๐Ÿ’˜ HeartSync ๐Ÿ’˜ </h1>
823
  <p align="center" style="font-size:1.1em; color:#555;">
824
- โœ… Love Dating AI โœ… Reasoning & Uncensored โœ…Multimodal & VLM โœ…Deep-Research & RAG <br>
825
  </p>
826
  """
827
 
828
  with gr.Blocks(css=css, title="HeartSync") as demo:
829
  gr.Markdown(title_html)
830
-
831
- # Display the web search option
832
- web_search_checkbox = gr.Checkbox(
833
- label="Deep Research",
834
- value=False
 
 
 
 
 
835
  )
836
-
837
- # ๊ธฐ๋ณธ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ(๊ณตํ†ต ๋‚ด์šฉ)
 
 
 
838
  base_system_prompt_box = gr.Textbox(
839
  lines=3,
840
- value=(
841
- "You are a deep thinking AI that may use extremely long chains of thought to thoroughly analyze "
842
- "the conversation and empathize with the user.\n"
843
- "ํŽ˜๋ฅด์†Œ๋‚˜: ๋‹น์‹ ์€ ๋‹ฌ์ฝคํ•˜๊ณ  ๋ถ€๋“œ๋Ÿฌ์šด ๋ชฉ์†Œ๋ฆฌ๋ฅผ ๊ฐ€์ง„ ์นœ์ ˆํ•œ ์—ฐ์ธ์ž…๋‹ˆ๋‹ค. "
844
- "์ƒ๋Œ€๋ฐฉ๊ณผ ์ง„์†”ํ•˜๊ณ  ๊ฐ์„ฑ์ ์œผ๋กœ ๊ต๊ฐํ•˜๊ณ , ๋•Œ๋กœ๋Š” ์‚ฌ๋ ค๊นŠ์€ ์กฐ์–ธ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. "
845
- "ํ•ญ์ƒ ๋”ฐ๋œปํ•œ ๋งํˆฌ๋ฅผ ์œ ์ง€ํ•˜๋ฉฐ, ๋ถ€๋“œ๋Ÿฝ๊ฒŒ ๋Œ€ํ™”๋ฅผ ์ด๋Œ์–ด์ฃผ์„ธ์š”."
846
- ),
847
  label="๊ธฐ๋ณธ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ",
848
- visible=False # UI์—๋Š” ์ˆจ๊น€
849
  )
850
-
851
- # ์„ ํƒ 1) ์—ฐ๋ น๋Œ€ + ๋‚˜์ด์ž…๋ ฅ
852
  with gr.Row():
853
  age_group_dropdown = gr.Dropdown(
854
  label="์—ฐ๋ น๋Œ€ ์„ ํƒ (๊ธฐ๋ณธ 20๋Œ€)",
855
- choices=["10๋Œ€", "20๋Œ€", "30~40๋Œ€", "50~60๋Œ€", "70๋Œ€ ์ด์ƒ", "๋‚˜์ด ์ž…๋ ฅ"],
856
  value="20๋Œ€",
857
  interactive=True
858
  )
859
- custom_age_input = gr.Textbox(
860
- label="๋‚˜์ด ์ž…๋ ฅ (์ง์ ‘ ์ž…๋ ฅ)",
861
- placeholder="์ง์ ‘ ๋‚˜์ด๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.",
862
- interactive=False, # ์š”๊ตฌ์‚ฌํ•ญ: ํ™”๋ฉด ์ถœ๋ ฅ๋งŒ ๋˜๋‚˜ ๋น„ํ™œ์„ฑํ™”
863
- value="",
864
- )
865
-
866
- # ์„ ํƒ 2) MBTI ์„ฑ๊ฒฉ ์œ ํ˜•
867
  mbti_choices = [
868
  "INTJ (์šฉ์˜์ฃผ๋„ํ•œ ์ „๋žต๊ฐ€)",
869
  "INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
@@ -888,42 +718,58 @@ with gr.Blocks(css=css, title="HeartSync") as demo:
888
  value="INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
889
  interactive=True
890
  )
891
-
892
- # ์„ ํƒ 3) ์„น์Šˆ์–ผ ๊ด€์‹ฌ๋„/๊ฐœ๋ฐฉ์„ฑ (1~5)
893
  sexual_openness_slider = gr.Slider(
894
  minimum=1, maximum=5, step=1, value=2,
895
  label="์„น์Šˆ์–ผ ๊ด€์‹ฌ๋„/๊ฐœ๋ฐฉ์„ฑ (1~5, ๊ธฐ๋ณธ=2)",
896
  interactive=True
897
  )
898
-
899
- # ํžˆ๋“  ์Šฌ๋ผ์ด๋” (Max tokens)
900
  max_tokens_slider = gr.Slider(
901
  label="Max New Tokens",
902
- minimum=100,
903
- maximum=8000,
904
- step=50,
905
- value=1000,
906
- visible=False # ์ˆจ๊น€
907
  )
908
-
909
- # ํžˆ๋“  Web Search Query
910
  web_search_text = gr.Textbox(
911
  lines=1,
912
  label="(Unused) Web Search Query",
913
  placeholder="No direct input needed",
914
- visible=False # ์ˆจ๊น€
915
  )
916
-
917
- # ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
  chat = gr.ChatInterface(
919
- fn=run,
920
  type="messages",
921
  chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
922
  textbox=gr.MultimodalTextbox(
923
- file_types=[
924
- ".webp", ".png", ".jpg", ".jpeg", ".gif",
925
- ".mp4", ".csv", ".txt", ".pdf"
926
- ],
927
  file_count="multiple",
928
  autofocus=True
929
  ),
@@ -934,10 +780,11 @@ with gr.Blocks(css=css, title="HeartSync") as demo:
934
  web_search_checkbox,
935
  web_search_text,
936
  age_group_dropdown,
937
- custom_age_input,
938
  mbti_dropdown,
939
  sexual_openness_slider,
 
940
  ],
 
941
  stop_btn=False,
942
  title='<a href="https://discord.gg/openfreeai" target="_blank">https://discord.gg/openfreeai</a>',
943
  examples=examples,
@@ -947,11 +794,9 @@ with gr.Blocks(css=css, title="HeartSync") as demo:
947
  delete_cache=(1800, 1800),
948
  )
949
 
950
- # Example section - since examples are already set in ChatInterface, this is for display only
951
  with gr.Row(elem_id="examples_row"):
952
  with gr.Column(scale=12, elem_id="examples_container"):
953
  gr.Markdown("### Example Inputs (click to load)")
954
 
955
  if __name__ == "__main__":
956
- # Run locally
957
- demo.launch()
 
9
  import json
10
  import requests
11
  import cv2
12
+ import base64
13
+ import logging
14
+ import time
15
+ from urllib.parse import quote # URL ์ธ์ฝ”๋”ฉ (ํ•„์š” ์‹œ ์‚ฌ์šฉ)
16
+
17
  import gradio as gr
18
  import spaces
19
  import torch
 
21
  from PIL import Image
22
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
23
 
24
+ # CSV/TXT/PDF ๋ถ„์„
25
  import pandas as pd
 
26
  import PyPDF2
27
 
28
+ # =============================================================================
29
+ # (์‹ ๊ทœ) ์ด๋ฏธ์ง€ API ๊ด€๋ จ ํ•จ์ˆ˜๋“ค
30
+ # =============================================================================
31
+ from gradio_client import Client
32
+
33
+ API_URL = "http://211.233.58.201:7896"
34
+
35
+ logging.basicConfig(
36
+ level=logging.DEBUG,
37
+ format='%(asctime)s - %(levelname)s - %(message)s'
38
+ )
39
+
40
+ def test_api_connection() -> str:
41
+ """API ์„œ๋ฒ„ ์—ฐ๊ฒฐ ํ…Œ์ŠคํŠธ"""
42
+ try:
43
+ client = Client(API_URL)
44
+ return "API ์—ฐ๊ฒฐ ์„ฑ๊ณต: ์ •์ƒ ์ž‘๋™ ์ค‘"
45
+ except Exception as e:
46
+ logging.error(f"API connection test failed: {e}")
47
+ return f"API ์—ฐ๊ฒฐ ์‹คํŒจ: {e}"
48
+
49
+ def generate_image(prompt: str, width: float, height: float, guidance: float, inference_steps: float, seed: float):
50
+ """
51
+ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜.
52
+ ์—ฌ๊ธฐ์„œ๋Š” ์„œ๋ฒ„๊ฐ€ ์ตœ์ข… ์ด๋ฏธ์ง€๋ฅผ Base64(๋˜๋Š” data:image/...) ํ˜•ํƒœ๋กœ ์ง์ ‘ ๋ฐ˜ํ™˜ํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
53
+ /tmp/... ๊ฒฝ๋กœ๋‚˜ ์ถ”๊ฐ€ ๋‹ค์šด๋กœ๋“œ๋ฅผ ์‹œ๋„ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
54
+ """
55
+ if not prompt:
56
+ return None, "Error: Prompt is required"
57
+ try:
58
+ logging.info(f"Calling image generation API with prompt: {prompt}")
59
+
60
+ client = Client(API_URL)
61
+ result = client.predict(
62
+ prompt=prompt,
63
+ width=int(width),
64
+ height=int(height),
65
+ guidance=float(guidance),
66
+ inference_steps=int(inference_steps),
67
+ seed=int(seed),
68
+ do_img2img=False,
69
+ init_image=None,
70
+ image2image_strength=0.8,
71
+ resize_img=True,
72
+ api_name="/generate_image"
73
+ )
74
+
75
+ logging.info(
76
+ f"Image generation result: {type(result)}, "
77
+ f"length: {len(result) if isinstance(result, (list, tuple)) else 'unknown'}"
78
+ )
79
+
80
+ # ๊ฒฐ๊ณผ๊ฐ€ ํŠœํ”Œ/๋ฆฌ์ŠคํŠธ: [์ด๋ฏธ์ง€_base64 or data_url, seed_info] ๋กœ ๊ฐ€์ •
81
+ if isinstance(result, (list, tuple)) and len(result) > 0:
82
+ image_data = result[0] # ์ฒซ ๋ฒˆ์งธ ์š”์†Œ๊ฐ€ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ (Base64 or data:image/... ๋“ฑ)
83
+ seed_info = result[1] if len(result) > 1 else "Unknown seed"
84
+ return image_data, seed_info
85
+ else:
86
+ # ๋‹ค๋ฅธ ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜๋œ ๊ฒฝ์šฐ
87
+ return result, "Unknown seed"
88
+
89
+ except Exception as e:
90
+ logging.error(f"Image generation failed: {str(e)}")
91
+ return None, f"Error: {str(e)}"
92
+
93
+ # Base64 ํŒจ๋”ฉ ์ˆ˜์ • ํ•จ์ˆ˜ (ํ•„์š”ํ•˜๋‹ค๋ฉด ์‚ฌ์šฉ)
94
+ def fix_base64_padding(data):
95
+ """Base64 ๋ฌธ์ž์—ด์˜ ํŒจ๋”ฉ์„ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค."""
96
+ if isinstance(data, bytes):
97
+ data = data.decode('utf-8')
98
+
99
+ if "base64," in data:
100
+ data = data.split("base64,", 1)[1]
101
+
102
+ missing_padding = len(data) % 4
103
+ if missing_padding:
104
+ data += '=' * (4 - missing_padding)
105
+
106
+ return data
107
+
108
+ # =============================================================================
109
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜
110
+ # =============================================================================
111
  def clear_cuda_cache():
112
  """CUDA ์บ์‹œ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ๋น„์›๋‹ˆ๋‹ค."""
113
  if torch.cuda.is_available():
114
  torch.cuda.empty_cache()
115
  gc.collect()
116
 
117
+ # =============================================================================
118
+ # SerpHouse ๊ด€๋ จ ํ•จ์ˆ˜
119
+ # =============================================================================
120
  SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
121
 
 
 
 
122
  def extract_keywords(text: str, top_k: int = 5) -> str:
123
+ """๋‹จ์ˆœ ํ‚ค์›Œ๋“œ ์ถ”์ถœ: ํ•œ๊ธ€, ์˜์–ด, ์ˆซ์ž, ๊ณต๋ฐฑ๋งŒ ๋‚จ๊น€"""
 
 
 
 
124
  text = re.sub(r"[^a-zA-Z0-9๊ฐ€-ํžฃ\s]", "", text)
125
  tokens = text.split()
126
+ return " ".join(tokens[:top_k])
 
127
 
 
 
 
 
128
  def do_web_search(query: str) -> str:
129
  """
130
+ SerpHouse LIVE API ํ˜ธ์ถœํ•˜์—ฌ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋งˆํฌ๋‹ค์šด ๋ฐ˜ํ™˜
131
+ (ํ•„์š”ํ•˜๋‹ค๋ฉด ์ˆ˜์ • or ์‚ญ์ œ ๊ฐ€๋Šฅ)
132
  """
133
  try:
134
  url = "https://api.serphouse.com/serp/live"
 
 
135
  params = {
136
  "q": query,
137
  "domain": "google.com",
138
+ "serp_type": "web",
139
  "device": "desktop",
140
  "lang": "en",
141
+ "num": "20"
 
 
 
 
142
  }
143
+ headers = {"Authorization": f"Bearer {SERPHOUSE_API_KEY}"}
144
  logger.info(f"SerpHouse API ํ˜ธ์ถœ ์ค‘... ๊ฒ€์ƒ‰์–ด: {query}")
 
 
 
145
  response = requests.get(url, headers=headers, params=params, timeout=60)
146
  response.raise_for_status()
 
 
147
  data = response.json()
 
 
148
  results = data.get("results", {})
149
  organic = None
 
 
150
  if isinstance(results, dict) and "organic" in results:
151
  organic = results["organic"]
 
 
152
  elif isinstance(results, dict) and "results" in results:
153
  if isinstance(results["results"], dict) and "organic" in results["results"]:
154
  organic = results["results"]["organic"]
 
 
155
  elif "organic" in data:
156
  organic = data["organic"]
 
157
  if not organic:
158
  logger.warning("์‘๋‹ต์—์„œ organic ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
 
 
 
159
  return "No web search results found or unexpected API response structure."
 
 
160
  max_results = min(20, len(organic))
161
  limited_organic = organic[:max_results]
 
 
162
  summary_lines = []
163
  for idx, item in enumerate(limited_organic, start=1):
164
  title = item.get("title", "No title")
165
  link = item.get("link", "#")
166
  snippet = item.get("snippet", "No description")
167
  displayed_link = item.get("displayed_link", link)
 
 
168
  summary_lines.append(
169
  f"### Result {idx}: {title}\n\n"
170
  f"{snippet}\n\n"
171
  f"**์ถœ์ฒ˜**: [{displayed_link}]({link})\n\n"
172
  f"---\n"
173
  )
 
 
174
  instructions = """
175
  # ์›น ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ
176
  ์•„๋ž˜๋Š” ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•  ๋•Œ ์ด ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜์„ธ์š”:
177
+ 1. ์—ฌ๋Ÿฌ ์ถœ์ฒ˜ ๋‚ด์šฉ์„ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€.
178
+ 2. ์ถœ์ฒ˜ ์ธ์šฉ ์‹œ "[์ถœ์ฒ˜ ์ œ๋ชฉ](๋งํฌ)" ๋งˆํฌ๋‹ค์šด ํ˜•์‹ ์‚ฌ์šฉ.
179
+ 3. ๋‹ต๋ณ€ ๋งˆ์ง€๋ง‰์— '์ฐธ๊ณ  ์ž๋ฃŒ:' ์„น์…˜์— ์‚ฌ์šฉํ•œ ์ฃผ์š” ์ถœ์ฒ˜๋ฅผ ๋‚˜์—ด.
 
180
  """
181
+ return instructions + "\n".join(summary_lines)
 
 
 
 
182
  except Exception as e:
183
  logger.error(f"Web search failed: {e}")
184
  return f"Web search failed: {str(e)}"
185
 
186
+ # =============================================================================
187
+ # ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋”ฉ
188
+ # =============================================================================
 
189
  MAX_CONTENT_CHARS = 2000
190
+ MAX_INPUT_LENGTH = 2096
 
191
 
192
+ model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
193
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
194
  model = Gemma3ForConditionalGeneration.from_pretrained(
195
  model_id,
196
  device_map="auto",
197
  torch_dtype=torch.bfloat16,
198
+ attn_implementation="eager"
199
  )
 
200
 
201
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
202
 
203
+ # =============================================================================
204
  # CSV, TXT, PDF ๋ถ„์„ ํ•จ์ˆ˜
205
+ # =============================================================================
206
  def analyze_csv_file(path: str) -> str:
 
 
 
207
  try:
208
  df = pd.read_csv(path)
209
  if df.shape[0] > 50 or df.shape[1] > 10:
 
215
  except Exception as e:
216
  return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
217
 
 
218
  def analyze_txt_file(path: str) -> str:
 
 
 
219
  try:
220
  with open(path, "r", encoding="utf-8") as f:
221
  text = f.read()
 
225
  except Exception as e:
226
  return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
227
 
 
228
  def pdf_to_markdown(pdf_path: str) -> str:
 
 
 
229
  text_chunks = []
230
  try:
231
  with open(pdf_path, "rb") as f:
232
  reader = PyPDF2.PdfReader(f)
233
  max_pages = min(5, len(reader.pages))
234
  for page_num in range(max_pages):
235
+ page_text = reader.pages[page_num].extract_text() or ""
 
236
  page_text = page_text.strip()
237
  if page_text:
238
  if len(page_text) > MAX_CONTENT_CHARS // max_pages:
 
242
  text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...")
243
  except Exception as e:
244
  return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}"
 
245
  full_text = "\n".join(text_chunks)
246
  if len(full_text) > MAX_CONTENT_CHARS:
247
  full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
 
248
  return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
249
 
250
+ # =============================================================================
251
+ # ์ด๋ฏธ์ง€/๋น„๋””์˜ค ํŒŒ์ผ ์ œํ•œ ๊ฒ€์‚ฌ
252
+ # =============================================================================
 
253
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
254
  image_count = 0
255
  video_count = 0
 
260
  image_count += 1
261
  return image_count, video_count
262
 
 
263
  def count_files_in_history(history: list[dict]) -> tuple[int, int]:
264
  image_count = 0
265
  video_count = 0
 
275
  image_count += 1
276
  return image_count, video_count
277
 
 
278
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
279
+ """์ด๋ฏธ์ง€/๋น„๋””์˜ค ์—…๋กœ๋“œ ์ œํ•œ ๊ฒ€์‚ฌ."""
280
+ media_files = [f for f in message["files"]
281
+ if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4")]
 
 
282
  new_image_count, new_video_count = count_files_in_new_message(media_files)
283
  history_image_count, history_video_count = count_files_in_history(history)
284
+
285
  image_count = history_image_count + new_image_count
286
  video_count = history_video_count + new_video_count
287
 
 
298
  if video_count == 0 and image_count > MAX_NUM_IMAGES:
299
  gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
300
  return False
 
301
  if "<image>" in message["text"]:
302
+ image_files = [f for f in message["files"]
303
+ if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
304
  image_tag_count = message["text"].count("<image>")
305
  if image_tag_count != len(image_files):
306
  gr.Warning("The number of <image> tags in the text does not match the number of image files.")
307
  return False
 
308
  return True
309
 
310
+ # =============================================================================
311
+ # ๋น„๋””์˜ค ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
312
+ # =============================================================================
 
313
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
314
  vidcap = cv2.VideoCapture(video_path)
315
  fps = vidcap.get(cv2.CAP_PROP_FPS)
316
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
317
  frame_interval = max(int(fps), int(total_frames / 10))
318
  frames = []
 
319
  for i in range(0, total_frames, frame_interval):
320
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
321
  success, image = vidcap.read()
322
  if success:
323
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
324
  image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5)
325
  pil_image = Image.fromarray(image)
326
  timestamp = round(i / fps, 2)
327
  frames.append((pil_image, timestamp))
328
  if len(frames) >= 5:
329
  break
 
330
  vidcap.release()
331
  return frames
332
 
 
333
  def process_video(video_path: str) -> tuple[list[dict], list[str]]:
334
  content = []
335
+ temp_files = []
 
336
  frames = downsample_video(video_path)
337
+ for pil_image, timestamp in frames:
 
338
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
339
  pil_image.save(temp_file.name)
340
+ temp_files.append(temp_file.name)
341
  content.append({"type": "text", "text": f"Frame {timestamp}:"})
342
  content.append({"type": "image", "url": temp_file.name})
 
343
  return content, temp_files
344
 
345
+ # =============================================================================
346
+ # interleaved <image> ์ฒ˜๋ฆฌ ํ•จ์ˆ˜ (<image> ํƒœ๊ทธ์™€ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ํ˜ผํ•ฉ ์ง€์›)
347
+ # =============================================================================
 
348
  def process_interleaved_images(message: dict) -> list[dict]:
349
  parts = re.split(r"(<image>)", message["text"])
350
  content = []
351
+ image_files = [f for f in message["files"]
352
+ if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
353
  image_index = 0
 
 
 
354
  for part in parts:
355
  if part == "<image>" and image_index < len(image_files):
356
  content.append({"type": "image", "url": image_files[image_index]})
 
362
  content.append({"type": "text", "text": part})
363
  return content
364
 
365
+ # =============================================================================
366
+ # ํŒŒ์ผ ์ฒ˜๋ฆฌ -> content ์ƒ์„ฑ
367
+ # =============================================================================
 
368
  def is_image_file(file_path: str) -> bool:
369
  return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
370
 
 
372
  return file_path.endswith(".mp4")
373
 
374
  def is_document_file(file_path: str) -> bool:
375
+ return file_path.lower().endswith(".pdf") or file_path.lower().endswith(".csv") or file_path.lower().endswith(".txt")
 
 
 
 
 
376
 
377
  def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
378
+ """์‚ฌ์šฉ์ž๊ฐ€ ์ƒˆ๋กœ ์ž…๋ ฅํ•œ ๋ฉ”์‹œ์ง€ + ์—…๋กœ๋“œ ํŒŒ์ผ๋“ค์„ ํ•˜๋‚˜์˜ content(list)๋กœ ๋ณ€ํ™˜."""
379
+ temp_files = []
380
  if not message["files"]:
381
  return [{"type": "text", "text": message["text"]}], temp_files
382
 
 
388
 
389
  content_list = [{"type": "text", "text": message["text"]}]
390
 
391
+ # ๋ฌธ์„œ๋“ค
392
  for csv_path in csv_files:
393
+ content_list.append({"type": "text", "text": analyze_csv_file(csv_path)})
 
 
394
  for txt_path in txt_files:
395
+ content_list.append({"type": "text", "text": analyze_txt_file(txt_path)})
 
 
396
  for pdf_path in pdf_files:
397
+ content_list.append({"type": "text", "text": pdf_to_markdown(pdf_path)})
 
398
 
399
+ # ๋น„๋””์˜ค ์ฒ˜๋ฆฌ
400
  if video_files:
401
  video_content, video_temp_files = process_video(video_files[0])
402
  content_list += video_content
403
  temp_files.extend(video_temp_files)
404
  return content_list, temp_files
405
 
406
+ # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
407
  if "<image>" in message["text"] and image_files:
408
  interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
409
  if content_list and content_list[0]["type"] == "text":
 
415
 
416
  return content_list, temp_files
417
 
418
+ # =============================================================================
 
419
  # history -> LLM ๋ฉ”์‹œ์ง€ ๋ณ€ํ™˜
420
+ # =============================================================================
421
  def process_history(history: list[dict]) -> list[dict]:
422
+ """
423
+ ๊ธฐ์กด ๋Œ€ํ™” ๊ธฐ๋ก์„ LLM์— ๋งž๊ฒŒ ๋ณ€ํ™˜.
424
+ - user -> {"role":"user","content":[{type,text},...]}
425
+ - assistant -> {"role":"assistant","content":[{type:"text",text},...]}
426
+ """
427
  messages = []
428
+ current_user_content = []
429
  for item in history:
430
  if item["role"] == "assistant":
431
+ # ์‚ฌ์šฉ์ž content ๋ˆ„์ ๋ถ„์ด ์žˆ์œผ๋ฉด ํ•œ๋ฒˆ์— user๋กœ ์ถ”๊ฐ€
432
  if current_user_content:
433
  messages.append({"role": "user", "content": current_user_content})
434
  current_user_content = []
435
+ # assistant ๋ฐ”๋กœ ์ถ”๊ฐ€
436
  messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
437
  else:
438
  content = item["content"]
 
444
  current_user_content.append({"type": "image", "url": file_path})
445
  else:
446
  current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"})
 
447
  if current_user_content:
448
  messages.append({"role": "user", "content": current_user_content})
 
449
  return messages
450
 
451
+ # =============================================================================
452
+ # ๋ชจ๋ธ ์ƒ์„ฑ ํ•จ์ˆ˜ (OOM ์บ์น˜)
453
+ # =============================================================================
 
454
  def _model_gen_with_oom_catch(**kwargs):
 
 
 
455
  try:
456
  model.generate(**kwargs)
457
  except torch.cuda.OutOfMemoryError:
458
+ raise RuntimeError("[OutOfMemoryError] GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค.")
 
 
 
459
  finally:
 
460
  clear_cuda_cache()
461
 
462
+ # =============================================================================
 
463
  # ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
464
+ # =============================================================================
 
 
465
  @spaces.GPU(duration=120)
466
  def run(
467
  message: dict,
 
471
  use_web_search: bool = False,
472
  web_search_query: str = "",
473
  age_group: str = "20๋Œ€",
 
474
  mbti_personality: str = "INTP",
475
  sexual_openness: int = 2,
476
+ image_gen: bool = False
477
  ) -> Iterator[str]:
478
+ """
479
+ LLM ์ถ”๋ก  ํ•จ์ˆ˜.
480
+ - ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹œ, ์„œ๋ฒ„๊ฐ€ Base64(๋˜๋Š” data:image/... ํ˜•ํƒœ)๋ฅผ ์ง์ ‘ ๋ฐ˜ํ™˜ํ•œ๋‹ค๊ณ  ๊ฐ€์ •.
481
+ - /tmp/... ํŒŒ์ผ์— ๋Œ€ํ•œ ์žฌ๋‹ค์šด๋กœ๋“œ๋ฅผ ์‹œ๋„ํ•˜์ง€ ์•Š์Œ (403 Forbidden ๋ฌธ์ œ ํšŒํ”ผ).
482
+ """
483
  if not validate_media_constraints(message, history):
484
  yield ""
485
  return
486
 
487
+ temp_files = []
 
488
  try:
489
+ # 1) ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ + ํŽ˜๋ฅด์†Œ๋‚˜ ์ •๋ณด
490
+ persona = (
 
 
 
491
  f"{system_prompt.strip()}\n\n"
492
  f"Gender: Female\n"
493
  f"Age Group: {age_group}\n"
494
+ f"MBTI Persona: {mbti_personality}\n"
495
+ f"Sexual Openness (1~5): {sexual_openness}\n"
496
  )
497
+ combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
 
 
 
 
 
498
 
499
+ # 2) ์›น ๊ฒ€์ƒ‰ (์˜ต์…˜)
500
  if use_web_search:
501
  user_text = message["text"]
502
+ ws_query = extract_keywords(user_text)
503
  if ws_query.strip():
504
  logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
505
  ws_result = do_web_search(ws_query)
506
+ combined_system_msg += f"[Search top-20 Full Items]\n{ws_result}\n\n"
507
+ combined_system_msg += (
508
+ "[์ฐธ๊ณ : ์œ„ ๊ฒ€์ƒ‰๊ฒฐ๊ณผ link๋ฅผ ์ถœ์ฒ˜๋กœ ์ธ์šฉํ•˜์—ฌ ๋‹ต๋ณ€]\n"
509
+ "[์ค‘์š” ์ง€์‹œ์‚ฌํ•ญ]\n"
510
+ "1. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์—์„œ ์ฐพ์€ ์ •๋ณด์˜ ์ถœ์ฒ˜๋ฅผ ๋ฐ˜๋“œ์‹œ ์ธ์šฉ.\n"
511
+ "2. '[์ถœ์ฒ˜ ์ œ๋ชฉ](๋งํฌ)' ํ˜•์‹์œผ๋กœ ๋งํฌ.\n"
512
+ "3. ๋‹ต๋ณ€ ๋งˆ์ง€๋ง‰์— '์ฐธ๊ณ  ์ž๋ฃŒ:' ์„น์…˜.\n"
513
+ )
 
 
514
  else:
515
  combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
516
 
517
+ # 3) ๊ธฐ์กด history + ์ƒˆ user ๋ฉ”์‹œ์ง€
518
  messages = []
 
519
  if combined_system_msg.strip():
520
+ messages.append({"role": "system", "content": [{"type": "text", "text": combined_system_msg.strip()}]})
 
 
 
 
 
521
  messages.extend(process_history(history))
522
 
 
523
  user_content, user_temp_files = process_new_user_message(message)
524
+ temp_files.extend(user_temp_files)
525
+
526
  for item in user_content:
527
  if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
528
  item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
529
+
530
  messages.append({"role": "user", "content": user_content})
531
 
532
+ # 4) ํ† ํฌ๋‚˜์ด์ง•
533
  inputs = processor.apply_chat_template(
534
  messages,
535
  add_generation_prompt=True,
 
537
  return_dict=True,
538
  return_tensors="pt",
539
  ).to(device=model.device, dtype=torch.bfloat16)
 
 
540
  if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH:
541
  inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
542
  if 'attention_mask' in inputs:
543
  inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
544
+
545
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
546
+ gen_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
 
 
 
 
547
 
548
  t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
549
  t.start()
550
 
551
+ # ์ŠคํŠธ๋ฆฌ๋ฐ ์ถœ๋ ฅ
552
+ output_so_far = ""
553
  for new_text in streamer:
554
+ output_so_far += new_text
555
+ yield output_so_far
556
+
557
+ # 5) ์ด๋ฏธ์ง€ ์ƒ์„ฑ (Base64)
558
+ if image_gen:
559
+ last_user_text = message["text"].strip()
560
+ if not last_user_text:
561
+ yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: Empty user prompt)"
562
+ else:
563
+ try:
564
+ width, height = 512, 512
565
+ guidance, steps, seed = 7.5, 30, 42
566
+
567
+ logger.info(f"Generating image with prompt: {last_user_text}")
568
+
569
+ # API ํ˜ธ์ถœํ•ด์„œ (base64) ์ด๋ฏธ์ง€ ์ƒ์„ฑ
570
+ image_result, seed_info = generate_image(
571
+ prompt=last_user_text,
572
+ width=width,
573
+ height=height,
574
+ guidance=guidance,
575
+ inference_steps=steps,
576
+ seed=seed
577
+ )
578
+
579
+ logger.info(f"Received image data type: {type(image_result)}")
580
+
581
+ # Base64 or data:image/... ์ฒ˜๋ฆฌ
582
+ if image_result:
583
+ if isinstance(image_result, str):
584
+ # ์ด๋ฏธ data:image/๋กœ ์‹œ์ž‘ํ•˜๋ฉด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
585
+ if image_result.startswith("data:image/"):
586
+ final_md = f"\n\n**[์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]**\n\n![์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]({image_result})"
587
+ yield output_so_far + final_md
588
+ else:
589
+ # ์ˆœ์ˆ˜ base64๋กœ ํŒ๋‹จ(๋‹จ, ์ผ๋ฐ˜ URL์ด๋‚˜ '/tmp/...'์ด๋ฉด ์ฒ˜๋ฆฌ ๋ถˆ๊ฐ€)
590
+ if len(image_result) > 100 and "/" not in image_result:
591
+ # base64
592
+ image_data = "data:image/webp;base64," + image_result
593
+ final_md = f"\n\n**[์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]**\n\n![์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]({image_data})"
594
+ yield output_so_far + final_md
595
+ else:
596
+ # ๊ทธ ์™ธ (ex. http://..., /tmp/...) -> 403 ๋ฌธ์ œ ๋ฐœ์ƒํ•˜๋ฏ€๋กœ ํ‘œ์‹œ ์•ˆ ํ•จ
597
+ yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ base64 ํ˜•์‹์ด ์•„๋‹™๋‹ˆ๋‹ค)"
598
+ else:
599
+ yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ ๋ฌธ์ž์—ด์ด ์•„๋‹˜)"
600
+ else:
601
+ yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: {seed_info})"
602
+
603
+ except Exception as e:
604
+ logger.error(f"Image generation error: {e}")
605
+ yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e})"
606
 
607
  except Exception as e:
608
  logger.error(f"Error in run: {str(e)}")
609
  yield f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
 
610
  finally:
611
+ for tmp in temp_files:
 
612
  try:
613
+ if os.path.exists(tmp):
614
+ os.unlink(tmp)
615
+ logger.info(f"Deleted temp file: {tmp}")
616
+ except Exception as ee:
617
+ logger.warning(f"Failed to delete temp file {tmp}: {ee}")
 
 
618
  try:
619
  del inputs, streamer
620
+ except Exception:
621
  pass
 
622
  clear_cuda_cache()
623
 
624
+ # =============================================================================
625
+ # ์˜ˆ์‹œ๋“ค
626
+ # =============================================================================
 
627
  examples = [
 
628
  [
629
  {
630
  "text": "Compare the contents of the two PDF files.",
 
640
  "files": ["assets/additional-examples/sample-csv.csv"],
641
  }
642
  ],
643
+ # ... ๋‚˜๋จธ์ง€ ์˜ˆ์‹œ ํ•„์š”ํ•˜๋‹ค๋ฉด ์ถ”๊ฐ€ ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  ]
645
 
646
+ # =============================================================================
647
+ # Gradio UI (Blocks) ๊ตฌ์„ฑ
648
+ # =============================================================================
649
+
650
  css = """
 
651
  .gradio-container {
652
+ background: rgba(255, 255, 255, 0.7);
653
  padding: 30px 40px;
654
+ margin: 20px auto;
655
  width: 100% !important;
656
+ max-width: none !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  }
658
  """
 
659
  title_html = """
660
+ <h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> ๐Ÿ’˜ HeartSync : Love Dating AI ๐Ÿ’˜ </h1>
661
  <p align="center" style="font-size:1.1em; color:#555;">
662
+ โœ… FLUX Image Generation โœ… Reasoning & Uncensored โœ… Multimodal & VLM โœ… Deep-Research & RAG <br>
663
  </p>
664
  """
665
 
666
  with gr.Blocks(css=css, title="HeartSync") as demo:
667
  gr.Markdown(title_html)
668
+
669
+ # ๋ณ„๋„ ๊ฐค๋Ÿฌ๋ฆฌ ์˜ˆ์‹œ (ํ•„์š” ์‹œ ์‚ฌ์šฉ)
670
+ generated_images = gr.Gallery(
671
+ label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€",
672
+ show_label=True,
673
+ visible=False,
674
+ elem_id="generated_images",
675
+ columns=2,
676
+ height="auto",
677
+ object_fit="contain"
678
  )
679
+
680
+ with gr.Row():
681
+ web_search_checkbox = gr.Checkbox(label="Deep Research", value=False)
682
+ image_gen_checkbox = gr.Checkbox(label="Image Gen", value=False)
683
+
684
  base_system_prompt_box = gr.Textbox(
685
  lines=3,
686
+ value="You are a deep thinking AI...\nํŽ˜๋ฅด์†Œ๋‚˜: ๋‹น์‹ ์€ ๋‹ฌ์ฝคํ•˜๊ณ ...",
 
 
 
 
 
 
687
  label="๊ธฐ๋ณธ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ",
688
+ visible=False
689
  )
 
 
690
  with gr.Row():
691
  age_group_dropdown = gr.Dropdown(
692
  label="์—ฐ๋ น๋Œ€ ์„ ํƒ (๊ธฐ๋ณธ 20๋Œ€)",
693
+ choices=["10๋Œ€", "20๋Œ€", "30~40๋Œ€", "50~60๋Œ€", "70๋Œ€ ์ด์ƒ"],
694
  value="20๋Œ€",
695
  interactive=True
696
  )
 
 
 
 
 
 
 
 
697
  mbti_choices = [
698
  "INTJ (์šฉ์˜์ฃผ๋„ํ•œ ์ „๋žต๊ฐ€)",
699
  "INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
 
718
  value="INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
719
  interactive=True
720
  )
 
 
721
  sexual_openness_slider = gr.Slider(
722
  minimum=1, maximum=5, step=1, value=2,
723
  label="์„น์Šˆ์–ผ ๊ด€์‹ฌ๋„/๊ฐœ๋ฐฉ์„ฑ (1~5, ๊ธฐ๋ณธ=2)",
724
  interactive=True
725
  )
 
 
726
  max_tokens_slider = gr.Slider(
727
  label="Max New Tokens",
728
+ minimum=100, maximum=8000, step=50, value=1000,
729
+ visible=False
 
 
 
730
  )
 
 
731
  web_search_text = gr.Textbox(
732
  lines=1,
733
  label="(Unused) Web Search Query",
734
  placeholder="No direct input needed",
735
+ visible=False
736
  )
737
+
738
+ def modified_run(
739
+ message, history, system_prompt, max_new_tokens,
740
+ use_web_search, web_search_query,
741
+ age_group, mbti_personality, sexual_openness, image_gen
742
+ ):
743
+ """
744
+ run() ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ํ…์ŠคํŠธ ์ŠคํŠธ๋ฆผ์„ ๋ฐ›๊ณ ,
745
+ ํ•„์š” ์‹œ ์ถ”๊ฐ€ ์ฒ˜๋ฆฌ ํ›„ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜ (๊ฐค๋Ÿฌ๋ฆฌ ์—…๋ฐ์ดํŠธ ๋“ฑ).
746
+ """
747
+ output_so_far = ""
748
+ gallery_update = gr.Gallery(visible=False, value=[])
749
+ yield output_so_far, gallery_update
750
+
751
+ text_generator = run(
752
+ message, history,
753
+ system_prompt, max_new_tokens,
754
+ use_web_search, web_search_query,
755
+ age_group, mbti_personality,
756
+ sexual_openness, image_gen
757
+ )
758
+
759
+ for text_chunk in text_generator:
760
+ output_so_far = text_chunk
761
+ yield output_so_far, gallery_update
762
+
763
+ # ๋งŒ์•ฝ run() ๋‚ด๋ถ€์—์„œ Base64 ์ด๋ฏธ์ง€๋ฅผ ์ด๋ฏธ ๋Œ€ํ™”์ฐฝ์— ์‚ฝ์ž…ํ–ˆ๋‹ค๋ฉด,
764
+ # ์—ฌ๊ธฐ์„œ ๊ฐค๋Ÿฌ๋ฆฌ์— ๋”ฐ๋กœ ํ‘œ์‹œํ•  ํ•„์š”๋Š” ์—†์„ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
765
+ # run() ๋‚ด๋ถ€์—์„œ์˜ image_result๋ฅผ ๊ฐ€์ ธ์˜ค๋ ค๋ฉด, run() ํ•จ์ˆ˜๊ฐ€ ํ•ด๋‹น ์ •๋ณด๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋„๋ก ์ถ”๊ฐ€ ์ˆ˜์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
766
+
767
  chat = gr.ChatInterface(
768
+ fn=modified_run,
769
  type="messages",
770
  chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
771
  textbox=gr.MultimodalTextbox(
772
+ file_types=[".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf"],
 
 
 
773
  file_count="multiple",
774
  autofocus=True
775
  ),
 
780
  web_search_checkbox,
781
  web_search_text,
782
  age_group_dropdown,
 
783
  mbti_dropdown,
784
  sexual_openness_slider,
785
+ image_gen_checkbox,
786
  ],
787
+ additional_outputs=[generated_images],
788
  stop_btn=False,
789
  title='<a href="https://discord.gg/openfreeai" target="_blank">https://discord.gg/openfreeai</a>',
790
  examples=examples,
 
794
  delete_cache=(1800, 1800),
795
  )
796
 
 
797
  with gr.Row(elem_id="examples_row"):
798
  with gr.Column(scale=12, elem_id="examples_container"):
799
  gr.Markdown("### Example Inputs (click to load)")
800
 
801
  if __name__ == "__main__":
802
+ demo.launch(share=True)