import sys import subprocess def install_required_packages(): packages = [ "git+https://github.com/black-forest-labs/diffusers", "transformers>=4.25.1", "safetensors>=0.3.1", "accelerate>=0.16.0" ] for package in packages: try: subprocess.check_call([sys.executable, "-m", "pip", "install", package]) except subprocess.CalledProcessError as e: print(f"Error installing {package}: {e}") raise # 필요한 패키지 설치 install_required_packages() import spaces import argparse import os import time from os import path import shutil from datetime import datetime from safetensors.torch import load_file from huggingface_hub import hf_hub_download import gradio as gr import torch try: from diffusers.pipelines.flux import FluxPipeline except ImportError: from diffusers import StableDiffusionPipeline as FluxPipeline from diffusers.pipelines.stable_diffusion import safety_checker from PIL import Image from transformers import pipeline import replicate import logging import requests from pathlib import Path import cv2 import numpy as np import sys import io # 로깅 설정 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 상수 및 환경 변수 설정 MAX_SEED = np.iinfo(np.int32).max PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".") MODEL_PATH = "asset" CACHE_PATH = path.join(path.dirname(path.abspath(__file__)), "models") GALLERY_PATH = path.join(PERSISTENT_DIR, "gallery") VIDEO_GALLERY_PATH = path.join(PERSISTENT_DIR, "video_gallery") # API 키 설정 HF_TOKEN = os.getenv("HF_TOKEN") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5" REPLICATE_API_TOKEN = os.getenv("API_KEY") # 시스템 프롬프트 로드 SYSTEM_PROMPT_PATH = "assets/system_prompt_t2v.txt" with open(SYSTEM_PROMPT_PATH, "r") as f: SYSTEM_PROMPT = f.read() # 디렉토리 초기화 def init_directories(): """필요한 디렉토리들을 생성""" directories = [GALLERY_PATH, VIDEO_GALLERY_PATH, CACHE_PATH] for directory in directories: os.makedirs(directory, exist_ok=True) logger.info(f"Directory initialized: {directory}") # CUDA 설정 def setup_cuda(): """CUDA 관련 설정 초기화""" torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.deterministic = False torch.backends.cuda.preferred_blas_library = "cublas" torch.set_float32_matmul_precision("highest") logger.info("CUDA settings initialized") # Model initialization if not path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) try: # FluxPipeline 초기화 시도 model_id = "black-forest-labs/FLUX.1-dev" pipe = FluxPipeline.from_pretrained( model_id, torch_dtype=torch.bfloat16, cache_dir=cache_path, local_files_only=False ) # LoRA 가중치 다운로드 및 적용 lora_path = hf_hub_download( "ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors", cache_dir=cache_path ) if hasattr(pipe, 'load_lora_weights'): pipe.load_lora_weights(lora_path) pipe.fuse_lora(lora_scale=0.125) # 디바이스 설정 pipe = pipe.to("cuda") # 안전성 검사기 설정 if hasattr(pipe, 'safety_checker'): pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker", cache_dir=cache_path ) logger.info("Model initialized successfully") except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise # 모델 관리 클래스 class ModelManager: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.models = {} self.current_model = None logger.info(f"ModelManager initialized with device: {self.device}") def load_model(self, model_name): """모델을 동적으로 로드""" if self.current_model == model_name and model_name in self.models: return self.models[model_name] # 현재 로드된 모델 언로드 self.unload_current_model() logger.info(f"Loading model: {model_name}") try: if model_name == "flux": model = self._load_flux_model() elif model_name == "xora": model = self._load_xora_model() elif model_name == "clip": model = self._load_clip_model() else: raise ValueError(f"Unknown model: {model_name}") self.models[model_name] = model self.current_model = model_name return model except Exception as e: logger.error(f"Error loading model {model_name}: {str(e)}") raise def unload_current_model(self): """현재 로드된 모델 언로드""" if self.current_model: logger.info(f"Unloading model: {self.current_model}") if self.current_model in self.models: del self.models[self.current_model] self.current_model = None torch.cuda.empty_cache() gc.collect() def _load_flux_model(self): """Flux 모델 로드""" pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 ) pipe.load_lora_weights( hf_hub_download( "ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors" ) ) pipe.fuse_lora(lora_scale=0.125) pipe.to(device=self.device, dtype=torch.bfloat16) pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker" ) return pipe def _load_xora_model(self): """Xora 모델 로드""" if not path.exists(MODEL_PATH): snapshot_download( "Lightricks/LTX-Video", revision='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc', local_dir=MODEL_PATH, repo_type="model", token=HF_TOKEN ) vae = load_vae(Path(MODEL_PATH) / "vae") unet = load_unet(Path(MODEL_PATH) / "unet") scheduler = load_scheduler(Path(MODEL_PATH) / "scheduler") patchifier = SymmetricPatchifier(patch_size=1) text_encoder = T5EncoderModel.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder" ).to(self.device) tokenizer = T5Tokenizer.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer" ) return XoraVideoPipeline( transformer=unet, patchifier=patchifier, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, vae=vae ).to(self.device) def _load_clip_model(self): """CLIP 모델 로드""" model = CLIPModel.from_pretrained( "openai/clip-vit-base-patch32", cache_dir=MODEL_PATH ).to(self.device) processor = CLIPProcessor.from_pretrained( "openai/clip-vit-base-patch32", cache_dir=MODEL_PATH ) return {"model": model, "processor": processor} # 번역기 초기화 @lru_cache(maxsize=None) def get_translator(): """번역기를 lazy loading으로 초기화""" return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") # OpenAI 클라이언트 초기화 @lru_cache(maxsize=None) def get_openai_client(): """OpenAI 클라이언트를 lazy loading으로 초기화""" return OpenAI(api_key=OPENAI_API_KEY) # 유틸리티 함수들 class Timer: """작업 시간 측정을 위한 컨텍스트 매니저""" def __init__(self, method_name="timed process"): self.method = method_name def __enter__(self): self.start = time.time() logger.info(f"{self.method} starts") def __exit__(self, exc_type, exc_val, exc_tb): end = time.time() logger.info(f"{self.method} took {str(round(end - self.start, 2))}s") def process_prompt(prompt): """프롬프트 전처리 (한글 번역 및 필터링)""" if any(ord('가') <= ord(char) <= ord('힣') for char in prompt): translator = get_translator() translated = translator(prompt)[0]['translation_text'] logger.info(f"Translated prompt: {translated}") return translated return prompt def filter_prompt(prompt): """부적절한 내용 필터링""" inappropriate_keywords = [ "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx", "erotic", "sensual", "seductive", "provocative", "intimate", "violence", "gore", "blood", "death", "kill", "murder", "torture", "drug", "suicide", "abuse", "hate", "discrimination" ] prompt_lower = prompt.lower() for keyword in inappropriate_keywords: if keyword in prompt_lower: logger.warning(f"Inappropriate content detected: {keyword}") return False, "부적절한 내용이 포함된 프롬프트입니다." return True, prompt def enhance_prompt(prompt, enhance_toggle): """GPT를 사용한 프롬프트 개선""" if not enhance_toggle: logger.info("Prompt enhancement disabled") return prompt try: client = get_openai_client() messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ] response = client.chat.completions.create( model="gpt-4-mini", messages=messages, max_tokens=200, ) enhanced_prompt = response.choices[0].message.content.strip() logger.info(f"Enhanced prompt: {enhanced_prompt}") return enhanced_prompt except Exception as e: logger.error(f"Prompt enhancement failed: {str(e)}") return prompt def save_image(image, directory=GALLERY_PATH): """생성된 이미지 저장""" try: os.makedirs(directory, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") random_suffix = os.urandom(4).hex() filename = f"generated_{timestamp}_{random_suffix}.png" filepath = os.path.join(directory, filename) if not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') image.save(filepath, format='PNG', optimize=True, quality=100) logger.info(f"Image saved: {filepath}") return filepath except Exception as e: logger.error(f"Error saving image: {str(e)}") return None def add_watermark(video_path): """비디오에 워터마크 추가""" try: cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) text = "GiniGEN.AI" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = height * 0.05 / 30 thickness = 2 color = (255, 255, 255) (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) margin = int(height * 0.02) x_pos = width - text_width - margin y_pos = height - margin output_path = os.path.join(VIDEO_GALLERY_PATH, f"watermarked_{os.path.basename(video_path)}") fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) while cap.isOpened(): ret, frame = cap.read() if not ret: break cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) out.write(frame) cap.release() out.release() logger.info(f"Video watermarked: {output_path}") return output_path except Exception as e: logger.error(f"Error adding watermark: {str(e)}") return video_path def upload_to_catbox(file_path): """파일을 catbox.moe에 업로드""" try: logger.info(f"Uploading file: {file_path}") url = "https://catbox.moe/user/api.php" file_extension = Path(file_path).suffix.lower() supported_extensions = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif', '.mp4': 'video/mp4' } if file_extension not in supported_extensions: logger.error(f"Unsupported file type: {file_extension}") return None files = { 'fileToUpload': ( os.path.basename(file_path), open(file_path, 'rb'), supported_extensions[file_extension] ) } data = { 'reqtype': 'fileupload', 'userhash': CATBOX_USER_HASH } response = requests.post(url, files=files, data=data) if response.status_code == 200 and response.text.startswith('http'): logger.info(f"Upload successful: {response.text}") return response.text else: raise Exception(f"Upload failed: {response.text}") except Exception as e: logger.error(f"Upload error: {str(e)}") return None # 모델 매니저 인스턴스 생성 model_manager = ModelManager() # Gradio 인터페이스 관련 상수 및 설정 PRESET_OPTIONS = [ {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41}, {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49}, {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57}, {"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100}, {"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200}, {"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300}, {"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80}, {"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120}, {"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64}, {"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90}, {"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64}, {"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100}, {"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97}, {"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160}, {"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200}, ] # 메인 처리 함수들 @spaces.GPU(duration=90) def generate_image( prompt, height, width, steps, scales, seed, enhance_prompt_toggle=False, progress=gr.Progress() ): """이미지 생성 함수""" try: # 프롬프트 전처리 processed_prompt = process_prompt(prompt) is_safe, filtered_prompt = filter_prompt(processed_prompt) if not is_safe: raise gr.Error("부적절한 내용이 포함된 프롬프트입니다.") if enhance_prompt_toggle: filtered_prompt = enhance_prompt(filtered_prompt, True) # Flux 모델 로드 pipe = model_manager.load_model("flux") with Timer("Image generation"), torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): generated_image = pipe( prompt=[filtered_prompt], generator=torch.Generator().manual_seed(int(seed)), num_inference_steps=int(steps), guidance_scale=float(scales), height=int(height), width=int(width), max_sequence_length=256 ).images[0] # 이미지 저장 및 반환 saved_path = save_image(generated_image) if saved_path is None: raise gr.Error("이미지 저장에 실패했습니다.") return Image.open(saved_path) except Exception as e: logger.error(f"Image generation error: {str(e)}") raise gr.Error(f"이미지 생성 중 오류가 발생했습니다: {str(e)}") finally: model_manager.unload_current_model() torch.cuda.empty_cache() gc.collect() @spaces.GPU(duration=90) def generate_video_xora( prompt, enhance_prompt_toggle, negative_prompt, frame_rate, seed, num_inference_steps, guidance_scale, height, width, num_frames, progress=gr.Progress() ): """Xora 비디오 생성 함수""" try: # 프롬프트 처리 prompt = process_prompt(prompt) negative_prompt = process_prompt(negative_prompt) if len(prompt.strip()) < 50: raise gr.Error("프롬프트는 최소 50자 이상이어야 합니다.") prompt = enhance_prompt(prompt, enhance_prompt_toggle) # Xora 모델 로드 pipeline = model_manager.load_model("xora") sample = { "prompt": prompt, "prompt_attention_mask": None, "negative_prompt": negative_prompt, "negative_prompt_attention_mask": None, "media_items": None, } generator = torch.Generator(device="cuda").manual_seed(seed) def progress_callback(step, timestep, kwargs): progress((step + 1) / num_inference_steps) with torch.no_grad(): images = pipeline( num_inference_steps=num_inference_steps, num_images_per_prompt=1, guidance_scale=guidance_scale, generator=generator, output_type="pt", height=height, width=width, num_frames=num_frames, frame_rate=frame_rate, **sample, is_video=True, vae_per_channel_normalize=True, conditioning_method=ConditioningMethod.UNCONDITIONAL, mixed_precision=True, callback_on_step_end=progress_callback, ).images # 비디오 저장 output_path = os.path.join(VIDEO_GALLERY_PATH, f"generated_{int(time.time())}.mp4") video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy() video_np = (video_np * 255).astype(np.uint8) out = cv2.VideoWriter( output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height) ) for frame in video_np[..., ::-1]: out.write(frame) out.release() # 워터마크 추가 final_path = add_watermark(output_path) return final_path except Exception as e: logger.error(f"Video generation error: {str(e)}") raise gr.Error(f"비디오 생성 중 오류가 발생했습니다: {str(e)}") finally: model_manager.unload_current_model() torch.cuda.empty_cache() gc.collect() def generate_video_replicate(image, prompt): """Replicate API를 사용한 비디오 생성 함수""" try: is_safe, filtered_prompt = filter_prompt(prompt) if not is_safe: raise gr.Error("부적절한 내용이 포함된 프롬프트입니다.") if not image: raise gr.Error("이미지를 업로드해주세요.") # 이미지 업로드 image_url = upload_to_catbox(image) if not image_url: raise gr.Error("이미지 업로드에 실패했습니다.") # Replicate API 호출 client = replicate.Client(api_token=REPLICATE_API_TOKEN) output = client.run( "minimax/video-01-live", input={ "prompt": filtered_prompt, "first_frame_image": image_url } ) # 결과 비디오 저장 output_path = os.path.join(VIDEO_GALLERY_PATH, f"replicate_{int(time.time())}.mp4") if hasattr(output, 'read'): with open(output_path, "wb") as f: f.write(output.read()) elif isinstance(output, str): response = requests.get(output) with open(output_path, "wb") as f: f.write(response.content) # 워터마크 추가 final_path = add_watermark(output_path) return final_path except Exception as e: logger.error(f"Replicate video generation error: {str(e)}") raise gr.Error(f"비디오 생성 중 오류가 발생했습니다: {str(e)}") @spaces.GPU def process_and_save_image(height, width, steps, scales, prompt, seed): is_safe, translated_prompt = process_prompt(prompt) if not is_safe: gr.Warning("부적절한 내용이 포함된 프롬프트입니다.") return None, load_gallery() with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"): try: # 모델 호출 방식 수정 if hasattr(pipe, '__call__'): output = pipe( prompt=[translated_prompt], generator=torch.Generator().manual_seed(int(seed)), num_inference_steps=int(steps), guidance_scale=float(scales), height=int(height), width=int(width), max_sequence_length=256 ) generated_image = output.images[0] else: generated_image = pipe.text2img( prompt=translated_prompt, generator=torch.Generator().manual_seed(int(seed)), num_inference_steps=int(steps), guidance_scale=float(scales), height=int(height), width=int(width) )[0] # 이미지 처리 및 저장 if not isinstance(generated_image, Image.Image): generated_image = Image.fromarray(generated_image) if generated_image.mode != 'RGB': generated_image = generated_image.convert('RGB') img_byte_arr = io.BytesIO() generated_image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() saved_path = save_image(generated_image) if saved_path is None: logger.warning("Failed to save generated image") return None, load_gallery() return Image.open(io.BytesIO(img_byte_arr)), load_gallery() except Exception as e: logger.error(f"Error in image generation: {str(e)}") return None, load_gallery() # Gradio UI 스타일 css = """ .gradio-container { font-family: 'Pretendard', 'Noto Sans KR', sans-serif !important; } .title { text-align: center; font-size: 2.5rem; font-weight: bold; color: #2a9d8f; margin: 1rem 0; padding: 1rem; background: linear-gradient(to right, #264653, #2a9d8f); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .generate-btn { background: linear-gradient(to right, #2a9d8f, #264653) !important; border: none !important; color: white !important; font-weight: bold !important; transition: all 0.3s ease !important; } .generate-btn:hover { transform: translateY(-2px) !important; box-shadow: 0 5px 15px rgba(42, 157, 143, 0.4) !important; } .gallery { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 1rem; padding: 1rem; } .gallery img { width: 100%; height: auto; border-radius: 8px; transition: transform 0.3s ease; } .gallery img:hover { transform: scale(1.05); } """ # Gradio 인터페이스 구성 def create_ui(): with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: gr.HTML('