Abdualkader commited on
Commit
55b3544
·
verified ·
1 Parent(s): 02c3350

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +29 -19
pipeline.py CHANGED
@@ -660,14 +660,14 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
660
 
661
  def __init__(
662
  self,
663
- image_size,
664
- in_channels,
665
- model_channels,
666
- out_channels,
667
- num_res_blocks,
668
- attention_resolutions,
 
669
  dropout=0,
670
- channel_mult=(1, 2, 4, 8),
671
  conv_resample=True,
672
  dims=2,
673
  num_classes=None,
@@ -688,7 +688,15 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
688
  ):
689
  super().__init__()
690
  assert context_dim is not None
691
-
 
 
 
 
 
 
 
 
692
  if num_heads_upsample == -1:
693
  num_heads_upsample = num_heads
694
 
@@ -730,7 +738,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
730
  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
731
  f"attention will still not be set."
732
  )
733
-
734
  self.attention_resolutions = attention_resolutions
735
  self.dropout = dropout
736
  self.channel_mult = channel_mult
@@ -1418,20 +1426,18 @@ class MVDreamPipeline(DiffusionPipeline):
1418
  return torch.zeros_like(image_embeds), image_embeds
1419
 
1420
  def encode_image_latents(self, image, device, num_images_per_prompt):
1421
-
1422
  dtype = next(self.image_encoder.parameters()).dtype
1423
-
1424
- image = (
1425
- torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
1426
- ) # [1, 3, H, W]
1427
  image = 2 * image - 1
1428
- image = F.interpolate(image, (256, 256), mode="bilinear", align_corners=False)
1429
  image = image.to(dtype=dtype)
1430
 
1431
  posterior = self.vae.encode(image).latent_dist
1432
- latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
1433
  latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
1434
-
1435
  return torch.zeros_like(latents), latents
1436
 
1437
  @torch.no_grad()
@@ -1439,8 +1445,8 @@ class MVDreamPipeline(DiffusionPipeline):
1439
  self,
1440
  prompt: str = "",
1441
  image: Optional[np.ndarray] = None,
1442
- height: int = 256,
1443
- width: int = 256,
1444
  elevation: float = 0,
1445
  num_inference_steps: int = 50,
1446
  guidance_scale: float = 7.0,
@@ -1454,6 +1460,10 @@ class MVDreamPipeline(DiffusionPipeline):
1454
  num_frames: int = 4,
1455
  device=torch.device("cuda:0"),
1456
  ):
 
 
 
 
1457
  self.unet = self.unet.to(device=device)
1458
  self.vae = self.vae.to(device=device)
1459
  self.text_encoder = self.text_encoder.to(device=device)
 
660
 
661
  def __init__(
662
  self,
663
+ image_size=512, # Force 512 resolution
664
+ in_channels=4,
665
+ model_channels=320,
666
+ out_channels=4,
667
+ num_res_blocks=[2, 2, 2, 2],
668
+ attention_resolutions=[8, 4, 2], # Adjusted for 512x512
669
+ channel_mult=[1, 2, 4, 8],
670
  dropout=0,
 
671
  conv_resample=True,
672
  dims=2,
673
  num_classes=None,
 
688
  ):
689
  super().__init__()
690
  assert context_dim is not None
691
+
692
+ # Add resolution validation
693
+ assert image_size in [256, 512], "Only 256/512 resolutions supported"
694
+ super().__init__()
695
+
696
+ # Modify attention resolutions for 512
697
+ if image_size == 512:
698
+ attention_resolutions = [16, 8, 4]
699
+
700
  if num_heads_upsample == -1:
701
  num_heads_upsample = num_heads
702
 
 
738
  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
739
  f"attention will still not be set."
740
  )
741
+
742
  self.attention_resolutions = attention_resolutions
743
  self.dropout = dropout
744
  self.channel_mult = channel_mult
 
1426
  return torch.zeros_like(image_embeds), image_embeds
1427
 
1428
  def encode_image_latents(self, image, device, num_images_per_prompt):
 
1429
  dtype = next(self.image_encoder.parameters()).dtype
1430
+
1431
+ # Change interpolation size to match target resolution
1432
+ image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
 
1433
  image = 2 * image - 1
1434
+ image = F.interpolate(image, (512, 512), mode='bilinear', align_corners=False) # Changed from 256
1435
  image = image.to(dtype=dtype)
1436
 
1437
  posterior = self.vae.encode(image).latent_dist
1438
+ latents = posterior.sample() * self.vae.config.scaling_factor
1439
  latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
1440
+
1441
  return torch.zeros_like(latents), latents
1442
 
1443
  @torch.no_grad()
 
1445
  self,
1446
  prompt: str = "",
1447
  image: Optional[np.ndarray] = None,
1448
+ height: int = 512,
1449
+ width: int = 512,
1450
  elevation: float = 0,
1451
  num_inference_steps: int = 50,
1452
  guidance_scale: float = 7.0,
 
1460
  num_frames: int = 4,
1461
  device=torch.device("cuda:0"),
1462
  ):
1463
+ # Add resolution validation
1464
+ if height != 512 or width != 512:
1465
+ raise ValueError("Current implementation requires 512x512 resolution")
1466
+
1467
  self.unet = self.unet.to(device=device)
1468
  self.vae = self.vae.to(device=device)
1469
  self.text_encoder = self.text_encoder.to(device=device)