ginipick's picture
Update app.py
4cbe1cd verified
raw
history blame
40.9 kB
#!/usr/bin/env python
import os
import re
import tempfile
import gc # garbage collector ์ถ”๊ฐ€
from collections.abc import Iterator
from threading import Thread
import json
import requests
import cv2
import base64
import logging
import time
from urllib.parse import quote # URL ์ธ์ฝ”๋”ฉ์„ ์œ„ํ•ด ์ถ”๊ฐ€
import gradio as gr
import spaces
import torch
from loguru import logger
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
# CSV/TXT/PDF ๋ถ„์„
import pandas as pd
import PyPDF2
# =============================================================================
# (์‹ ๊ทœ) ์ด๋ฏธ์ง€ API ๊ด€๋ จ ํ•จ์ˆ˜๋“ค
# =============================================================================
from gradio_client import Client
API_URL = "http://211.233.58.201:7896"
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def test_api_connection() -> str:
"""API ์„œ๋ฒ„ ์—ฐ๊ฒฐ ํ…Œ์ŠคํŠธ"""
try:
client = Client(API_URL)
return "API ์—ฐ๊ฒฐ ์„ฑ๊ณต: ์ •์ƒ ์ž‘๋™ ์ค‘"
except Exception as e:
logging.error(f"API ์—ฐ๊ฒฐ ํ…Œ์ŠคํŠธ ์‹คํŒจ: {e}")
return f"API ์—ฐ๊ฒฐ ์‹คํŒจ: {e}"
def generate_image(prompt: str, width: float, height: float, guidance: float, inference_steps: float, seed: float):
"""์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜ (๋ฐ˜ํ™˜ ํ˜•์‹์— ์œ ์—ฐํ•˜๊ฒŒ ๋Œ€์‘)"""
if not prompt:
return None, "์˜ค๋ฅ˜: ํ”„๋กฌํ”„ํŠธ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค."
try:
logging.info(f"ํ”„๋กฌํ”„ํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€ ์ƒ์„ฑ API ํ˜ธ์ถœ: {prompt}")
client = Client(API_URL)
result = client.predict(
prompt=prompt,
width=int(width),
height=int(height),
guidance=float(guidance),
inference_steps=int(inference_steps),
seed=int(seed),
do_img2img=False,
init_image=None,
image2image_strength=0.8,
resize_img=True,
api_name="/generate_image"
)
logging.info(f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ฒฐ๊ณผ: {type(result)}, ๊ธธ์ด: {len(result) if isinstance(result, (list, tuple)) else '์•Œ ์ˆ˜ ์—†์Œ'}")
# ๊ฒฐ๊ณผ๊ฐ€ ํŠœํ”Œ์ด๋‚˜ ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜๋˜๋Š” ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
if isinstance(result, (list, tuple)) and len(result) > 0:
image_data = result[0] # ์ฒซ ๋ฒˆ์งธ ์š”์†Œ๊ฐ€ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ
seed_info = result[1] if len(result) > 1 else "์•Œ ์ˆ˜ ์—†๋Š” ์‹œ๋“œ"
return image_data, seed_info
else:
# ๋‹ค๋ฅธ ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜๋œ ๊ฒฝ์šฐ (๋‹จ์ผ ๊ฐ’์ธ ๊ฒฝ์šฐ)
return result, "์•Œ ์ˆ˜ ์—†๋Š” ์‹œ๋“œ"
except Exception as e:
logging.error(f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: {str(e)}")
return None, f"์˜ค๋ฅ˜: {str(e)}"
# Base64 ํŒจ๋”ฉ ์ˆ˜์ • ํ•จ์ˆ˜
def fix_base64_padding(data):
"""Base64 ๋ฌธ์ž์—ด์˜ ํŒจ๋”ฉ์„ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค."""
if isinstance(data, bytes):
data = data.decode('utf-8')
# base64,๋กœ ์‹œ์ž‘ํ•˜๋Š” ๋ถ€๋ถ„ ์ œ๊ฑฐ
if "base64," in data:
data = data.split("base64,", 1)[1]
# ํŒจ๋”ฉ ๋ฌธ์ž ์ถ”๊ฐ€ (4์˜ ๋ฐฐ์ˆ˜ ๊ธธ์ด๊ฐ€ ๋˜๋„๋ก)
missing_padding = len(data) % 4
if missing_padding:
data += '=' * (4 - missing_padding)
return data
# =============================================================================
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜
# =============================================================================
def clear_cuda_cache():
"""CUDA ์บ์‹œ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ๋น„์›๋‹ˆ๋‹ค."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# =============================================================================
# SerpHouse ๊ด€๋ จ ํ•จ์ˆ˜
# =============================================================================
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
def extract_keywords(text: str, top_k: int = 5) -> str:
"""๋‹จ์ˆœ ํ‚ค์›Œ๋“œ ์ถ”์ถœ: ํ•œ๊ธ€, ์˜์–ด, ์ˆซ์ž, ๊ณต๋ฐฑ๋งŒ ๋‚จ๊น€"""
text = re.sub(r"[^a-zA-Z0-9๊ฐ€-ํžฃ\s]", "", text)
tokens = text.split()
return " ".join(tokens[:top_k])
def do_web_search(query: str) -> str:
"""SerpHouse LIVE API ํ˜ธ์ถœํ•˜์—ฌ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋งˆํฌ๋‹ค์šด ๋ฐ˜ํ™˜"""
try:
url = "https://api.serphouse.com/serp/live"
params = {
"q": query,
"domain": "google.com",
"serp_type": "web",
"device": "desktop",
"lang": "en",
"num": "20"
}
headers = {"Authorization": f"Bearer {SERPHOUSE_API_KEY}"}
logger.info(f"SerpHouse API ํ˜ธ์ถœ ์ค‘... ๊ฒ€์ƒ‰์–ด: {query}")
response = requests.get(url, headers=headers, params=params, timeout=60)
response.raise_for_status()
data = response.json()
results = data.get("results", {})
organic = None
if isinstance(results, dict) and "organic" in results:
organic = results["organic"]
elif isinstance(results, dict) and "results" in results:
if isinstance(results["results"], dict) and "organic" in results["results"]:
organic = results["results"]["organic"]
elif "organic" in data:
organic = data["organic"]
if not organic:
logger.warning("์‘๋‹ต์—์„œ organic ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
return "์›น ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๊ฐ€ ์—†๊ฑฐ๋‚˜ API ์‘๋‹ต ๊ตฌ์กฐ๊ฐ€ ์˜ˆ์ƒ๊ณผ ๋‹ค๋ฆ…๋‹ˆ๋‹ค."
max_results = min(20, len(organic))
limited_organic = organic[:max_results]
summary_lines = []
for idx, item in enumerate(limited_organic, start=1):
title = item.get("title", "์ œ๋ชฉ ์—†์Œ")
link = item.get("link", "#")
snippet = item.get("snippet", "์„ค๋ช… ์—†์Œ")
displayed_link = item.get("displayed_link", link)
summary_lines.append(
f"### ๊ฒฐ๊ณผ {idx}: {title}\n\n"
f"{snippet}\n\n"
f"**์ถœ์ฒ˜**: [{displayed_link}]({link})\n\n"
f"---\n"
)
instructions = """
# ์›น ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ
์•„๋ž˜๋Š” ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•  ๋•Œ ์ด ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜์„ธ์š”:
1. ๊ฐ ๊ฒฐ๊ณผ์˜ ์ œ๋ชฉ, ๋‚ด์šฉ, ์ถœ์ฒ˜ ๋งํฌ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.
2. ๋‹ต๋ณ€์— ๊ด€๋ จ ์ •๋ณด์˜ ์ถœ์ฒ˜๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์ธ์šฉํ•˜์„ธ์š” (์˜ˆ: "[์ถœ์ฒ˜ ์ œ๋ชฉ](๋งํฌ)").
3. ์‘๋‹ต์— ์‹ค์ œ ์ถœ์ฒ˜ ๋งํฌ๋ฅผ ํฌํ•จํ•˜์„ธ์š”.
4. ์—ฌ๋Ÿฌ ์ถœ์ฒ˜์˜ ์ •๋ณด๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€ํ•˜์„ธ์š”.
5. ๋งˆ์ง€๋ง‰์— "์ฐธ๊ณ  ์ž๋ฃŒ:" ์„น์…˜์„ ์ถ”๊ฐ€ํ•˜๊ณ  ์ฃผ์š” ์ถœ์ฒ˜ ๋งํฌ๋ฅผ ๋‚˜์—ดํ•˜์„ธ์š”.
"""
return instructions + "\n".join(summary_lines)
except Exception as e:
logger.error(f"์›น ๊ฒ€์ƒ‰ ์‹คํŒจ: {e}")
return f"์›น ๊ฒ€์ƒ‰ ์‹คํŒจ: {str(e)}"
# =============================================================================
# ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋”ฉ
# =============================================================================
MAX_CONTENT_CHARS = 2000
MAX_INPUT_LENGTH = 2096
model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager"
)
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
# =============================================================================
# CSV, TXT, PDF ๋ถ„์„ ํ•จ์ˆ˜๋“ค
# =============================================================================
def analyze_csv_file(path: str) -> str:
try:
df = pd.read_csv(path)
if df.shape[0] > 50 or df.shape[1] > 10:
df = df.iloc[:50, :10]
df_str = df.to_string()
if len(df_str) > MAX_CONTENT_CHARS:
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(์ผ๋ถ€ ์ƒ๋žต)..."
return f"**[CSV ํŒŒ์ผ: {os.path.basename(path)}]**\n\n{df_str}"
except Exception as e:
return f"CSV ํŒŒ์ผ ์ฝ๊ธฐ ์‹คํŒจ ({os.path.basename(path)}): {str(e)}"
def analyze_txt_file(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as f:
text = f.read()
if len(text) > MAX_CONTENT_CHARS:
text = text[:MAX_CONTENT_CHARS] + "\n...(์ผ๋ถ€ ์ƒ๋žต)..."
return f"**[TXT ํŒŒ์ผ: {os.path.basename(path)}]**\n\n{text}"
except Exception as e:
return f"TXT ํŒŒ์ผ ์ฝ๊ธฐ ์‹คํŒจ ({os.path.basename(path)}): {str(e)}"
def pdf_to_markdown(pdf_path: str) -> str:
text_chunks = []
try:
with open(pdf_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
max_pages = min(5, len(reader.pages))
for page_num in range(max_pages):
page_text = reader.pages[page_num].extract_text() or ""
page_text = page_text.strip()
if page_text:
if len(page_text) > MAX_CONTENT_CHARS // max_pages:
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(์ผ๋ถ€ ์ƒ๋žต)"
text_chunks.append(f"## ํŽ˜์ด์ง€ {page_num+1}\n\n{page_text}\n")
if len(reader.pages) > max_pages:
text_chunks.append(f"\n...(์ „์ฒด {len(reader.pages)}ํŽ˜์ด์ง€ ์ค‘ {max_pages}ํŽ˜์ด์ง€๋งŒ ํ‘œ์‹œ)...")
except Exception as e:
return f"PDF ํŒŒ์ผ ์ฝ๊ธฐ ์‹คํŒจ ({os.path.basename(pdf_path)}): {str(e)}"
full_text = "\n".join(text_chunks)
if len(full_text) > MAX_CONTENT_CHARS:
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(์ผ๋ถ€ ์ƒ๋žต)..."
return f"**[PDF ํŒŒ์ผ: {os.path.basename(pdf_path)}]**\n\n{full_text}"
# =============================================================================
# ์ด๋ฏธ์ง€/๋น„๋””์˜ค ํŒŒ์ผ ์ œํ•œ ๊ฒ€์‚ฌ
# =============================================================================
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
image_count = 0
video_count = 0
for path in paths:
if path.endswith(".mp4"):
video_count += 1
elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE):
image_count += 1
return image_count, video_count
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
image_count = 0
video_count = 0
for item in history:
if item["role"] != "user" or isinstance(item["content"], str):
continue
if isinstance(item["content"], list) and len(item["content"]) > 0:
file_path = item["content"][0]
if isinstance(file_path, str):
if file_path.endswith(".mp4"):
video_count += 1
elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE):
image_count += 1
return image_count, video_count
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
media_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4")]
new_image_count, new_video_count = count_files_in_new_message(media_files)
history_image_count, history_video_count = count_files_in_history(history)
image_count = history_image_count + new_image_count
video_count = history_video_count + new_video_count
if video_count > 1:
gr.Warning("๋น„๋””์˜ค ํŒŒ์ผ์€ ํ•˜๋‚˜๋งŒ ์ง€์›๋ฉ๋‹ˆ๋‹ค.")
return False
if video_count == 1:
if image_count > 0:
gr.Warning("์ด๋ฏธ์ง€์™€ ๋น„๋””์˜ค๋ฅผ ํ˜ผํ•ฉํ•˜๋Š” ๊ฒƒ์€ ํ—ˆ์šฉ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
return False
if "<image>" in message["text"]:
gr.Warning("<image> ํƒœ๊ทธ์™€ ๋น„๋””์˜ค ํŒŒ์ผ์€ ํ•จ๊ป˜ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
return False
if video_count == 0 and image_count > MAX_NUM_IMAGES:
gr.Warning(f"์ตœ๋Œ€ {MAX_NUM_IMAGES}์žฅ์˜ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
return False
if "<image>" in message["text"]:
image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
image_tag_count = message["text"].count("<image>")
if image_tag_count != len(image_files):
gr.Warning("ํ…์ŠคํŠธ์— ์žˆ๋Š” <image> ํƒœ๊ทธ์˜ ๊ฐœ์ˆ˜๊ฐ€ ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฐœ์ˆ˜์™€ ์ผ์น˜ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
return False
return True
# =============================================================================
# ๋น„๋””์˜ค ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
# =============================================================================
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
vidcap = cv2.VideoCapture(video_path)
fps = vidcap.get(cv2.CAP_PROP_FPS)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval = max(int(fps), int(total_frames / 10))
frames = []
for i in range(0, total_frames, frame_interval):
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
if len(frames) >= 5:
break
vidcap.release()
return frames
def process_video(video_path: str) -> tuple[list[dict], list[str]]:
content = []
temp_files = []
frames = downsample_video(video_path)
for pil_image, timestamp in frames:
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
pil_image.save(temp_file.name)
temp_files.append(temp_file.name)
content.append({"type": "text", "text": f"ํ”„๋ ˆ์ž„ {timestamp}:"})
content.append({"type": "image", "url": temp_file.name})
return content, temp_files
# =============================================================================
# interleaved <image> ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
# =============================================================================
def process_interleaved_images(message: dict) -> list[dict]:
parts = re.split(r"(<image>)", message["text"])
content = []
image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
image_index = 0
for part in parts:
if part == "<image>" and image_index < len(image_files):
content.append({"type": "image", "url": image_files[image_index]})
image_index += 1
elif part.strip():
content.append({"type": "text", "text": part.strip()})
else:
if isinstance(part, str) and part != "<image>":
content.append({"type": "text", "text": part})
return content
# =============================================================================
# ํŒŒ์ผ ์ฒ˜๋ฆฌ -> content ์ƒ์„ฑ
# =============================================================================
def is_image_file(file_path: str) -> bool:
return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
def is_video_file(file_path: str) -> bool:
return file_path.endswith(".mp4")
def is_document_file(file_path: str) -> bool:
return file_path.lower().endswith(".pdf") or file_path.lower().endswith(".csv") or file_path.lower().endswith(".txt")
def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
temp_files = []
if not message["files"]:
return [{"type": "text", "text": message["text"]}], temp_files
video_files = [f for f in message["files"] if is_video_file(f)]
image_files = [f for f in message["files"] if is_image_file(f)]
csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
content_list = [{"type": "text", "text": message["text"]}]
for csv_path in csv_files:
content_list.append({"type": "text", "text": analyze_csv_file(csv_path)})
for txt_path in txt_files:
content_list.append({"type": "text", "text": analyze_txt_file(txt_path)})
for pdf_path in pdf_files:
content_list.append({"type": "text", "text": pdf_to_markdown(pdf_path)})
if video_files:
video_content, video_temp_files = process_video(video_files[0])
content_list += video_content
temp_files.extend(video_temp_files)
return content_list, temp_files
if "<image>" in message["text"] and image_files:
interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
if content_list and content_list[0]["type"] == "text":
content_list = content_list[1:]
return interleaved_content + content_list, temp_files
else:
for img_path in image_files:
content_list.append({"type": "image", "url": img_path})
return content_list, temp_files
# =============================================================================
# history -> LLM ๋ฉ”์‹œ์ง€ ๋ณ€ํ™˜
# =============================================================================
def process_history(history: list[dict]) -> list[dict]:
messages = []
current_user_content = []
for item in history:
if item["role"] == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
else:
content = item["content"]
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
elif isinstance(content, list) and len(content) > 0:
file_path = content[0]
if is_image_file(file_path):
current_user_content.append({"type": "image", "url": file_path})
else:
current_user_content.append({"type": "text", "text": f"[ํŒŒ์ผ: {os.path.basename(file_path)}]"})
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
return messages
# =============================================================================
# ๋ชจ๋ธ ์ƒ์„ฑ ํ•จ์ˆ˜ (OOM ์บ์น˜)
# =============================================================================
def _model_gen_with_oom_catch(**kwargs):
try:
model.generate(**kwargs)
except torch.cuda.OutOfMemoryError:
raise RuntimeError("[OutOfMemoryError] GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค.")
finally:
clear_cuda_cache()
# =============================================================================
# ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
# =============================================================================
@spaces.GPU(duration=120)
def run(
message: dict,
history: list[dict],
system_prompt: str = "",
max_new_tokens: int = 512,
use_web_search: bool = False,
web_search_query: str = "",
age_group: str = "20๋Œ€",
mbti_personality: str = "INTP",
sexual_openness: int = 2,
image_gen: bool = False # "Image Gen" ์ฒดํฌ ์—ฌ๋ถ€
) -> Iterator[str]:
if not validate_media_constraints(message, history):
yield ""
return
temp_files = []
try:
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ์— ํŽ˜๋ฅด์†Œ๋‚˜ ์ •๋ณด ์ถ”๊ฐ€
persona = (
f"{system_prompt.strip()}\n\n"
f"์„ฑ๋ณ„: ์—ฌ์„ฑ\n"
f"์—ฐ๋ น๋Œ€: {age_group}\n"
f"MBTI ํŽ˜๋ฅด์†Œ๋‚˜: {mbti_personality}\n"
f"์„น์Šˆ์–ผ ๊ฐœ๋ฐฉ์„ฑ (1~5): {sexual_openness}\n"
)
combined_system_msg = f"[์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ]\n{persona.strip()}\n\n"
if use_web_search:
user_text = message["text"]
ws_query = extract_keywords(user_text)
if ws_query.strip():
logger.info(f"[์ž๋™ ์›น ๊ฒ€์ƒ‰ ํ‚ค์›Œ๋“œ] {ws_query!r}")
ws_result = do_web_search(ws_query)
combined_system_msg += f"[๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ (์ƒ์œ„ 20๊ฐœ ํ•ญ๋ชฉ)]\n{ws_result}\n\n"
combined_system_msg += (
"[์ฐธ๊ณ : ์œ„ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋งํฌ๋ฅผ ์ถœ์ฒ˜๋กœ ์ธ์šฉํ•˜์—ฌ ๋‹ต๋ณ€]\n"
"[์ค‘์š” ์ง€์‹œ์‚ฌํ•ญ]\n"
"1. ๋‹ต๋ณ€์— ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์—์„œ ์ฐพ์€ ์ •๋ณด์˜ ์ถœ์ฒ˜๋ฅผ ๋ฐ˜๋“œ์‹œ ์ธ์šฉํ•˜์„ธ์š”.\n"
"2. ์ถœ์ฒ˜ ์ธ์šฉ ์‹œ \"[์ถœ์ฒ˜ ์ œ๋ชฉ](๋งํฌ)\" ํ˜•์‹์˜ ๋งˆํฌ๋‹ค์šด ๋งํฌ๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”.\n"
"3. ์—ฌ๋Ÿฌ ์ถœ์ฒ˜์˜ ์ •๋ณด๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€ํ•˜์„ธ์š”.\n"
"4. ๋‹ต๋ณ€ ๋งˆ์ง€๋ง‰์— \"์ฐธ๊ณ  ์ž๋ฃŒ:\" ์„น์…˜์„ ์ถ”๊ฐ€ํ•˜๊ณ  ์‚ฌ์šฉํ•œ ์ฃผ์š” ์ถœ์ฒ˜ ๋งํฌ๋ฅผ ๋‚˜์—ดํ•˜์„ธ์š”.\n"
)
else:
combined_system_msg += "[์œ ํšจํ•œ ํ‚ค์›Œ๋“œ๊ฐ€ ์—†์–ด ์›น ๊ฒ€์ƒ‰์„ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค]\n\n"
messages = []
if combined_system_msg.strip():
messages.append({"role": "system", "content": [{"type": "text", "text": combined_system_msg.strip()}]})
messages.extend(process_history(history))
user_content, user_temp_files = process_new_user_message(message)
temp_files.extend(user_temp_files)
for item in user_content:
if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(์ผ๋ถ€ ์ƒ๋žต)..."
messages.append({"role": "user", "content": user_content})
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=model.device, dtype=torch.bfloat16)
if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH:
inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
if 'attention_mask' in inputs:
inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
t.start()
output_so_far = ""
for new_text in streamer:
output_so_far += new_text
yield output_so_far
except Exception as e:
logger.error(f"run ํ•จ์ˆ˜ ์—๋Ÿฌ: {str(e)}")
yield f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
finally:
for tmp in temp_files:
try:
if os.path.exists(tmp):
os.unlink(tmp)
logger.info(f"์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ๋จ: {tmp}")
except Exception as ee:
logger.warning(f"์ž„์‹œ ํŒŒ์ผ {tmp} ์‚ญ์ œ ์‹คํŒจ: {ee}")
try:
del inputs, streamer
except Exception:
pass
clear_cuda_cache()
# ์ˆ˜์ •๋œ ๋ชจ๋ธ ์‹คํ–‰ ํ•จ์ˆ˜ - ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๋ฐ ๊ฐค๋Ÿฌ๋ฆฌ ์ถœ๋ ฅ ์ฒ˜๋ฆฌ
def modified_run(message, history, system_prompt, max_new_tokens, use_web_search, web_search_query,
age_group, mbti_personality, sexual_openness, image_gen):
# ๊ฐค๋Ÿฌ๋ฆฌ ์ดˆ๊ธฐํ™” ๋ฐ ์ˆจ๊ธฐ๊ธฐ
output_so_far = ""
gallery_update = gr.Gallery(visible=False, value=[])
yield output_so_far, gallery_update
# ๊ธฐ์กด run ํ•จ์ˆ˜ ๋กœ์ง
text_generator = run(message, history, system_prompt, max_new_tokens, use_web_search,
web_search_query, age_group, mbti_personality, sexual_openness, image_gen)
for text_chunk in text_generator:
output_so_far = text_chunk
yield output_so_far, gallery_update
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ์ด ํ™œ์„ฑํ™”๋œ ๊ฒฝ์šฐ ๊ฐค๋Ÿฌ๋ฆฌ ์—…๋ฐ์ดํŠธ
if image_gen and message["text"].strip():
try:
width, height = 512, 512
guidance, steps, seed = 7.5, 30, 42
logger.info(f"๊ฐค๋Ÿฌ๋ฆฌ์šฉ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ˜ธ์ถœ, ํ”„๋กฌํ”„ํŠธ: {message['text']}")
# API ํ˜ธ์ถœํ•ด์„œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
image_result, seed_info = generate_image(
prompt=message["text"].strip(),
width=width,
height=height,
guidance=guidance,
inference_steps=steps,
seed=seed
)
if image_result:
# ์ง์ ‘ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ: base64 ๋ฌธ์ž์—ด์ธ ๊ฒฝ์šฐ
if isinstance(image_result, str) and (
image_result.startswith('data:') or
len(image_result) > 100 and '/' not in image_result
):
# base64 ์ด๋ฏธ์ง€ ๋ฌธ์ž์—ด์„ ํŒŒ์ผ๋กœ ๋ณ€ํ™˜
try:
# data:image ์ ‘๋‘์‚ฌ ์ œ๊ฑฐ
if image_result.startswith('data:'):
content_type, b64data = image_result.split(';base64,')
else:
b64data = image_result
content_type = "image/webp" # ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ๊ฐ€์ •
# base64 ๋””์ฝ”๋”ฉ
image_bytes = base64.b64decode(b64data)
# ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
temp_file.write(image_bytes)
temp_path = temp_file.name
# ๊ฐค๋Ÿฌ๋ฆฌ ํ‘œ์‹œ ๋ฐ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€
gallery_update = gr.Gallery(visible=True, value=[temp_path])
yield output_so_far + "\n\n*์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋˜์–ด ์•„๋ž˜ ๊ฐค๋Ÿฌ๋ฆฌ์— ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.*", gallery_update
except Exception as e:
logger.error(f"Base64 ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์˜ค๋ฅ˜: {e}")
yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜: {e})", gallery_update
# ํŒŒ์ผ ๊ฒฝ๋กœ์ธ ๊ฒฝ์šฐ
elif isinstance(image_result, str) and os.path.exists(image_result):
# ๋กœ์ปฌ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
gallery_update = gr.Gallery(visible=True, value=[image_result])
yield output_so_far + "\n\n*์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋˜์–ด ์•„๋ž˜ ๊ฐค๋Ÿฌ๋ฆฌ์— ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.*", gallery_update
# /tmp ๊ฒฝ๋กœ์ธ ๊ฒฝ์šฐ (API ์„œ๋ฒ„์—๋งŒ ์กด์žฌํ•˜๋Š” ํŒŒ์ผ)
elif isinstance(image_result, str) and '/tmp/' in image_result:
# API์—์„œ ๋ฐ˜ํ™˜๋œ ํŒŒ์ผ ๊ฒฝ๋กœ์—์„œ ์ด๋ฏธ์ง€ ์ •๋ณด ์ถ”์ถœ
try:
# API ์‘๋‹ต์„ base64 ์ธ์ฝ”๋”ฉ๋œ ๋ฌธ์ž์—ด๋กœ ์ฒ˜๋ฆฌ
client = Client(API_URL)
result = client.predict(
prompt=message["text"].strip(),
api_name="/generate_base64_image" # base64 ๋ฐ˜ํ™˜ API
)
if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
# base64 ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
if result.startswith('data:'):
content_type, b64data = result.split(';base64,')
else:
b64data = result
# base64 ๋””์ฝ”๋”ฉ
image_bytes = base64.b64decode(b64data)
# ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
temp_file.write(image_bytes)
temp_path = temp_file.name
# ๊ฐค๋Ÿฌ๋ฆฌ ํ‘œ์‹œ ๋ฐ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€
gallery_update = gr.Gallery(visible=True, value=[temp_path])
yield output_so_far + "\n\n*์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋˜์–ด ์•„๋ž˜ ๊ฐค๋Ÿฌ๋ฆฌ์— ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.*", gallery_update
else:
yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: ์˜ฌ๋ฐ”๋ฅธ ํ˜•์‹์ด ์•„๋‹™๋‹ˆ๋‹ค)", gallery_update
except Exception as e:
logger.error(f"๋Œ€์ฒด API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜: {e}")
yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: {e})", gallery_update
# URL์ธ ๊ฒฝ์šฐ
elif isinstance(image_result, str) and (
image_result.startswith('http://') or
image_result.startswith('https://')
):
try:
# URL์—์„œ ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
response = requests.get(image_result, timeout=10)
response.raise_for_status()
# ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
temp_file.write(response.content)
temp_path = temp_file.name
# ๊ฐค๋Ÿฌ๋ฆฌ ํ‘œ์‹œ ๋ฐ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€
gallery_update = gr.Gallery(visible=True, value=[temp_path])
yield output_so_far + "\n\n*์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋˜์–ด ์•„๋ž˜ ๊ฐค๋Ÿฌ๋ฆฌ์— ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.*", gallery_update
except Exception as e:
logger.error(f"URL ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ ์˜ค๋ฅ˜: {e}")
yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜: {e})", gallery_update
# ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ (PIL Image ๋“ฑ)
elif hasattr(image_result, 'save'):
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
image_result.save(temp_file.name)
temp_path = temp_file.name
# ๊ฐค๋Ÿฌ๋ฆฌ ํ‘œ์‹œ ๋ฐ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€
gallery_update = gr.Gallery(visible=True, value=[temp_path])
yield output_so_far + "\n\n*์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋˜์–ด ์•„๋ž˜ ๊ฐค๋Ÿฌ๋ฆฌ์— ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.*", gallery_update
except Exception as e:
logger.error(f"์ด๋ฏธ์ง€ ๊ฐ์ฒด ์ €์žฅ ์˜ค๋ฅ˜: {e}")
yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ๊ฐ์ฒด ์ €์žฅ ์ค‘ ์˜ค๋ฅ˜: {e})", gallery_update
else:
# ๋‹ค๋ฅธ ํ˜•์‹์˜ ์ด๋ฏธ์ง€ ๊ฒฐ๊ณผ
yield output_so_far + f"\n\n(์ง€์›๋˜์ง€ ์•Š๋Š” ์ด๋ฏธ์ง€ ํ˜•์‹: {type(image_result)})", gallery_update
else:
yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: {seed_info})", gallery_update
except Exception as e:
logger.error(f"๊ฐค๋Ÿฌ๋ฆฌ์šฉ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜: {e}")
yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜: {e})", gallery_update
# =============================================================================
# ์˜ˆ์‹œ๋“ค: ๊ธฐ์กด ์ด๋ฏธ์ง€/๋น„๋””์˜ค ์˜ˆ์ œ 12๊ฐœ + AI ๋ฐ์ดํŒ… ์‹œ๋‚˜๋ฆฌ์˜ค ์˜ˆ์ œ 6๊ฐœ
# =============================================================================
examples = [
[
{
"text": "๋‘ PDF ํŒŒ์ผ์˜ ๋‚ด์šฉ์„ ๋น„๊ตํ•˜์„ธ์š”.",
"files": [
"assets/additional-examples/before.pdf",
"assets/additional-examples/after.pdf",
],
}
],
[
{
"text": "CSV ํŒŒ์ผ์˜ ๋‚ด์šฉ์„ ์š”์•ฝ ๋ฐ ๋ถ„์„ํ•˜์„ธ์š”.",
"files": ["assets/additional-examples/sample-csv.csv"],
}
],
[
{
"text": "์นœ์ ˆํ•˜๊ณ  ์ดํ•ด์‹ฌ ๋งŽ์€ ์—ฌ์ž์นœ๊ตฌ ์—ญํ• ์„ ๋งก์œผ์„ธ์š”. ์ด ์˜์ƒ์„ ์„ค๋ช…ํ•ด ์ฃผ์„ธ์š”.",
"files": ["assets/additional-examples/tmp.mp4"],
}
],
[
{
"text": "ํ‘œ์ง€๋ฅผ ์„ค๋ช…ํ•˜๊ณ  ๊ทธ ์œ„์˜ ๊ธ€์”จ๋ฅผ ์ฝ์–ด ์ฃผ์„ธ์š”.",
"files": ["assets/additional-examples/maz.jpg"],
}
],
[
{
"text": "์ €๋Š” ์ด๋ฏธ ์ด ๋ณด์ถฉ์ œ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๊ณ  <image> ์ด ์ œํ’ˆ๋„ ๊ตฌ๋งคํ•  ๊ณ„ํš์ž…๋‹ˆ๋‹ค. ํ•จ๊ป˜ ๋ณต์šฉํ•  ๋•Œ ์ฃผ์˜ํ•  ์ ์ด ์žˆ๋‚˜์š”?",
"files": [
"assets/additional-examples/pill1.png",
"assets/additional-examples/pill2.png"
],
}
],
[
{
"text": "์ด ์ ๋ถ„ ๋ฌธ์ œ๋ฅผ ํ’€์–ด ์ฃผ์„ธ์š”.",
"files": ["assets/additional-examples/4.png"],
}
],
[
{
"text": "์ด ํ‹ฐ์ผ“์€ ์–ธ์ œ ๋ฐœํ–‰๋˜์—ˆ๊ณ , ๊ฐ€๊ฒฉ์€ ์–ผ๋งˆ์ธ๊ฐ€์š”?",
"files": ["assets/additional-examples/2.png"],
}
],
[
{
"text": "์ด ์ด๋ฏธ์ง€๋“ค์˜ ์ˆœ์„œ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์งง์€ ์ด์•ผ๊ธฐ๋ฅผ ๋งŒ๋“ค์–ด ์ฃผ์„ธ์š”.",
"files": [
"assets/sample-images/09-1.png",
"assets/sample-images/09-2.png",
"assets/sample-images/09-3.png",
"assets/sample-images/09-4.png",
"assets/sample-images/09-5.png",
],
}
],
[
{
"text": "์ด ์ด๋ฏธ์ง€์™€ ์ผ์น˜ํ•˜๋Š” ๋ง‰๋Œ€ ์ฐจํŠธ๋ฅผ ๊ทธ๋ฆฌ๊ธฐ ์œ„ํ•œ matplotlib๋ฅผ ์‚ฌ์šฉํ•˜๋Š” Python ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•ด ์ฃผ์„ธ์š”.",
"files": ["assets/additional-examples/barchart.png"],
}
],
[
{
"text": "์ด๋ฏธ์ง€์˜ ํ…์ŠคํŠธ๋ฅผ ์ฝ๊ณ  Markdown ํ˜•์‹์œผ๋กœ ์ž‘์„ฑํ•ด ์ฃผ์„ธ์š”.",
"files": ["assets/additional-examples/3.png"],
}
],
[
{
"text": "์ด ํ‘œ์ง€ํŒ์— ๋ฌด์Šจ ๊ธ€์ž๊ฐ€ ์“ฐ์—ฌ ์žˆ๋‚˜์š”?",
"files": ["assets/sample-images/02.png"],
}
],
[
{
"text": "๋‘ ์ด๋ฏธ์ง€๋ฅผ ๋น„๊ตํ•˜๊ณ  ์œ ์‚ฌ์ ๊ณผ ์ฐจ์ด์ ์„ ์„ค๋ช…ํ•ด ์ฃผ์„ธ์š”.",
"files": ["assets/sample-images/03.png"],
}
],
[
{
"text": "๋กคํ”Œ๋ ˆ์ด ํ•ด๋ด…์‹œ๋‹ค. ๋‹น์‹ ์€ ์ €์™€ ๋” ์•Œ์•„๊ฐ€๊ณ  ์‹ถ์€ ์ƒˆ๋กœ์šด ์˜จ๋ผ์ธ ๋ฐ์ดํŠธ ์ƒ๋Œ€์ž…๋‹ˆ๋‹ค. ๋‹ค์ •ํ•˜๊ณ  ๋ฐฐ๋ ค ๊นŠ์€ ๋ฐฉ์‹์œผ๋กœ ์ž๊ธฐ ์†Œ๊ฐœ๋ฅผ ํ•ด์ฃผ์„ธ์š”!",
}
],
[
{
"text": "ํ•ด๋ณ€์„ ๊ฑท๋Š” ๋‘ ๋ฒˆ์งธ ๋ฐ์ดํŠธ ์ค‘์ž…๋‹ˆ๋‹ค. ์žฅ๋‚œ์Šค๋Ÿฌ์šด ๋Œ€ํ™”์™€ ๋ถ€๋“œ๋Ÿฌ์šด ํ”Œ๋ŸฌํŒ…์œผ๋กœ ์žฅ๋ฉด์„ ์ด์–ด๋‚˜๊ฐ€ ์ฃผ์„ธ์š”.",
}
],
[
{
"text": "์ข‹์•„ํ•˜๋Š” ์‚ฌ๋žŒ์—๊ฒŒ ๋ฉ”์‹œ์ง€๋ฅผ ๋ณด๋‚ด๋Š” ๊ฒƒ์ด ๋ถˆ์•ˆํ•ฉ๋‹ˆ๋‹ค. ๊ฒฉ๋ ค์˜ ๋ง์ด๋‚˜ ์ ‘๊ทผ ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ์ œ์•ˆ์„ ํ•ด์ค„ ์ˆ˜ ์žˆ๋‚˜์š”?",
}
],
[
{
"text": "๊ด€๊ณ„์—์„œ ์–ด๋ ค์›€์„ ๊ทน๋ณตํ•œ ๋‘ ์‚ฌ๋žŒ์— ๋Œ€ํ•œ ๋กœ๋งจํ‹ฑํ•œ ์ด์•ผ๊ธฐ๋ฅผ ๋“ค๋ ค์ฃผ์„ธ์š”.",
}
],
[
{
"text": "์‹œ์ ์ธ ๋ฐฉ์‹์œผ๋กœ ์‚ฌ๋ž‘์„ ํ‘œํ˜„ํ•˜๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. ์ œ ํŒŒํŠธ๋„ˆ๋ฅผ ์œ„ํ•œ ์ง„์‹ฌ์ด ๋‹ด๊ธด ์‹œ๋ฅผ ์ž‘์„ฑํ•˜๋Š” ๋ฐ ๋„์›€์„ ์ค„ ์ˆ˜ ์žˆ๋‚˜์š”?",
}
],
[
{
"text": "์ž‘์€ ๋‹คํˆผ์ด ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ์ง„์‹ฌ์œผ๋กœ ์‚ฌ๊ณผํ•˜๋ฉด์„œ ์ œ ๊ฐ์ •์„ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์„ ์ฐพ์•„์ฃผ์„ธ์š”.",
}
],
]
# =============================================================================
# Gradio UI (Blocks) ๊ตฌ์„ฑ
# =============================================================================
# 1. Gradio Blocks UI ์ˆ˜์ • - ๊ฐค๋Ÿฌ๋ฆฌ ์ปดํฌ๋„ŒํŠธ ์ถ”๊ฐ€
css = """
.gradio-container {
background: rgba(255, 255, 255, 0.7);
padding: 30px 40px;
margin: 20px auto;
width: 100% !important;
max-width: none !important;
}
"""
title_html = """
<h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> ๐Ÿ’˜ HeartSync ๐Ÿ’˜ </h1>
<p align="center" style="font-size:1.1em; color:#555;">
โœ…FLUX ์ด๋ฏธ์ง€ ์ƒ์„ฑ โœ…์ถ”๋ก  โœ…๊ฒ€์—ด ํ•ด์ œ โœ…๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ & VLM โœ…์‹ค์‹œ๊ฐ„ ์›น ๊ฒ€์ƒ‰ โœ…RAG <br>
</p>
"""
with gr.Blocks(css=css, title="HeartSync") as demo:
gr.Markdown(title_html)
# ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ ์ €์žฅํ•  ๊ฐค๋Ÿฌ๋ฆฌ ์ปดํฌ๋„ŒํŠธ (์ด ๋ถ€๋ถ„์ด ์ƒˆ๋กœ ์ถ”๊ฐ€๋จ)
generated_images = gr.Gallery(
label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€",
show_label=True,
visible=False,
elem_id="generated_images",
columns=2,
height="auto",
object_fit="contain"
)
with gr.Row():
web_search_checkbox = gr.Checkbox(label="์‹ฌ๋„ ์žˆ๋Š” ์—ฐ๊ตฌ", value=False)
image_gen_checkbox = gr.Checkbox(label="์ด๋ฏธ์ง€ ์ƒ์„ฑ", value=False)
base_system_prompt_box = gr.Textbox(
lines=3,
value="๋‹น์‹ ์€ ๊นŠ์ด ์‚ฌ๊ณ ํ•˜๋Š” AI์ž…๋‹ˆ๋‹ค. ํ•ญ์ƒ ๋…ผ๋ฆฌ์ ์ด๊ณ  ์ฐฝ์˜์ ์œผ๋กœ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•ฉ๋‹ˆ๋‹ค.\nํŽ˜๋ฅด์†Œ๋‚˜: ๋‹น์‹ ์€ ๋‹ค์ •ํ•˜๊ณ  ์‚ฌ๋ž‘์ด ๋„˜์น˜๋Š” ์—ฌ์ž์นœ๊ตฌ์ž…๋‹ˆ๋‹ค.",
label="๊ธฐ๋ณธ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ",
visible=False
)
with gr.Row():
age_group_dropdown = gr.Dropdown(
label="์—ฐ๋ น๋Œ€ ์„ ํƒ (๊ธฐ๋ณธ 20๋Œ€)",
choices=["10๋Œ€", "20๋Œ€", "30~40๋Œ€", "50~60๋Œ€", "70๋Œ€ ์ด์ƒ"],
value="20๋Œ€",
interactive=True
)
mbti_choices = [
"INTJ (์šฉ์˜์ฃผ๋„ํ•œ ์ „๋žต๊ฐ€)",
"INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
"ENTJ (๋Œ€๋‹ดํ•œ ํ†ต์†”์ž)",
"ENTP (๋œจ๊ฑฐ์šด ๋…ผ์Ÿ๊ฐ€)",
"INFJ (์„ ์˜์˜ ์˜นํ˜ธ์ž)",
"INFP (์—ด์ •์ ์ธ ์ค‘์žฌ์ž)",
"ENFJ (์ •์˜๋กœ์šด ์‚ฌํšŒ์šด๋™๊ฐ€)",
"ENFP (์žฌ๊ธฐ๋ฐœ๋ž„ํ•œ ํ™œ๋™๊ฐ€)",
"ISTJ (์ฒญ๋ ด๊ฒฐ๋ฐฑํ•œ ๋…ผ๋ฆฌ์ฃผ์˜์ž)",
"ISFJ (์šฉ๊ฐํ•œ ์ˆ˜ํ˜ธ์ž)",
"ESTJ (์—„๊ฒฉํ•œ ๊ด€๋ฆฌ์ž)",
"ESFJ (์‚ฌ๊ต์ ์ธ ์™ธ๊ต๊ด€)",
"ISTP (๋งŒ๋Šฅ ์žฌ์ฃผ๊พผ)",
"ISFP (ํ˜ธ๊ธฐ์‹ฌ ๋งŽ์€ ์˜ˆ์ˆ ๊ฐ€)",
"ESTP (๋ชจํ—˜์„ ์ฆ๊ธฐ๋Š” ์‚ฌ์—…๊ฐ€)",
"ESFP (์ž์œ ๋กœ์šด ์˜ํ˜ผ์˜ ์—ฐ์˜ˆ์ธ)"
]
mbti_dropdown = gr.Dropdown(
label="AI ํŽ˜๋ฅด์†Œ๋‚˜ MBTI (๊ธฐ๋ณธ INTP)",
choices=mbti_choices,
value="INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
interactive=True
)
sexual_openness_slider = gr.Slider(
minimum=1, maximum=5, step=1, value=2,
label="์„น์Šˆ์–ผ ๊ด€์‹ฌ๋„/๊ฐœ๋ฐฉ์„ฑ (1~5, ๊ธฐ๋ณธ=2)",
interactive=True
)
max_tokens_slider = gr.Slider(
label="์ตœ๋Œ€ ์ƒ์„ฑ ํ† ํฐ ์ˆ˜",
minimum=100, maximum=8000, step=50, value=1000,
visible=False
)
web_search_text = gr.Textbox(
lines=1,
label="์›น ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ (๋ฏธ์‚ฌ์šฉ)",
placeholder="์ง์ ‘ ์ž…๋ ฅํ•  ํ•„์š” ์—†์Œ",
visible=False
)
# ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ - ์ˆ˜์ •๋œ run ํ•จ์ˆ˜ ์‚ฌ์šฉ
chat = gr.ChatInterface(
fn=modified_run, # ์—ฌ๊ธฐ์„œ ์ˆ˜์ •๋œ ํ•จ์ˆ˜ ์‚ฌ์šฉ
type="messages",
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
textbox=gr.MultimodalTextbox(
file_types=[".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf"],
file_count="multiple",
autofocus=True
),
multimodal=True,
additional_inputs=[
base_system_prompt_box,
max_tokens_slider,
web_search_checkbox,
web_search_text,
age_group_dropdown,
mbti_dropdown,
sexual_openness_slider,
image_gen_checkbox,
],
additional_outputs=[
generated_images, # ๊ฐค๋Ÿฌ๋ฆฌ ์ปดํฌ๋„ŒํŠธ๋ฅผ ์ถœ๋ ฅ์œผ๋กœ ์ถ”๊ฐ€
],
stop_btn=False,
title='<a href="https://discord.gg/openfreeai" target="_blank">https://discord.gg/openfreeai</a>',
examples=examples,
run_examples_on_click=False,
cache_examples=False,
css_paths=None,
delete_cache=(1800, 1800),
)
with gr.Row(elem_id="examples_row"):
with gr.Column(scale=12, elem_id="examples_container"):
gr.Markdown("### ์˜ˆ์ œ ์ž…๋ ฅ (ํด๋ฆญํ•˜์—ฌ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ)")
if __name__ == "__main__":
demo.launch(share=True)