Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import spaces
|
2 |
import argparse
|
3 |
import os
|
@@ -6,14 +26,16 @@ from os import path
|
|
6 |
import shutil
|
7 |
from datetime import datetime
|
8 |
from safetensors.torch import load_file
|
9 |
-
from huggingface_hub import hf_hub_download
|
10 |
import gradio as gr
|
11 |
-
from gradio_toggle import Toggle
|
12 |
import torch
|
13 |
-
|
|
|
|
|
|
|
14 |
from diffusers.pipelines.stable_diffusion import safety_checker
|
15 |
from PIL import Image
|
16 |
-
from transformers import pipeline
|
17 |
import replicate
|
18 |
import logging
|
19 |
import requests
|
@@ -22,18 +44,6 @@ import cv2
|
|
22 |
import numpy as np
|
23 |
import sys
|
24 |
import io
|
25 |
-
import json
|
26 |
-
import gc
|
27 |
-
import csv
|
28 |
-
from openai import OpenAI
|
29 |
-
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
30 |
-
from xora.models.transformers.transformer3d import Transformer3DModel
|
31 |
-
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
32 |
-
from xora.schedulers.rf import RectifiedFlowScheduler
|
33 |
-
from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
|
34 |
-
from xora.utils.conditioning_method import ConditioningMethod
|
35 |
-
from functools import lru_cache
|
36 |
-
from diffusers.pipelines.flux import FluxPipeline
|
37 |
|
38 |
# ๋ก๊น
์ค์
|
39 |
logging.basicConfig(
|
@@ -87,25 +97,39 @@ if not path.exists(cache_path):
|
|
87 |
os.makedirs(cache_path, exist_ok=True)
|
88 |
|
89 |
try:
|
|
|
|
|
90 |
pipe = FluxPipeline.from_pretrained(
|
91 |
-
|
92 |
torch_dtype=torch.bfloat16,
|
93 |
-
cache_dir=cache_path
|
|
|
94 |
)
|
|
|
|
|
95 |
lora_path = hf_hub_download(
|
96 |
"ByteDance/Hyper-SD",
|
97 |
"Hyper-FLUX.1-dev-8steps-lora.safetensors",
|
98 |
cache_dir=cache_path
|
99 |
)
|
100 |
-
|
101 |
-
pipe
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
except Exception as e:
|
108 |
-
logger.error(f"Error initializing
|
109 |
raise
|
110 |
|
111 |
# ๋ชจ๋ธ ๊ด๋ฆฌ ํด๋์ค
|
@@ -611,6 +635,58 @@ def generate_video_replicate(image, prompt):
|
|
611 |
raise gr.Error(f"๋น๋์ค ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}")
|
612 |
|
613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
# Gradio UI ์คํ์ผ
|
615 |
css = """
|
616 |
.gradio-container {
|
|
|
1 |
+
import sys
|
2 |
+
import subprocess
|
3 |
+
|
4 |
+
def install_required_packages():
|
5 |
+
packages = [
|
6 |
+
"git+https://github.com/black-forest-labs/diffusers",
|
7 |
+
"transformers>=4.25.1",
|
8 |
+
"safetensors>=0.3.1",
|
9 |
+
"accelerate>=0.16.0"
|
10 |
+
]
|
11 |
+
for package in packages:
|
12 |
+
try:
|
13 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
14 |
+
except subprocess.CalledProcessError as e:
|
15 |
+
print(f"Error installing {package}: {e}")
|
16 |
+
raise
|
17 |
+
|
18 |
+
# ํ์ํ ํจํค์ง ์ค์น
|
19 |
+
install_required_packages()
|
20 |
+
|
21 |
import spaces
|
22 |
import argparse
|
23 |
import os
|
|
|
26 |
import shutil
|
27 |
from datetime import datetime
|
28 |
from safetensors.torch import load_file
|
29 |
+
from huggingface_hub import hf_hub_download
|
30 |
import gradio as gr
|
|
|
31 |
import torch
|
32 |
+
try:
|
33 |
+
from diffusers.pipelines.flux import FluxPipeline
|
34 |
+
except ImportError:
|
35 |
+
from diffusers import StableDiffusionPipeline as FluxPipeline
|
36 |
from diffusers.pipelines.stable_diffusion import safety_checker
|
37 |
from PIL import Image
|
38 |
+
from transformers import pipeline
|
39 |
import replicate
|
40 |
import logging
|
41 |
import requests
|
|
|
44 |
import numpy as np
|
45 |
import sys
|
46 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# ๋ก๊น
์ค์
|
49 |
logging.basicConfig(
|
|
|
97 |
os.makedirs(cache_path, exist_ok=True)
|
98 |
|
99 |
try:
|
100 |
+
# FluxPipeline ์ด๊ธฐํ ์๋
|
101 |
+
model_id = "black-forest-labs/FLUX.1-dev"
|
102 |
pipe = FluxPipeline.from_pretrained(
|
103 |
+
model_id,
|
104 |
torch_dtype=torch.bfloat16,
|
105 |
+
cache_dir=cache_path,
|
106 |
+
local_files_only=False
|
107 |
)
|
108 |
+
|
109 |
+
# LoRA ๊ฐ์ค์น ๋ค์ด๋ก๋ ๋ฐ ์ ์ฉ
|
110 |
lora_path = hf_hub_download(
|
111 |
"ByteDance/Hyper-SD",
|
112 |
"Hyper-FLUX.1-dev-8steps-lora.safetensors",
|
113 |
cache_dir=cache_path
|
114 |
)
|
115 |
+
|
116 |
+
if hasattr(pipe, 'load_lora_weights'):
|
117 |
+
pipe.load_lora_weights(lora_path)
|
118 |
+
pipe.fuse_lora(lora_scale=0.125)
|
119 |
+
|
120 |
+
# ๋๋ฐ์ด์ค ์ค์
|
121 |
+
pipe = pipe.to("cuda")
|
122 |
+
|
123 |
+
# ์์ ์ฑ ๊ฒ์ฌ๊ธฐ ์ค์
|
124 |
+
if hasattr(pipe, 'safety_checker'):
|
125 |
+
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
|
126 |
+
"CompVis/stable-diffusion-safety-checker",
|
127 |
+
cache_dir=cache_path
|
128 |
+
)
|
129 |
+
|
130 |
+
logger.info("Model initialized successfully")
|
131 |
except Exception as e:
|
132 |
+
logger.error(f"Error initializing model: {str(e)}")
|
133 |
raise
|
134 |
|
135 |
# ๋ชจ๋ธ ๊ด๋ฆฌ ํด๋์ค
|
|
|
635 |
raise gr.Error(f"๋น๋์ค ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}")
|
636 |
|
637 |
|
638 |
+
@spaces.GPU
|
639 |
+
def process_and_save_image(height, width, steps, scales, prompt, seed):
|
640 |
+
is_safe, translated_prompt = process_prompt(prompt)
|
641 |
+
if not is_safe:
|
642 |
+
gr.Warning("๋ถ์ ์ ํ ๋ด์ฉ์ด ํฌํจ๋ ํ๋กฌํํธ์
๋๋ค.")
|
643 |
+
return None, load_gallery()
|
644 |
+
|
645 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
|
646 |
+
try:
|
647 |
+
# ๋ชจ๋ธ ํธ์ถ ๋ฐฉ์ ์์
|
648 |
+
if hasattr(pipe, '__call__'):
|
649 |
+
output = pipe(
|
650 |
+
prompt=[translated_prompt],
|
651 |
+
generator=torch.Generator().manual_seed(int(seed)),
|
652 |
+
num_inference_steps=int(steps),
|
653 |
+
guidance_scale=float(scales),
|
654 |
+
height=int(height),
|
655 |
+
width=int(width),
|
656 |
+
max_sequence_length=256
|
657 |
+
)
|
658 |
+
generated_image = output.images[0]
|
659 |
+
else:
|
660 |
+
generated_image = pipe.text2img(
|
661 |
+
prompt=translated_prompt,
|
662 |
+
generator=torch.Generator().manual_seed(int(seed)),
|
663 |
+
num_inference_steps=int(steps),
|
664 |
+
guidance_scale=float(scales),
|
665 |
+
height=int(height),
|
666 |
+
width=int(width)
|
667 |
+
)[0]
|
668 |
+
|
669 |
+
# ์ด๋ฏธ์ง ์ฒ๋ฆฌ ๋ฐ ์ ์ฅ
|
670 |
+
if not isinstance(generated_image, Image.Image):
|
671 |
+
generated_image = Image.fromarray(generated_image)
|
672 |
+
|
673 |
+
if generated_image.mode != 'RGB':
|
674 |
+
generated_image = generated_image.convert('RGB')
|
675 |
+
|
676 |
+
img_byte_arr = io.BytesIO()
|
677 |
+
generated_image.save(img_byte_arr, format='PNG')
|
678 |
+
img_byte_arr = img_byte_arr.getvalue()
|
679 |
+
|
680 |
+
saved_path = save_image(generated_image)
|
681 |
+
if saved_path is None:
|
682 |
+
logger.warning("Failed to save generated image")
|
683 |
+
return None, load_gallery()
|
684 |
+
|
685 |
+
return Image.open(io.BytesIO(img_byte_arr)), load_gallery()
|
686 |
+
except Exception as e:
|
687 |
+
logger.error(f"Error in image generation: {str(e)}")
|
688 |
+
return None, load_gallery()
|
689 |
+
|
690 |
# Gradio UI ์คํ์ผ
|
691 |
css = """
|
692 |
.gradio-container {
|