fantaxy commited on
Commit
3fbccb1
ยท
verified ยท
1 Parent(s): 6545e86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +794 -288
app.py CHANGED
@@ -1,111 +1,386 @@
1
  import spaces
2
- from functools import lru_cache
 
 
 
 
 
 
 
3
  import gradio as gr
4
  from gradio_toggle import Toggle
5
  import torch
6
- from huggingface_hub import snapshot_download
7
- from transformers import CLIPProcessor, CLIPModel, pipeline
8
- import random
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
10
  from xora.models.transformers.transformer3d import Transformer3DModel
11
  from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
12
  from xora.schedulers.rf import RectifiedFlowScheduler
13
  from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
14
- from transformers import T5EncoderModel, T5Tokenizer
15
  from xora.utils.conditioning_method import ConditioningMethod
16
- from pathlib import Path
17
- import safetensors.torch
18
- import json
19
- import numpy as np
20
- import cv2
21
- from PIL import Image
22
- import tempfile
23
- import os
24
- import gc
25
- import csv
26
- from datetime import datetime
27
- from openai import OpenAI
28
-
29
- # ํ•œ๊ธ€-์˜์–ด ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
30
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
31
 
32
- torch.backends.cuda.matmul.allow_tf32 = False
33
- torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
34
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
35
- torch.backends.cudnn.allow_tf32 = False
36
- torch.backends.cudnn.deterministic = False
37
- torch.backends.cuda.preferred_blas_library="cublas"
38
- torch.set_float32_matmul_precision("highest")
39
 
 
40
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Load Hugging Face token if needed
43
- hf_token = os.getenv("HF_TOKEN")
44
- openai_api_key = os.getenv("OPENAI_API_KEY")
45
- client = OpenAI(api_key=openai_api_key)
46
-
47
- system_prompt_t2v_path = "assets/system_prompt_t2v.txt"
48
- with open(system_prompt_t2v_path, "r") as f:
49
- system_prompt_t2v = f.read()
50
-
51
- # Set model download directory within Hugging Face Spaces
52
- model_path = "asset"
 
 
 
 
 
 
 
 
 
53
 
54
- commit_hash='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc'
 
 
 
 
55
 
56
- if not os.path.exists(model_path):
57
- snapshot_download("Lightricks/LTX-Video", revision=commit_hash, local_dir=model_path, repo_type="model", token=hf_token)
 
 
 
58
 
59
- # Global variables to load components
60
- vae_dir = Path(model_path) / "vae"
61
- unet_dir = Path(model_path) / "unet"
62
- scheduler_dir = Path(model_path) / "scheduler"
63
 
64
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
 
66
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path).to(torch.device("cuda:0"))
67
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def process_prompt(prompt):
70
- # ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธ
71
  if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in prompt):
72
- # ํ•œ๊ธ€์„ ์˜์–ด๋กœ ๋ฒˆ์—ญ
73
  translated = translator(prompt)[0]['translation_text']
 
74
  return translated
75
  return prompt
76
 
77
- def compute_clip_embedding(text=None):
78
- inputs = clip_processor(text=text, return_tensors="pt", padding=True).to(device)
79
- outputs = clip_model.get_text_features(**inputs)
80
- embedding = outputs.detach().cpu().numpy().flatten().tolist()
81
- return embedding
82
-
83
- def load_vae(vae_dir):
84
- vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
85
- vae_config_path = vae_dir / "config.json"
86
- with open(vae_config_path, "r") as f:
87
- vae_config = json.load(f)
88
- vae = CausalVideoAutoencoder.from_config(vae_config)
89
- vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
90
- vae.load_state_dict(vae_state_dict)
91
- return vae.to(device).to(torch.bfloat16)
92
-
93
- def load_unet(unet_dir):
94
- unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
95
- unet_config_path = unet_dir / "config.json"
96
- transformer_config = Transformer3DModel.load_config(unet_config_path)
97
- transformer = Transformer3DModel.from_config(transformer_config)
98
- unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
99
- transformer.load_state_dict(unet_state_dict, strict=True)
100
- return transformer.to(device).to(torch.bfloat16)
101
-
102
- def load_scheduler(scheduler_dir):
103
- scheduler_config_path = scheduler_dir / "scheduler_config.json"
104
- scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
105
- return RectifiedFlowScheduler.from_config(scheduler_config)
106
-
107
- # Preset options for resolution and frame configuration
108
- preset_options = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
110
  {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
111
  {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
@@ -123,106 +398,99 @@ preset_options = [
123
  {"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200},
124
  ]
125
 
126
- def preset_changed(preset):
127
- if preset != "Custom":
128
- selected = next(item for item in preset_options if item["label"] == preset)
129
- return (
130
- selected["height"],
131
- selected["width"],
132
- selected["num_frames"],
133
- gr.update(visible=False),
134
- gr.update(visible=False),
135
- gr.update(visible=False),
136
- )
137
- else:
138
- return (
139
- None,
140
- None,
141
- None,
142
- gr.update(visible=True),
143
- gr.update(visible=True),
144
- gr.update(visible=True),
145
- )
146
-
147
- # Load models
148
- vae = load_vae(vae_dir)
149
- unet = load_unet(unet_dir)
150
- scheduler = load_scheduler(scheduler_dir)
151
- patchifier = SymmetricPatchifier(patch_size=1)
152
- text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(torch.device("cuda:0"))
153
- tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
154
-
155
- pipeline = XoraVideoPipeline(
156
- transformer=unet,
157
- patchifier=patchifier,
158
- text_encoder=text_encoder,
159
- tokenizer=tokenizer,
160
- scheduler=scheduler,
161
- vae=vae,
162
- ).to(torch.device("cuda:0"))
163
-
164
- def enhance_prompt_if_enabled(prompt, enhance_toggle):
165
- if not enhance_toggle:
166
- print("Enhance toggle is off, Prompt: ", prompt)
167
- return prompt
168
-
169
- messages = [
170
- {"role": "system", "content": system_prompt_t2v},
171
- {"role": "user", "content": prompt},
172
- ]
173
-
174
  try:
175
- response = client.chat.completions.create(
176
- model="gpt-4-mini",
177
- messages=messages,
178
- max_tokens=200,
179
- )
180
- print("Enhanced Prompt: ", response.choices[0].message.content.strip())
181
- return response.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  except Exception as e:
183
- print(f"Error: {e}")
184
- return prompt
 
 
 
 
185
 
186
  @spaces.GPU(duration=90)
187
- def generate_video_from_text_90(
188
- prompt="",
189
- enhance_prompt_toggle=False,
190
- negative_prompt="",
191
- frame_rate=25,
192
- seed=random.randint(0, MAX_SEED),
193
- num_inference_steps=30,
194
- guidance_scale=3.2,
195
- height=768,
196
- width=768,
197
- num_frames=60,
198
- progress=gr.Progress(),
199
  ):
200
- # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ (ํ•œ๊ธ€ -> ์˜์–ด)
201
- prompt = process_prompt(prompt)
202
- negative_prompt = process_prompt(negative_prompt)
203
-
204
- if len(prompt.strip()) < 50:
205
- raise gr.Error(
206
- "Prompt must be at least 50 characters long. Please provide more details for the best results.",
207
- duration=5,
208
- )
209
 
210
- prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle)
 
211
 
212
- sample = {
213
- "prompt": prompt,
214
- "prompt_attention_mask": None,
215
- "negative_prompt": negative_prompt,
216
- "negative_prompt_attention_mask": None,
217
- "media_items": None,
218
- }
219
 
220
- generator = torch.Generator(device="cuda").manual_seed(seed)
 
221
 
222
- def gradio_progress_callback(self, step, timestep, kwargs):
223
- progress((step + 1) / num_inference_steps)
 
 
 
 
 
 
 
 
 
 
224
 
225
- try:
226
  with torch.no_grad():
227
  images = pipeline(
228
  num_inference_steps=num_inference_steps,
@@ -239,137 +507,375 @@ def generate_video_from_text_90(
239
  vae_per_channel_normalize=True,
240
  conditioning_method=ConditioningMethod.UNCONDITIONAL,
241
  mixed_precision=True,
242
- callback_on_step_end=gradio_progress_callback,
243
  ).images
244
- except Exception as e:
245
- raise gr.Error(
246
- f"An error occurred while generating the video. Please try again. Error: {e}",
247
- duration=5,
248
- )
249
- finally:
250
- torch.cuda.empty_cache()
251
- gc.collect()
252
 
253
- output_path = tempfile.mktemp(suffix=".mp4")
254
- video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
255
- video_np = (video_np * 255).astype(np.uint8)
256
- height, width = video_np.shape[1:3]
257
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height))
258
- for frame in video_np[..., ::-1]:
259
- out.write(frame)
260
- out.release()
261
- del images
262
- del video_np
263
- torch.cuda.empty_cache()
264
- return output_path
265
-
266
- def create_advanced_options():
267
- with gr.Accordion("Step 4: Advanced Options (Optional)", open=False):
268
- seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373)
269
- inference_steps = gr.Slider(label="4.2 Inference Steps", minimum=5, maximum=150, step=5, value=40)
270
- guidance_scale = gr.Slider(label="4.3 Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.2)
271
 
272
- height_slider = gr.Slider(
273
- label="4.4 Height",
274
- minimum=256,
275
- maximum=1024,
276
- step=64,
277
- value=768,
278
- visible=False,
279
  )
280
- width_slider = gr.Slider(
281
- label="4.5 Width",
282
- minimum=256,
283
- maximum=1024,
284
- step=64,
285
- value=768,
286
- visible=False,
287
- )
288
- num_frames_slider = gr.Slider(
289
- label="4.5 Number of Frames",
290
- minimum=1,
291
- maximum=500,
292
- step=1,
293
- value=60,
294
- visible=False,
295
- )
296
-
297
- return [
298
- seed,
299
- inference_steps,
300
- guidance_scale,
301
- height_slider,
302
- width_slider,
303
- num_frames_slider,
304
- ]
305
-
306
- with gr.Blocks(theme=gr.themes.Soft()) as iface:
307
 
308
- with gr.Column():
309
- txt2vid_prompt = gr.Textbox(
310
- label="Step 1: Enter Your Prompt (ํ•œ๊ธ€ ๋˜๋Š” ์˜์–ด)",
311
- placeholder="์ƒ์„ฑํ•˜๊ณ  ์‹ถ์€ ๋น„๋””์˜ค๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š” (์ตœ์†Œ 50์ž)...",
312
- value="๊ธด ๊ฐˆ์ƒ‰ ๋จธ๋ฆฌ์™€ ๋ฐ์€ ํ”ผ๋ถ€๋ฅผ ๊ฐ€์ง„ ์—ฌ์„ฑ์ด ๊ธด ๊ธˆ๋ฐœ ๋จธ๋ฆฌ๋ฅผ ๊ฐ€์ง„ ๋‹ค๋ฅธ ์—ฌ์„ฑ์„ ํ–ฅํ•ด ๋ฏธ์†Œ ์ง“์Šต๋‹ˆ๋‹ค. ๊ฐˆ์ƒ‰ ๋จธ๋ฆฌ ์—ฌ์„ฑ์€ ๊ฒ€์€ ์žฌํ‚ท์„ ์ž…๊ณ  ์žˆ์œผ๋ฉฐ ์˜ค๋ฅธ์ชฝ ๋บจ์— ์ž‘๊ณ  ๊ฑฐ์˜ ๋ˆˆ์— ๋„์ง€ ์•Š๋Š” ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์นด๋ฉ”๋ผ ์•ต๊ธ€์€ ๊ฐˆ์ƒ‰ ๋จธ๋ฆฌ ์—ฌ์„ฑ์˜ ์–ผ๊ตด์— ์ดˆ์ ์„ ๋งž์ถ˜ ํด๋กœ์ฆˆ์—…์ž…๋‹ˆ๋‹ค. ์กฐ๋ช…์€ ๋”ฐ๋œปํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šฐ๋ฉฐ, ์•„๋งˆ๋„ ์ง€๋Š” ํ•ด์—์„œ ๋‚˜์˜ค๋Š” ๊ฒƒ ๊ฐ™์•„ ์žฅ๋ฉด์— ๋ถ€๋“œ๋Ÿฌ์šด ๋น›์„ ๋น„์ถฅ๋‹ˆ๋‹ค.",
313
- lines=5,
314
- )
315
 
 
 
 
 
 
 
 
316
 
317
- txt2vid_enhance_toggle = Toggle(
318
- label="Enhance Prompt",
319
- value=False,
320
- interactive=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  )
322
 
323
- txt2vid_negative_prompt = gr.Textbox(
324
- label="Step 2: Enter Negative Prompt",
325
- placeholder="๋น„๋””์˜ค์—์„œ ์›ํ•˜์ง€ ์•Š๋Š” ์š”์†Œ๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”...",
326
- value="low quality, worst quality, deformed, distorted, damaged, motion blur, motion artifacts, fused fingers, incorrect anatomy, strange hands, ugly",
327
- lines=2,
328
- )
 
 
 
 
 
 
 
 
329
 
330
- txt2vid_preset = gr.Dropdown(
331
- choices=[p["label"] for p in preset_options],
332
- value="512x512, 160 frames",
333
- label="Step 3.1: Choose Resolution Preset",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  )
335
 
336
- txt2vid_frame_rate = gr.Slider(
337
- label="Step 3.2: Frame Rate",
338
- minimum=6,
339
- maximum=60,
340
- step=1,
341
- value=20,
 
 
 
 
 
 
 
 
 
342
  )
343
 
344
- txt2vid_advanced = create_advanced_options()
345
- txt2vid_generate = gr.Button(
346
- "Step 5: Generate Video",
347
- variant="primary",
348
- size="lg",
349
  )
350
 
351
- txt2vid_output = gr.Video(label="Generated Output")
352
-
353
- txt2vid_preset.change(
354
- fn=preset_changed,
355
- inputs=[txt2vid_preset],
356
- outputs=txt2vid_advanced[3:],
357
- )
358
-
359
- txt2vid_generate.click(
360
- fn=generate_video_from_text_90,
361
- inputs=[
362
- txt2vid_prompt,
363
- txt2vid_enhance_toggle,
364
- txt2vid_negative_prompt,
365
- txt2vid_frame_rate,
366
- *txt2vid_advanced,
367
- ],
368
- outputs=txt2vid_output,
369
- concurrency_limit=1,
370
- concurrency_id="generate_video",
371
- queue=True,
372
  )
373
 
374
- iface.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False)
375
- # ===== Application Startup at 2024-12-20 01:30:34 =====
 
1
  import spaces
2
+ import argparse
3
+ import os
4
+ import time
5
+ from os import path
6
+ import shutil
7
+ from datetime import datetime
8
+ from safetensors.torch import load_file
9
+ from huggingface_hub import hf_hub_download, snapshot_download
10
  import gradio as gr
11
  from gradio_toggle import Toggle
12
  import torch
13
+ from diffusers import FluxPipeline
14
+ from diffusers.pipelines.stable_diffusion import safety_checker
15
+ from PIL import Image
16
+ from transformers import pipeline, CLIPProcessor, CLIPModel, T5EncoderModel, T5Tokenizer
17
+ import replicate
18
+ import logging
19
+ import requests
20
+ from pathlib import Path
21
+ import cv2
22
+ import numpy as np
23
+ import sys
24
+ import io
25
+ import json
26
+ import gc
27
+ import csv
28
+ from openai import OpenAI
29
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
30
  from xora.models.transformers.transformer3d import Transformer3DModel
31
  from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
32
  from xora.schedulers.rf import RectifiedFlowScheduler
33
  from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
 
34
  from xora.utils.conditioning_method import ConditioningMethod
35
+ from functools import lru_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # ๋กœ๊น… ์„ค์ •
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format='%(asctime)s - %(levelname)s - %(message)s'
41
+ )
42
+ logger = logging.getLogger(__name__)
 
43
 
44
+ # ์ƒ์ˆ˜ ๋ฐ ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
45
  MAX_SEED = np.iinfo(np.int32).max
46
+ PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
47
+ MODEL_PATH = "asset"
48
+ CACHE_PATH = path.join(path.dirname(path.abspath(__file__)), "models")
49
+ GALLERY_PATH = path.join(PERSISTENT_DIR, "gallery")
50
+ VIDEO_GALLERY_PATH = path.join(PERSISTENT_DIR, "video_gallery")
51
+
52
+ # API ํ‚ค ์„ค์ •
53
+ HF_TOKEN = os.getenv("HF_TOKEN")
54
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
55
+ CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5"
56
+ REPLICATE_API_TOKEN = os.getenv("API_KEY")
57
+
58
+ # ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ๋กœ๋“œ
59
+ SYSTEM_PROMPT_PATH = "assets/system_prompt_t2v.txt"
60
+ with open(SYSTEM_PROMPT_PATH, "r") as f:
61
+ SYSTEM_PROMPT = f.read()
62
+
63
+ # ๋””๋ ‰ํ† ๋ฆฌ ์ดˆ๊ธฐํ™”
64
+ def init_directories():
65
+ """ํ•„์š”ํ•œ ๋””๋ ‰ํ† ๋ฆฌ๋“ค์„ ์ƒ์„ฑ"""
66
+ directories = [GALLERY_PATH, VIDEO_GALLERY_PATH, CACHE_PATH]
67
+ for directory in directories:
68
+ os.makedirs(directory, exist_ok=True)
69
+ logger.info(f"Directory initialized: {directory}")
70
+
71
+ # CUDA ์„ค์ •
72
+ def setup_cuda():
73
+ """CUDA ๊ด€๋ จ ์„ค์ • ์ดˆ๊ธฐํ™”"""
74
+ torch.backends.cuda.matmul.allow_tf32 = False
75
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
76
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
77
+ torch.backends.cudnn.allow_tf32 = False
78
+ torch.backends.cudnn.deterministic = False
79
+ torch.backends.cuda.preferred_blas_library = "cublas"
80
+ torch.set_float32_matmul_precision("highest")
81
+ logger.info("CUDA settings initialized")
82
+
83
+ # ๋ชจ๋ธ ๊ด€๋ฆฌ ํด๋ž˜์Šค
84
+ class ModelManager:
85
+ def __init__(self):
86
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
+ self.models = {}
88
+ self.current_model = None
89
+ logger.info(f"ModelManager initialized with device: {self.device}")
90
+
91
+ def load_model(self, model_name):
92
+ """๋ชจ๋ธ์„ ๋™์ ์œผ๋กœ ๋กœ๋“œ"""
93
+ if self.current_model == model_name and model_name in self.models:
94
+ return self.models[model_name]
95
+
96
+ # ํ˜„์žฌ ๋กœ๋“œ๋œ ๋ชจ๋ธ ์–ธ๋กœ๋“œ
97
+ self.unload_current_model()
98
+
99
+ logger.info(f"Loading model: {model_name}")
100
+ try:
101
+ if model_name == "flux":
102
+ model = self._load_flux_model()
103
+ elif model_name == "xora":
104
+ model = self._load_xora_model()
105
+ elif model_name == "clip":
106
+ model = self._load_clip_model()
107
+ else:
108
+ raise ValueError(f"Unknown model: {model_name}")
109
+
110
+ self.models[model_name] = model
111
+ self.current_model = model_name
112
+ return model
113
+
114
+ except Exception as e:
115
+ logger.error(f"Error loading model {model_name}: {str(e)}")
116
+ raise
117
+
118
+ def unload_current_model(self):
119
+ """ํ˜„์žฌ ๏ฟฝ๏ฟฝ๋“œ๋œ ๋ชจ๋ธ ์–ธ๋กœ๋“œ"""
120
+ if self.current_model:
121
+ logger.info(f"Unloading model: {self.current_model}")
122
+ if self.current_model in self.models:
123
+ del self.models[self.current_model]
124
+ self.current_model = None
125
+ torch.cuda.empty_cache()
126
+ gc.collect()
127
+
128
+ def _load_flux_model(self):
129
+ """Flux ๋ชจ๋ธ ๋กœ๋“œ"""
130
+ pipe = FluxPipeline.from_pretrained(
131
+ "black-forest-labs/FLUX.1-dev",
132
+ torch_dtype=torch.bfloat16
133
+ )
134
+ pipe.load_lora_weights(
135
+ hf_hub_download(
136
+ "ByteDance/Hyper-SD",
137
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors"
138
+ )
139
+ )
140
+ pipe.fuse_lora(lora_scale=0.125)
141
+ pipe.to(device=self.device, dtype=torch.bfloat16)
142
+ pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
143
+ "CompVis/stable-diffusion-safety-checker"
144
+ )
145
+ return pipe
146
+
147
+ def _load_xora_model(self):
148
+ """Xora ๋ชจ๋ธ ๋กœ๋“œ"""
149
+ if not path.exists(MODEL_PATH):
150
+ snapshot_download(
151
+ "Lightricks/LTX-Video",
152
+ revision='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc',
153
+ local_dir=MODEL_PATH,
154
+ repo_type="model",
155
+ token=HF_TOKEN
156
+ )
157
+
158
+ vae = load_vae(Path(MODEL_PATH) / "vae")
159
+ unet = load_unet(Path(MODEL_PATH) / "unet")
160
+ scheduler = load_scheduler(Path(MODEL_PATH) / "scheduler")
161
+ patchifier = SymmetricPatchifier(patch_size=1)
162
+ text_encoder = T5EncoderModel.from_pretrained(
163
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
164
+ subfolder="text_encoder"
165
+ ).to(self.device)
166
+ tokenizer = T5Tokenizer.from_pretrained(
167
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
168
+ subfolder="tokenizer"
169
+ )
170
 
171
+ return XoraVideoPipeline(
172
+ transformer=unet,
173
+ patchifier=patchifier,
174
+ text_encoder=text_encoder,
175
+ tokenizer=tokenizer,
176
+ scheduler=scheduler,
177
+ vae=vae
178
+ ).to(self.device)
179
+
180
+ def _load_clip_model(self):
181
+ """CLIP ๋ชจ๋ธ ๋กœ๋“œ"""
182
+ model = CLIPModel.from_pretrained(
183
+ "openai/clip-vit-base-patch32",
184
+ cache_dir=MODEL_PATH
185
+ ).to(self.device)
186
+ processor = CLIPProcessor.from_pretrained(
187
+ "openai/clip-vit-base-patch32",
188
+ cache_dir=MODEL_PATH
189
+ )
190
+ return {"model": model, "processor": processor}
191
 
192
+ # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
193
+ @lru_cache(maxsize=None)
194
+ def get_translator():
195
+ """๋ฒˆ์—ญ๊ธฐ๋ฅผ lazy loading์œผ๋กœ ์ดˆ๊ธฐํ™”"""
196
+ return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
197
 
198
+ # OpenAI ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
199
+ @lru_cache(maxsize=None)
200
+ def get_openai_client():
201
+ """OpenAI ํด๋ผ์ด์–ธํŠธ๋ฅผ lazy loading์œผ๋กœ ์ดˆ๊ธฐํ™”"""
202
+ return OpenAI(api_key=OPENAI_API_KEY)
203
 
 
 
 
 
204
 
 
205
 
206
+ # ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜๋“ค
207
+ class Timer:
208
+ """์ž‘์—… ์‹œ๊ฐ„ ์ธก์ •์„ ์œ„ํ•œ ์ปจํ…์ŠคํŠธ ๋งค๋‹ˆ์ €"""
209
+ def __init__(self, method_name="timed process"):
210
+ self.method = method_name
211
+
212
+ def __enter__(self):
213
+ self.start = time.time()
214
+ logger.info(f"{self.method} starts")
215
+
216
+ def __exit__(self, exc_type, exc_val, exc_tb):
217
+ end = time.time()
218
+ logger.info(f"{self.method} took {str(round(end - self.start, 2))}s")
219
 
220
  def process_prompt(prompt):
221
+ """ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ (ํ•œ๊ธ€ ๋ฒˆ์—ญ ๋ฐ ํ•„ํ„ฐ๋ง)"""
222
  if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in prompt):
223
+ translator = get_translator()
224
  translated = translator(prompt)[0]['translation_text']
225
+ logger.info(f"Translated prompt: {translated}")
226
  return translated
227
  return prompt
228
 
229
+ def filter_prompt(prompt):
230
+ """๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ ํ•„ํ„ฐ๋ง"""
231
+ inappropriate_keywords = [
232
+ "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult",
233
+ "xxx", "erotic", "sensual", "seductive", "provocative",
234
+ "intimate", "violence", "gore", "blood", "death", "kill",
235
+ "murder", "torture", "drug", "suicide", "abuse", "hate",
236
+ "discrimination"
237
+ ]
238
+
239
+ prompt_lower = prompt.lower()
240
+ for keyword in inappropriate_keywords:
241
+ if keyword in prompt_lower:
242
+ logger.warning(f"Inappropriate content detected: {keyword}")
243
+ return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
244
+ return True, prompt
245
+
246
+ def enhance_prompt(prompt, enhance_toggle):
247
+ """GPT๋ฅผ ์‚ฌ์šฉํ•œ ํ”„๋กฌํ”„ํŠธ ๊ฐœ์„ """
248
+ if not enhance_toggle:
249
+ logger.info("Prompt enhancement disabled")
250
+ return prompt
251
+
252
+ try:
253
+ client = get_openai_client()
254
+ messages = [
255
+ {"role": "system", "content": SYSTEM_PROMPT},
256
+ {"role": "user", "content": prompt},
257
+ ]
258
+
259
+ response = client.chat.completions.create(
260
+ model="gpt-4-mini",
261
+ messages=messages,
262
+ max_tokens=200,
263
+ )
264
+
265
+ enhanced_prompt = response.choices[0].message.content.strip()
266
+ logger.info(f"Enhanced prompt: {enhanced_prompt}")
267
+ return enhanced_prompt
268
+ except Exception as e:
269
+ logger.error(f"Prompt enhancement failed: {str(e)}")
270
+ return prompt
271
+
272
+ def save_image(image, directory=GALLERY_PATH):
273
+ """์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ์ €์žฅ"""
274
+ try:
275
+ os.makedirs(directory, exist_ok=True)
276
+
277
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
278
+ random_suffix = os.urandom(4).hex()
279
+ filename = f"generated_{timestamp}_{random_suffix}.png"
280
+ filepath = os.path.join(directory, filename)
281
+
282
+ if not isinstance(image, Image.Image):
283
+ image = Image.fromarray(image)
284
+
285
+ if image.mode != 'RGB':
286
+ image = image.convert('RGB')
287
+
288
+ image.save(filepath, format='PNG', optimize=True, quality=100)
289
+ logger.info(f"Image saved: {filepath}")
290
+ return filepath
291
+ except Exception as e:
292
+ logger.error(f"Error saving image: {str(e)}")
293
+ return None
294
+
295
+ def add_watermark(video_path):
296
+ """๋น„๋””์˜ค์— ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€"""
297
+ try:
298
+ cap = cv2.VideoCapture(video_path)
299
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
300
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
301
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
302
+
303
+ text = "GiniGEN.AI"
304
+ font = cv2.FONT_HERSHEY_SIMPLEX
305
+ font_scale = height * 0.05 / 30
306
+ thickness = 2
307
+ color = (255, 255, 255)
308
+
309
+ (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
310
+ margin = int(height * 0.02)
311
+ x_pos = width - text_width - margin
312
+ y_pos = height - margin
313
+
314
+ output_path = os.path.join(VIDEO_GALLERY_PATH, f"watermarked_{os.path.basename(video_path)}")
315
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
316
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
317
+
318
+ while cap.isOpened():
319
+ ret, frame = cap.read()
320
+ if not ret:
321
+ break
322
+ cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
323
+ out.write(frame)
324
+
325
+ cap.release()
326
+ out.release()
327
+ logger.info(f"Video watermarked: {output_path}")
328
+ return output_path
329
+
330
+ except Exception as e:
331
+ logger.error(f"Error adding watermark: {str(e)}")
332
+ return video_path
333
+
334
+ def upload_to_catbox(file_path):
335
+ """ํŒŒ์ผ์„ catbox.moe์— ์—…๋กœ๋“œ"""
336
+ try:
337
+ logger.info(f"Uploading file: {file_path}")
338
+ url = "https://catbox.moe/user/api.php"
339
+
340
+ file_extension = Path(file_path).suffix.lower()
341
+ supported_extensions = {
342
+ '.jpg': 'image/jpeg',
343
+ '.jpeg': 'image/jpeg',
344
+ '.png': 'image/png',
345
+ '.gif': 'image/gif',
346
+ '.mp4': 'video/mp4'
347
+ }
348
+
349
+ if file_extension not in supported_extensions:
350
+ logger.error(f"Unsupported file type: {file_extension}")
351
+ return None
352
+
353
+ files = {
354
+ 'fileToUpload': (
355
+ os.path.basename(file_path),
356
+ open(file_path, 'rb'),
357
+ supported_extensions[file_extension]
358
+ )
359
+ }
360
+
361
+ data = {
362
+ 'reqtype': 'fileupload',
363
+ 'userhash': CATBOX_USER_HASH
364
+ }
365
+
366
+ response = requests.post(url, files=files, data=data)
367
+
368
+ if response.status_code == 200 and response.text.startswith('http'):
369
+ logger.info(f"Upload successful: {response.text}")
370
+ return response.text
371
+ else:
372
+ raise Exception(f"Upload failed: {response.text}")
373
+
374
+ except Exception as e:
375
+ logger.error(f"Upload error: {str(e)}")
376
+ return None
377
+
378
+ # ๋ชจ๋ธ ๋งค๋‹ˆ์ € ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
379
+ model_manager = ModelManager()
380
+
381
+
382
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ด€๋ จ ์ƒ์ˆ˜ ๋ฐ ์„ค์ •
383
+ PRESET_OPTIONS = [
384
  {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
385
  {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
386
  {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
 
398
  {"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200},
399
  ]
400
 
401
+ # ๋ฉ”์ธ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋“ค
402
+ @spaces.GPU(duration=90)
403
+ def generate_image(
404
+ prompt,
405
+ height,
406
+ width,
407
+ steps,
408
+ scales,
409
+ seed,
410
+ enhance_prompt_toggle=False,
411
+ progress=gr.Progress()
412
+ ):
413
+ """์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  try:
415
+ # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ
416
+ processed_prompt = process_prompt(prompt)
417
+ is_safe, filtered_prompt = filter_prompt(processed_prompt)
418
+ if not is_safe:
419
+ raise gr.Error("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
420
+
421
+ if enhance_prompt_toggle:
422
+ filtered_prompt = enhance_prompt(filtered_prompt, True)
423
+
424
+ # Flux ๋ชจ๋ธ ๋กœ๋“œ
425
+ pipe = model_manager.load_model("flux")
426
+
427
+ with Timer("Image generation"), torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
428
+ generated_image = pipe(
429
+ prompt=[filtered_prompt],
430
+ generator=torch.Generator().manual_seed(int(seed)),
431
+ num_inference_steps=int(steps),
432
+ guidance_scale=float(scales),
433
+ height=int(height),
434
+ width=int(width),
435
+ max_sequence_length=256
436
+ ).images[0]
437
+
438
+ # ์ด๋ฏธ์ง€ ์ €์žฅ ๋ฐ ๋ฐ˜ํ™˜
439
+ saved_path = save_image(generated_image)
440
+ if saved_path is None:
441
+ raise gr.Error("์ด๋ฏธ์ง€ ์ €์žฅ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.")
442
+
443
+ return Image.open(saved_path)
444
+
445
  except Exception as e:
446
+ logger.error(f"Image generation error: {str(e)}")
447
+ raise gr.Error(f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
448
+ finally:
449
+ model_manager.unload_current_model()
450
+ torch.cuda.empty_cache()
451
+ gc.collect()
452
 
453
  @spaces.GPU(duration=90)
454
+ def generate_video_xora(
455
+ prompt,
456
+ enhance_prompt_toggle,
457
+ negative_prompt,
458
+ frame_rate,
459
+ seed,
460
+ num_inference_steps,
461
+ guidance_scale,
462
+ height,
463
+ width,
464
+ num_frames,
465
+ progress=gr.Progress()
466
  ):
467
+ """Xora ๋น„๋””์˜ค ์ƒ์„ฑ ํ•จ์ˆ˜"""
468
+ try:
469
+ # ํ”„๋กฌํ”„ํŠธ ์ฒ˜๋ฆฌ
470
+ prompt = process_prompt(prompt)
471
+ negative_prompt = process_prompt(negative_prompt)
 
 
 
 
472
 
473
+ if len(prompt.strip()) < 50:
474
+ raise gr.Error("ํ”„๋กฌํ”„ํŠธ๋Š” ์ตœ์†Œ 50์ž ์ด์ƒ์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
475
 
476
+ prompt = enhance_prompt(prompt, enhance_prompt_toggle)
 
 
 
 
 
 
477
 
478
+ # Xora ๋ชจ๋ธ ๋กœ๋“œ
479
+ pipeline = model_manager.load_model("xora")
480
 
481
+ sample = {
482
+ "prompt": prompt,
483
+ "prompt_attention_mask": None,
484
+ "negative_prompt": negative_prompt,
485
+ "negative_prompt_attention_mask": None,
486
+ "media_items": None,
487
+ }
488
+
489
+ generator = torch.Generator(device="cuda").manual_seed(seed)
490
+
491
+ def progress_callback(step, timestep, kwargs):
492
+ progress((step + 1) / num_inference_steps)
493
 
 
494
  with torch.no_grad():
495
  images = pipeline(
496
  num_inference_steps=num_inference_steps,
 
507
  vae_per_channel_normalize=True,
508
  conditioning_method=ConditioningMethod.UNCONDITIONAL,
509
  mixed_precision=True,
510
+ callback_on_step_end=progress_callback,
511
  ).images
 
 
 
 
 
 
 
 
512
 
513
+ # ๋น„๋””์˜ค ์ €์žฅ
514
+ output_path = os.path.join(VIDEO_GALLERY_PATH, f"generated_{int(time.time())}.mp4")
515
+ video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
516
+ video_np = (video_np * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
+ out = cv2.VideoWriter(
519
+ output_path,
520
+ cv2.VideoWriter_fourcc(*"mp4v"),
521
+ frame_rate,
522
+ (width, height)
 
 
523
  )
524
+
525
+ for frame in video_np[..., ::-1]:
526
+ out.write(frame)
527
+ out.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
+ # ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€
530
+ final_path = add_watermark(output_path)
531
+ return final_path
 
 
 
 
532
 
533
+ except Exception as e:
534
+ logger.error(f"Video generation error: {str(e)}")
535
+ raise gr.Error(f"๋น„๋””์˜ค ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
536
+ finally:
537
+ model_manager.unload_current_model()
538
+ torch.cuda.empty_cache()
539
+ gc.collect()
540
 
541
+ def generate_video_replicate(image, prompt):
542
+ """Replicate API๋ฅผ ์‚ฌ์šฉํ•œ ๋น„๋””์˜ค ์ƒ์„ฑ ํ•จ์ˆ˜"""
543
+ try:
544
+ is_safe, filtered_prompt = filter_prompt(prompt)
545
+ if not is_safe:
546
+ raise gr.Error("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
547
+
548
+ if not image:
549
+ raise gr.Error("์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.")
550
+
551
+ # ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ
552
+ image_url = upload_to_catbox(image)
553
+ if not image_url:
554
+ raise gr.Error("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.")
555
+
556
+ # Replicate API ํ˜ธ์ถœ
557
+ client = replicate.Client(api_token=REPLICATE_API_TOKEN)
558
+ output = client.run(
559
+ "minimax/video-01-live",
560
+ input={
561
+ "prompt": filtered_prompt,
562
+ "first_frame_image": image_url
563
+ }
564
  )
565
 
566
+ # ๊ฒฐ๊ณผ ๋น„๋””์˜ค ์ €์žฅ
567
+ output_path = os.path.join(VIDEO_GALLERY_PATH, f"replicate_{int(time.time())}.mp4")
568
+
569
+ if hasattr(output, 'read'):
570
+ with open(output_path, "wb") as f:
571
+ f.write(output.read())
572
+ elif isinstance(output, str):
573
+ response = requests.get(output)
574
+ with open(output_path, "wb") as f:
575
+ f.write(response.content)
576
+
577
+ # ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€
578
+ final_path = add_watermark(output_path)
579
+ return final_path
580
 
581
+ except Exception as e:
582
+ logger.error(f"Replicate video generation error: {str(e)}")
583
+ raise gr.Error(f"๋น„๋””์˜ค ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
584
+
585
+
586
+ # Gradio UI ์Šคํƒ€์ผ
587
+ css = """
588
+ .gradio-container {
589
+ font-family: 'Pretendard', 'Noto Sans KR', sans-serif !important;
590
+ }
591
+
592
+ .title {
593
+ text-align: center;
594
+ font-size: 2.5rem;
595
+ font-weight: bold;
596
+ color: #2a9d8f;
597
+ margin: 1rem 0;
598
+ padding: 1rem;
599
+ background: linear-gradient(to right, #264653, #2a9d8f);
600
+ -webkit-background-clip: text;
601
+ -webkit-text-fill-color: transparent;
602
+ }
603
+
604
+ .generate-btn {
605
+ background: linear-gradient(to right, #2a9d8f, #264653) !important;
606
+ border: none !important;
607
+ color: white !important;
608
+ font-weight: bold !important;
609
+ transition: all 0.3s ease !important;
610
+ }
611
+
612
+ .generate-btn:hover {
613
+ transform: translateY(-2px) !important;
614
+ box-shadow: 0 5px 15px rgba(42, 157, 143, 0.4) !important;
615
+ }
616
+
617
+ .gallery {
618
+ display: grid;
619
+ grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
620
+ gap: 1rem;
621
+ padding: 1rem;
622
+ }
623
+
624
+ .gallery img {
625
+ width: 100%;
626
+ height: auto;
627
+ border-radius: 8px;
628
+ transition: transform 0.3s ease;
629
+ }
630
+
631
+ .gallery img:hover {
632
+ transform: scale(1.05);
633
+ }
634
+ """
635
+
636
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
637
+ def create_ui():
638
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
639
+ gr.HTML('<div class="title">AI Image & Video Generator</div>')
640
+
641
+ with gr.Tabs():
642
+ # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํƒญ
643
+ with gr.Tab("Image Generation"):
644
+ with gr.Row():
645
+ with gr.Column(scale=3):
646
+ img_prompt = gr.Textbox(
647
+ label="Image Description",
648
+ placeholder="์ด๋ฏธ์ง€ ์„ค๋ช…์„ ์ž…๋ ฅํ•˜์„ธ์š”... (ํ•œ๊ธ€ ์ž…๋ ฅ ๊ฐ€๋Šฅ)",
649
+ lines=3
650
+ )
651
+
652
+ img_enhance_toggle = Toggle(
653
+ label="Enhance Prompt",
654
+ value=False,
655
+ interactive=True,
656
+ )
657
+
658
+ with gr.Accordion("Advanced Settings", open=False):
659
+ with gr.Row():
660
+ img_height = gr.Slider(
661
+ label="Height",
662
+ minimum=256,
663
+ maximum=1024,
664
+ step=64,
665
+ value=768
666
+ )
667
+ img_width = gr.Slider(
668
+ label="Width",
669
+ minimum=256,
670
+ maximum=1024,
671
+ step=64,
672
+ value=768
673
+ )
674
+
675
+ with gr.Row():
676
+ steps = gr.Slider(
677
+ label="Inference Steps",
678
+ minimum=6,
679
+ maximum=25,
680
+ step=1,
681
+ value=8
682
+ )
683
+ scales = gr.Slider(
684
+ label="Guidance Scale",
685
+ minimum=0.0,
686
+ maximum=5.0,
687
+ step=0.1,
688
+ value=3.5
689
+ )
690
+
691
+ seed = gr.Number(
692
+ label="Seed",
693
+ value=random.randint(0, MAX_SEED),
694
+ precision=0
695
+ )
696
+
697
+ img_generate_btn = gr.Button(
698
+ "Generate Image",
699
+ variant="primary",
700
+ elem_classes=["generate-btn"]
701
+ )
702
+
703
+ with gr.Column(scale=4):
704
+ img_output = gr.Image(
705
+ label="Generated Image",
706
+ type="pil",
707
+ format="png"
708
+ )
709
+ img_gallery = gr.Gallery(
710
+ label="Image Gallery",
711
+ show_label=True,
712
+ elem_id="gallery",
713
+ columns=[4],
714
+ rows=[2],
715
+ height="auto",
716
+ object_fit="cover"
717
+ )
718
+
719
+ # Xora ๋น„๋””์˜ค ์ƒ์„ฑ ํƒญ
720
+ with gr.Tab("Xora Video Generation"):
721
+ with gr.Row():
722
+ with gr.Column(scale=3):
723
+ xora_prompt = gr.Textbox(
724
+ label="Video Description",
725
+ placeholder="๋น„๋””์˜ค ์„ค๋ช…์„ ์ž…๋ ฅํ•˜์„ธ์š”... (์ตœ์†Œ 50์ž)",
726
+ lines=5
727
+ )
728
+
729
+ xora_enhance_toggle = Toggle(
730
+ label="Enhance Prompt",
731
+ value=False
732
+ )
733
+
734
+ xora_negative_prompt = gr.Textbox(
735
+ label="Negative Prompt",
736
+ value="low quality, worst quality, deformed, distorted",
737
+ lines=2
738
+ )
739
+
740
+ xora_preset = gr.Dropdown(
741
+ choices=[p["label"] for p in PRESET_OPTIONS],
742
+ value="512x512, 160 frames",
743
+ label="Resolution Preset"
744
+ )
745
+
746
+ xora_frame_rate = gr.Slider(
747
+ label="Frame Rate",
748
+ minimum=6,
749
+ maximum=60,
750
+ step=1,
751
+ value=20
752
+ )
753
+
754
+ with gr.Accordion("Advanced Settings", open=False):
755
+ xora_seed = gr.Slider(
756
+ label="Seed",
757
+ minimum=0,
758
+ maximum=MAX_SEED,
759
+ step=1,
760
+ value=random.randint(0, MAX_SEED)
761
+ )
762
+ xora_steps = gr.Slider(
763
+ label="Inference Steps",
764
+ minimum=5,
765
+ maximum=150,
766
+ step=5,
767
+ value=40
768
+ )
769
+ xora_guidance = gr.Slider(
770
+ label="Guidance Scale",
771
+ minimum=1.0,
772
+ maximum=10.0,
773
+ step=0.1,
774
+ value=4.2
775
+ )
776
+
777
+ xora_generate_btn = gr.Button(
778
+ "Generate Video",
779
+ variant="primary",
780
+ elem_classes=["generate-btn"]
781
+ )
782
+
783
+ with gr.Column(scale=4):
784
+ xora_output = gr.Video(label="Generated Video")
785
+ xora_gallery = gr.Gallery(
786
+ label="Video Gallery",
787
+ show_label=True,
788
+ columns=[4],
789
+ rows=[2],
790
+ height="auto",
791
+ object_fit="cover"
792
+ )
793
+
794
+ # Replicate ๋น„๋””์˜ค ์ƒ์„ฑ ํƒญ
795
+ with gr.Tab("Image to Video"):
796
+ with gr.Row():
797
+ with gr.Column(scale=3):
798
+ upload_image = gr.Image(
799
+ type="filepath",
800
+ label="Upload First Frame Image"
801
+ )
802
+ replicate_prompt = gr.Textbox(
803
+ label="Video Description",
804
+ placeholder="๋น„๋””์˜ค ์„ค๋ช…์„ ์ž…๋ ฅํ•˜์„ธ์š”...",
805
+ lines=3
806
+ )
807
+ replicate_generate_btn = gr.Button(
808
+ "Generate Video",
809
+ variant="primary",
810
+ elem_classes=["generate-btn"]
811
+ )
812
+
813
+ with gr.Column(scale=4):
814
+ replicate_output = gr.Video(label="Generated Video")
815
+ replicate_gallery = gr.Gallery(
816
+ label="Video Gallery",
817
+ show_label=True,
818
+ columns=[4],
819
+ rows=[2],
820
+ height="auto",
821
+ object_fit="cover"
822
+ )
823
+
824
+ # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ
825
+ img_generate_btn.click(
826
+ fn=generate_image,
827
+ inputs=[
828
+ img_prompt,
829
+ img_height,
830
+ img_width,
831
+ steps,
832
+ scales,
833
+ seed,
834
+ img_enhance_toggle
835
+ ],
836
+ outputs=img_output
837
  )
838
 
839
+ xora_generate_btn.click(
840
+ fn=generate_video_xora,
841
+ inputs=[
842
+ xora_prompt,
843
+ xora_enhance_toggle,
844
+ xora_negative_prompt,
845
+ xora_frame_rate,
846
+ xora_seed,
847
+ xora_steps,
848
+ xora_guidance,
849
+ img_height,
850
+ img_width,
851
+ gr.Slider(label="Number of Frames", value=60)
852
+ ],
853
+ outputs=xora_output
854
  )
855
 
856
+ replicate_generate_btn.click(
857
+ fn=generate_video_replicate,
858
+ inputs=[upload_image, replicate_prompt],
859
+ outputs=replicate_output
 
860
  )
861
 
862
+ # ๊ฐค๋Ÿฌ๋ฆฌ ์ž๋™ ์—…๋ฐ์ดํŠธ
863
+ demo.load(lambda: None, None, [img_gallery, xora_gallery, replicate_gallery], every=30)
864
+
865
+ return demo
866
+
867
+ if __name__ == "__main__":
868
+ # ์ดˆ๊ธฐํ™”
869
+ init_directories()
870
+ setup_cuda()
871
+
872
+ # UI ์‹คํ–‰
873
+ demo = create_ui()
874
+ demo.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(
875
+ share=True,
876
+ show_api=False,
877
+ server_name="0.0.0.0",
878
+ server_port=7860,
879
+ debug=False
 
 
 
880
  )
881