Deadmon commited on
Commit
6f042e6
·
verified ·
1 Parent(s): a5dfd22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import torch
 
3
  import gradio as gr
4
  import numpy as np
5
  from PIL import Image
@@ -13,17 +14,18 @@ from diffusers import FluxControlNetPipeline, FluxControlNetModel
13
 
14
  # Source: https://github.com/XLabs-AI/x-flux.git
15
  name = "flux-dev"
16
- device = torch.device("cuda")
17
- offload = False
18
- is_schnell = name == "flux-schnell"
 
19
 
20
  base_model = 'black-forest-labs/FLUX.1-dev'
21
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
22
 
23
- # Load the new ControlNet model and pipeline
24
- controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
25
  pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
26
- pipe.to(device)
27
 
28
  controlnet_conditioning_scale = 0.5
29
 
@@ -68,15 +70,23 @@ def preprocess_image(image, target_width, target_height, crop=True):
68
 
69
  return image
70
 
 
 
 
 
 
71
  @spaces.GPU(duration=120)
72
  def generate_image(prompt, control_image, control_mode, num_steps=50, guidance=4, width=512, height=512, seed=42, random_seed=False):
 
 
73
  if random_seed:
74
  seed = np.random.randint(0, 10000)
75
 
76
  if not os.path.isdir("./controlnet_results/"):
77
  os.makedirs("./controlnet_results/")
78
 
79
- torch_device = torch.device("cuda")
 
80
 
81
  control_image = preprocess_image(control_image, width, height)
82
 
@@ -93,6 +103,9 @@ def generate_image(prompt, control_image, control_mode, num_steps=50, guidance=4
93
  guidance_scale=guidance,
94
  ).images[0]
95
 
 
 
 
96
  return [control_image, image] # Return both images for slider
97
 
98
  interface = gr.Interface(
 
1
  import os
2
  import torch
3
+ import gc
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
 
14
 
15
  # Source: https://github.com/XLabs-AI/x-flux.git
16
  name = "flux-dev"
17
+
18
+ # Load the model on CPU
19
+ device_cpu = torch.device("cpu")
20
+ device_gpu = torch.device("cuda")
21
 
22
  base_model = 'black-forest-labs/FLUX.1-dev'
23
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
24
 
25
+ # Load the ControlNet model and pipeline on CPU
26
+ controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16, device_map="cpu")
27
  pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
28
+ pipe.to(device_cpu) # Keep on CPU initially
29
 
30
  controlnet_conditioning_scale = 0.5
31
 
 
70
 
71
  return image
72
 
73
+ def clear_cuda_memory():
74
+ gc.collect()
75
+ torch.cuda.empty_cache()
76
+ torch.cuda.ipc_collect()
77
+
78
  @spaces.GPU(duration=120)
79
  def generate_image(prompt, control_image, control_mode, num_steps=50, guidance=4, width=512, height=512, seed=42, random_seed=False):
80
+ clear_cuda_memory()
81
+
82
  if random_seed:
83
  seed = np.random.randint(0, 10000)
84
 
85
  if not os.path.isdir("./controlnet_results/"):
86
  os.makedirs("./controlnet_results/")
87
 
88
+ # Move model to GPU for inference
89
+ pipe.to(device_gpu)
90
 
91
  control_image = preprocess_image(control_image, width, height)
92
 
 
103
  guidance_scale=guidance,
104
  ).images[0]
105
 
106
+ # Move model back to CPU after inference
107
+ pipe.to(device_cpu)
108
+
109
  return [control_image, image] # Return both images for slider
110
 
111
  interface = gr.Interface(