Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import argparse | |
import os | |
import time | |
from os import path | |
import shutil | |
from datetime import datetime | |
from safetensors.torch import load_file | |
from huggingface_hub import hf_hub_download, snapshot_download | |
import gradio as gr | |
from gradio_toggle import Toggle | |
import torch | |
from diffusers import FluxPipeline | |
from diffusers.pipelines.stable_diffusion import safety_checker | |
from PIL import Image | |
from transformers import pipeline, CLIPProcessor, CLIPModel, T5EncoderModel, T5Tokenizer | |
import replicate | |
import logging | |
import requests | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import sys | |
import io | |
import json | |
import gc | |
import csv | |
from openai import OpenAI | |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
from xora.models.transformers.transformer3d import Transformer3DModel | |
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
from xora.schedulers.rf import RectifiedFlowScheduler | |
from xora.pipelines.pipeline_xora_video import XoraVideoPipeline | |
from xora.utils.conditioning_method import ConditioningMethod | |
from functools import lru_cache | |
from diffusers.pipelines.flux import FluxPipeline | |
# ๋ก๊น ์ค์ | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# ์์ ๋ฐ ํ๊ฒฝ ๋ณ์ ์ค์ | |
MAX_SEED = np.iinfo(np.int32).max | |
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".") | |
MODEL_PATH = "asset" | |
CACHE_PATH = path.join(path.dirname(path.abspath(__file__)), "models") | |
GALLERY_PATH = path.join(PERSISTENT_DIR, "gallery") | |
VIDEO_GALLERY_PATH = path.join(PERSISTENT_DIR, "video_gallery") | |
# API ํค ์ค์ | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5" | |
REPLICATE_API_TOKEN = os.getenv("API_KEY") | |
# ์์คํ ํ๋กฌํํธ ๋ก๋ | |
SYSTEM_PROMPT_PATH = "assets/system_prompt_t2v.txt" | |
with open(SYSTEM_PROMPT_PATH, "r") as f: | |
SYSTEM_PROMPT = f.read() | |
# ๋๋ ํ ๋ฆฌ ์ด๊ธฐํ | |
def init_directories(): | |
"""ํ์ํ ๋๋ ํ ๋ฆฌ๋ค์ ์์ฑ""" | |
directories = [GALLERY_PATH, VIDEO_GALLERY_PATH, CACHE_PATH] | |
for directory in directories: | |
os.makedirs(directory, exist_ok=True) | |
logger.info(f"Directory initialized: {directory}") | |
# CUDA ์ค์ | |
def setup_cuda(): | |
"""CUDA ๊ด๋ จ ์ค์ ์ด๊ธฐํ""" | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
torch.backends.cudnn.allow_tf32 = False | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cuda.preferred_blas_library = "cublas" | |
torch.set_float32_matmul_precision("highest") | |
logger.info("CUDA settings initialized") | |
# Model initialization | |
if not path.exists(cache_path): | |
os.makedirs(cache_path, exist_ok=True) | |
try: | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch.bfloat16, | |
cache_dir=cache_path | |
) | |
lora_path = hf_hub_download( | |
"ByteDance/Hyper-SD", | |
"Hyper-FLUX.1-dev-8steps-lora.safetensors", | |
cache_dir=cache_path | |
) | |
pipe.load_lora_weights(lora_path) | |
pipe.fuse_lora(lora_scale=0.125) | |
pipe.to(device="cuda", dtype=torch.bfloat16) | |
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker", | |
cache_dir=cache_path | |
) | |
except Exception as e: | |
logger.error(f"Error initializing FluxPipeline: {str(e)}") | |
raise | |
# ๋ชจ๋ธ ๊ด๋ฆฌ ํด๋์ค | |
class ModelManager: | |
def __init__(self): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.models = {} | |
self.current_model = None | |
logger.info(f"ModelManager initialized with device: {self.device}") | |
def load_model(self, model_name): | |
"""๋ชจ๋ธ์ ๋์ ์ผ๋ก ๋ก๋""" | |
if self.current_model == model_name and model_name in self.models: | |
return self.models[model_name] | |
# ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ ์ธ๋ก๋ | |
self.unload_current_model() | |
logger.info(f"Loading model: {model_name}") | |
try: | |
if model_name == "flux": | |
model = self._load_flux_model() | |
elif model_name == "xora": | |
model = self._load_xora_model() | |
elif model_name == "clip": | |
model = self._load_clip_model() | |
else: | |
raise ValueError(f"Unknown model: {model_name}") | |
self.models[model_name] = model | |
self.current_model = model_name | |
return model | |
except Exception as e: | |
logger.error(f"Error loading model {model_name}: {str(e)}") | |
raise | |
def unload_current_model(self): | |
"""ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ ์ธ๋ก๋""" | |
if self.current_model: | |
logger.info(f"Unloading model: {self.current_model}") | |
if self.current_model in self.models: | |
del self.models[self.current_model] | |
self.current_model = None | |
torch.cuda.empty_cache() | |
gc.collect() | |
def _load_flux_model(self): | |
"""Flux ๋ชจ๋ธ ๋ก๋""" | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.load_lora_weights( | |
hf_hub_download( | |
"ByteDance/Hyper-SD", | |
"Hyper-FLUX.1-dev-8steps-lora.safetensors" | |
) | |
) | |
pipe.fuse_lora(lora_scale=0.125) | |
pipe.to(device=self.device, dtype=torch.bfloat16) | |
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
) | |
return pipe | |
def _load_xora_model(self): | |
"""Xora ๋ชจ๋ธ ๋ก๋""" | |
if not path.exists(MODEL_PATH): | |
snapshot_download( | |
"Lightricks/LTX-Video", | |
revision='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc', | |
local_dir=MODEL_PATH, | |
repo_type="model", | |
token=HF_TOKEN | |
) | |
vae = load_vae(Path(MODEL_PATH) / "vae") | |
unet = load_unet(Path(MODEL_PATH) / "unet") | |
scheduler = load_scheduler(Path(MODEL_PATH) / "scheduler") | |
patchifier = SymmetricPatchifier(patch_size=1) | |
text_encoder = T5EncoderModel.from_pretrained( | |
"PixArt-alpha/PixArt-XL-2-1024-MS", | |
subfolder="text_encoder" | |
).to(self.device) | |
tokenizer = T5Tokenizer.from_pretrained( | |
"PixArt-alpha/PixArt-XL-2-1024-MS", | |
subfolder="tokenizer" | |
) | |
return XoraVideoPipeline( | |
transformer=unet, | |
patchifier=patchifier, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
vae=vae | |
).to(self.device) | |
def _load_clip_model(self): | |
"""CLIP ๋ชจ๋ธ ๋ก๋""" | |
model = CLIPModel.from_pretrained( | |
"openai/clip-vit-base-patch32", | |
cache_dir=MODEL_PATH | |
).to(self.device) | |
processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-base-patch32", | |
cache_dir=MODEL_PATH | |
) | |
return {"model": model, "processor": processor} | |
# ๋ฒ์ญ๊ธฐ ์ด๊ธฐํ | |
def get_translator(): | |
"""๋ฒ์ญ๊ธฐ๋ฅผ lazy loading์ผ๋ก ์ด๊ธฐํ""" | |
return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
# OpenAI ํด๋ผ์ด์ธํธ ์ด๊ธฐํ | |
def get_openai_client(): | |
"""OpenAI ํด๋ผ์ด์ธํธ๋ฅผ lazy loading์ผ๋ก ์ด๊ธฐํ""" | |
return OpenAI(api_key=OPENAI_API_KEY) | |
# ์ ํธ๋ฆฌํฐ ํจ์๋ค | |
class Timer: | |
"""์์ ์๊ฐ ์ธก์ ์ ์ํ ์ปจํ ์คํธ ๋งค๋์ """ | |
def __init__(self, method_name="timed process"): | |
self.method = method_name | |
def __enter__(self): | |
self.start = time.time() | |
logger.info(f"{self.method} starts") | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
end = time.time() | |
logger.info(f"{self.method} took {str(round(end - self.start, 2))}s") | |
def process_prompt(prompt): | |
"""ํ๋กฌํํธ ์ ์ฒ๋ฆฌ (ํ๊ธ ๋ฒ์ญ ๋ฐ ํํฐ๋ง)""" | |
if any(ord('๊ฐ') <= ord(char) <= ord('ํฃ') for char in prompt): | |
translator = get_translator() | |
translated = translator(prompt)[0]['translation_text'] | |
logger.info(f"Translated prompt: {translated}") | |
return translated | |
return prompt | |
def filter_prompt(prompt): | |
"""๋ถ์ ์ ํ ๋ด์ฉ ํํฐ๋ง""" | |
inappropriate_keywords = [ | |
"nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", | |
"xxx", "erotic", "sensual", "seductive", "provocative", | |
"intimate", "violence", "gore", "blood", "death", "kill", | |
"murder", "torture", "drug", "suicide", "abuse", "hate", | |
"discrimination" | |
] | |
prompt_lower = prompt.lower() | |
for keyword in inappropriate_keywords: | |
if keyword in prompt_lower: | |
logger.warning(f"Inappropriate content detected: {keyword}") | |
return False, "๋ถ์ ์ ํ ๋ด์ฉ์ด ํฌํจ๋ ํ๋กฌํํธ์ ๋๋ค." | |
return True, prompt | |
def enhance_prompt(prompt, enhance_toggle): | |
"""GPT๋ฅผ ์ฌ์ฉํ ํ๋กฌํํธ ๊ฐ์ """ | |
if not enhance_toggle: | |
logger.info("Prompt enhancement disabled") | |
return prompt | |
try: | |
client = get_openai_client() | |
messages = [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": prompt}, | |
] | |
response = client.chat.completions.create( | |
model="gpt-4-mini", | |
messages=messages, | |
max_tokens=200, | |
) | |
enhanced_prompt = response.choices[0].message.content.strip() | |
logger.info(f"Enhanced prompt: {enhanced_prompt}") | |
return enhanced_prompt | |
except Exception as e: | |
logger.error(f"Prompt enhancement failed: {str(e)}") | |
return prompt | |
def save_image(image, directory=GALLERY_PATH): | |
"""์์ฑ๋ ์ด๋ฏธ์ง ์ ์ฅ""" | |
try: | |
os.makedirs(directory, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
random_suffix = os.urandom(4).hex() | |
filename = f"generated_{timestamp}_{random_suffix}.png" | |
filepath = os.path.join(directory, filename) | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image.save(filepath, format='PNG', optimize=True, quality=100) | |
logger.info(f"Image saved: {filepath}") | |
return filepath | |
except Exception as e: | |
logger.error(f"Error saving image: {str(e)}") | |
return None | |
def add_watermark(video_path): | |
"""๋น๋์ค์ ์ํฐ๋งํฌ ์ถ๊ฐ""" | |
try: | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
text = "GiniGEN.AI" | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = height * 0.05 / 30 | |
thickness = 2 | |
color = (255, 255, 255) | |
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) | |
margin = int(height * 0.02) | |
x_pos = width - text_width - margin | |
y_pos = height - margin | |
output_path = os.path.join(VIDEO_GALLERY_PATH, f"watermarked_{os.path.basename(video_path)}") | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) | |
out.write(frame) | |
cap.release() | |
out.release() | |
logger.info(f"Video watermarked: {output_path}") | |
return output_path | |
except Exception as e: | |
logger.error(f"Error adding watermark: {str(e)}") | |
return video_path | |
def upload_to_catbox(file_path): | |
"""ํ์ผ์ catbox.moe์ ์ ๋ก๋""" | |
try: | |
logger.info(f"Uploading file: {file_path}") | |
url = "https://catbox.moe/user/api.php" | |
file_extension = Path(file_path).suffix.lower() | |
supported_extensions = { | |
'.jpg': 'image/jpeg', | |
'.jpeg': 'image/jpeg', | |
'.png': 'image/png', | |
'.gif': 'image/gif', | |
'.mp4': 'video/mp4' | |
} | |
if file_extension not in supported_extensions: | |
logger.error(f"Unsupported file type: {file_extension}") | |
return None | |
files = { | |
'fileToUpload': ( | |
os.path.basename(file_path), | |
open(file_path, 'rb'), | |
supported_extensions[file_extension] | |
) | |
} | |
data = { | |
'reqtype': 'fileupload', | |
'userhash': CATBOX_USER_HASH | |
} | |
response = requests.post(url, files=files, data=data) | |
if response.status_code == 200 and response.text.startswith('http'): | |
logger.info(f"Upload successful: {response.text}") | |
return response.text | |
else: | |
raise Exception(f"Upload failed: {response.text}") | |
except Exception as e: | |
logger.error(f"Upload error: {str(e)}") | |
return None | |
# ๋ชจ๋ธ ๋งค๋์ ์ธ์คํด์ค ์์ฑ | |
model_manager = ModelManager() | |
# Gradio ์ธํฐํ์ด์ค ๊ด๋ จ ์์ ๋ฐ ์ค์ | |
PRESET_OPTIONS = [ | |
{"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41}, | |
{"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49}, | |
{"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57}, | |
{"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100}, | |
{"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200}, | |
{"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300}, | |
{"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80}, | |
{"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120}, | |
{"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64}, | |
{"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90}, | |
{"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64}, | |
{"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100}, | |
{"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97}, | |
{"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160}, | |
{"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200}, | |
] | |
# ๋ฉ์ธ ์ฒ๋ฆฌ ํจ์๋ค | |
def generate_image( | |
prompt, | |
height, | |
width, | |
steps, | |
scales, | |
seed, | |
enhance_prompt_toggle=False, | |
progress=gr.Progress() | |
): | |
"""์ด๋ฏธ์ง ์์ฑ ํจ์""" | |
try: | |
# ํ๋กฌํํธ ์ ์ฒ๋ฆฌ | |
processed_prompt = process_prompt(prompt) | |
is_safe, filtered_prompt = filter_prompt(processed_prompt) | |
if not is_safe: | |
raise gr.Error("๋ถ์ ์ ํ ๋ด์ฉ์ด ํฌํจ๋ ํ๋กฌํํธ์ ๋๋ค.") | |
if enhance_prompt_toggle: | |
filtered_prompt = enhance_prompt(filtered_prompt, True) | |
# Flux ๋ชจ๋ธ ๋ก๋ | |
pipe = model_manager.load_model("flux") | |
with Timer("Image generation"), torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
generated_image = pipe( | |
prompt=[filtered_prompt], | |
generator=torch.Generator().manual_seed(int(seed)), | |
num_inference_steps=int(steps), | |
guidance_scale=float(scales), | |
height=int(height), | |
width=int(width), | |
max_sequence_length=256 | |
).images[0] | |
# ์ด๋ฏธ์ง ์ ์ฅ ๋ฐ ๋ฐํ | |
saved_path = save_image(generated_image) | |
if saved_path is None: | |
raise gr.Error("์ด๋ฏธ์ง ์ ์ฅ์ ์คํจํ์ต๋๋ค.") | |
return Image.open(saved_path) | |
except Exception as e: | |
logger.error(f"Image generation error: {str(e)}") | |
raise gr.Error(f"์ด๋ฏธ์ง ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}") | |
finally: | |
model_manager.unload_current_model() | |
torch.cuda.empty_cache() | |
gc.collect() | |
def generate_video_xora( | |
prompt, | |
enhance_prompt_toggle, | |
negative_prompt, | |
frame_rate, | |
seed, | |
num_inference_steps, | |
guidance_scale, | |
height, | |
width, | |
num_frames, | |
progress=gr.Progress() | |
): | |
"""Xora ๋น๋์ค ์์ฑ ํจ์""" | |
try: | |
# ํ๋กฌํํธ ์ฒ๋ฆฌ | |
prompt = process_prompt(prompt) | |
negative_prompt = process_prompt(negative_prompt) | |
if len(prompt.strip()) < 50: | |
raise gr.Error("ํ๋กฌํํธ๋ ์ต์ 50์ ์ด์์ด์ด์ผ ํฉ๋๋ค.") | |
prompt = enhance_prompt(prompt, enhance_prompt_toggle) | |
# Xora ๋ชจ๋ธ ๋ก๋ | |
pipeline = model_manager.load_model("xora") | |
sample = { | |
"prompt": prompt, | |
"prompt_attention_mask": None, | |
"negative_prompt": negative_prompt, | |
"negative_prompt_attention_mask": None, | |
"media_items": None, | |
} | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
def progress_callback(step, timestep, kwargs): | |
progress((step + 1) / num_inference_steps) | |
with torch.no_grad(): | |
images = pipeline( | |
num_inference_steps=num_inference_steps, | |
num_images_per_prompt=1, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
output_type="pt", | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
frame_rate=frame_rate, | |
**sample, | |
is_video=True, | |
vae_per_channel_normalize=True, | |
conditioning_method=ConditioningMethod.UNCONDITIONAL, | |
mixed_precision=True, | |
callback_on_step_end=progress_callback, | |
).images | |
# ๋น๋์ค ์ ์ฅ | |
output_path = os.path.join(VIDEO_GALLERY_PATH, f"generated_{int(time.time())}.mp4") | |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy() | |
video_np = (video_np * 255).astype(np.uint8) | |
out = cv2.VideoWriter( | |
output_path, | |
cv2.VideoWriter_fourcc(*"mp4v"), | |
frame_rate, | |
(width, height) | |
) | |
for frame in video_np[..., ::-1]: | |
out.write(frame) | |
out.release() | |
# ์ํฐ๋งํฌ ์ถ๊ฐ | |
final_path = add_watermark(output_path) | |
return final_path | |
except Exception as e: | |
logger.error(f"Video generation error: {str(e)}") | |
raise gr.Error(f"๋น๋์ค ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}") | |
finally: | |
model_manager.unload_current_model() | |
torch.cuda.empty_cache() | |
gc.collect() | |
def generate_video_replicate(image, prompt): | |
"""Replicate API๋ฅผ ์ฌ์ฉํ ๋น๋์ค ์์ฑ ํจ์""" | |
try: | |
is_safe, filtered_prompt = filter_prompt(prompt) | |
if not is_safe: | |
raise gr.Error("๋ถ์ ์ ํ ๋ด์ฉ์ด ํฌํจ๋ ํ๋กฌํํธ์ ๋๋ค.") | |
if not image: | |
raise gr.Error("์ด๋ฏธ์ง๋ฅผ ์ ๋ก๋ํด์ฃผ์ธ์.") | |
# ์ด๋ฏธ์ง ์ ๋ก๋ | |
image_url = upload_to_catbox(image) | |
if not image_url: | |
raise gr.Error("์ด๋ฏธ์ง ์ ๋ก๋์ ์คํจํ์ต๋๋ค.") | |
# Replicate API ํธ์ถ | |
client = replicate.Client(api_token=REPLICATE_API_TOKEN) | |
output = client.run( | |
"minimax/video-01-live", | |
input={ | |
"prompt": filtered_prompt, | |
"first_frame_image": image_url | |
} | |
) | |
# ๊ฒฐ๊ณผ ๋น๋์ค ์ ์ฅ | |
output_path = os.path.join(VIDEO_GALLERY_PATH, f"replicate_{int(time.time())}.mp4") | |
if hasattr(output, 'read'): | |
with open(output_path, "wb") as f: | |
f.write(output.read()) | |
elif isinstance(output, str): | |
response = requests.get(output) | |
with open(output_path, "wb") as f: | |
f.write(response.content) | |
# ์ํฐ๋งํฌ ์ถ๊ฐ | |
final_path = add_watermark(output_path) | |
return final_path | |
except Exception as e: | |
logger.error(f"Replicate video generation error: {str(e)}") | |
raise gr.Error(f"๋น๋์ค ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}") | |
# Gradio UI ์คํ์ผ | |
css = """ | |
.gradio-container { | |
font-family: 'Pretendard', 'Noto Sans KR', sans-serif !important; | |
} | |
.title { | |
text-align: center; | |
font-size: 2.5rem; | |
font-weight: bold; | |
color: #2a9d8f; | |
margin: 1rem 0; | |
padding: 1rem; | |
background: linear-gradient(to right, #264653, #2a9d8f); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
} | |
.generate-btn { | |
background: linear-gradient(to right, #2a9d8f, #264653) !important; | |
border: none !important; | |
color: white !important; | |
font-weight: bold !important; | |
transition: all 0.3s ease !important; | |
} | |
.generate-btn:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 5px 15px rgba(42, 157, 143, 0.4) !important; | |
} | |
.gallery { | |
display: grid; | |
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); | |
gap: 1rem; | |
padding: 1rem; | |
} | |
.gallery img { | |
width: 100%; | |
height: auto; | |
border-radius: 8px; | |
transition: transform 0.3s ease; | |
} | |
.gallery img:hover { | |
transform: scale(1.05); | |
} | |
""" | |
# Gradio ์ธํฐํ์ด์ค ๊ตฌ์ฑ | |
def create_ui(): | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
gr.HTML('<div class="title">AI Image & Video Generator</div>') | |
with gr.Tabs(): | |
# ์ด๋ฏธ์ง ์์ฑ ํญ | |
with gr.Tab("Image Generation"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
img_prompt = gr.Textbox( | |
label="Image Description", | |
placeholder="์ด๋ฏธ์ง ์ค๋ช ์ ์ ๋ ฅํ์ธ์... (ํ๊ธ ์ ๋ ฅ ๊ฐ๋ฅ)", | |
lines=3 | |
) | |
img_enhance_toggle = Toggle( | |
label="Enhance Prompt", | |
value=False, | |
interactive=True, | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
img_height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=1024, | |
step=64, | |
value=768 | |
) | |
img_width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=1024, | |
step=64, | |
value=768 | |
) | |
with gr.Row(): | |
steps = gr.Slider( | |
label="Inference Steps", | |
minimum=6, | |
maximum=25, | |
step=1, | |
value=8 | |
) | |
scales = gr.Slider( | |
label="Guidance Scale", | |
minimum=0.0, | |
maximum=5.0, | |
step=0.1, | |
value=3.5 | |
) | |
seed = gr.Number( | |
label="Seed", | |
value=random.randint(0, MAX_SEED), | |
precision=0 | |
) | |
img_generate_btn = gr.Button( | |
"Generate Image", | |
variant="primary", | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(scale=4): | |
img_output = gr.Image( | |
label="Generated Image", | |
type="pil", | |
format="png" | |
) | |
img_gallery = gr.Gallery( | |
label="Image Gallery", | |
show_label=True, | |
elem_id="gallery", | |
columns=[4], | |
rows=[2], | |
height="auto", | |
object_fit="cover" | |
) | |
# Xora ๋น๋์ค ์์ฑ ํญ | |
with gr.Tab("Xora Video Generation"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
xora_prompt = gr.Textbox( | |
label="Video Description", | |
placeholder="๋น๋์ค ์ค๋ช ์ ์ ๋ ฅํ์ธ์... (์ต์ 50์)", | |
lines=5 | |
) | |
xora_enhance_toggle = Toggle( | |
label="Enhance Prompt", | |
value=False | |
) | |
xora_negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="low quality, worst quality, deformed, distorted", | |
lines=2 | |
) | |
xora_preset = gr.Dropdown( | |
choices=[p["label"] for p in PRESET_OPTIONS], | |
value="512x512, 160 frames", | |
label="Resolution Preset" | |
) | |
xora_frame_rate = gr.Slider( | |
label="Frame Rate", | |
minimum=6, | |
maximum=60, | |
step=1, | |
value=20 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
xora_seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=random.randint(0, MAX_SEED) | |
) | |
xora_steps = gr.Slider( | |
label="Inference Steps", | |
minimum=5, | |
maximum=150, | |
step=5, | |
value=40 | |
) | |
xora_guidance = gr.Slider( | |
label="Guidance Scale", | |
minimum=1.0, | |
maximum=10.0, | |
step=0.1, | |
value=4.2 | |
) | |
xora_generate_btn = gr.Button( | |
"Generate Video", | |
variant="primary", | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(scale=4): | |
xora_output = gr.Video(label="Generated Video") | |
xora_gallery = gr.Gallery( | |
label="Video Gallery", | |
show_label=True, | |
columns=[4], | |
rows=[2], | |
height="auto", | |
object_fit="cover" | |
) | |
# Replicate ๋น๋์ค ์์ฑ ํญ | |
with gr.Tab("Image to Video"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
upload_image = gr.Image( | |
type="filepath", | |
label="Upload First Frame Image" | |
) | |
replicate_prompt = gr.Textbox( | |
label="Video Description", | |
placeholder="๋น๋์ค ์ค๋ช ์ ์ ๋ ฅํ์ธ์...", | |
lines=3 | |
) | |
replicate_generate_btn = gr.Button( | |
"Generate Video", | |
variant="primary", | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(scale=4): | |
replicate_output = gr.Video(label="Generated Video") | |
replicate_gallery = gr.Gallery( | |
label="Video Gallery", | |
show_label=True, | |
columns=[4], | |
rows=[2], | |
height="auto", | |
object_fit="cover" | |
) | |
# ์ด๋ฒคํธ ํธ๋ค๋ฌ ์ฐ๊ฒฐ | |
img_generate_btn.click( | |
fn=generate_image, | |
inputs=[ | |
img_prompt, | |
img_height, | |
img_width, | |
steps, | |
scales, | |
seed, | |
img_enhance_toggle | |
], | |
outputs=img_output | |
) | |
xora_generate_btn.click( | |
fn=generate_video_xora, | |
inputs=[ | |
xora_prompt, | |
xora_enhance_toggle, | |
xora_negative_prompt, | |
xora_frame_rate, | |
xora_seed, | |
xora_steps, | |
xora_guidance, | |
img_height, | |
img_width, | |
gr.Slider(label="Number of Frames", value=60) | |
], | |
outputs=xora_output | |
) | |
replicate_generate_btn.click( | |
fn=generate_video_replicate, | |
inputs=[upload_image, replicate_prompt], | |
outputs=replicate_output | |
) | |
# ๊ฐค๋ฌ๋ฆฌ ์๋ ์ ๋ฐ์ดํธ | |
demo.load(lambda: None, None, [img_gallery, xora_gallery, replicate_gallery], every=30) | |
return demo | |
if __name__ == "__main__": | |
# ์ด๊ธฐํ | |
init_directories() | |
setup_cuda() | |
# UI ์คํ | |
demo = create_ui() | |
demo.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch( | |
share=True, | |
show_api=False, | |
server_name="0.0.0.0", | |
server_port=7860, | |
debug=False | |
) | |