fantaxy commited on
Commit
b459565
ยท
verified ยท
1 Parent(s): fafc380

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -26
app.py CHANGED
@@ -1,3 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
  import argparse
3
  import os
@@ -6,14 +26,16 @@ from os import path
6
  import shutil
7
  from datetime import datetime
8
  from safetensors.torch import load_file
9
- from huggingface_hub import hf_hub_download, snapshot_download
10
  import gradio as gr
11
- from gradio_toggle import Toggle
12
  import torch
13
- from diffusers import FluxPipeline
 
 
 
14
  from diffusers.pipelines.stable_diffusion import safety_checker
15
  from PIL import Image
16
- from transformers import pipeline, CLIPProcessor, CLIPModel, T5EncoderModel, T5Tokenizer
17
  import replicate
18
  import logging
19
  import requests
@@ -22,18 +44,6 @@ import cv2
22
  import numpy as np
23
  import sys
24
  import io
25
- import json
26
- import gc
27
- import csv
28
- from openai import OpenAI
29
- from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
30
- from xora.models.transformers.transformer3d import Transformer3DModel
31
- from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
32
- from xora.schedulers.rf import RectifiedFlowScheduler
33
- from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
34
- from xora.utils.conditioning_method import ConditioningMethod
35
- from functools import lru_cache
36
- from diffusers.pipelines.flux import FluxPipeline
37
 
38
  # ๋กœ๊น… ์„ค์ •
39
  logging.basicConfig(
@@ -87,25 +97,39 @@ if not path.exists(cache_path):
87
  os.makedirs(cache_path, exist_ok=True)
88
 
89
  try:
 
 
90
  pipe = FluxPipeline.from_pretrained(
91
- "black-forest-labs/FLUX.1-dev",
92
  torch_dtype=torch.bfloat16,
93
- cache_dir=cache_path
 
94
  )
 
 
95
  lora_path = hf_hub_download(
96
  "ByteDance/Hyper-SD",
97
  "Hyper-FLUX.1-dev-8steps-lora.safetensors",
98
  cache_dir=cache_path
99
  )
100
- pipe.load_lora_weights(lora_path)
101
- pipe.fuse_lora(lora_scale=0.125)
102
- pipe.to(device="cuda", dtype=torch.bfloat16)
103
- pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
104
- "CompVis/stable-diffusion-safety-checker",
105
- cache_dir=cache_path
106
- )
 
 
 
 
 
 
 
 
 
107
  except Exception as e:
108
- logger.error(f"Error initializing FluxPipeline: {str(e)}")
109
  raise
110
 
111
  # ๋ชจ๋ธ ๊ด€๋ฆฌ ํด๋ž˜์Šค
@@ -611,6 +635,58 @@ def generate_video_replicate(image, prompt):
611
  raise gr.Error(f"๋น„๋””์˜ค ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
612
 
613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  # Gradio UI ์Šคํƒ€์ผ
615
  css = """
616
  .gradio-container {
 
1
+ import sys
2
+ import subprocess
3
+
4
+ def install_required_packages():
5
+ packages = [
6
+ "git+https://github.com/black-forest-labs/diffusers",
7
+ "transformers>=4.25.1",
8
+ "safetensors>=0.3.1",
9
+ "accelerate>=0.16.0"
10
+ ]
11
+ for package in packages:
12
+ try:
13
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
14
+ except subprocess.CalledProcessError as e:
15
+ print(f"Error installing {package}: {e}")
16
+ raise
17
+
18
+ # ํ•„์š”ํ•œ ํŒจํ‚ค์ง€ ์„ค์น˜
19
+ install_required_packages()
20
+
21
  import spaces
22
  import argparse
23
  import os
 
26
  import shutil
27
  from datetime import datetime
28
  from safetensors.torch import load_file
29
+ from huggingface_hub import hf_hub_download
30
  import gradio as gr
 
31
  import torch
32
+ try:
33
+ from diffusers.pipelines.flux import FluxPipeline
34
+ except ImportError:
35
+ from diffusers import StableDiffusionPipeline as FluxPipeline
36
  from diffusers.pipelines.stable_diffusion import safety_checker
37
  from PIL import Image
38
+ from transformers import pipeline
39
  import replicate
40
  import logging
41
  import requests
 
44
  import numpy as np
45
  import sys
46
  import io
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # ๋กœ๊น… ์„ค์ •
49
  logging.basicConfig(
 
97
  os.makedirs(cache_path, exist_ok=True)
98
 
99
  try:
100
+ # FluxPipeline ์ดˆ๊ธฐํ™” ์‹œ๋„
101
+ model_id = "black-forest-labs/FLUX.1-dev"
102
  pipe = FluxPipeline.from_pretrained(
103
+ model_id,
104
  torch_dtype=torch.bfloat16,
105
+ cache_dir=cache_path,
106
+ local_files_only=False
107
  )
108
+
109
+ # LoRA ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ ๋ฐ ์ ์šฉ
110
  lora_path = hf_hub_download(
111
  "ByteDance/Hyper-SD",
112
  "Hyper-FLUX.1-dev-8steps-lora.safetensors",
113
  cache_dir=cache_path
114
  )
115
+
116
+ if hasattr(pipe, 'load_lora_weights'):
117
+ pipe.load_lora_weights(lora_path)
118
+ pipe.fuse_lora(lora_scale=0.125)
119
+
120
+ # ๋””๋ฐ”์ด์Šค ์„ค์ •
121
+ pipe = pipe.to("cuda")
122
+
123
+ # ์•ˆ์ „์„ฑ ๊ฒ€์‚ฌ๊ธฐ ์„ค์ •
124
+ if hasattr(pipe, 'safety_checker'):
125
+ pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
126
+ "CompVis/stable-diffusion-safety-checker",
127
+ cache_dir=cache_path
128
+ )
129
+
130
+ logger.info("Model initialized successfully")
131
  except Exception as e:
132
+ logger.error(f"Error initializing model: {str(e)}")
133
  raise
134
 
135
  # ๋ชจ๋ธ ๊ด€๋ฆฌ ํด๋ž˜์Šค
 
635
  raise gr.Error(f"๋น„๋””์˜ค ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}")
636
 
637
 
638
+ @spaces.GPU
639
+ def process_and_save_image(height, width, steps, scales, prompt, seed):
640
+ is_safe, translated_prompt = process_prompt(prompt)
641
+ if not is_safe:
642
+ gr.Warning("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
643
+ return None, load_gallery()
644
+
645
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
646
+ try:
647
+ # ๋ชจ๋ธ ํ˜ธ์ถœ ๋ฐฉ์‹ ์ˆ˜์ •
648
+ if hasattr(pipe, '__call__'):
649
+ output = pipe(
650
+ prompt=[translated_prompt],
651
+ generator=torch.Generator().manual_seed(int(seed)),
652
+ num_inference_steps=int(steps),
653
+ guidance_scale=float(scales),
654
+ height=int(height),
655
+ width=int(width),
656
+ max_sequence_length=256
657
+ )
658
+ generated_image = output.images[0]
659
+ else:
660
+ generated_image = pipe.text2img(
661
+ prompt=translated_prompt,
662
+ generator=torch.Generator().manual_seed(int(seed)),
663
+ num_inference_steps=int(steps),
664
+ guidance_scale=float(scales),
665
+ height=int(height),
666
+ width=int(width)
667
+ )[0]
668
+
669
+ # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ๋ฐ ์ €์žฅ
670
+ if not isinstance(generated_image, Image.Image):
671
+ generated_image = Image.fromarray(generated_image)
672
+
673
+ if generated_image.mode != 'RGB':
674
+ generated_image = generated_image.convert('RGB')
675
+
676
+ img_byte_arr = io.BytesIO()
677
+ generated_image.save(img_byte_arr, format='PNG')
678
+ img_byte_arr = img_byte_arr.getvalue()
679
+
680
+ saved_path = save_image(generated_image)
681
+ if saved_path is None:
682
+ logger.warning("Failed to save generated image")
683
+ return None, load_gallery()
684
+
685
+ return Image.open(io.BytesIO(img_byte_arr)), load_gallery()
686
+ except Exception as e:
687
+ logger.error(f"Error in image generation: {str(e)}")
688
+ return None, load_gallery()
689
+
690
  # Gradio UI ์Šคํƒ€์ผ
691
  css = """
692
  .gradio-container {