Spaces:
Runtime error
Runtime error
pixart special
Browse files
main.py
CHANGED
|
@@ -202,8 +202,12 @@ def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pip
|
|
| 202 |
if args.task == "single":
|
| 203 |
# Attempt to move the model to GPU if model is not Flux
|
| 204 |
if args.model != "flux":
|
| 205 |
-
if
|
| 206 |
-
pipe.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
else:
|
| 208 |
print(f"PIPE:{pipe}")
|
| 209 |
|
|
|
|
| 202 |
if args.task == "single":
|
| 203 |
# Attempt to move the model to GPU if model is not Flux
|
| 204 |
if args.model != "flux":
|
| 205 |
+
if args.model != "pixart":
|
| 206 |
+
if pipe.device != torch.device('cuda'):
|
| 207 |
+
pipe.to(device, dtype)
|
| 208 |
+
else:
|
| 209 |
+
if pipe.device != torch.device('cuda'):
|
| 210 |
+
pipe.to(device)
|
| 211 |
else:
|
| 212 |
print(f"PIPE:{pipe}")
|
| 213 |
|