fantaxy commited on
Commit
ee64981
Β·
verified Β·
1 Parent(s): b459565

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -903
app.py CHANGED
@@ -1,438 +1,111 @@
1
- import sys
2
- import subprocess
3
-
4
- def install_required_packages():
5
- packages = [
6
- "git+https://github.com/black-forest-labs/diffusers",
7
- "transformers>=4.25.1",
8
- "safetensors>=0.3.1",
9
- "accelerate>=0.16.0"
10
- ]
11
- for package in packages:
12
- try:
13
- subprocess.check_call([sys.executable, "-m", "pip", "install", package])
14
- except subprocess.CalledProcessError as e:
15
- print(f"Error installing {package}: {e}")
16
- raise
17
-
18
- # ν•„μš”ν•œ νŒ¨ν‚€μ§€ μ„€μΉ˜
19
- install_required_packages()
20
-
21
  import spaces
22
- import argparse
23
- import os
24
- import time
25
- from os import path
26
- import shutil
27
- from datetime import datetime
28
- from safetensors.torch import load_file
29
- from huggingface_hub import hf_hub_download
30
  import gradio as gr
 
31
  import torch
32
- try:
33
- from diffusers.pipelines.flux import FluxPipeline
34
- except ImportError:
35
- from diffusers import StableDiffusionPipeline as FluxPipeline
36
- from diffusers.pipelines.stable_diffusion import safety_checker
37
- from PIL import Image
38
- from transformers import pipeline
39
- import replicate
40
- import logging
41
- import requests
42
  from pathlib import Path
43
- import cv2
 
44
  import numpy as np
45
- import sys
46
- import io
 
 
 
 
 
 
 
 
 
47
 
48
- # λ‘œκΉ… μ„€μ •
49
- logging.basicConfig(
50
- level=logging.INFO,
51
- format='%(asctime)s - %(levelname)s - %(message)s'
52
- )
53
- logger = logging.getLogger(__name__)
 
54
 
55
- # μƒμˆ˜ 및 ν™˜κ²½ λ³€μˆ˜ μ„€μ •
56
  MAX_SEED = np.iinfo(np.int32).max
57
- PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
58
- MODEL_PATH = "asset"
59
- CACHE_PATH = path.join(path.dirname(path.abspath(__file__)), "models")
60
- GALLERY_PATH = path.join(PERSISTENT_DIR, "gallery")
61
- VIDEO_GALLERY_PATH = path.join(PERSISTENT_DIR, "video_gallery")
62
-
63
- # API ν‚€ μ„€μ •
64
- HF_TOKEN = os.getenv("HF_TOKEN")
65
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
66
- CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5"
67
- REPLICATE_API_TOKEN = os.getenv("API_KEY")
68
-
69
- # μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ λ‘œλ“œ
70
- SYSTEM_PROMPT_PATH = "assets/system_prompt_t2v.txt"
71
- with open(SYSTEM_PROMPT_PATH, "r") as f:
72
- SYSTEM_PROMPT = f.read()
73
-
74
- # 디렉토리 μ΄ˆκΈ°ν™”
75
- def init_directories():
76
- """ν•„μš”ν•œ 디렉토리듀을 생성"""
77
- directories = [GALLERY_PATH, VIDEO_GALLERY_PATH, CACHE_PATH]
78
- for directory in directories:
79
- os.makedirs(directory, exist_ok=True)
80
- logger.info(f"Directory initialized: {directory}")
81
-
82
- # CUDA μ„€μ •
83
- def setup_cuda():
84
- """CUDA κ΄€λ ¨ μ„€μ • μ΄ˆκΈ°ν™”"""
85
- torch.backends.cuda.matmul.allow_tf32 = False
86
- torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
87
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
88
- torch.backends.cudnn.allow_tf32 = False
89
- torch.backends.cudnn.deterministic = False
90
- torch.backends.cuda.preferred_blas_library = "cublas"
91
- torch.set_float32_matmul_precision("highest")
92
- logger.info("CUDA settings initialized")
93
-
94
-
95
- # Model initialization
96
- if not path.exists(cache_path):
97
- os.makedirs(cache_path, exist_ok=True)
98
-
99
- try:
100
- # FluxPipeline μ΄ˆκΈ°ν™” μ‹œλ„
101
- model_id = "black-forest-labs/FLUX.1-dev"
102
- pipe = FluxPipeline.from_pretrained(
103
- model_id,
104
- torch_dtype=torch.bfloat16,
105
- cache_dir=cache_path,
106
- local_files_only=False
107
- )
108
-
109
- # LoRA κ°€μ€‘μΉ˜ λ‹€μš΄λ‘œλ“œ 및 적용
110
- lora_path = hf_hub_download(
111
- "ByteDance/Hyper-SD",
112
- "Hyper-FLUX.1-dev-8steps-lora.safetensors",
113
- cache_dir=cache_path
114
- )
115
-
116
- if hasattr(pipe, 'load_lora_weights'):
117
- pipe.load_lora_weights(lora_path)
118
- pipe.fuse_lora(lora_scale=0.125)
119
-
120
- # λ””λ°”μ΄μŠ€ μ„€μ •
121
- pipe = pipe.to("cuda")
122
-
123
- # μ•ˆμ „μ„± 검사기 μ„€μ •
124
- if hasattr(pipe, 'safety_checker'):
125
- pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
126
- "CompVis/stable-diffusion-safety-checker",
127
- cache_dir=cache_path
128
- )
129
-
130
- logger.info("Model initialized successfully")
131
- except Exception as e:
132
- logger.error(f"Error initializing model: {str(e)}")
133
- raise
134
-
135
- # λͺ¨λΈ 관리 클래슀
136
- class ModelManager:
137
- def __init__(self):
138
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
139
- self.models = {}
140
- self.current_model = None
141
- logger.info(f"ModelManager initialized with device: {self.device}")
142
-
143
- def load_model(self, model_name):
144
- """λͺ¨λΈμ„ λ™μ μœΌλ‘œ λ‘œλ“œ"""
145
- if self.current_model == model_name and model_name in self.models:
146
- return self.models[model_name]
147
-
148
- # ν˜„μž¬ λ‘œλ“œλœ λͺ¨λΈ μ–Έλ‘œλ“œ
149
- self.unload_current_model()
150
-
151
- logger.info(f"Loading model: {model_name}")
152
- try:
153
- if model_name == "flux":
154
- model = self._load_flux_model()
155
- elif model_name == "xora":
156
- model = self._load_xora_model()
157
- elif model_name == "clip":
158
- model = self._load_clip_model()
159
- else:
160
- raise ValueError(f"Unknown model: {model_name}")
161
-
162
- self.models[model_name] = model
163
- self.current_model = model_name
164
- return model
165
-
166
- except Exception as e:
167
- logger.error(f"Error loading model {model_name}: {str(e)}")
168
- raise
169
-
170
- def unload_current_model(self):
171
- """ν˜„μž¬ λ‘œλ“œλœ λͺ¨λΈ μ–Έλ‘œλ“œ"""
172
- if self.current_model:
173
- logger.info(f"Unloading model: {self.current_model}")
174
- if self.current_model in self.models:
175
- del self.models[self.current_model]
176
- self.current_model = None
177
- torch.cuda.empty_cache()
178
- gc.collect()
179
-
180
- def _load_flux_model(self):
181
- """Flux λͺ¨λΈ λ‘œλ“œ"""
182
- pipe = FluxPipeline.from_pretrained(
183
- "black-forest-labs/FLUX.1-dev",
184
- torch_dtype=torch.bfloat16
185
- )
186
- pipe.load_lora_weights(
187
- hf_hub_download(
188
- "ByteDance/Hyper-SD",
189
- "Hyper-FLUX.1-dev-8steps-lora.safetensors"
190
- )
191
- )
192
- pipe.fuse_lora(lora_scale=0.125)
193
- pipe.to(device=self.device, dtype=torch.bfloat16)
194
- pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
195
- "CompVis/stable-diffusion-safety-checker"
196
- )
197
- return pipe
198
-
199
- def _load_xora_model(self):
200
- """Xora λͺ¨λΈ λ‘œλ“œ"""
201
- if not path.exists(MODEL_PATH):
202
- snapshot_download(
203
- "Lightricks/LTX-Video",
204
- revision='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc',
205
- local_dir=MODEL_PATH,
206
- repo_type="model",
207
- token=HF_TOKEN
208
- )
209
-
210
- vae = load_vae(Path(MODEL_PATH) / "vae")
211
- unet = load_unet(Path(MODEL_PATH) / "unet")
212
- scheduler = load_scheduler(Path(MODEL_PATH) / "scheduler")
213
- patchifier = SymmetricPatchifier(patch_size=1)
214
- text_encoder = T5EncoderModel.from_pretrained(
215
- "PixArt-alpha/PixArt-XL-2-1024-MS",
216
- subfolder="text_encoder"
217
- ).to(self.device)
218
- tokenizer = T5Tokenizer.from_pretrained(
219
- "PixArt-alpha/PixArt-XL-2-1024-MS",
220
- subfolder="tokenizer"
221
- )
222
 
223
- return XoraVideoPipeline(
224
- transformer=unet,
225
- patchifier=patchifier,
226
- text_encoder=text_encoder,
227
- tokenizer=tokenizer,
228
- scheduler=scheduler,
229
- vae=vae
230
- ).to(self.device)
231
-
232
- def _load_clip_model(self):
233
- """CLIP λͺ¨λΈ λ‘œλ“œ"""
234
- model = CLIPModel.from_pretrained(
235
- "openai/clip-vit-base-patch32",
236
- cache_dir=MODEL_PATH
237
- ).to(self.device)
238
- processor = CLIPProcessor.from_pretrained(
239
- "openai/clip-vit-base-patch32",
240
- cache_dir=MODEL_PATH
241
- )
242
- return {"model": model, "processor": processor}
243
 
244
- # λ²ˆμ—­κΈ° μ΄ˆκΈ°ν™”
245
- @lru_cache(maxsize=None)
246
- def get_translator():
247
- """λ²ˆμ—­κΈ°λ₯Ό lazy loading으둜 μ΄ˆκΈ°ν™”"""
248
- return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
249
 
250
- # OpenAI ν΄λΌμ΄μ–ΈνŠΈ μ΄ˆκΈ°ν™”
251
- @lru_cache(maxsize=None)
252
- def get_openai_client():
253
- """OpenAI ν΄λΌμ΄μ–ΈνŠΈλ₯Ό lazy loading으둜 μ΄ˆκΈ°ν™”"""
254
- return OpenAI(api_key=OPENAI_API_KEY)
255
 
 
256
 
 
 
257
 
258
- # μœ ν‹Έλ¦¬ν‹° ν•¨μˆ˜λ“€
259
- class Timer:
260
- """μž‘μ—… μ‹œκ°„ 츑정을 μœ„ν•œ μ»¨ν…μŠ€νŠΈ λ§€λ‹ˆμ €"""
261
- def __init__(self, method_name="timed process"):
262
- self.method = method_name
263
-
264
- def __enter__(self):
265
- self.start = time.time()
266
- logger.info(f"{self.method} starts")
267
-
268
- def __exit__(self, exc_type, exc_val, exc_tb):
269
- end = time.time()
270
- logger.info(f"{self.method} took {str(round(end - self.start, 2))}s")
271
 
272
  def process_prompt(prompt):
273
- """ν”„λ‘¬ν”„νŠΈ μ „μ²˜λ¦¬ (ν•œκΈ€ λ²ˆμ—­ 및 필터링)"""
274
  if any(ord('κ°€') <= ord(char) <= ord('힣') for char in prompt):
275
- translator = get_translator()
276
  translated = translator(prompt)[0]['translation_text']
277
- logger.info(f"Translated prompt: {translated}")
278
  return translated
279
  return prompt
280
 
281
- def filter_prompt(prompt):
282
- """λΆ€μ μ ˆν•œ λ‚΄μš© 필터링"""
283
- inappropriate_keywords = [
284
- "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult",
285
- "xxx", "erotic", "sensual", "seductive", "provocative",
286
- "intimate", "violence", "gore", "blood", "death", "kill",
287
- "murder", "torture", "drug", "suicide", "abuse", "hate",
288
- "discrimination"
289
- ]
290
-
291
- prompt_lower = prompt.lower()
292
- for keyword in inappropriate_keywords:
293
- if keyword in prompt_lower:
294
- logger.warning(f"Inappropriate content detected: {keyword}")
295
- return False, "λΆ€μ μ ˆν•œ λ‚΄μš©μ΄ ν¬ν•¨λœ ν”„λ‘¬ν”„νŠΈμž…λ‹ˆλ‹€."
296
- return True, prompt
297
-
298
- def enhance_prompt(prompt, enhance_toggle):
299
- """GPTλ₯Ό μ‚¬μš©ν•œ ν”„λ‘¬ν”„νŠΈ κ°œμ„ """
300
- if not enhance_toggle:
301
- logger.info("Prompt enhancement disabled")
302
- return prompt
303
-
304
- try:
305
- client = get_openai_client()
306
- messages = [
307
- {"role": "system", "content": SYSTEM_PROMPT},
308
- {"role": "user", "content": prompt},
309
- ]
310
-
311
- response = client.chat.completions.create(
312
- model="gpt-4-mini",
313
- messages=messages,
314
- max_tokens=200,
315
- )
316
-
317
- enhanced_prompt = response.choices[0].message.content.strip()
318
- logger.info(f"Enhanced prompt: {enhanced_prompt}")
319
- return enhanced_prompt
320
- except Exception as e:
321
- logger.error(f"Prompt enhancement failed: {str(e)}")
322
- return prompt
323
-
324
- def save_image(image, directory=GALLERY_PATH):
325
- """μƒμ„±λœ 이미지 μ €μž₯"""
326
- try:
327
- os.makedirs(directory, exist_ok=True)
328
-
329
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
330
- random_suffix = os.urandom(4).hex()
331
- filename = f"generated_{timestamp}_{random_suffix}.png"
332
- filepath = os.path.join(directory, filename)
333
-
334
- if not isinstance(image, Image.Image):
335
- image = Image.fromarray(image)
336
-
337
- if image.mode != 'RGB':
338
- image = image.convert('RGB')
339
-
340
- image.save(filepath, format='PNG', optimize=True, quality=100)
341
- logger.info(f"Image saved: {filepath}")
342
- return filepath
343
- except Exception as e:
344
- logger.error(f"Error saving image: {str(e)}")
345
- return None
346
-
347
- def add_watermark(video_path):
348
- """λΉ„λ””μ˜€μ— μ›Œν„°λ§ˆν¬ μΆ”κ°€"""
349
- try:
350
- cap = cv2.VideoCapture(video_path)
351
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
352
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
353
- fps = int(cap.get(cv2.CAP_PROP_FPS))
354
-
355
- text = "GiniGEN.AI"
356
- font = cv2.FONT_HERSHEY_SIMPLEX
357
- font_scale = height * 0.05 / 30
358
- thickness = 2
359
- color = (255, 255, 255)
360
-
361
- (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
362
- margin = int(height * 0.02)
363
- x_pos = width - text_width - margin
364
- y_pos = height - margin
365
-
366
- output_path = os.path.join(VIDEO_GALLERY_PATH, f"watermarked_{os.path.basename(video_path)}")
367
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
368
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
369
-
370
- while cap.isOpened():
371
- ret, frame = cap.read()
372
- if not ret:
373
- break
374
- cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
375
- out.write(frame)
376
-
377
- cap.release()
378
- out.release()
379
- logger.info(f"Video watermarked: {output_path}")
380
- return output_path
381
-
382
- except Exception as e:
383
- logger.error(f"Error adding watermark: {str(e)}")
384
- return video_path
385
-
386
- def upload_to_catbox(file_path):
387
- """νŒŒμΌμ„ catbox.moe에 μ—…λ‘œλ“œ"""
388
- try:
389
- logger.info(f"Uploading file: {file_path}")
390
- url = "https://catbox.moe/user/api.php"
391
-
392
- file_extension = Path(file_path).suffix.lower()
393
- supported_extensions = {
394
- '.jpg': 'image/jpeg',
395
- '.jpeg': 'image/jpeg',
396
- '.png': 'image/png',
397
- '.gif': 'image/gif',
398
- '.mp4': 'video/mp4'
399
- }
400
-
401
- if file_extension not in supported_extensions:
402
- logger.error(f"Unsupported file type: {file_extension}")
403
- return None
404
-
405
- files = {
406
- 'fileToUpload': (
407
- os.path.basename(file_path),
408
- open(file_path, 'rb'),
409
- supported_extensions[file_extension]
410
- )
411
- }
412
-
413
- data = {
414
- 'reqtype': 'fileupload',
415
- 'userhash': CATBOX_USER_HASH
416
- }
417
-
418
- response = requests.post(url, files=files, data=data)
419
-
420
- if response.status_code == 200 and response.text.startswith('http'):
421
- logger.info(f"Upload successful: {response.text}")
422
- return response.text
423
- else:
424
- raise Exception(f"Upload failed: {response.text}")
425
-
426
- except Exception as e:
427
- logger.error(f"Upload error: {str(e)}")
428
- return None
429
-
430
- # λͺ¨λΈ λ§€λ‹ˆμ € μΈμŠ€ν„΄μŠ€ 생성
431
- model_manager = ModelManager()
432
-
433
-
434
- # Gradio μΈν„°νŽ˜μ΄μŠ€ κ΄€λ ¨ μƒμˆ˜ 및 μ„€μ •
435
- PRESET_OPTIONS = [
436
  {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
437
  {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
438
  {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
@@ -450,99 +123,106 @@ PRESET_OPTIONS = [
450
  {"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200},
451
  ]
452
 
453
- # 메인 처리 ν•¨μˆ˜λ“€
454
- @spaces.GPU(duration=90)
455
- def generate_image(
456
- prompt,
457
- height,
458
- width,
459
- steps,
460
- scales,
461
- seed,
462
- enhance_prompt_toggle=False,
463
- progress=gr.Progress()
464
- ):
465
- """이미지 생성 ν•¨μˆ˜"""
466
- try:
467
- # ν”„λ‘¬ν”„νŠΈ μ „μ²˜λ¦¬
468
- processed_prompt = process_prompt(prompt)
469
- is_safe, filtered_prompt = filter_prompt(processed_prompt)
470
- if not is_safe:
471
- raise gr.Error("λΆ€μ μ ˆν•œ λ‚΄μš©μ΄ ν¬ν•¨λœ ν”„λ‘¬ν”„νŠΈμž…λ‹ˆλ‹€.")
472
-
473
- if enhance_prompt_toggle:
474
- filtered_prompt = enhance_prompt(filtered_prompt, True)
475
-
476
- # Flux λͺ¨λΈ λ‘œλ“œ
477
- pipe = model_manager.load_model("flux")
478
-
479
- with Timer("Image generation"), torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
480
- generated_image = pipe(
481
- prompt=[filtered_prompt],
482
- generator=torch.Generator().manual_seed(int(seed)),
483
- num_inference_steps=int(steps),
484
- guidance_scale=float(scales),
485
- height=int(height),
486
- width=int(width),
487
- max_sequence_length=256
488
- ).images[0]
489
-
490
- # 이미지 μ €μž₯ 및 λ°˜ν™˜
491
- saved_path = save_image(generated_image)
492
- if saved_path is None:
493
- raise gr.Error("οΏ½οΏ½οΏ½λ―Έμ§€ μ €μž₯에 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€.")
494
-
495
- return Image.open(saved_path)
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  except Exception as e:
498
- logger.error(f"Image generation error: {str(e)}")
499
- raise gr.Error(f"이미지 생성 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}")
500
- finally:
501
- model_manager.unload_current_model()
502
- torch.cuda.empty_cache()
503
- gc.collect()
504
 
505
  @spaces.GPU(duration=90)
506
- def generate_video_xora(
507
- prompt,
508
- enhance_prompt_toggle,
509
- negative_prompt,
510
- frame_rate,
511
- seed,
512
- num_inference_steps,
513
- guidance_scale,
514
- height,
515
- width,
516
- num_frames,
517
- progress=gr.Progress()
518
  ):
519
- """Xora λΉ„λ””μ˜€ 생성 ν•¨μˆ˜"""
520
- try:
521
- # ν”„λ‘¬ν”„νŠΈ 처리
522
- prompt = process_prompt(prompt)
523
- negative_prompt = process_prompt(negative_prompt)
524
-
525
- if len(prompt.strip()) < 50:
526
- raise gr.Error("ν”„λ‘¬ν”„νŠΈλŠ” μ΅œμ†Œ 50자 이상이어야 ν•©λ‹ˆλ‹€.")
527
-
528
- prompt = enhance_prompt(prompt, enhance_prompt_toggle)
529
 
530
- # Xora λͺ¨λΈ λ‘œλ“œ
531
- pipeline = model_manager.load_model("xora")
532
 
533
- sample = {
534
- "prompt": prompt,
535
- "prompt_attention_mask": None,
536
- "negative_prompt": negative_prompt,
537
- "negative_prompt_attention_mask": None,
538
- "media_items": None,
539
- }
540
 
541
- generator = torch.Generator(device="cuda").manual_seed(seed)
542
 
543
- def progress_callback(step, timestep, kwargs):
544
- progress((step + 1) / num_inference_steps)
545
 
 
546
  with torch.no_grad():
547
  images = pipeline(
548
  num_inference_steps=num_inference_steps,
@@ -559,427 +239,137 @@ def generate_video_xora(
559
  vae_per_channel_normalize=True,
560
  conditioning_method=ConditioningMethod.UNCONDITIONAL,
561
  mixed_precision=True,
562
- callback_on_step_end=progress_callback,
563
  ).images
 
 
 
 
 
 
 
 
564
 
565
- # λΉ„λ””μ˜€ μ €μž₯
566
- output_path = os.path.join(VIDEO_GALLERY_PATH, f"generated_{int(time.time())}.mp4")
567
- video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
568
- video_np = (video_np * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
- out = cv2.VideoWriter(
571
- output_path,
572
- cv2.VideoWriter_fourcc(*"mp4v"),
573
- frame_rate,
574
- (width, height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  )
576
-
577
- for frame in video_np[..., ::-1]:
578
- out.write(frame)
579
- out.release()
580
 
581
- # μ›Œν„°λ§ˆν¬ μΆ”κ°€
582
- final_path = add_watermark(output_path)
583
- return final_path
 
 
 
 
 
584
 
585
- except Exception as e:
586
- logger.error(f"Video generation error: {str(e)}")
587
- raise gr.Error(f"λΉ„λ””μ˜€ 생성 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}")
588
- finally:
589
- model_manager.unload_current_model()
590
- torch.cuda.empty_cache()
591
- gc.collect()
592
 
593
- def generate_video_replicate(image, prompt):
594
- """Replicate APIλ₯Ό μ‚¬μš©ν•œ λΉ„λ””μ˜€ 생성 ν•¨μˆ˜"""
595
- try:
596
- is_safe, filtered_prompt = filter_prompt(prompt)
597
- if not is_safe:
598
- raise gr.Error("λΆ€μ μ ˆν•œ λ‚΄μš©μ΄ ν¬ν•¨λœ ν”„λ‘¬ν”„νŠΈμž…λ‹ˆλ‹€.")
599
-
600
- if not image:
601
- raise gr.Error("이미지λ₯Ό μ—…λ‘œλ“œν•΄μ£Όμ„Έμš”.")
602
-
603
- # 이미지 μ—…λ‘œλ“œ
604
- image_url = upload_to_catbox(image)
605
- if not image_url:
606
- raise gr.Error("이미지 μ—…λ‘œλ“œμ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€.")
607
-
608
- # Replicate API 호좜
609
- client = replicate.Client(api_token=REPLICATE_API_TOKEN)
610
- output = client.run(
611
- "minimax/video-01-live",
612
- input={
613
- "prompt": filtered_prompt,
614
- "first_frame_image": image_url
615
- }
616
  )
617
 
618
- # κ²°κ³Ό λΉ„λ””μ˜€ μ €μž₯
619
- output_path = os.path.join(VIDEO_GALLERY_PATH, f"replicate_{int(time.time())}.mp4")
620
-
621
- if hasattr(output, 'read'):
622
- with open(output_path, "wb") as f:
623
- f.write(output.read())
624
- elif isinstance(output, str):
625
- response = requests.get(output)
626
- with open(output_path, "wb") as f:
627
- f.write(response.content)
628
-
629
- # μ›Œν„°λ§ˆν¬ μΆ”κ°€
630
- final_path = add_watermark(output_path)
631
- return final_path
632
 
633
- except Exception as e:
634
- logger.error(f"Replicate video generation error: {str(e)}")
635
- raise gr.Error(f"λΉ„λ””μ˜€ 생성 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}")
636
-
637
-
638
- @spaces.GPU
639
- def process_and_save_image(height, width, steps, scales, prompt, seed):
640
- is_safe, translated_prompt = process_prompt(prompt)
641
- if not is_safe:
642
- gr.Warning("λΆ€μ μ ˆν•œ λ‚΄μš©μ΄ ν¬ν•¨λœ ν”„λ‘¬ν”„νŠΈμž…λ‹ˆλ‹€.")
643
- return None, load_gallery()
644
-
645
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
646
- try:
647
- # λͺ¨λΈ 호좜 방식 μˆ˜μ •
648
- if hasattr(pipe, '__call__'):
649
- output = pipe(
650
- prompt=[translated_prompt],
651
- generator=torch.Generator().manual_seed(int(seed)),
652
- num_inference_steps=int(steps),
653
- guidance_scale=float(scales),
654
- height=int(height),
655
- width=int(width),
656
- max_sequence_length=256
657
- )
658
- generated_image = output.images[0]
659
- else:
660
- generated_image = pipe.text2img(
661
- prompt=translated_prompt,
662
- generator=torch.Generator().manual_seed(int(seed)),
663
- num_inference_steps=int(steps),
664
- guidance_scale=float(scales),
665
- height=int(height),
666
- width=int(width)
667
- )[0]
668
-
669
- # 이미지 처리 및 μ €μž₯
670
- if not isinstance(generated_image, Image.Image):
671
- generated_image = Image.fromarray(generated_image)
672
-
673
- if generated_image.mode != 'RGB':
674
- generated_image = generated_image.convert('RGB')
675
-
676
- img_byte_arr = io.BytesIO()
677
- generated_image.save(img_byte_arr, format='PNG')
678
- img_byte_arr = img_byte_arr.getvalue()
679
-
680
- saved_path = save_image(generated_image)
681
- if saved_path is None:
682
- logger.warning("Failed to save generated image")
683
- return None, load_gallery()
684
-
685
- return Image.open(io.BytesIO(img_byte_arr)), load_gallery()
686
- except Exception as e:
687
- logger.error(f"Error in image generation: {str(e)}")
688
- return None, load_gallery()
689
-
690
- # Gradio UI μŠ€νƒ€μΌ
691
- css = """
692
- .gradio-container {
693
- font-family: 'Pretendard', 'Noto Sans KR', sans-serif !important;
694
- }
695
-
696
- .title {
697
- text-align: center;
698
- font-size: 2.5rem;
699
- font-weight: bold;
700
- color: #2a9d8f;
701
- margin: 1rem 0;
702
- padding: 1rem;
703
- background: linear-gradient(to right, #264653, #2a9d8f);
704
- -webkit-background-clip: text;
705
- -webkit-text-fill-color: transparent;
706
- }
707
-
708
- .generate-btn {
709
- background: linear-gradient(to right, #2a9d8f, #264653) !important;
710
- border: none !important;
711
- color: white !important;
712
- font-weight: bold !important;
713
- transition: all 0.3s ease !important;
714
- }
715
-
716
- .generate-btn:hover {
717
- transform: translateY(-2px) !important;
718
- box-shadow: 0 5px 15px rgba(42, 157, 143, 0.4) !important;
719
- }
720
-
721
- .gallery {
722
- display: grid;
723
- grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
724
- gap: 1rem;
725
- padding: 1rem;
726
- }
727
-
728
- .gallery img {
729
- width: 100%;
730
- height: auto;
731
- border-radius: 8px;
732
- transition: transform 0.3s ease;
733
- }
734
-
735
- .gallery img:hover {
736
- transform: scale(1.05);
737
- }
738
- """
739
-
740
- # Gradio μΈν„°νŽ˜μ΄μŠ€ ꡬ성
741
- def create_ui():
742
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
743
- gr.HTML('<div class="title">AI Image & Video Generator</div>')
744
-
745
- with gr.Tabs():
746
- # 이미지 생성 νƒ­
747
- with gr.Tab("Image Generation"):
748
- with gr.Row():
749
- with gr.Column(scale=3):
750
- img_prompt = gr.Textbox(
751
- label="Image Description",
752
- placeholder="이미지 μ„€λͺ…을 μž…λ ₯ν•˜μ„Έμš”... (ν•œκΈ€ μž…λ ₯ κ°€λŠ₯)",
753
- lines=3
754
- )
755
-
756
- img_enhance_toggle = Toggle(
757
- label="Enhance Prompt",
758
- value=False,
759
- interactive=True,
760
- )
761
-
762
- with gr.Accordion("Advanced Settings", open=False):
763
- with gr.Row():
764
- img_height = gr.Slider(
765
- label="Height",
766
- minimum=256,
767
- maximum=1024,
768
- step=64,
769
- value=768
770
- )
771
- img_width = gr.Slider(
772
- label="Width",
773
- minimum=256,
774
- maximum=1024,
775
- step=64,
776
- value=768
777
- )
778
-
779
- with gr.Row():
780
- steps = gr.Slider(
781
- label="Inference Steps",
782
- minimum=6,
783
- maximum=25,
784
- step=1,
785
- value=8
786
- )
787
- scales = gr.Slider(
788
- label="Guidance Scale",
789
- minimum=0.0,
790
- maximum=5.0,
791
- step=0.1,
792
- value=3.5
793
- )
794
-
795
- seed = gr.Number(
796
- label="Seed",
797
- value=random.randint(0, MAX_SEED),
798
- precision=0
799
- )
800
-
801
- img_generate_btn = gr.Button(
802
- "Generate Image",
803
- variant="primary",
804
- elem_classes=["generate-btn"]
805
- )
806
-
807
- with gr.Column(scale=4):
808
- img_output = gr.Image(
809
- label="Generated Image",
810
- type="pil",
811
- format="png"
812
- )
813
- img_gallery = gr.Gallery(
814
- label="Image Gallery",
815
- show_label=True,
816
- elem_id="gallery",
817
- columns=[4],
818
- rows=[2],
819
- height="auto",
820
- object_fit="cover"
821
- )
822
-
823
- # Xora λΉ„λ””μ˜€ 생성 νƒ­
824
- with gr.Tab("Xora Video Generation"):
825
- with gr.Row():
826
- with gr.Column(scale=3):
827
- xora_prompt = gr.Textbox(
828
- label="Video Description",
829
- placeholder="λΉ„λ””μ˜€ μ„€λͺ…을 μž…λ ₯ν•˜μ„Έμš”... (μ΅œμ†Œ 50자)",
830
- lines=5
831
- )
832
-
833
- xora_enhance_toggle = Toggle(
834
- label="Enhance Prompt",
835
- value=False
836
- )
837
-
838
- xora_negative_prompt = gr.Textbox(
839
- label="Negative Prompt",
840
- value="low quality, worst quality, deformed, distorted",
841
- lines=2
842
- )
843
-
844
- xora_preset = gr.Dropdown(
845
- choices=[p["label"] for p in PRESET_OPTIONS],
846
- value="512x512, 160 frames",
847
- label="Resolution Preset"
848
- )
849
-
850
- xora_frame_rate = gr.Slider(
851
- label="Frame Rate",
852
- minimum=6,
853
- maximum=60,
854
- step=1,
855
- value=20
856
- )
857
-
858
- with gr.Accordion("Advanced Settings", open=False):
859
- xora_seed = gr.Slider(
860
- label="Seed",
861
- minimum=0,
862
- maximum=MAX_SEED,
863
- step=1,
864
- value=random.randint(0, MAX_SEED)
865
- )
866
- xora_steps = gr.Slider(
867
- label="Inference Steps",
868
- minimum=5,
869
- maximum=150,
870
- step=5,
871
- value=40
872
- )
873
- xora_guidance = gr.Slider(
874
- label="Guidance Scale",
875
- minimum=1.0,
876
- maximum=10.0,
877
- step=0.1,
878
- value=4.2
879
- )
880
-
881
- xora_generate_btn = gr.Button(
882
- "Generate Video",
883
- variant="primary",
884
- elem_classes=["generate-btn"]
885
- )
886
-
887
- with gr.Column(scale=4):
888
- xora_output = gr.Video(label="Generated Video")
889
- xora_gallery = gr.Gallery(
890
- label="Video Gallery",
891
- show_label=True,
892
- columns=[4],
893
- rows=[2],
894
- height="auto",
895
- object_fit="cover"
896
- )
897
-
898
- # Replicate λΉ„λ””μ˜€ 생성 νƒ­
899
- with gr.Tab("Image to Video"):
900
- with gr.Row():
901
- with gr.Column(scale=3):
902
- upload_image = gr.Image(
903
- type="filepath",
904
- label="Upload First Frame Image"
905
- )
906
- replicate_prompt = gr.Textbox(
907
- label="Video Description",
908
- placeholder="λΉ„λ””μ˜€ μ„€λͺ…을 μž…λ ₯ν•˜μ„Έμš”...",
909
- lines=3
910
- )
911
- replicate_generate_btn = gr.Button(
912
- "Generate Video",
913
- variant="primary",
914
- elem_classes=["generate-btn"]
915
- )
916
-
917
- with gr.Column(scale=4):
918
- replicate_output = gr.Video(label="Generated Video")
919
- replicate_gallery = gr.Gallery(
920
- label="Video Gallery",
921
- show_label=True,
922
- columns=[4],
923
- rows=[2],
924
- height="auto",
925
- object_fit="cover"
926
- )
927
-
928
- # 이벀트 ν•Έλ“€λŸ¬ μ—°κ²°
929
- img_generate_btn.click(
930
- fn=generate_image,
931
- inputs=[
932
- img_prompt,
933
- img_height,
934
- img_width,
935
- steps,
936
- scales,
937
- seed,
938
- img_enhance_toggle
939
- ],
940
- outputs=img_output
941
  )
942
 
943
- xora_generate_btn.click(
944
- fn=generate_video_xora,
945
- inputs=[
946
- xora_prompt,
947
- xora_enhance_toggle,
948
- xora_negative_prompt,
949
- xora_frame_rate,
950
- xora_seed,
951
- xora_steps,
952
- xora_guidance,
953
- img_height,
954
- img_width,
955
- gr.Slider(label="Number of Frames", value=60)
956
- ],
957
- outputs=xora_output
958
  )
959
 
960
- replicate_generate_btn.click(
961
- fn=generate_video_replicate,
962
- inputs=[upload_image, replicate_prompt],
963
- outputs=replicate_output
964
  )
965
 
966
- # 가러리 μžλ™ μ—…λ°μ΄νŠΈ
967
- demo.load(lambda: None, None, [img_gallery, xora_gallery, replicate_gallery], every=30)
968
-
969
- return demo
970
-
971
- if __name__ == "__main__":
972
- # μ΄ˆκΈ°ν™”
973
- init_directories()
974
- setup_cuda()
975
-
976
- # UI μ‹€ν–‰
977
- demo = create_ui()
978
- demo.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(
979
- share=True,
980
- show_api=False,
981
- server_name="0.0.0.0",
982
- server_port=7860,
983
- debug=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984
  )
985
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
+ from functools import lru_cache
 
 
 
 
 
 
 
3
  import gradio as gr
4
+ from gradio_toggle import Toggle
5
  import torch
6
+ from huggingface_hub import snapshot_download
7
+ from transformers import CLIPProcessor, CLIPModel, pipeline
8
+ import random
9
+ from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
10
+ from xora.models.transformers.transformer3d import Transformer3DModel
11
+ from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
12
+ from xora.schedulers.rf import RectifiedFlowScheduler
13
+ from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
14
+ from transformers import T5EncoderModel, T5Tokenizer
15
+ from xora.utils.conditioning_method import ConditioningMethod
16
  from pathlib import Path
17
+ import safetensors.torch
18
+ import json
19
  import numpy as np
20
+ import cv2
21
+ from PIL import Image
22
+ import tempfile
23
+ import os
24
+ import gc
25
+ import csv
26
+ from datetime import datetime
27
+ from openai import OpenAI
28
+
29
+ # ν•œκΈ€-μ˜μ–΄ λ²ˆμ—­κΈ° μ΄ˆκΈ°ν™”
30
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
31
 
32
+ torch.backends.cuda.matmul.allow_tf32 = False
33
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
34
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
35
+ torch.backends.cudnn.allow_tf32 = False
36
+ torch.backends.cudnn.deterministic = False
37
+ torch.backends.cuda.preferred_blas_library="cublas"
38
+ torch.set_float32_matmul_precision("highest")
39
 
 
40
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Load Hugging Face token if needed
43
+ hf_token = os.getenv("HF_TOKEN")
44
+ openai_api_key = os.getenv("OPENAI_API_KEY")
45
+ client = OpenAI(api_key=openai_api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ system_prompt_t2v_path = "assets/system_prompt_t2v.txt"
48
+ with open(system_prompt_t2v_path, "r") as f:
49
+ system_prompt_t2v = f.read()
 
 
50
 
51
+ # Set model download directory within Hugging Face Spaces
52
+ model_path = "asset"
 
 
 
53
 
54
+ commit_hash='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc'
55
 
56
+ if not os.path.exists(model_path):
57
+ snapshot_download("Lightricks/LTX-Video", revision=commit_hash, local_dir=model_path, repo_type="model", token=hf_token)
58
 
59
+ # Global variables to load components
60
+ vae_dir = Path(model_path) / "vae"
61
+ unet_dir = Path(model_path) / "unet"
62
+ scheduler_dir = Path(model_path) / "scheduler"
63
+
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+
66
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path).to(torch.device("cuda:0"))
67
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
 
 
 
 
68
 
69
  def process_prompt(prompt):
70
+ # ν•œκΈ€μ΄ ν¬ν•¨λ˜μ–΄ μžˆλŠ”μ§€ 확인
71
  if any(ord('κ°€') <= ord(char) <= ord('힣') for char in prompt):
72
+ # ν•œκΈ€μ„ μ˜μ–΄λ‘œ λ²ˆμ—­
73
  translated = translator(prompt)[0]['translation_text']
 
74
  return translated
75
  return prompt
76
 
77
+ def compute_clip_embedding(text=None):
78
+ inputs = clip_processor(text=text, return_tensors="pt", padding=True).to(device)
79
+ outputs = clip_model.get_text_features(**inputs)
80
+ embedding = outputs.detach().cpu().numpy().flatten().tolist()
81
+ return embedding
82
+
83
+ def load_vae(vae_dir):
84
+ vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
85
+ vae_config_path = vae_dir / "config.json"
86
+ with open(vae_config_path, "r") as f:
87
+ vae_config = json.load(f)
88
+ vae = CausalVideoAutoencoder.from_config(vae_config)
89
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
90
+ vae.load_state_dict(vae_state_dict)
91
+ return vae.to(device).to(torch.bfloat16)
92
+
93
+ def load_unet(unet_dir):
94
+ unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
95
+ unet_config_path = unet_dir / "config.json"
96
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
97
+ transformer = Transformer3DModel.from_config(transformer_config)
98
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
99
+ transformer.load_state_dict(unet_state_dict, strict=True)
100
+ return transformer.to(device).to(torch.bfloat16)
101
+
102
+ def load_scheduler(scheduler_dir):
103
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
104
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
105
+ return RectifiedFlowScheduler.from_config(scheduler_config)
106
+
107
+ # Preset options for resolution and frame configuration
108
+ preset_options = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
110
  {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
111
  {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
 
123
  {"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200},
124
  ]
125
 
126
+ def preset_changed(preset):
127
+ if preset != "Custom":
128
+ selected = next(item for item in preset_options if item["label"] == preset)
129
+ return (
130
+ selected["height"],
131
+ selected["width"],
132
+ selected["num_frames"],
133
+ gr.update(visible=False),
134
+ gr.update(visible=False),
135
+ gr.update(visible=False),
136
+ )
137
+ else:
138
+ return (
139
+ None,
140
+ None,
141
+ None,
142
+ gr.update(visible=True),
143
+ gr.update(visible=True),
144
+ gr.update(visible=True),
145
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ # Load models
148
+ vae = load_vae(vae_dir)
149
+ unet = load_unet(unet_dir)
150
+ scheduler = load_scheduler(scheduler_dir)
151
+ patchifier = SymmetricPatchifier(patch_size=1)
152
+ text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(torch.device("cuda:0"))
153
+ tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
154
+
155
+ pipeline = XoraVideoPipeline(
156
+ transformer=unet,
157
+ patchifier=patchifier,
158
+ text_encoder=text_encoder,
159
+ tokenizer=tokenizer,
160
+ scheduler=scheduler,
161
+ vae=vae,
162
+ ).to(torch.device("cuda:0"))
163
+
164
+ def enhance_prompt_if_enabled(prompt, enhance_toggle):
165
+ if not enhance_toggle:
166
+ print("Enhance toggle is off, Prompt: ", prompt)
167
+ return prompt
168
+
169
+ messages = [
170
+ {"role": "system", "content": system_prompt_t2v},
171
+ {"role": "user", "content": prompt},
172
+ ]
173
+
174
+ try:
175
+ response = client.chat.completions.create(
176
+ model="gpt-4-mini",
177
+ messages=messages,
178
+ max_tokens=200,
179
+ )
180
+ print("Enhanced Prompt: ", response.choices[0].message.content.strip())
181
+ return response.choices[0].message.content.strip()
182
  except Exception as e:
183
+ print(f"Error: {e}")
184
+ return prompt
 
 
 
 
185
 
186
  @spaces.GPU(duration=90)
187
+ def generate_video_from_text_90(
188
+ prompt="",
189
+ enhance_prompt_toggle=False,
190
+ negative_prompt="",
191
+ frame_rate=25,
192
+ seed=random.randint(0, MAX_SEED),
193
+ num_inference_steps=30,
194
+ guidance_scale=3.2,
195
+ height=768,
196
+ width=768,
197
+ num_frames=60,
198
+ progress=gr.Progress(),
199
  ):
200
+ # ν”„λ‘¬ν”„νŠΈ μ „μ²˜λ¦¬ (ν•œκΈ€ -> μ˜μ–΄)
201
+ prompt = process_prompt(prompt)
202
+ negative_prompt = process_prompt(negative_prompt)
203
+
204
+ if len(prompt.strip()) < 50:
205
+ raise gr.Error(
206
+ "Prompt must be at least 50 characters long. Please provide more details for the best results.",
207
+ duration=5,
208
+ )
 
209
 
210
+ prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle)
 
211
 
212
+ sample = {
213
+ "prompt": prompt,
214
+ "prompt_attention_mask": None,
215
+ "negative_prompt": negative_prompt,
216
+ "negative_prompt_attention_mask": None,
217
+ "media_items": None,
218
+ }
219
 
220
+ generator = torch.Generator(device="cuda").manual_seed(seed)
221
 
222
+ def gradio_progress_callback(self, step, timestep, kwargs):
223
+ progress((step + 1) / num_inference_steps)
224
 
225
+ try:
226
  with torch.no_grad():
227
  images = pipeline(
228
  num_inference_steps=num_inference_steps,
 
239
  vae_per_channel_normalize=True,
240
  conditioning_method=ConditioningMethod.UNCONDITIONAL,
241
  mixed_precision=True,
242
+ callback_on_step_end=gradio_progress_callback,
243
  ).images
244
+ except Exception as e:
245
+ raise gr.Error(
246
+ f"An error occurred while generating the video. Please try again. Error: {e}",
247
+ duration=5,
248
+ )
249
+ finally:
250
+ torch.cuda.empty_cache()
251
+ gc.collect()
252
 
253
+ output_path = tempfile.mktemp(suffix=".mp4")
254
+ video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
255
+ video_np = (video_np * 255).astype(np.uint8)
256
+ height, width = video_np.shape[1:3]
257
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height))
258
+ for frame in video_np[..., ::-1]:
259
+ out.write(frame)
260
+ out.release()
261
+ del images
262
+ del video_np
263
+ torch.cuda.empty_cache()
264
+ return output_path
265
+
266
+ def create_advanced_options():
267
+ with gr.Accordion("Step 4: Advanced Options (Optional)", open=False):
268
+ seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373)
269
+ inference_steps = gr.Slider(label="4.2 Inference Steps", minimum=5, maximum=150, step=5, value=40)
270
+ guidance_scale = gr.Slider(label="4.3 Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.2)
271
 
272
+ height_slider = gr.Slider(
273
+ label="4.4 Height",
274
+ minimum=256,
275
+ maximum=1024,
276
+ step=64,
277
+ value=768,
278
+ visible=False,
279
+ )
280
+ width_slider = gr.Slider(
281
+ label="4.5 Width",
282
+ minimum=256,
283
+ maximum=1024,
284
+ step=64,
285
+ value=768,
286
+ visible=False,
287
+ )
288
+ num_frames_slider = gr.Slider(
289
+ label="4.5 Number of Frames",
290
+ minimum=1,
291
+ maximum=500,
292
+ step=1,
293
+ value=60,
294
+ visible=False,
295
  )
 
 
 
 
296
 
297
+ return [
298
+ seed,
299
+ inference_steps,
300
+ guidance_scale,
301
+ height_slider,
302
+ width_slider,
303
+ num_frames_slider,
304
+ ]
305
 
306
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
 
 
 
 
 
 
307
 
308
+ with gr.Column():
309
+ txt2vid_prompt = gr.Textbox(
310
+ label="Step 1: Enter Your Prompt (ν•œκΈ€ λ˜λŠ” μ˜μ–΄)",
311
+ placeholder="μƒμ„±ν•˜κ³  싢은 λΉ„λ””μ˜€λ₯Ό μ„€λͺ…ν•˜μ„Έμš” (μ΅œμ†Œ 50자)...",
312
+ value="κΈ΄ κ°ˆμƒ‰ 머리와 밝은 ν”ΌλΆ€λ₯Ό κ°€μ§„ 여성이 κΈ΄ 금발 머리λ₯Ό κ°€μ§„ λ‹€λ₯Έ 여성을 ν–₯ν•΄ λ―Έμ†Œ μ§“μŠ΅λ‹ˆλ‹€. κ°ˆμƒ‰ 머리 여성은 검은 μž¬ν‚·μ„ μž…κ³  있으며 였λ₯Έμͺ½ 뺨에 μž‘κ³  거의 λˆˆμ— 띄지 μ•ŠλŠ” 점이 μžˆμŠ΅λ‹ˆλ‹€. 카메라 액글은 κ°ˆμƒ‰ 머리 μ—¬μ„±μ˜ 얼꡴에 μ΄ˆμ μ„ 맞좘 ν΄λ‘œμ¦ˆμ—…μž…λ‹ˆλ‹€. μ‘°λͺ…은 λ”°λœ»ν•˜κ³  μžμ—°μŠ€λŸ¬μš°λ©°, μ•„λ§ˆλ„ μ§€λŠ” ν•΄μ—μ„œ λ‚˜μ˜€λŠ” 것 κ°™μ•„ μž₯면에 λΆ€λ“œλŸ¬μš΄ 빛을 λΉ„μΆ₯λ‹ˆλ‹€.",
313
+ lines=5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  )
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
+ txt2vid_enhance_toggle = Toggle(
318
+ label="Enhance Prompt",
319
+ value=False,
320
+ interactive=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  )
322
 
323
+ txt2vid_negative_prompt = gr.Textbox(
324
+ label="Step 2: Enter Negative Prompt",
325
+ placeholder="λΉ„λ””μ˜€μ—μ„œ μ›ν•˜μ§€ μ•ŠλŠ” μš”μ†Œλ₯Ό μ„€λͺ…ν•˜μ„Έμš”...",
326
+ value="low quality, worst quality, deformed, distorted, damaged, motion blur, motion artifacts, fused fingers, incorrect anatomy, strange hands, ugly",
327
+ lines=2,
 
 
 
 
 
 
 
 
 
 
328
  )
329
 
330
+ txt2vid_preset = gr.Dropdown(
331
+ choices=[p["label"] for p in preset_options],
332
+ value="512x512, 160 frames",
333
+ label="Step 3.1: Choose Resolution Preset",
334
  )
335
 
336
+ txt2vid_frame_rate = gr.Slider(
337
+ label="Step 3.2: Frame Rate",
338
+ minimum=6,
339
+ maximum=60,
340
+ step=1,
341
+ value=20,
342
+ )
343
+
344
+ txt2vid_advanced = create_advanced_options()
345
+ txt2vid_generate = gr.Button(
346
+ "Step 5: Generate Video",
347
+ variant="primary",
348
+ size="lg",
349
+ )
350
+
351
+ txt2vid_output = gr.Video(label="Generated Output")
352
+
353
+ txt2vid_preset.change(
354
+ fn=preset_changed,
355
+ inputs=[txt2vid_preset],
356
+ outputs=txt2vid_advanced[3:],
357
+ )
358
+
359
+ txt2vid_generate.click(
360
+ fn=generate_video_from_text_90,
361
+ inputs=[
362
+ txt2vid_prompt,
363
+ txt2vid_enhance_toggle,
364
+ txt2vid_negative_prompt,
365
+ txt2vid_frame_rate,
366
+ *txt2vid_advanced,
367
+ ],
368
+ outputs=txt2vid_output,
369
+ concurrency_limit=1,
370
+ concurrency_id="generate_video",
371
+ queue=True,
372
  )
373
 
374
+ iface.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False)
375
+ # ===== Application Startup at 2024-12-20 01:30:34 =====