import spaces import argparse import os import time from os import path import shutil from datetime import datetime from safetensors.torch import load_file from huggingface_hub import hf_hub_download, snapshot_download import gradio as gr from gradio_toggle import Toggle import torch from diffusers import FluxPipeline from diffusers.pipelines.stable_diffusion import safety_checker from PIL import Image from transformers import pipeline, CLIPProcessor, CLIPModel, T5EncoderModel, T5Tokenizer import replicate import logging import requests from pathlib import Path import cv2 import numpy as np import sys import io import json import gc import csv from openai import OpenAI from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder from xora.models.transformers.transformer3d import Transformer3DModel from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier from xora.schedulers.rf import RectifiedFlowScheduler from xora.pipelines.pipeline_xora_video import XoraVideoPipeline from xora.utils.conditioning_method import ConditioningMethod from functools import lru_cache from diffusers.pipelines.flux import FluxPipeline # 로깅 설정 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 상수 및 환경 변수 설정 MAX_SEED = np.iinfo(np.int32).max PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".") MODEL_PATH = "asset" CACHE_PATH = path.join(path.dirname(path.abspath(__file__)), "models") GALLERY_PATH = path.join(PERSISTENT_DIR, "gallery") VIDEO_GALLERY_PATH = path.join(PERSISTENT_DIR, "video_gallery") # API 키 설정 HF_TOKEN = os.getenv("HF_TOKEN") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5" REPLICATE_API_TOKEN = os.getenv("API_KEY") # 시스템 프롬프트 로드 SYSTEM_PROMPT_PATH = "assets/system_prompt_t2v.txt" with open(SYSTEM_PROMPT_PATH, "r") as f: SYSTEM_PROMPT = f.read() # 디렉토리 초기화 def init_directories(): """필요한 디렉토리들을 생성""" directories = [GALLERY_PATH, VIDEO_GALLERY_PATH, CACHE_PATH] for directory in directories: os.makedirs(directory, exist_ok=True) logger.info(f"Directory initialized: {directory}") # CUDA 설정 def setup_cuda(): """CUDA 관련 설정 초기화""" torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.deterministic = False torch.backends.cuda.preferred_blas_library = "cublas" torch.set_float32_matmul_precision("highest") logger.info("CUDA settings initialized") # Model initialization if not path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) try: pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, cache_dir=cache_path ) lora_path = hf_hub_download( "ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors", cache_dir=cache_path ) pipe.load_lora_weights(lora_path) pipe.fuse_lora(lora_scale=0.125) pipe.to(device="cuda", dtype=torch.bfloat16) pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker", cache_dir=cache_path ) except Exception as e: logger.error(f"Error initializing FluxPipeline: {str(e)}") raise # 모델 관리 클래스 class ModelManager: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.models = {} self.current_model = None logger.info(f"ModelManager initialized with device: {self.device}") def load_model(self, model_name): """모델을 동적으로 로드""" if self.current_model == model_name and model_name in self.models: return self.models[model_name] # 현재 로드된 모델 언로드 self.unload_current_model() logger.info(f"Loading model: {model_name}") try: if model_name == "flux": model = self._load_flux_model() elif model_name == "xora": model = self._load_xora_model() elif model_name == "clip": model = self._load_clip_model() else: raise ValueError(f"Unknown model: {model_name}") self.models[model_name] = model self.current_model = model_name return model except Exception as e: logger.error(f"Error loading model {model_name}: {str(e)}") raise def unload_current_model(self): """현재 로드된 모델 언로드""" if self.current_model: logger.info(f"Unloading model: {self.current_model}") if self.current_model in self.models: del self.models[self.current_model] self.current_model = None torch.cuda.empty_cache() gc.collect() def _load_flux_model(self): """Flux 모델 로드""" pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 ) pipe.load_lora_weights( hf_hub_download( "ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors" ) ) pipe.fuse_lora(lora_scale=0.125) pipe.to(device=self.device, dtype=torch.bfloat16) pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker" ) return pipe def _load_xora_model(self): """Xora 모델 로드""" if not path.exists(MODEL_PATH): snapshot_download( "Lightricks/LTX-Video", revision='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc', local_dir=MODEL_PATH, repo_type="model", token=HF_TOKEN ) vae = load_vae(Path(MODEL_PATH) / "vae") unet = load_unet(Path(MODEL_PATH) / "unet") scheduler = load_scheduler(Path(MODEL_PATH) / "scheduler") patchifier = SymmetricPatchifier(patch_size=1) text_encoder = T5EncoderModel.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder" ).to(self.device) tokenizer = T5Tokenizer.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer" ) return XoraVideoPipeline( transformer=unet, patchifier=patchifier, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, vae=vae ).to(self.device) def _load_clip_model(self): """CLIP 모델 로드""" model = CLIPModel.from_pretrained( "openai/clip-vit-base-patch32", cache_dir=MODEL_PATH ).to(self.device) processor = CLIPProcessor.from_pretrained( "openai/clip-vit-base-patch32", cache_dir=MODEL_PATH ) return {"model": model, "processor": processor} # 번역기 초기화 @lru_cache(maxsize=None) def get_translator(): """번역기를 lazy loading으로 초기화""" return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") # OpenAI 클라이언트 초기화 @lru_cache(maxsize=None) def get_openai_client(): """OpenAI 클라이언트를 lazy loading으로 초기화""" return OpenAI(api_key=OPENAI_API_KEY) # 유틸리티 함수들 class Timer: """작업 시간 측정을 위한 컨텍스트 매니저""" def __init__(self, method_name="timed process"): self.method = method_name def __enter__(self): self.start = time.time() logger.info(f"{self.method} starts") def __exit__(self, exc_type, exc_val, exc_tb): end = time.time() logger.info(f"{self.method} took {str(round(end - self.start, 2))}s") def process_prompt(prompt): """프롬프트 전처리 (한글 번역 및 필터링)""" if any(ord('가') <= ord(char) <= ord('힣') for char in prompt): translator = get_translator() translated = translator(prompt)[0]['translation_text'] logger.info(f"Translated prompt: {translated}") return translated return prompt def filter_prompt(prompt): """부적절한 내용 필터링""" inappropriate_keywords = [ "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx", "erotic", "sensual", "seductive", "provocative", "intimate", "violence", "gore", "blood", "death", "kill", "murder", "torture", "drug", "suicide", "abuse", "hate", "discrimination" ] prompt_lower = prompt.lower() for keyword in inappropriate_keywords: if keyword in prompt_lower: logger.warning(f"Inappropriate content detected: {keyword}") return False, "부적절한 내용이 포함된 프롬프트입니다." return True, prompt def enhance_prompt(prompt, enhance_toggle): """GPT를 사용한 프롬프트 개선""" if not enhance_toggle: logger.info("Prompt enhancement disabled") return prompt try: client = get_openai_client() messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ] response = client.chat.completions.create( model="gpt-4-mini", messages=messages, max_tokens=200, ) enhanced_prompt = response.choices[0].message.content.strip() logger.info(f"Enhanced prompt: {enhanced_prompt}") return enhanced_prompt except Exception as e: logger.error(f"Prompt enhancement failed: {str(e)}") return prompt def save_image(image, directory=GALLERY_PATH): """생성된 이미지 저장""" try: os.makedirs(directory, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") random_suffix = os.urandom(4).hex() filename = f"generated_{timestamp}_{random_suffix}.png" filepath = os.path.join(directory, filename) if not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') image.save(filepath, format='PNG', optimize=True, quality=100) logger.info(f"Image saved: {filepath}") return filepath except Exception as e: logger.error(f"Error saving image: {str(e)}") return None def add_watermark(video_path): """비디오에 워터마크 추가""" try: cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) text = "GiniGEN.AI" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = height * 0.05 / 30 thickness = 2 color = (255, 255, 255) (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) margin = int(height * 0.02) x_pos = width - text_width - margin y_pos = height - margin output_path = os.path.join(VIDEO_GALLERY_PATH, f"watermarked_{os.path.basename(video_path)}") fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) while cap.isOpened(): ret, frame = cap.read() if not ret: break cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) out.write(frame) cap.release() out.release() logger.info(f"Video watermarked: {output_path}") return output_path except Exception as e: logger.error(f"Error adding watermark: {str(e)}") return video_path def upload_to_catbox(file_path): """파일을 catbox.moe에 업로드""" try: logger.info(f"Uploading file: {file_path}") url = "https://catbox.moe/user/api.php" file_extension = Path(file_path).suffix.lower() supported_extensions = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif', '.mp4': 'video/mp4' } if file_extension not in supported_extensions: logger.error(f"Unsupported file type: {file_extension}") return None files = { 'fileToUpload': ( os.path.basename(file_path), open(file_path, 'rb'), supported_extensions[file_extension] ) } data = { 'reqtype': 'fileupload', 'userhash': CATBOX_USER_HASH } response = requests.post(url, files=files, data=data) if response.status_code == 200 and response.text.startswith('http'): logger.info(f"Upload successful: {response.text}") return response.text else: raise Exception(f"Upload failed: {response.text}") except Exception as e: logger.error(f"Upload error: {str(e)}") return None # 모델 매니저 인스턴스 생성 model_manager = ModelManager() # Gradio 인터페이스 관련 상수 및 설정 PRESET_OPTIONS = [ {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41}, {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49}, {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57}, {"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100}, {"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200}, {"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300}, {"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80}, {"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120}, {"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64}, {"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90}, {"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64}, {"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100}, {"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97}, {"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160}, {"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200}, ] # 메인 처리 함수들 @spaces.GPU(duration=90) def generate_image( prompt, height, width, steps, scales, seed, enhance_prompt_toggle=False, progress=gr.Progress() ): """이미지 생성 함수""" try: # 프롬프트 전처리 processed_prompt = process_prompt(prompt) is_safe, filtered_prompt = filter_prompt(processed_prompt) if not is_safe: raise gr.Error("부적절한 내용이 포함된 프롬프트입니다.") if enhance_prompt_toggle: filtered_prompt = enhance_prompt(filtered_prompt, True) # Flux 모델 로드 pipe = model_manager.load_model("flux") with Timer("Image generation"), torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): generated_image = pipe( prompt=[filtered_prompt], generator=torch.Generator().manual_seed(int(seed)), num_inference_steps=int(steps), guidance_scale=float(scales), height=int(height), width=int(width), max_sequence_length=256 ).images[0] # 이미지 저장 및 반환 saved_path = save_image(generated_image) if saved_path is None: raise gr.Error("이미지 저장에 실패했습니다.") return Image.open(saved_path) except Exception as e: logger.error(f"Image generation error: {str(e)}") raise gr.Error(f"이미지 생성 중 오류가 발생했습니다: {str(e)}") finally: model_manager.unload_current_model() torch.cuda.empty_cache() gc.collect() @spaces.GPU(duration=90) def generate_video_xora( prompt, enhance_prompt_toggle, negative_prompt, frame_rate, seed, num_inference_steps, guidance_scale, height, width, num_frames, progress=gr.Progress() ): """Xora 비디오 생성 함수""" try: # 프롬프트 처리 prompt = process_prompt(prompt) negative_prompt = process_prompt(negative_prompt) if len(prompt.strip()) < 50: raise gr.Error("프롬프트는 최소 50자 이상이어야 합니다.") prompt = enhance_prompt(prompt, enhance_prompt_toggle) # Xora 모델 로드 pipeline = model_manager.load_model("xora") sample = { "prompt": prompt, "prompt_attention_mask": None, "negative_prompt": negative_prompt, "negative_prompt_attention_mask": None, "media_items": None, } generator = torch.Generator(device="cuda").manual_seed(seed) def progress_callback(step, timestep, kwargs): progress((step + 1) / num_inference_steps) with torch.no_grad(): images = pipeline( num_inference_steps=num_inference_steps, num_images_per_prompt=1, guidance_scale=guidance_scale, generator=generator, output_type="pt", height=height, width=width, num_frames=num_frames, frame_rate=frame_rate, **sample, is_video=True, vae_per_channel_normalize=True, conditioning_method=ConditioningMethod.UNCONDITIONAL, mixed_precision=True, callback_on_step_end=progress_callback, ).images # 비디오 저장 output_path = os.path.join(VIDEO_GALLERY_PATH, f"generated_{int(time.time())}.mp4") video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy() video_np = (video_np * 255).astype(np.uint8) out = cv2.VideoWriter( output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height) ) for frame in video_np[..., ::-1]: out.write(frame) out.release() # 워터마크 추가 final_path = add_watermark(output_path) return final_path except Exception as e: logger.error(f"Video generation error: {str(e)}") raise gr.Error(f"비디오 생성 중 오류가 발생했습니다: {str(e)}") finally: model_manager.unload_current_model() torch.cuda.empty_cache() gc.collect() def generate_video_replicate(image, prompt): """Replicate API를 사용한 비디오 생성 함수""" try: is_safe, filtered_prompt = filter_prompt(prompt) if not is_safe: raise gr.Error("부적절한 내용이 포함된 프롬프트입니다.") if not image: raise gr.Error("이미지를 업로드해주세요.") # 이미지 업로드 image_url = upload_to_catbox(image) if not image_url: raise gr.Error("이미지 업로드에 실패했습니다.") # Replicate API 호출 client = replicate.Client(api_token=REPLICATE_API_TOKEN) output = client.run( "minimax/video-01-live", input={ "prompt": filtered_prompt, "first_frame_image": image_url } ) # 결과 비디오 저장 output_path = os.path.join(VIDEO_GALLERY_PATH, f"replicate_{int(time.time())}.mp4") if hasattr(output, 'read'): with open(output_path, "wb") as f: f.write(output.read()) elif isinstance(output, str): response = requests.get(output) with open(output_path, "wb") as f: f.write(response.content) # 워터마크 추가 final_path = add_watermark(output_path) return final_path except Exception as e: logger.error(f"Replicate video generation error: {str(e)}") raise gr.Error(f"비디오 생성 중 오류가 발생했습니다: {str(e)}") # Gradio UI 스타일 css = """ .gradio-container { font-family: 'Pretendard', 'Noto Sans KR', sans-serif !important; } .title { text-align: center; font-size: 2.5rem; font-weight: bold; color: #2a9d8f; margin: 1rem 0; padding: 1rem; background: linear-gradient(to right, #264653, #2a9d8f); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .generate-btn { background: linear-gradient(to right, #2a9d8f, #264653) !important; border: none !important; color: white !important; font-weight: bold !important; transition: all 0.3s ease !important; } .generate-btn:hover { transform: translateY(-2px) !important; box-shadow: 0 5px 15px rgba(42, 157, 143, 0.4) !important; } .gallery { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 1rem; padding: 1rem; } .gallery img { width: 100%; height: auto; border-radius: 8px; transition: transform 0.3s ease; } .gallery img:hover { transform: scale(1.05); } """ # Gradio 인터페이스 구성 def create_ui(): with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: gr.HTML('
AI Image & Video Generator
') with gr.Tabs(): # 이미지 생성 탭 with gr.Tab("Image Generation"): with gr.Row(): with gr.Column(scale=3): img_prompt = gr.Textbox( label="Image Description", placeholder="이미지 설명을 입력하세요... (한글 입력 가능)", lines=3 ) img_enhance_toggle = Toggle( label="Enhance Prompt", value=False, interactive=True, ) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): img_height = gr.Slider( label="Height", minimum=256, maximum=1024, step=64, value=768 ) img_width = gr.Slider( label="Width", minimum=256, maximum=1024, step=64, value=768 ) with gr.Row(): steps = gr.Slider( label="Inference Steps", minimum=6, maximum=25, step=1, value=8 ) scales = gr.Slider( label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5 ) seed = gr.Number( label="Seed", value=random.randint(0, MAX_SEED), precision=0 ) img_generate_btn = gr.Button( "Generate Image", variant="primary", elem_classes=["generate-btn"] ) with gr.Column(scale=4): img_output = gr.Image( label="Generated Image", type="pil", format="png" ) img_gallery = gr.Gallery( label="Image Gallery", show_label=True, elem_id="gallery", columns=[4], rows=[2], height="auto", object_fit="cover" ) # Xora 비디오 생성 탭 with gr.Tab("Xora Video Generation"): with gr.Row(): with gr.Column(scale=3): xora_prompt = gr.Textbox( label="Video Description", placeholder="비디오 설명을 입력하세요... (최소 50자)", lines=5 ) xora_enhance_toggle = Toggle( label="Enhance Prompt", value=False ) xora_negative_prompt = gr.Textbox( label="Negative Prompt", value="low quality, worst quality, deformed, distorted", lines=2 ) xora_preset = gr.Dropdown( choices=[p["label"] for p in PRESET_OPTIONS], value="512x512, 160 frames", label="Resolution Preset" ) xora_frame_rate = gr.Slider( label="Frame Rate", minimum=6, maximum=60, step=1, value=20 ) with gr.Accordion("Advanced Settings", open=False): xora_seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=random.randint(0, MAX_SEED) ) xora_steps = gr.Slider( label="Inference Steps", minimum=5, maximum=150, step=5, value=40 ) xora_guidance = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.2 ) xora_generate_btn = gr.Button( "Generate Video", variant="primary", elem_classes=["generate-btn"] ) with gr.Column(scale=4): xora_output = gr.Video(label="Generated Video") xora_gallery = gr.Gallery( label="Video Gallery", show_label=True, columns=[4], rows=[2], height="auto", object_fit="cover" ) # Replicate 비디오 생성 탭 with gr.Tab("Image to Video"): with gr.Row(): with gr.Column(scale=3): upload_image = gr.Image( type="filepath", label="Upload First Frame Image" ) replicate_prompt = gr.Textbox( label="Video Description", placeholder="비디오 설명을 입력하세요...", lines=3 ) replicate_generate_btn = gr.Button( "Generate Video", variant="primary", elem_classes=["generate-btn"] ) with gr.Column(scale=4): replicate_output = gr.Video(label="Generated Video") replicate_gallery = gr.Gallery( label="Video Gallery", show_label=True, columns=[4], rows=[2], height="auto", object_fit="cover" ) # 이벤트 핸들러 연결 img_generate_btn.click( fn=generate_image, inputs=[ img_prompt, img_height, img_width, steps, scales, seed, img_enhance_toggle ], outputs=img_output ) xora_generate_btn.click( fn=generate_video_xora, inputs=[ xora_prompt, xora_enhance_toggle, xora_negative_prompt, xora_frame_rate, xora_seed, xora_steps, xora_guidance, img_height, img_width, gr.Slider(label="Number of Frames", value=60) ], outputs=xora_output ) replicate_generate_btn.click( fn=generate_video_replicate, inputs=[upload_image, replicate_prompt], outputs=replicate_output ) # 갤러리 자동 업데이트 demo.load(lambda: None, None, [img_gallery, xora_gallery, replicate_gallery], every=30) return demo if __name__ == "__main__": # 초기화 init_directories() setup_cuda() # UI 실행 demo = create_ui() demo.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch( share=True, show_api=False, server_name="0.0.0.0", server_port=7860, debug=False )