Spaces:
Runtime error
Runtime error
Commit
·
b812142
1
Parent(s):
36de41f
force the device
Browse files
app.py
CHANGED
|
@@ -119,6 +119,9 @@ class ImageGenerator:
|
|
| 119 |
max_length=max_length,
|
| 120 |
dtype=dtype,
|
| 121 |
)
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
def prepare(self, prompt, img, ref_image, ref_image_raw):
|
| 124 |
bs, _, h, w = img.shape
|
|
@@ -314,6 +317,7 @@ class ImageGenerator:
|
|
| 314 |
|
| 315 |
ref_images_raw = self.load_image(ref_images_raw)
|
| 316 |
ref_images_raw = ref_images_raw.to(self.device)
|
|
|
|
| 317 |
ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
|
| 318 |
|
| 319 |
seed = int(seed)
|
|
@@ -398,7 +402,7 @@ def prepare_infer_func():
|
|
| 398 |
|
| 399 |
return image_edit.generate_image
|
| 400 |
|
| 401 |
-
@spaces.GPU
|
| 402 |
def inference(prompt, ref_images, seed, size_level, infer_func=None):
|
| 403 |
start_time = time.time()
|
| 404 |
|
|
|
|
| 119 |
max_length=max_length,
|
| 120 |
dtype=dtype,
|
| 121 |
)
|
| 122 |
+
self.ae = self.ae.to(device=self.device, dtype=torch.float32)
|
| 123 |
+
self.dit = self.dit.to(device=self.device, dtype=dtype)
|
| 124 |
+
self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
|
| 125 |
|
| 126 |
def prepare(self, prompt, img, ref_image, ref_image_raw):
|
| 127 |
bs, _, h, w = img.shape
|
|
|
|
| 317 |
|
| 318 |
ref_images_raw = self.load_image(ref_images_raw)
|
| 319 |
ref_images_raw = ref_images_raw.to(self.device)
|
| 320 |
+
print(f'self.ae, self.dit device: {self.ae.device}, {self.dit.device}')
|
| 321 |
ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
|
| 322 |
|
| 323 |
seed = int(seed)
|
|
|
|
| 402 |
|
| 403 |
return image_edit.generate_image
|
| 404 |
|
| 405 |
+
@spaces.GPU(duration=240)
|
| 406 |
def inference(prompt, ref_images, seed, size_level, infer_func=None):
|
| 407 |
start_time = time.time()
|
| 408 |
|