fastvideogen / app.py
fantaxy's picture
Update app.py
b459565 verified
raw
history blame
34.5 kB
import sys
import subprocess
def install_required_packages():
packages = [
"git+https://github.com/black-forest-labs/diffusers",
"transformers>=4.25.1",
"safetensors>=0.3.1",
"accelerate>=0.16.0"
]
for package in packages:
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
except subprocess.CalledProcessError as e:
print(f"Error installing {package}: {e}")
raise
# ํ•„์š”ํ•œ ํŒจํ‚ค์ง€ ์„ค์น˜
install_required_packages()
import spaces
import argparse
import os
import time
from os import path
import shutil
from datetime import datetime
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
try:
from diffusers.pipelines.flux import FluxPipeline
except ImportError:
from diffusers import StableDiffusionPipeline as FluxPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from PIL import Image
from transformers import pipeline
import replicate
import logging
import requests
from pathlib import Path
import cv2
import numpy as np
import sys
import io
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# ์ƒ์ˆ˜ ๋ฐ ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
MAX_SEED = np.iinfo(np.int32).max
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
MODEL_PATH = "asset"
CACHE_PATH = path.join(path.dirname(path.abspath(__file__)), "models")
GALLERY_PATH = path.join(PERSISTENT_DIR, "gallery")
VIDEO_GALLERY_PATH = path.join(PERSISTENT_DIR, "video_gallery")
# API ํ‚ค ์„ค์ •
HF_TOKEN = os.getenv("HF_TOKEN")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5"
REPLICATE_API_TOKEN = os.getenv("API_KEY")
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ๋กœ๋“œ
SYSTEM_PROMPT_PATH = "assets/system_prompt_t2v.txt"
with open(SYSTEM_PROMPT_PATH, "r") as f:
SYSTEM_PROMPT = f.read()
# ๋””๋ ‰ํ† ๋ฆฌ ์ดˆ๊ธฐํ™”
def init_directories():
"""ํ•„์š”ํ•œ ๋””๋ ‰ํ† ๋ฆฌ๋“ค์„ ์ƒ์„ฑ"""
directories = [GALLERY_PATH, VIDEO_GALLERY_PATH, CACHE_PATH]
for directory in directories:
os.makedirs(directory, exist_ok=True)
logger.info(f"Directory initialized: {directory}")
# CUDA ์„ค์ •
def setup_cuda():
"""CUDA ๊ด€๋ จ ์„ค์ • ์ดˆ๊ธฐํ™”"""
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = False
torch.backends.cuda.preferred_blas_library = "cublas"
torch.set_float32_matmul_precision("highest")
logger.info("CUDA settings initialized")
# Model initialization
if not path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
try:
# FluxPipeline ์ดˆ๊ธฐํ™” ์‹œ๋„
model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
cache_dir=cache_path,
local_files_only=False
)
# LoRA ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ ๋ฐ ์ ์šฉ
lora_path = hf_hub_download(
"ByteDance/Hyper-SD",
"Hyper-FLUX.1-dev-8steps-lora.safetensors",
cache_dir=cache_path
)
if hasattr(pipe, 'load_lora_weights'):
pipe.load_lora_weights(lora_path)
pipe.fuse_lora(lora_scale=0.125)
# ๋””๋ฐ”์ด์Šค ์„ค์ •
pipe = pipe.to("cuda")
# ์•ˆ์ „์„ฑ ๊ฒ€์‚ฌ๊ธฐ ์„ค์ •
if hasattr(pipe, 'safety_checker'):
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker",
cache_dir=cache_path
)
logger.info("Model initialized successfully")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
# ๋ชจ๋ธ ๊ด€๋ฆฌ ํด๋ž˜์Šค
class ModelManager:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.models = {}
self.current_model = None
logger.info(f"ModelManager initialized with device: {self.device}")
def load_model(self, model_name):
"""๋ชจ๋ธ์„ ๋™์ ์œผ๋กœ ๋กœ๋“œ"""
if self.current_model == model_name and model_name in self.models:
return self.models[model_name]
# ํ˜„์žฌ ๋กœ๋“œ๋œ ๋ชจ๋ธ ์–ธ๋กœ๋“œ
self.unload_current_model()
logger.info(f"Loading model: {model_name}")
try:
if model_name == "flux":
model = self._load_flux_model()
elif model_name == "xora":
model = self._load_xora_model()
elif model_name == "clip":
model = self._load_clip_model()
else:
raise ValueError(f"Unknown model: {model_name}")
self.models[model_name] = model
self.current_model = model_name
return model
except Exception as e:
logger.error(f"Error loading model {model_name}: {str(e)}")
raise
def unload_current_model(self):
"""ํ˜„์žฌ ๋กœ๋“œ๋œ ๋ชจ๋ธ ์–ธ๋กœ๋“œ"""
if self.current_model:
logger.info(f"Unloading model: {self.current_model}")
if self.current_model in self.models:
del self.models[self.current_model]
self.current_model = None
torch.cuda.empty_cache()
gc.collect()
def _load_flux_model(self):
"""Flux ๋ชจ๋ธ ๋กœ๋“œ"""
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(
hf_hub_download(
"ByteDance/Hyper-SD",
"Hyper-FLUX.1-dev-8steps-lora.safetensors"
)
)
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device=self.device, dtype=torch.bfloat16)
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
return pipe
def _load_xora_model(self):
"""Xora ๋ชจ๋ธ ๋กœ๋“œ"""
if not path.exists(MODEL_PATH):
snapshot_download(
"Lightricks/LTX-Video",
revision='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc',
local_dir=MODEL_PATH,
repo_type="model",
token=HF_TOKEN
)
vae = load_vae(Path(MODEL_PATH) / "vae")
unet = load_unet(Path(MODEL_PATH) / "unet")
scheduler = load_scheduler(Path(MODEL_PATH) / "scheduler")
patchifier = SymmetricPatchifier(patch_size=1)
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="text_encoder"
).to(self.device)
tokenizer = T5Tokenizer.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="tokenizer"
)
return XoraVideoPipeline(
transformer=unet,
patchifier=patchifier,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
vae=vae
).to(self.device)
def _load_clip_model(self):
"""CLIP ๋ชจ๋ธ ๋กœ๋“œ"""
model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=MODEL_PATH
).to(self.device)
processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=MODEL_PATH
)
return {"model": model, "processor": processor}
# ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
@lru_cache(maxsize=None)
def get_translator():
"""๋ฒˆ์—ญ๊ธฐ๋ฅผ lazy loading์œผ๋กœ ์ดˆ๊ธฐํ™”"""
return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# OpenAI ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
@lru_cache(maxsize=None)
def get_openai_client():
"""OpenAI ํด๋ผ์ด์–ธํŠธ๋ฅผ lazy loading์œผ๋กœ ์ดˆ๊ธฐํ™”"""
return OpenAI(api_key=OPENAI_API_KEY)
# ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜๋“ค
class Timer:
"""์ž‘์—… ์‹œ๊ฐ„ ์ธก์ •์„ ์œ„ํ•œ ์ปจํ…์ŠคํŠธ ๋งค๋‹ˆ์ €"""
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
logger.info(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
logger.info(f"{self.method} took {str(round(end - self.start, 2))}s")
def process_prompt(prompt):
"""ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ (ํ•œ๊ธ€ ๋ฒˆ์—ญ ๋ฐ ํ•„ํ„ฐ๋ง)"""
if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in prompt):
translator = get_translator()
translated = translator(prompt)[0]['translation_text']
logger.info(f"Translated prompt: {translated}")
return translated
return prompt
def filter_prompt(prompt):
"""๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ ํ•„ํ„ฐ๋ง"""
inappropriate_keywords = [
"nude", "naked", "nsfw", "porn", "sex", "explicit", "adult",
"xxx", "erotic", "sensual", "seductive", "provocative",
"intimate", "violence", "gore", "blood", "death", "kill",
"murder", "torture", "drug", "suicide", "abuse", "hate",
"discrimination"
]
prompt_lower = prompt.lower()
for keyword in inappropriate_keywords:
if keyword in prompt_lower:
logger.warning(f"Inappropriate content detected: {keyword}")
return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
return True, prompt
def enhance_prompt(prompt, enhance_toggle):
"""GPT๋ฅผ ์‚ฌ์šฉํ•œ ํ”„๋กฌํ”„ํŠธ ๊ฐœ์„ """
if not enhance_toggle:
logger.info("Prompt enhancement disabled")
return prompt
try:
client = get_openai_client()
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
response = client.chat.completions.create(
model="gpt-4-mini",
messages=messages,
max_tokens=200,
)
enhanced_prompt = response.choices[0].message.content.strip()
logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
except Exception as e:
logger.error(f"Prompt enhancement failed: {str(e)}")
return prompt
def save_image(image, directory=GALLERY_PATH):
"""์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ์ €์žฅ"""
try:
os.makedirs(directory, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
random_suffix = os.urandom(4).hex()
filename = f"generated_{timestamp}_{random_suffix}.png"
filepath = os.path.join(directory, filename)
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
image.save(filepath, format='PNG', optimize=True, quality=100)
logger.info(f"Image saved: {filepath}")
return filepath
except Exception as e:
logger.error(f"Error saving image: {str(e)}")
return None
def add_watermark(video_path):
"""๋น„๋””์˜ค์— ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€"""
try:
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
text = "GiniGEN.AI"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = height * 0.05 / 30
thickness = 2
color = (255, 255, 255)
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
margin = int(height * 0.02)
x_pos = width - text_width - margin
y_pos = height - margin
output_path = os.path.join(VIDEO_GALLERY_PATH, f"watermarked_{os.path.basename(video_path)}")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
out.write(frame)
cap.release()
out.release()
logger.info(f"Video watermarked: {output_path}")
return output_path
except Exception as e:
logger.error(f"Error adding watermark: {str(e)}")
return video_path
def upload_to_catbox(file_path):
"""ํŒŒ์ผ์„ catbox.moe์— ์—…๋กœ๋“œ"""
try:
logger.info(f"Uploading file: {file_path}")
url = "https://catbox.moe/user/api.php"
file_extension = Path(file_path).suffix.lower()
supported_extensions = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.mp4': 'video/mp4'
}
if file_extension not in supported_extensions:
logger.error(f"Unsupported file type: {file_extension}")
return None
files = {
'fileToUpload': (
os.path.basename(file_path),
open(file_path, 'rb'),
supported_extensions[file_extension]
)
}
data = {
'reqtype': 'fileupload',
'userhash': CATBOX_USER_HASH
}
response = requests.post(url, files=files, data=data)
if response.status_code == 200 and response.text.startswith('http'):
logger.info(f"Upload successful: {response.text}")
return response.text
else:
raise Exception(f"Upload failed: {response.text}")
except Exception as e:
logger.error(f"Upload error: {str(e)}")
return None
# ๋ชจ๋ธ ๋งค๋‹ˆ์ € ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
model_manager = ModelManager()
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ด€๋ จ ์ƒ์ˆ˜ ๋ฐ ์„ค์ •
PRESET_OPTIONS = [
{"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
{"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
{"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
{"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100},
{"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200},
{"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300},
{"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80},
{"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120},
{"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64},
{"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90},
{"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64},
{"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100},
{"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97},
{"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160},
{"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200},
]
# ๋ฉ”์ธ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋“ค
@spaces.GPU(duration=90)
def generate_image(
prompt,
height,
width,
steps,
scales,
seed,
enhance_prompt_toggle=False,
progress=gr.Progress()
):
"""์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜"""
try:
# ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ
processed_prompt = process_prompt(prompt)
is_safe, filtered_prompt = filter_prompt(processed_prompt)
if not is_safe:
raise gr.Error("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
if enhance_prompt_toggle:
filtered_prompt = enhance_prompt(filtered_prompt, True)
# Flux ๋ชจ๋ธ ๋กœ๋“œ
pipe = model_manager.load_model("flux")
with Timer("Image generation"), torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
generated_image = pipe(
prompt=[filtered_prompt],
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width),
max_sequence_length=256
).images[0]
# ์ด๋ฏธ์ง€ ์ €์žฅ ๋ฐ ๋ฐ˜ํ™˜
saved_path = save_image(generated_image)
if saved_path is None:
raise gr.Error("์ด๋ฏธ์ง€ ์ €์žฅ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.")
return Image.open(saved_path)
except Exception as e:
logger.error(f"Image generation error: {str(e)}")
raise gr.Error(f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
finally:
model_manager.unload_current_model()
torch.cuda.empty_cache()
gc.collect()
@spaces.GPU(duration=90)
def generate_video_xora(
prompt,
enhance_prompt_toggle,
negative_prompt,
frame_rate,
seed,
num_inference_steps,
guidance_scale,
height,
width,
num_frames,
progress=gr.Progress()
):
"""Xora ๋น„๋””์˜ค ์ƒ์„ฑ ํ•จ์ˆ˜"""
try:
# ํ”„๋กฌํ”„ํŠธ ์ฒ˜๋ฆฌ
prompt = process_prompt(prompt)
negative_prompt = process_prompt(negative_prompt)
if len(prompt.strip()) < 50:
raise gr.Error("ํ”„๋กฌํ”„ํŠธ๋Š” ์ตœ์†Œ 50์ž ์ด์ƒ์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
prompt = enhance_prompt(prompt, enhance_prompt_toggle)
# Xora ๋ชจ๋ธ ๋กœ๋“œ
pipeline = model_manager.load_model("xora")
sample = {
"prompt": prompt,
"prompt_attention_mask": None,
"negative_prompt": negative_prompt,
"negative_prompt_attention_mask": None,
"media_items": None,
}
generator = torch.Generator(device="cuda").manual_seed(seed)
def progress_callback(step, timestep, kwargs):
progress((step + 1) / num_inference_steps)
with torch.no_grad():
images = pipeline(
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
guidance_scale=guidance_scale,
generator=generator,
output_type="pt",
height=height,
width=width,
num_frames=num_frames,
frame_rate=frame_rate,
**sample,
is_video=True,
vae_per_channel_normalize=True,
conditioning_method=ConditioningMethod.UNCONDITIONAL,
mixed_precision=True,
callback_on_step_end=progress_callback,
).images
# ๋น„๋””์˜ค ์ €์žฅ
output_path = os.path.join(VIDEO_GALLERY_PATH, f"generated_{int(time.time())}.mp4")
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
video_np = (video_np * 255).astype(np.uint8)
out = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*"mp4v"),
frame_rate,
(width, height)
)
for frame in video_np[..., ::-1]:
out.write(frame)
out.release()
# ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€
final_path = add_watermark(output_path)
return final_path
except Exception as e:
logger.error(f"Video generation error: {str(e)}")
raise gr.Error(f"๋น„๋””์˜ค ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
finally:
model_manager.unload_current_model()
torch.cuda.empty_cache()
gc.collect()
def generate_video_replicate(image, prompt):
"""Replicate API๋ฅผ ์‚ฌ์šฉํ•œ ๋น„๋””์˜ค ์ƒ์„ฑ ํ•จ์ˆ˜"""
try:
is_safe, filtered_prompt = filter_prompt(prompt)
if not is_safe:
raise gr.Error("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
if not image:
raise gr.Error("์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.")
# ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ
image_url = upload_to_catbox(image)
if not image_url:
raise gr.Error("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.")
# Replicate API ํ˜ธ์ถœ
client = replicate.Client(api_token=REPLICATE_API_TOKEN)
output = client.run(
"minimax/video-01-live",
input={
"prompt": filtered_prompt,
"first_frame_image": image_url
}
)
# ๊ฒฐ๊ณผ ๋น„๋””์˜ค ์ €์žฅ
output_path = os.path.join(VIDEO_GALLERY_PATH, f"replicate_{int(time.time())}.mp4")
if hasattr(output, 'read'):
with open(output_path, "wb") as f:
f.write(output.read())
elif isinstance(output, str):
response = requests.get(output)
with open(output_path, "wb") as f:
f.write(response.content)
# ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€
final_path = add_watermark(output_path)
return final_path
except Exception as e:
logger.error(f"Replicate video generation error: {str(e)}")
raise gr.Error(f"๋น„๋””์˜ค ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
@spaces.GPU
def process_and_save_image(height, width, steps, scales, prompt, seed):
is_safe, translated_prompt = process_prompt(prompt)
if not is_safe:
gr.Warning("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
return None, load_gallery()
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
try:
# ๋ชจ๋ธ ํ˜ธ์ถœ ๋ฐฉ์‹ ์ˆ˜์ •
if hasattr(pipe, '__call__'):
output = pipe(
prompt=[translated_prompt],
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width),
max_sequence_length=256
)
generated_image = output.images[0]
else:
generated_image = pipe.text2img(
prompt=translated_prompt,
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width)
)[0]
# ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๋ฐ ์ €์žฅ
if not isinstance(generated_image, Image.Image):
generated_image = Image.fromarray(generated_image)
if generated_image.mode != 'RGB':
generated_image = generated_image.convert('RGB')
img_byte_arr = io.BytesIO()
generated_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
saved_path = save_image(generated_image)
if saved_path is None:
logger.warning("Failed to save generated image")
return None, load_gallery()
return Image.open(io.BytesIO(img_byte_arr)), load_gallery()
except Exception as e:
logger.error(f"Error in image generation: {str(e)}")
return None, load_gallery()
# Gradio UI ์Šคํƒ€์ผ
css = """
.gradio-container {
font-family: 'Pretendard', 'Noto Sans KR', sans-serif !important;
}
.title {
text-align: center;
font-size: 2.5rem;
font-weight: bold;
color: #2a9d8f;
margin: 1rem 0;
padding: 1rem;
background: linear-gradient(to right, #264653, #2a9d8f);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.generate-btn {
background: linear-gradient(to right, #2a9d8f, #264653) !important;
border: none !important;
color: white !important;
font-weight: bold !important;
transition: all 0.3s ease !important;
}
.generate-btn:hover {
transform: translateY(-2px) !important;
box-shadow: 0 5px 15px rgba(42, 157, 143, 0.4) !important;
}
.gallery {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
gap: 1rem;
padding: 1rem;
}
.gallery img {
width: 100%;
height: auto;
border-radius: 8px;
transition: transform 0.3s ease;
}
.gallery img:hover {
transform: scale(1.05);
}
"""
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
def create_ui():
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.HTML('<div class="title">AI Image & Video Generator</div>')
with gr.Tabs():
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํƒญ
with gr.Tab("Image Generation"):
with gr.Row():
with gr.Column(scale=3):
img_prompt = gr.Textbox(
label="Image Description",
placeholder="์ด๋ฏธ์ง€ ์„ค๋ช…์„ ์ž…๋ ฅํ•˜์„ธ์š”... (ํ•œ๊ธ€ ์ž…๋ ฅ ๊ฐ€๋Šฅ)",
lines=3
)
img_enhance_toggle = Toggle(
label="Enhance Prompt",
value=False,
interactive=True,
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
img_height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
step=64,
value=768
)
img_width = gr.Slider(
label="Width",
minimum=256,
maximum=1024,
step=64,
value=768
)
with gr.Row():
steps = gr.Slider(
label="Inference Steps",
minimum=6,
maximum=25,
step=1,
value=8
)
scales = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=5.0,
step=0.1,
value=3.5
)
seed = gr.Number(
label="Seed",
value=random.randint(0, MAX_SEED),
precision=0
)
img_generate_btn = gr.Button(
"Generate Image",
variant="primary",
elem_classes=["generate-btn"]
)
with gr.Column(scale=4):
img_output = gr.Image(
label="Generated Image",
type="pil",
format="png"
)
img_gallery = gr.Gallery(
label="Image Gallery",
show_label=True,
elem_id="gallery",
columns=[4],
rows=[2],
height="auto",
object_fit="cover"
)
# Xora ๋น„๋””์˜ค ์ƒ์„ฑ ํƒญ
with gr.Tab("Xora Video Generation"):
with gr.Row():
with gr.Column(scale=3):
xora_prompt = gr.Textbox(
label="Video Description",
placeholder="๋น„๋””์˜ค ์„ค๋ช…์„ ์ž…๋ ฅํ•˜์„ธ์š”... (์ตœ์†Œ 50์ž)",
lines=5
)
xora_enhance_toggle = Toggle(
label="Enhance Prompt",
value=False
)
xora_negative_prompt = gr.Textbox(
label="Negative Prompt",
value="low quality, worst quality, deformed, distorted",
lines=2
)
xora_preset = gr.Dropdown(
choices=[p["label"] for p in PRESET_OPTIONS],
value="512x512, 160 frames",
label="Resolution Preset"
)
xora_frame_rate = gr.Slider(
label="Frame Rate",
minimum=6,
maximum=60,
step=1,
value=20
)
with gr.Accordion("Advanced Settings", open=False):
xora_seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=random.randint(0, MAX_SEED)
)
xora_steps = gr.Slider(
label="Inference Steps",
minimum=5,
maximum=150,
step=5,
value=40
)
xora_guidance = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=10.0,
step=0.1,
value=4.2
)
xora_generate_btn = gr.Button(
"Generate Video",
variant="primary",
elem_classes=["generate-btn"]
)
with gr.Column(scale=4):
xora_output = gr.Video(label="Generated Video")
xora_gallery = gr.Gallery(
label="Video Gallery",
show_label=True,
columns=[4],
rows=[2],
height="auto",
object_fit="cover"
)
# Replicate ๋น„๋””์˜ค ์ƒ์„ฑ ํƒญ
with gr.Tab("Image to Video"):
with gr.Row():
with gr.Column(scale=3):
upload_image = gr.Image(
type="filepath",
label="Upload First Frame Image"
)
replicate_prompt = gr.Textbox(
label="Video Description",
placeholder="๋น„๋””์˜ค ์„ค๋ช…์„ ์ž…๋ ฅํ•˜์„ธ์š”...",
lines=3
)
replicate_generate_btn = gr.Button(
"Generate Video",
variant="primary",
elem_classes=["generate-btn"]
)
with gr.Column(scale=4):
replicate_output = gr.Video(label="Generated Video")
replicate_gallery = gr.Gallery(
label="Video Gallery",
show_label=True,
columns=[4],
rows=[2],
height="auto",
object_fit="cover"
)
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ
img_generate_btn.click(
fn=generate_image,
inputs=[
img_prompt,
img_height,
img_width,
steps,
scales,
seed,
img_enhance_toggle
],
outputs=img_output
)
xora_generate_btn.click(
fn=generate_video_xora,
inputs=[
xora_prompt,
xora_enhance_toggle,
xora_negative_prompt,
xora_frame_rate,
xora_seed,
xora_steps,
xora_guidance,
img_height,
img_width,
gr.Slider(label="Number of Frames", value=60)
],
outputs=xora_output
)
replicate_generate_btn.click(
fn=generate_video_replicate,
inputs=[upload_image, replicate_prompt],
outputs=replicate_output
)
# ๊ฐค๋Ÿฌ๋ฆฌ ์ž๋™ ์—…๋ฐ์ดํŠธ
demo.load(lambda: None, None, [img_gallery, xora_gallery, replicate_gallery], every=30)
return demo
if __name__ == "__main__":
# ์ดˆ๊ธฐํ™”
init_directories()
setup_cuda()
# UI ์‹คํ–‰
demo = create_ui()
demo.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(
share=True,
show_api=False,
server_name="0.0.0.0",
server_port=7860,
debug=False
)