jiuface commited on
Commit
b8d7aba
·
1 Parent(s): 6ae9ba8

update control

Browse files
Files changed (1) hide show
  1. app.py +41 -46
app.py CHANGED
@@ -30,7 +30,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
30
  login(token=HF_TOKEN)
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
33
- IMAGE_SIZE = 512
34
 
35
  # init
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -44,16 +44,6 @@ pipe = FluxControlNetInpaintPipeline.from_pretrained(base_model, controlnet=cont
44
 
45
  # pipe.enable_model_cpu_offload() # for saving memory
46
 
47
- control_mode_ids = {
48
- "canny": 0, # supported
49
- "tile": 1, # supported
50
- "depth": 2, # supported
51
- "blur": 3, # supported
52
- "pose": 4, # supported
53
- "gray": 5, # supported
54
- "lq": 6, # supported
55
- }
56
-
57
  def clear_cuda_cache():
58
  torch.cuda.empty_cache()
59
 
@@ -154,6 +144,8 @@ def run_flux(
154
  randomize_seed_checkbox: bool,
155
  strength_slider: float,
156
  num_inference_steps_slider: int,
 
 
157
  resolution_wh: Tuple[int, int],
158
  progress
159
  ) -> Image.Image:
@@ -173,11 +165,11 @@ def run_flux(
173
  image=image,
174
  mask_image=mask,
175
  control_image=control_image,
176
- control_mode=control_mode,
177
- controlnet_conditioning_scale=0.55,
 
178
  width=width,
179
  height=height,
180
- strength=strength_slider,
181
  generator=generator,
182
  num_inference_steps=num_inference_steps_slider,
183
  ).images[0]
@@ -212,33 +204,16 @@ def load_loras(lora_strings_json:str):
212
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
213
 
214
 
215
- def generate_control_image(image, mask, control_mode, width, height):
216
  # generated control_
217
  with calculateDuration("Generate control image"):
218
  preprocessor = Preprocessor()
219
- if control_mode == "depth":
220
- preprocessor.load("Midas")
221
- control_image = preprocessor(
222
- image=image,
223
- image_resolution=width,
224
- detect_resolution=512,
225
- )
226
- if control_mode == "pose":
227
- preprocessor.load("Openpose")
228
- control_image = preprocessor(
229
- image=image,
230
- hand_and_face=False,
231
- image_resolution=width,
232
- detect_resolution=512,
233
- )
234
- if control_mode == "canny":
235
- preprocessor.load("Canny")
236
- control_image = preprocessor(
237
- image=image,
238
- image_resolution=width,
239
- detect_resolution=512,
240
- )
241
-
242
  control_image = control_image.resize((width, height), Image.LANCZOS)
243
  return control_image
244
 
@@ -248,10 +223,11 @@ def process(
248
  inpainting_prompt_text: str,
249
  mask_inflation_slider: int,
250
  mask_blur_slider: int,
251
- control_mode: str,
252
  seed_slicer: int,
253
  randomize_seed_checkbox: bool,
254
  strength_slider: float,
 
 
255
  num_inference_steps_slider: int,
256
  lora_strings_json: str,
257
  upload_to_r2: bool,
@@ -290,8 +266,7 @@ def process(
290
  mask = mask.resize((width, height), Image.LANCZOS)
291
  mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
292
 
293
- control_image = generate_control_image(image, mask, control_mode, width, height)
294
- control_mode_id = control_mode_ids[control_mode]
295
  clear_cuda_cache()
296
 
297
  load_loras(lora_strings_json=lora_strings_json)
@@ -301,12 +276,13 @@ def process(
301
  image=image,
302
  mask=mask,
303
  control_image=control_image,
304
- control_mode=control_mode_id,
305
  prompt=inpainting_prompt_text,
306
  seed_slicer=seed_slicer,
307
  randomize_seed_checkbox=randomize_seed_checkbox,
308
  strength_slider=strength_slider,
309
  num_inference_steps_slider=num_inference_steps_slider,
 
 
310
  resolution_wh=(width, height),
311
  progress=progress
312
  )
@@ -366,9 +342,7 @@ with gr.Blocks() as demo:
366
  container=False,
367
  )
368
 
369
- control_mode = gr.Dropdown(
370
- [ "canny"], label="Controlnet Model", info="choose controlnet model!", value="canny"
371
- )
372
  lora_strings_json = gr.Text(label="LoRA Configs (JSON List String)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1", "adapter_weight": 1}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2", "adapter_weight": 1}]', lines=5)
373
 
374
  submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
@@ -408,6 +382,26 @@ with gr.Blocks() as demo:
408
  randomize_seed_checkbox_component = gr.Checkbox(
409
  label="Randomize seed", value=True)
410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  with gr.Row():
412
 
413
  strength_slider_component = gr.Slider(
@@ -453,10 +447,11 @@ with gr.Blocks() as demo:
453
  inpainting_prompt_text_component,
454
  mask_inflation_slider_component,
455
  mask_blur_slider_component,
456
- control_mode,
457
  seed_slicer_component,
458
  randomize_seed_checkbox_component,
459
  strength_slider_component,
 
 
460
  num_inference_steps_slider_component,
461
  lora_strings_json,
462
  upload_to_r2,
 
30
  login(token=HF_TOKEN)
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
33
+ IMAGE_SIZE = 1024
34
 
35
  # init
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
44
 
45
  # pipe.enable_model_cpu_offload() # for saving memory
46
 
 
 
 
 
 
 
 
 
 
 
47
  def clear_cuda_cache():
48
  torch.cuda.empty_cache()
49
 
 
144
  randomize_seed_checkbox: bool,
145
  strength_slider: float,
146
  num_inference_steps_slider: int,
147
+ controlnet_conditioning_scale: float,
148
+ guidance_scale: float,
149
  resolution_wh: Tuple[int, int],
150
  progress
151
  ) -> Image.Image:
 
165
  image=image,
166
  mask_image=mask,
167
  control_image=control_image,
168
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
169
+ strength=strength_slider,
170
+ guidance_scale=guidance_scale,
171
  width=width,
172
  height=height,
 
173
  generator=generator,
174
  num_inference_steps=num_inference_steps_slider,
175
  ).images[0]
 
204
  pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
205
 
206
 
207
+ def generate_control_image(image, mask, width, height):
208
  # generated control_
209
  with calculateDuration("Generate control image"):
210
  preprocessor = Preprocessor()
211
+ preprocessor.load("Canny")
212
+ control_image = preprocessor(
213
+ image=image,
214
+ image_resolution=width,
215
+ detect_resolution=512,
216
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  control_image = control_image.resize((width, height), Image.LANCZOS)
218
  return control_image
219
 
 
223
  inpainting_prompt_text: str,
224
  mask_inflation_slider: int,
225
  mask_blur_slider: int,
 
226
  seed_slicer: int,
227
  randomize_seed_checkbox: bool,
228
  strength_slider: float,
229
+ guidance_scale: float,
230
+ controlnet_conditioning_scale: float,
231
  num_inference_steps_slider: int,
232
  lora_strings_json: str,
233
  upload_to_r2: bool,
 
266
  mask = mask.resize((width, height), Image.LANCZOS)
267
  mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
268
 
269
+ control_image = generate_control_image(image, mask, width, height)
 
270
  clear_cuda_cache()
271
 
272
  load_loras(lora_strings_json=lora_strings_json)
 
276
  image=image,
277
  mask=mask,
278
  control_image=control_image,
 
279
  prompt=inpainting_prompt_text,
280
  seed_slicer=seed_slicer,
281
  randomize_seed_checkbox=randomize_seed_checkbox,
282
  strength_slider=strength_slider,
283
  num_inference_steps_slider=num_inference_steps_slider,
284
+ guidance_scale=guidance_scale,
285
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
286
  resolution_wh=(width, height),
287
  progress=progress
288
  )
 
342
  container=False,
343
  )
344
 
345
+
 
 
346
  lora_strings_json = gr.Text(label="LoRA Configs (JSON List String)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1", "adapter_weight": 1}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2", "adapter_weight": 1}]', lines=5)
347
 
348
  submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
 
382
  randomize_seed_checkbox_component = gr.Checkbox(
383
  label="Randomize seed", value=True)
384
 
385
+ with gr.Row():
386
+
387
+ guidance_scale = gr.Slider(
388
+ label="guidance_scale",
389
+ info="guidance_scale`.",
390
+ minimum=0,
391
+ maximum=1,
392
+ step=0.1,
393
+ value=3.5,
394
+ )
395
+
396
+ controlnet_conditioning_scale = gr.Slider(
397
+ label="controlnet_conditioning_scale",
398
+ info="controlnet_conditioning_scale",
399
+ minimum=1,
400
+ maximum=10,
401
+ step=0.1,
402
+ value=4,
403
+ )
404
+
405
  with gr.Row():
406
 
407
  strength_slider_component = gr.Slider(
 
447
  inpainting_prompt_text_component,
448
  mask_inflation_slider_component,
449
  mask_blur_slider_component,
 
450
  seed_slicer_component,
451
  randomize_seed_checkbox_component,
452
  strength_slider_component,
453
+ guidance_scale,
454
+ controlnet_conditioning_scale,
455
  num_inference_steps_slider_component,
456
  lora_strings_json,
457
  upload_to_r2,