aiqcamp commited on
Commit
fc668b9
Β·
verified Β·
1 Parent(s): 9aac7f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -45,13 +45,14 @@ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
  # ν…μŠ€νŠΈ 인코더λ₯Ό float16으둜 κ°•μ œ λ³€ν™˜
47
  pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
48
- # μΆ”κ°€: text_projection의 forwardλ₯Ό μ˜€λ²„λΌμ΄λ”©ν•˜μ—¬ μž…λ ₯이 float16이 μ•„λ‹ˆλ©΄ half둜 μΊμŠ€νŒ…
49
- original_text_projection_forward = pipe.text_encoder.text_projection.forward
50
- def fixed_text_projection_forward(x):
51
- if x.dtype != torch.float16:
52
- x = x.half()
53
- return original_text_projection_forward(x)
54
- pipe.text_encoder.text_projection.forward = fixed_text_projection_forward
 
55
 
56
  def can_expand(source_width, source_height, target_width, target_height, alignment):
57
  """Checks if the image can be expanded based on the alignment."""
 
45
 
46
  # ν…μŠ€νŠΈ 인코더λ₯Ό float16으둜 κ°•μ œ λ³€ν™˜
47
  pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
48
+ # λ§Œμ•½ text_projection 속성이 μžˆλ‹€λ©΄, μž…λ ₯이 float16이 μ•„λ‹ˆλ©΄ half둜 μΊμŠ€νŒ…ν•˜λ„λ‘ μ˜€λ²„λΌμ΄λ”©
49
+ if hasattr(pipe.text_encoder, "text_projection"):
50
+ original_text_projection_forward = pipe.text_encoder.text_projection.forward
51
+ def fixed_text_projection_forward(x):
52
+ if x.dtype != torch.float16:
53
+ x = x.half()
54
+ return original_text_projection_forward(x)
55
+ pipe.text_encoder.text_projection.forward = fixed_text_projection_forward
56
 
57
  def can_expand(source_width, source_height, target_width, target_height, alignment):
58
  """Checks if the image can be expanded based on the alignment."""