Update app.py
Browse files
app.py
CHANGED
|
@@ -45,13 +45,14 @@ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
|
| 45 |
|
| 46 |
# ํ
์คํธ ์ธ์ฝ๋๋ฅผ float16์ผ๋ก ๊ฐ์ ๋ณํ
|
| 47 |
pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
x
|
| 53 |
-
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
def can_expand(source_width, source_height, target_width, target_height, alignment):
|
| 57 |
"""Checks if the image can be expanded based on the alignment."""
|
|
|
|
| 45 |
|
| 46 |
# ํ
์คํธ ์ธ์ฝ๋๋ฅผ float16์ผ๋ก ๊ฐ์ ๋ณํ
|
| 47 |
pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
|
| 48 |
+
# ๋ง์ฝ text_projection ์์ฑ์ด ์๋ค๋ฉด, ์
๋ ฅ์ด float16์ด ์๋๋ฉด half๋ก ์บ์คํ
ํ๋๋ก ์ค๋ฒ๋ผ์ด๋ฉ
|
| 49 |
+
if hasattr(pipe.text_encoder, "text_projection"):
|
| 50 |
+
original_text_projection_forward = pipe.text_encoder.text_projection.forward
|
| 51 |
+
def fixed_text_projection_forward(x):
|
| 52 |
+
if x.dtype != torch.float16:
|
| 53 |
+
x = x.half()
|
| 54 |
+
return original_text_projection_forward(x)
|
| 55 |
+
pipe.text_encoder.text_projection.forward = fixed_text_projection_forward
|
| 56 |
|
| 57 |
def can_expand(source_width, source_height, target_width, target_height, alignment):
|
| 58 |
"""Checks if the image can be expanded based on the alignment."""
|