Update pipeline.py
Browse files- pipeline.py +29 -19
pipeline.py
CHANGED
|
@@ -660,14 +660,14 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 660 |
|
| 661 |
def __init__(
|
| 662 |
self,
|
| 663 |
-
image_size,
|
| 664 |
-
in_channels,
|
| 665 |
-
model_channels,
|
| 666 |
-
out_channels,
|
| 667 |
-
num_res_blocks,
|
| 668 |
-
attention_resolutions,
|
|
|
|
| 669 |
dropout=0,
|
| 670 |
-
channel_mult=(1, 2, 4, 8),
|
| 671 |
conv_resample=True,
|
| 672 |
dims=2,
|
| 673 |
num_classes=None,
|
|
@@ -688,7 +688,15 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 688 |
):
|
| 689 |
super().__init__()
|
| 690 |
assert context_dim is not None
|
| 691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
if num_heads_upsample == -1:
|
| 693 |
num_heads_upsample = num_heads
|
| 694 |
|
|
@@ -730,7 +738,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 730 |
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 731 |
f"attention will still not be set."
|
| 732 |
)
|
| 733 |
-
|
| 734 |
self.attention_resolutions = attention_resolutions
|
| 735 |
self.dropout = dropout
|
| 736 |
self.channel_mult = channel_mult
|
|
@@ -1418,20 +1426,18 @@ class MVDreamPipeline(DiffusionPipeline):
|
|
| 1418 |
return torch.zeros_like(image_embeds), image_embeds
|
| 1419 |
|
| 1420 |
def encode_image_latents(self, image, device, num_images_per_prompt):
|
| 1421 |
-
|
| 1422 |
dtype = next(self.image_encoder.parameters()).dtype
|
| 1423 |
-
|
| 1424 |
-
|
| 1425 |
-
|
| 1426 |
-
) # [1, 3, H, W]
|
| 1427 |
image = 2 * image - 1
|
| 1428 |
-
image = F.interpolate(image, (
|
| 1429 |
image = image.to(dtype=dtype)
|
| 1430 |
|
| 1431 |
posterior = self.vae.encode(image).latent_dist
|
| 1432 |
-
latents = posterior.sample() * self.vae.config.scaling_factor
|
| 1433 |
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
|
| 1434 |
-
|
| 1435 |
return torch.zeros_like(latents), latents
|
| 1436 |
|
| 1437 |
@torch.no_grad()
|
|
@@ -1439,8 +1445,8 @@ class MVDreamPipeline(DiffusionPipeline):
|
|
| 1439 |
self,
|
| 1440 |
prompt: str = "",
|
| 1441 |
image: Optional[np.ndarray] = None,
|
| 1442 |
-
height: int =
|
| 1443 |
-
width: int =
|
| 1444 |
elevation: float = 0,
|
| 1445 |
num_inference_steps: int = 50,
|
| 1446 |
guidance_scale: float = 7.0,
|
|
@@ -1454,6 +1460,10 @@ class MVDreamPipeline(DiffusionPipeline):
|
|
| 1454 |
num_frames: int = 4,
|
| 1455 |
device=torch.device("cuda:0"),
|
| 1456 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1457 |
self.unet = self.unet.to(device=device)
|
| 1458 |
self.vae = self.vae.to(device=device)
|
| 1459 |
self.text_encoder = self.text_encoder.to(device=device)
|
|
|
|
| 660 |
|
| 661 |
def __init__(
|
| 662 |
self,
|
| 663 |
+
image_size=512, # Force 512 resolution
|
| 664 |
+
in_channels=4,
|
| 665 |
+
model_channels=320,
|
| 666 |
+
out_channels=4,
|
| 667 |
+
num_res_blocks=[2, 2, 2, 2],
|
| 668 |
+
attention_resolutions=[8, 4, 2], # Adjusted for 512x512
|
| 669 |
+
channel_mult=[1, 2, 4, 8],
|
| 670 |
dropout=0,
|
|
|
|
| 671 |
conv_resample=True,
|
| 672 |
dims=2,
|
| 673 |
num_classes=None,
|
|
|
|
| 688 |
):
|
| 689 |
super().__init__()
|
| 690 |
assert context_dim is not None
|
| 691 |
+
|
| 692 |
+
# Add resolution validation
|
| 693 |
+
assert image_size in [256, 512], "Only 256/512 resolutions supported"
|
| 694 |
+
super().__init__()
|
| 695 |
+
|
| 696 |
+
# Modify attention resolutions for 512
|
| 697 |
+
if image_size == 512:
|
| 698 |
+
attention_resolutions = [16, 8, 4]
|
| 699 |
+
|
| 700 |
if num_heads_upsample == -1:
|
| 701 |
num_heads_upsample = num_heads
|
| 702 |
|
|
|
|
| 738 |
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 739 |
f"attention will still not be set."
|
| 740 |
)
|
| 741 |
+
|
| 742 |
self.attention_resolutions = attention_resolutions
|
| 743 |
self.dropout = dropout
|
| 744 |
self.channel_mult = channel_mult
|
|
|
|
| 1426 |
return torch.zeros_like(image_embeds), image_embeds
|
| 1427 |
|
| 1428 |
def encode_image_latents(self, image, device, num_images_per_prompt):
|
|
|
|
| 1429 |
dtype = next(self.image_encoder.parameters()).dtype
|
| 1430 |
+
|
| 1431 |
+
# Change interpolation size to match target resolution
|
| 1432 |
+
image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
|
|
|
|
| 1433 |
image = 2 * image - 1
|
| 1434 |
+
image = F.interpolate(image, (512, 512), mode='bilinear', align_corners=False) # Changed from 256
|
| 1435 |
image = image.to(dtype=dtype)
|
| 1436 |
|
| 1437 |
posterior = self.vae.encode(image).latent_dist
|
| 1438 |
+
latents = posterior.sample() * self.vae.config.scaling_factor
|
| 1439 |
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
|
| 1440 |
+
|
| 1441 |
return torch.zeros_like(latents), latents
|
| 1442 |
|
| 1443 |
@torch.no_grad()
|
|
|
|
| 1445 |
self,
|
| 1446 |
prompt: str = "",
|
| 1447 |
image: Optional[np.ndarray] = None,
|
| 1448 |
+
height: int = 512,
|
| 1449 |
+
width: int = 512,
|
| 1450 |
elevation: float = 0,
|
| 1451 |
num_inference_steps: int = 50,
|
| 1452 |
guidance_scale: float = 7.0,
|
|
|
|
| 1460 |
num_frames: int = 4,
|
| 1461 |
device=torch.device("cuda:0"),
|
| 1462 |
):
|
| 1463 |
+
# Add resolution validation
|
| 1464 |
+
if height != 512 or width != 512:
|
| 1465 |
+
raise ValueError("Current implementation requires 512x512 resolution")
|
| 1466 |
+
|
| 1467 |
self.unet = self.unet.to(device=device)
|
| 1468 |
self.vae = self.vae.to(device=device)
|
| 1469 |
self.text_encoder = self.text_encoder.to(device=device)
|