Deadmon commited on
Commit
4fd60a2
·
verified ·
1 Parent(s): 6f042e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
  import numpy as np
6
  from PIL import Image
7
  from einops import rearrange
 
8
  import requests
9
  import spaces
10
  from huggingface_hub import login
@@ -12,20 +13,17 @@ from gradio_imageslider import ImageSlider
12
  from diffusers.utils import load_image
13
  from diffusers import FluxControlNetPipeline, FluxControlNetModel
14
 
15
- # Source: https://github.com/XLabs-AI/x-flux.git
16
- name = "flux-dev"
17
-
18
- # Load the model on CPU
19
  device_cpu = torch.device("cpu")
20
- device_gpu = torch.device("cuda")
21
 
 
22
  base_model = 'black-forest-labs/FLUX.1-dev'
23
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
24
 
25
  # Load the ControlNet model and pipeline on CPU
26
- controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16, device_map="cpu")
27
- pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
28
- pipe.to(device_cpu) # Keep on CPU initially
29
 
30
  controlnet_conditioning_scale = 0.5
31
 
@@ -40,6 +38,7 @@ control_modes = {
40
  }
41
 
42
  def load_and_convert_image(image):
 
43
  if isinstance(image, str):
44
  image = Image.open(image)
45
  elif isinstance(image, bytes):
@@ -50,6 +49,7 @@ def load_and_convert_image(image):
50
  return image
51
 
52
  def preprocess_image(image, target_width, target_height, crop=True):
 
53
  image = load_and_convert_image(image)
54
  if crop:
55
  original_width, original_height = image.size
@@ -71,12 +71,14 @@ def preprocess_image(image, target_width, target_height, crop=True):
71
  return image
72
 
73
  def clear_cuda_memory():
 
74
  gc.collect()
75
  torch.cuda.empty_cache()
76
  torch.cuda.ipc_collect()
77
 
78
  @spaces.GPU(duration=120)
79
  def generate_image(prompt, control_image, control_mode, num_steps=50, guidance=4, width=512, height=512, seed=42, random_seed=False):
 
80
  clear_cuda_memory()
81
 
82
  if random_seed:
 
5
  import numpy as np
6
  from PIL import Image
7
  from einops import rearrange
8
+ import io
9
  import requests
10
  import spaces
11
  from huggingface_hub import login
 
13
  from diffusers.utils import load_image
14
  from diffusers import FluxControlNetPipeline, FluxControlNetModel
15
 
16
+ # Device settings: CPU for loading, GPU for inference
 
 
 
17
  device_cpu = torch.device("cpu")
18
+ device_gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
+ # Model identifiers
21
  base_model = 'black-forest-labs/FLUX.1-dev'
22
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
23
 
24
  # Load the ControlNet model and pipeline on CPU
25
+ controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16).to(device_cpu)
26
+ pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16).to(device_cpu)
 
27
 
28
  controlnet_conditioning_scale = 0.5
29
 
 
38
  }
39
 
40
  def load_and_convert_image(image):
41
+ """Load and convert images to a format that PIL can handle."""
42
  if isinstance(image, str):
43
  image = Image.open(image)
44
  elif isinstance(image, bytes):
 
49
  return image
50
 
51
  def preprocess_image(image, target_width, target_height, crop=True):
52
+ """Preprocess image to match the target dimensions."""
53
  image = load_and_convert_image(image)
54
  if crop:
55
  original_width, original_height = image.size
 
71
  return image
72
 
73
  def clear_cuda_memory():
74
+ """Clear CUDA memory."""
75
  gc.collect()
76
  torch.cuda.empty_cache()
77
  torch.cuda.ipc_collect()
78
 
79
  @spaces.GPU(duration=120)
80
  def generate_image(prompt, control_image, control_mode, num_steps=50, guidance=4, width=512, height=512, seed=42, random_seed=False):
81
+ """Generate image using the FLUX.1 ControlNet model."""
82
  clear_cuda_memory()
83
 
84
  if random_seed: