Update pipeline.py
Browse files- pipeline.py +19 -29
pipeline.py
CHANGED
@@ -660,14 +660,14 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
660 |
|
661 |
def __init__(
|
662 |
self,
|
663 |
-
image_size
|
664 |
-
in_channels
|
665 |
-
model_channels
|
666 |
-
out_channels
|
667 |
-
num_res_blocks
|
668 |
-
attention_resolutions
|
669 |
-
channel_mult=[1, 2, 4, 8],
|
670 |
dropout=0,
|
|
|
671 |
conv_resample=True,
|
672 |
dims=2,
|
673 |
num_classes=None,
|
@@ -688,15 +688,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
688 |
):
|
689 |
super().__init__()
|
690 |
assert context_dim is not None
|
691 |
-
|
692 |
-
# Add resolution validation
|
693 |
-
assert image_size in [256, 512], "Only 256/512 resolutions supported"
|
694 |
-
super().__init__()
|
695 |
-
|
696 |
-
# Modify attention resolutions for 512
|
697 |
-
if image_size == 512:
|
698 |
-
attention_resolutions = [16, 8, 4]
|
699 |
-
|
700 |
if num_heads_upsample == -1:
|
701 |
num_heads_upsample = num_heads
|
702 |
|
@@ -738,7 +730,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
738 |
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
739 |
f"attention will still not be set."
|
740 |
)
|
741 |
-
|
742 |
self.attention_resolutions = attention_resolutions
|
743 |
self.dropout = dropout
|
744 |
self.channel_mult = channel_mult
|
@@ -1426,18 +1418,20 @@ class MVDreamPipeline(DiffusionPipeline):
|
|
1426 |
return torch.zeros_like(image_embeds), image_embeds
|
1427 |
|
1428 |
def encode_image_latents(self, image, device, num_images_per_prompt):
|
|
|
1429 |
dtype = next(self.image_encoder.parameters()).dtype
|
1430 |
-
|
1431 |
-
|
1432 |
-
|
|
|
1433 |
image = 2 * image - 1
|
1434 |
-
image = F.interpolate(image, (
|
1435 |
image = image.to(dtype=dtype)
|
1436 |
|
1437 |
posterior = self.vae.encode(image).latent_dist
|
1438 |
-
latents = posterior.sample() * self.vae.config.scaling_factor
|
1439 |
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
|
1440 |
-
|
1441 |
return torch.zeros_like(latents), latents
|
1442 |
|
1443 |
@torch.no_grad()
|
@@ -1445,8 +1439,8 @@ class MVDreamPipeline(DiffusionPipeline):
|
|
1445 |
self,
|
1446 |
prompt: str = "",
|
1447 |
image: Optional[np.ndarray] = None,
|
1448 |
-
height: int =
|
1449 |
-
width: int =
|
1450 |
elevation: float = 0,
|
1451 |
num_inference_steps: int = 50,
|
1452 |
guidance_scale: float = 7.0,
|
@@ -1460,10 +1454,6 @@ class MVDreamPipeline(DiffusionPipeline):
|
|
1460 |
num_frames: int = 4,
|
1461 |
device=torch.device("cuda:0"),
|
1462 |
):
|
1463 |
-
# Add resolution validation
|
1464 |
-
if height != 512 or width != 512:
|
1465 |
-
raise ValueError("Current implementation requires 512x512 resolution")
|
1466 |
-
|
1467 |
self.unet = self.unet.to(device=device)
|
1468 |
self.vae = self.vae.to(device=device)
|
1469 |
self.text_encoder = self.text_encoder.to(device=device)
|
|
|
660 |
|
661 |
def __init__(
|
662 |
self,
|
663 |
+
image_size,
|
664 |
+
in_channels,
|
665 |
+
model_channels,
|
666 |
+
out_channels,
|
667 |
+
num_res_blocks,
|
668 |
+
attention_resolutions,
|
|
|
669 |
dropout=0,
|
670 |
+
channel_mult=(1, 2, 4, 8),
|
671 |
conv_resample=True,
|
672 |
dims=2,
|
673 |
num_classes=None,
|
|
|
688 |
):
|
689 |
super().__init__()
|
690 |
assert context_dim is not None
|
691 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
692 |
if num_heads_upsample == -1:
|
693 |
num_heads_upsample = num_heads
|
694 |
|
|
|
730 |
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
731 |
f"attention will still not be set."
|
732 |
)
|
733 |
+
|
734 |
self.attention_resolutions = attention_resolutions
|
735 |
self.dropout = dropout
|
736 |
self.channel_mult = channel_mult
|
|
|
1418 |
return torch.zeros_like(image_embeds), image_embeds
|
1419 |
|
1420 |
def encode_image_latents(self, image, device, num_images_per_prompt):
|
1421 |
+
|
1422 |
dtype = next(self.image_encoder.parameters()).dtype
|
1423 |
+
|
1424 |
+
image = (
|
1425 |
+
torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
|
1426 |
+
) # [1, 3, H, W]
|
1427 |
image = 2 * image - 1
|
1428 |
+
image = F.interpolate(image, (256, 256), mode="bilinear", align_corners=False)
|
1429 |
image = image.to(dtype=dtype)
|
1430 |
|
1431 |
posterior = self.vae.encode(image).latent_dist
|
1432 |
+
latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
|
1433 |
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
|
1434 |
+
|
1435 |
return torch.zeros_like(latents), latents
|
1436 |
|
1437 |
@torch.no_grad()
|
|
|
1439 |
self,
|
1440 |
prompt: str = "",
|
1441 |
image: Optional[np.ndarray] = None,
|
1442 |
+
height: int = 256,
|
1443 |
+
width: int = 256,
|
1444 |
elevation: float = 0,
|
1445 |
num_inference_steps: int = 50,
|
1446 |
guidance_scale: float = 7.0,
|
|
|
1454 |
num_frames: int = 4,
|
1455 |
device=torch.device("cuda:0"),
|
1456 |
):
|
|
|
|
|
|
|
|
|
1457 |
self.unet = self.unet.to(device=device)
|
1458 |
self.vae = self.vae.to(device=device)
|
1459 |
self.text_encoder = self.text_encoder.to(device=device)
|