Borcherding commited on
Commit
d071a59
·
verified ·
1 Parent(s): 16eb805

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -10,41 +10,38 @@ from image_gen_aux import DepthPreprocessor
10
  MAX_SEED = np.iinfo(np.int32).max
11
  MAX_IMAGE_SIZE = 2048
12
 
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
-
15
- # Initialize the pipeline and move it to GPU
16
  pipe = FluxControlPipeline.from_pretrained(
17
  "black-forest-labs/FLUX.1-Depth-dev",
18
  torch_dtype=torch.bfloat16
19
- ).to(device)
20
-
21
  processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
22
 
 
23
  def load_lora(lora_path):
24
  if not lora_path.strip():
25
  return "Please provide a valid LoRA path"
26
  try:
 
 
 
27
  # Unload any existing LoRA weights first
28
  try:
29
  pipe.unload_lora_weights()
30
  except:
31
  pass
32
 
33
- # Load new LoRA weights and move to the same device
34
  pipe.load_lora_weights(lora_path)
35
-
36
- # Ensure all model components are on the correct device
37
- pipe.to(device)
38
-
39
  return f"Successfully loaded LoRA weights from {lora_path}"
40
  except Exception as e:
41
  return f"Error loading LoRA weights: {str(e)}"
42
 
 
43
  def unload_lora():
44
  try:
 
45
  pipe.unload_lora_weights()
46
- # Ensure model is on correct device after unloading
47
- pipe.to(device)
48
  return "Successfully unloaded LoRA weights"
49
  except Exception as e:
50
  return f"Error unloading LoRA weights: {str(e)}"
@@ -57,6 +54,9 @@ def infer(control_image, prompt, seed=42, randomize_seed=False, width=1024, heig
57
  seed = random.randint(0, MAX_SEED)
58
 
59
  try:
 
 
 
60
  # Process control image
61
  control_image = processor(control_image)[0].convert("RGB")
62
 
@@ -68,7 +68,7 @@ def infer(control_image, prompt, seed=42, randomize_seed=False, width=1024, heig
68
  width=width,
69
  num_inference_steps=num_inference_steps,
70
  guidance_scale=guidance_scale,
71
- generator=torch.Generator(device=device).manual_seed(seed),
72
  ).images[0]
73
 
74
  return image, seed
@@ -113,6 +113,7 @@ with gr.Blocks(css=css) as demo:
113
  run_button = gr.Button("Run", scale=0)
114
 
115
  result = gr.Image(label="Result", show_label=False)
 
116
 
117
  with gr.Accordion("Advanced Settings", open=False):
118
  seed = gr.Slider(
 
10
  MAX_SEED = np.iinfo(np.int32).max
11
  MAX_IMAGE_SIZE = 2048
12
 
13
+ # Initialize models without moving to CUDA yet
 
 
14
  pipe = FluxControlPipeline.from_pretrained(
15
  "black-forest-labs/FLUX.1-Depth-dev",
16
  torch_dtype=torch.bfloat16
17
+ )
 
18
  processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
19
 
20
+ @spaces.GPU
21
  def load_lora(lora_path):
22
  if not lora_path.strip():
23
  return "Please provide a valid LoRA path"
24
  try:
25
+ # Move to GPU within the wrapped function
26
+ pipe.to("cuda")
27
+
28
  # Unload any existing LoRA weights first
29
  try:
30
  pipe.unload_lora_weights()
31
  except:
32
  pass
33
 
34
+ # Load new LoRA weights
35
  pipe.load_lora_weights(lora_path)
 
 
 
 
36
  return f"Successfully loaded LoRA weights from {lora_path}"
37
  except Exception as e:
38
  return f"Error loading LoRA weights: {str(e)}"
39
 
40
+ @spaces.GPU
41
  def unload_lora():
42
  try:
43
+ pipe.to("cuda")
44
  pipe.unload_lora_weights()
 
 
45
  return "Successfully unloaded LoRA weights"
46
  except Exception as e:
47
  return f"Error unloading LoRA weights: {str(e)}"
 
54
  seed = random.randint(0, MAX_SEED)
55
 
56
  try:
57
+ # Move pipeline to GPU within the wrapped function
58
+ pipe.to("cuda")
59
+
60
  # Process control image
61
  control_image = processor(control_image)[0].convert("RGB")
62
 
 
68
  width=width,
69
  num_inference_steps=num_inference_steps,
70
  guidance_scale=guidance_scale,
71
+ generator=torch.Generator("cuda").manual_seed(seed),
72
  ).images[0]
73
 
74
  return image, seed
 
113
  run_button = gr.Button("Run", scale=0)
114
 
115
  result = gr.Image(label="Result", show_label=False)
116
+ error_message = gr.Textbox(label="Error", visible=False)
117
 
118
  with gr.Accordion("Advanced Settings", open=False):
119
  seed = gr.Slider(