Abdualkader commited on
Commit
27692f6
·
verified ·
1 Parent(s): ee14abd

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +19 -29
pipeline.py CHANGED
@@ -660,14 +660,14 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
660
 
661
  def __init__(
662
  self,
663
- image_size=512, # Force 512 resolution
664
- in_channels=4,
665
- model_channels=320,
666
- out_channels=4,
667
- num_res_blocks=[2, 2, 2, 2],
668
- attention_resolutions=[8, 4, 2], # Adjusted for 512x512
669
- channel_mult=[1, 2, 4, 8],
670
  dropout=0,
 
671
  conv_resample=True,
672
  dims=2,
673
  num_classes=None,
@@ -688,15 +688,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
688
  ):
689
  super().__init__()
690
  assert context_dim is not None
691
-
692
- # Add resolution validation
693
- assert image_size in [256, 512], "Only 256/512 resolutions supported"
694
- super().__init__()
695
-
696
- # Modify attention resolutions for 512
697
- if image_size == 512:
698
- attention_resolutions = [16, 8, 4]
699
-
700
  if num_heads_upsample == -1:
701
  num_heads_upsample = num_heads
702
 
@@ -738,7 +730,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
738
  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
739
  f"attention will still not be set."
740
  )
741
-
742
  self.attention_resolutions = attention_resolutions
743
  self.dropout = dropout
744
  self.channel_mult = channel_mult
@@ -1426,18 +1418,20 @@ class MVDreamPipeline(DiffusionPipeline):
1426
  return torch.zeros_like(image_embeds), image_embeds
1427
 
1428
  def encode_image_latents(self, image, device, num_images_per_prompt):
 
1429
  dtype = next(self.image_encoder.parameters()).dtype
1430
-
1431
- # Change interpolation size to match target resolution
1432
- image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
 
1433
  image = 2 * image - 1
1434
- image = F.interpolate(image, (512, 512), mode='bilinear', align_corners=False) # Changed from 256
1435
  image = image.to(dtype=dtype)
1436
 
1437
  posterior = self.vae.encode(image).latent_dist
1438
- latents = posterior.sample() * self.vae.config.scaling_factor
1439
  latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
1440
-
1441
  return torch.zeros_like(latents), latents
1442
 
1443
  @torch.no_grad()
@@ -1445,8 +1439,8 @@ class MVDreamPipeline(DiffusionPipeline):
1445
  self,
1446
  prompt: str = "",
1447
  image: Optional[np.ndarray] = None,
1448
- height: int = 512,
1449
- width: int = 512,
1450
  elevation: float = 0,
1451
  num_inference_steps: int = 50,
1452
  guidance_scale: float = 7.0,
@@ -1460,10 +1454,6 @@ class MVDreamPipeline(DiffusionPipeline):
1460
  num_frames: int = 4,
1461
  device=torch.device("cuda:0"),
1462
  ):
1463
- # Add resolution validation
1464
- if height != 512 or width != 512:
1465
- raise ValueError("Current implementation requires 512x512 resolution")
1466
-
1467
  self.unet = self.unet.to(device=device)
1468
  self.vae = self.vae.to(device=device)
1469
  self.text_encoder = self.text_encoder.to(device=device)
 
660
 
661
  def __init__(
662
  self,
663
+ image_size,
664
+ in_channels,
665
+ model_channels,
666
+ out_channels,
667
+ num_res_blocks,
668
+ attention_resolutions,
 
669
  dropout=0,
670
+ channel_mult=(1, 2, 4, 8),
671
  conv_resample=True,
672
  dims=2,
673
  num_classes=None,
 
688
  ):
689
  super().__init__()
690
  assert context_dim is not None
691
+
 
 
 
 
 
 
 
 
692
  if num_heads_upsample == -1:
693
  num_heads_upsample = num_heads
694
 
 
730
  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
731
  f"attention will still not be set."
732
  )
733
+
734
  self.attention_resolutions = attention_resolutions
735
  self.dropout = dropout
736
  self.channel_mult = channel_mult
 
1418
  return torch.zeros_like(image_embeds), image_embeds
1419
 
1420
  def encode_image_latents(self, image, device, num_images_per_prompt):
1421
+
1422
  dtype = next(self.image_encoder.parameters()).dtype
1423
+
1424
+ image = (
1425
+ torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
1426
+ ) # [1, 3, H, W]
1427
  image = 2 * image - 1
1428
+ image = F.interpolate(image, (256, 256), mode="bilinear", align_corners=False)
1429
  image = image.to(dtype=dtype)
1430
 
1431
  posterior = self.vae.encode(image).latent_dist
1432
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
1433
  latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
1434
+
1435
  return torch.zeros_like(latents), latents
1436
 
1437
  @torch.no_grad()
 
1439
  self,
1440
  prompt: str = "",
1441
  image: Optional[np.ndarray] = None,
1442
+ height: int = 256,
1443
+ width: int = 256,
1444
  elevation: float = 0,
1445
  num_inference_steps: int = 50,
1446
  guidance_scale: float = 7.0,
 
1454
  num_frames: int = 4,
1455
  device=torch.device("cuda:0"),
1456
  ):
 
 
 
 
1457
  self.unet = self.unet.to(device=device)
1458
  self.vae = self.vae.to(device=device)
1459
  self.text_encoder = self.text_encoder.to(device=device)