nftblackmagic commited on
Commit
7b183da
·
unverified ·
1 Parent(s): ddc1268

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -46
app.py CHANGED
@@ -8,55 +8,22 @@ from PIL import Image
8
  import tempfile
9
  import torch
10
  from diffusers import FluxTransformer2DModel, FluxFillPipeline
 
11
 
12
- import shutil
 
 
13
 
14
- def find_cuda():
15
- # Check if CUDA_HOME or CUDA_PATH environment variables are set
16
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
17
-
18
- if cuda_home and os.path.exists(cuda_home):
19
- return cuda_home
20
-
21
- # Search for the nvcc executable in the system's PATH
22
- nvcc_path = shutil.which('nvcc')
23
-
24
- if nvcc_path:
25
- # Remove the 'bin/nvcc' part to get the CUDA installation path
26
- cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
27
- return cuda_path
28
-
29
- return None
30
-
31
- cuda_path = find_cuda()
32
-
33
- if cuda_path:
34
- print(f"CUDA installation found at: {cuda_path}")
35
- else:
36
- print("CUDA installation not found")
37
-
38
- device = torch.device('cuda')
39
-
40
- print("Start loading LoRA weights")
41
- state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
42
- pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", ## The tryon Lora weights
43
- weight_name="pytorch_lora_weights.safetensors",
44
- return_alphas=True
45
- )
46
- is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
47
- if not is_correct_format:
48
- raise ValueError("Invalid LoRA checkpoint.")
49
  print('Loading diffusion model ...')
 
 
 
 
50
  pipe = FluxFillPipeline.from_pretrained(
51
- "black-forest-labs/FLUX.1-Fill-dev",
52
- torch_dtype=torch.bfloat16
 
53
  ).to(device)
54
- FluxFillPipeline.load_lora_into_transformer(
55
- state_dict=state_dict,
56
- network_alphas=network_alphas,
57
- transformer=pipe.transformer,
58
- )
59
-
60
  print('Loading Finished!')
61
 
62
  @spaces.GPU
@@ -109,7 +76,7 @@ def gradio_inference(
109
 
110
  with gr.Blocks() as demo:
111
  gr.Markdown("""
112
- # CATVTON FLUX Virtual Try-On Demo (by using LoRA weights)
113
  Upload a model image, draw a mask, and a garment image to generate virtual try-on results.
114
 
115
  [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha)
@@ -222,4 +189,4 @@ with gr.Blocks() as demo:
222
  )
223
 
224
 
225
- demo.launch()
 
8
  import tempfile
9
  import torch
10
  from diffusers import FluxTransformer2DModel, FluxFillPipeline
11
+ import subprocess
12
 
13
+ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
14
+ dtype = torch.bfloat16
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  print('Loading diffusion model ...')
18
+ transformer = FluxTransformer2DModel.from_pretrained(
19
+ "xiaozaa/catvton-flux-alpha",
20
+ torch_dtype=device
21
+ )
22
  pipe = FluxFillPipeline.from_pretrained(
23
+ "black-forest-labs/FLUX.1-dev",
24
+ transformer=transformer,
25
+ torch_dtype=device
26
  ).to(device)
 
 
 
 
 
 
27
  print('Loading Finished!')
28
 
29
  @spaces.GPU
 
76
 
77
  with gr.Blocks() as demo:
78
  gr.Markdown("""
79
+ # CATVTON FLUX Virtual Try-On Demo
80
  Upload a model image, draw a mask, and a garment image to generate virtual try-on results.
81
 
82
  [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha)
 
189
  )
190
 
191
 
192
+ demo.launch()