jamesliu1217 commited on
Commit
cfb0d74
·
verified ·
1 Parent(s): e403b1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -42
app.py CHANGED
@@ -12,42 +12,20 @@ from src.pipeline import FluxPipeline
12
  from src.transformer_flux import FluxTransformer2DModel
13
  from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14
 
15
-
16
- class ImageProcessor:
17
- def __init__(self, path):
18
- device = "cuda"
19
- self.pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16, device=device)
20
- transformer = FluxTransformer2DModel.from_pretrained(path, subfolder="transformer", torch_dtype=torch.bfloat16, device=device)
21
- self.pipe.transformer = transformer
22
- self.pipe.to(device)
23
-
24
- def clear_cache(self, transformer):
25
- for name, attn_processor in transformer.attn_processors.items():
26
- attn_processor.bank_kv.clear()
27
-
28
- def process_image(self, prompt='', subject_imgs=[], spatial_imgs=[], height=768, width=768, output_path=None, seed=42):
29
- image = self.pipe(
30
- prompt,
31
- height=int(height),
32
- width=int(width),
33
- guidance_scale=3.5,
34
- num_inference_steps=25,
35
- max_sequence_length=512,
36
- generator=torch.Generator("cpu").manual_seed(seed),
37
- subject_images=subject_imgs,
38
- spatial_images=spatial_imgs,
39
- cond_size=512,
40
- ).images[0]
41
- self.clear_cache(self.pipe.transformer)
42
- if output_path:
43
- image.save(output_path)
44
- return image
45
-
46
  # Initialize the image processor
47
  base_path = "black-forest-labs/FLUX.1-dev"
48
  lora_base_path = "./models"
49
  style_lora_base_path = "Shakker-Labs"
50
- processor = ImageProcessor(base_path)
 
 
 
 
 
 
 
 
 
51
 
52
  # Define the Gradio interface
53
  @spaces.GPU()
@@ -59,29 +37,41 @@ def single_condition_generate_image(prompt, subject_img, spatial_img, height, wi
59
  lora_path = os.path.join(lora_base_path, "pose.safetensors")
60
  elif control_type == "inpainting":
61
  lora_path = os.path.join(lora_base_path, "inpainting.safetensors")
62
- set_single_lora(processor.pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
63
 
64
  # Set the style LoRA
65
  if style_lora=="None":
66
  pass
67
  else:
68
  if style_lora == "Simple_Sketch":
69
- processor.pipe.unload_lora_weights()
70
  style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Children-Simple-Sketch")
71
- processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
72
  if style_lora == "Text_Poster":
73
- processor.pipe.unload_lora_weights()
74
  style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Text-Poster")
75
- processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Text-Poster.safetensors")
76
  if style_lora == "Vector_Style":
77
- processor.pipe.unload_lora_weights()
78
  style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Vector-Journey")
79
- processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Vector-Journey.safetensors")
80
 
81
  # Process the image
82
  subject_imgs = [subject_img] if subject_img else []
83
  spatial_imgs = [spatial_img] if spatial_img else []
84
- image = processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=seed)
 
 
 
 
 
 
 
 
 
 
 
 
85
  return image
86
 
87
  # Define the Gradio interface
@@ -89,12 +79,24 @@ def single_condition_generate_image(prompt, subject_img, spatial_img, height, wi
89
  def multi_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed):
90
  subject_path = os.path.join(lora_base_path, "subject.safetensors")
91
  inpainting_path = os.path.join(lora_base_path, "inpainting.safetensors")
92
- set_multi_lora(processor.pipe.transformer, [subject_path, inpainting_path], lora_weights=[[1],[1]],cond_size=512)
93
 
94
  # Process the image
95
  subject_imgs = [subject_img] if subject_img else []
96
  spatial_imgs = [spatial_img] if spatial_img else []
97
- image = processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=seed)
 
 
 
 
 
 
 
 
 
 
 
 
98
  return image
99
 
100
  # Define the Gradio interface components
 
12
  from src.transformer_flux import FluxTransformer2DModel
13
  from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Initialize the image processor
16
  base_path = "black-forest-labs/FLUX.1-dev"
17
  lora_base_path = "./models"
18
  style_lora_base_path = "Shakker-Labs"
19
+
20
+
21
+ pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
22
+ transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
23
+ pipe.transformer = transformer
24
+ pipe.to(device)
25
+
26
+ def clear_cache(transformer):
27
+ for name, attn_processor in transformer.attn_processors.items():
28
+ attn_processor.bank_kv.clear()
29
 
30
  # Define the Gradio interface
31
  @spaces.GPU()
 
37
  lora_path = os.path.join(lora_base_path, "pose.safetensors")
38
  elif control_type == "inpainting":
39
  lora_path = os.path.join(lora_base_path, "inpainting.safetensors")
40
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
41
 
42
  # Set the style LoRA
43
  if style_lora=="None":
44
  pass
45
  else:
46
  if style_lora == "Simple_Sketch":
47
+ pipe.unload_lora_weights()
48
  style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Children-Simple-Sketch")
49
+ pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
50
  if style_lora == "Text_Poster":
51
+ pipe.unload_lora_weights()
52
  style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Text-Poster")
53
+ pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Text-Poster.safetensors")
54
  if style_lora == "Vector_Style":
55
+ pipe.unload_lora_weights()
56
  style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Vector-Journey")
57
+ pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Vector-Journey.safetensors")
58
 
59
  # Process the image
60
  subject_imgs = [subject_img] if subject_img else []
61
  spatial_imgs = [spatial_img] if spatial_img else []
62
+ image = pipe(
63
+ prompt,
64
+ height=int(height),
65
+ width=int(width),
66
+ guidance_scale=3.5,
67
+ num_inference_steps=25,
68
+ max_sequence_length=512,
69
+ generator=torch.Generator("cpu").manual_seed(seed),
70
+ subject_images=subject_imgs,
71
+ spatial_images=spatial_imgs,
72
+ cond_size=512,
73
+ ).images[0]
74
+ clear_cache(pipe.transformer)
75
  return image
76
 
77
  # Define the Gradio interface
 
79
  def multi_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed):
80
  subject_path = os.path.join(lora_base_path, "subject.safetensors")
81
  inpainting_path = os.path.join(lora_base_path, "inpainting.safetensors")
82
+ set_multi_lora(pipe.transformer, [subject_path, inpainting_path], lora_weights=[[1],[1]],cond_size=512)
83
 
84
  # Process the image
85
  subject_imgs = [subject_img] if subject_img else []
86
  spatial_imgs = [spatial_img] if spatial_img else []
87
+ image = pipe(
88
+ prompt,
89
+ height=int(height),
90
+ width=int(width),
91
+ guidance_scale=3.5,
92
+ num_inference_steps=25,
93
+ max_sequence_length=512,
94
+ generator=torch.Generator("cpu").manual_seed(seed),
95
+ subject_images=subject_imgs,
96
+ spatial_images=spatial_imgs,
97
+ cond_size=512,
98
+ ).images[0]
99
+ clear_cache(pipe.transformer)
100
  return image
101
 
102
  # Define the Gradio interface components