aiqcamp commited on
Commit
a955f9f
Β·
verified Β·
1 Parent(s): fc668b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -29
app.py CHANGED
@@ -43,16 +43,31 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
- # ν…μŠ€νŠΈ 인코더λ₯Ό float16으둜 κ°•μ œ λ³€ν™˜
47
- pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
48
- # λ§Œμ•½ text_projection 속성이 μžˆλ‹€λ©΄, μž…λ ₯이 float16이 μ•„λ‹ˆλ©΄ half둜 μΊμŠ€νŒ…ν•˜λ„λ‘ μ˜€λ²„λΌμ΄λ”©
49
- if hasattr(pipe.text_encoder, "text_projection"):
50
- original_text_projection_forward = pipe.text_encoder.text_projection.forward
51
- def fixed_text_projection_forward(x):
52
- if x.dtype != torch.float16:
53
- x = x.half()
54
- return original_text_projection_forward(x)
55
- pipe.text_encoder.text_projection.forward = fixed_text_projection_forward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def can_expand(source_width, source_height, target_width, target_height, alignment):
58
  """Checks if the image can be expanded based on the alignment."""
@@ -153,24 +168,25 @@ def infer(image, width, height, overlap_width, num_inference_steps, resize_optio
153
  cnet_image = background.copy()
154
  cnet_image.paste(0, (0, 0), mask)
155
 
156
- final_prompt = f"{prompt_input} , high quality, 4k"
157
-
158
- (
159
- prompt_embeds,
160
- negative_prompt_embeds,
161
- pooled_prompt_embeds,
162
- negative_pooled_prompt_embeds,
163
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
164
-
165
- for image in pipe(
166
- prompt_embeds=prompt_embeds,
167
- negative_prompt_embeds=negative_prompt_embeds,
168
- pooled_prompt_embeds=pooled_prompt_embeds,
169
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
170
- image=cnet_image,
171
- num_inference_steps=num_inference_steps
172
- ):
173
- yield cnet_image, image
 
174
 
175
  image = image.convert("RGBA")
176
  cnet_image.paste(image, (0, 0), mask)
@@ -371,4 +387,4 @@ with gr.Blocks(css=css) as demo:
371
  outputs=use_as_input_button,
372
  )
373
 
374
- demo.queue(max_size=12).launch(share=False)
 
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
+ # Make sure all text encoder components use the same dtype
47
+ pipe.text_encoder = pipe.text_encoder.to(dtype=torch.float16)
48
+ pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=torch.float16)
49
+
50
+ # Patch the text encoder forward methods to ensure consistent dtype
51
+ def patch_text_encoder_forward(encoder):
52
+ original_forward = encoder.forward
53
+
54
+ def patched_forward(*args, **kwargs):
55
+ # Convert input tensors to float16
56
+ if len(args) > 0 and isinstance(args[0], torch.Tensor):
57
+ args = list(args)
58
+ args[0] = args[0].to(dtype=torch.float16)
59
+
60
+ for key in kwargs:
61
+ if isinstance(kwargs[key], torch.Tensor):
62
+ kwargs[key] = kwargs[key].to(dtype=torch.float16)
63
+
64
+ return original_forward(*args, **kwargs)
65
+
66
+ encoder.forward = patched_forward
67
+
68
+ # Apply the patch to both encoders
69
+ patch_text_encoder_forward(pipe.text_encoder)
70
+ patch_text_encoder_forward(pipe.text_encoder_2)
71
 
72
  def can_expand(source_width, source_height, target_width, target_height, alignment):
73
  """Checks if the image can be expanded based on the alignment."""
 
168
  cnet_image = background.copy()
169
  cnet_image.paste(0, (0, 0), mask)
170
 
171
+ final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
172
+
173
+ with torch.cuda.amp.autocast(dtype=torch.float16):
174
+ (
175
+ prompt_embeds,
176
+ negative_prompt_embeds,
177
+ pooled_prompt_embeds,
178
+ negative_pooled_prompt_embeds,
179
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
180
+
181
+ for image in pipe(
182
+ prompt_embeds=prompt_embeds,
183
+ negative_prompt_embeds=negative_prompt_embeds,
184
+ pooled_prompt_embeds=pooled_prompt_embeds,
185
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
186
+ image=cnet_image,
187
+ num_inference_steps=num_inference_steps
188
+ ):
189
+ yield cnet_image, image
190
 
191
  image = image.convert("RGBA")
192
  cnet_image.paste(image, (0, 0), mask)
 
387
  outputs=use_as_input_button,
388
  )
389
 
390
+ demo.queue(max_size=12).launch(share=False)