Jordan Legg commited on
Commit
878ec45
Β·
1 Parent(s): 69e75b1

align latents to transformer

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -46,10 +46,24 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
46
  init_image = init_image.convert("RGB")
47
  init_image = preprocess_image(init_image, vae_image_size)
48
  latents = encode_image(init_image, pipe.vae)
 
 
 
 
49
  # Ensure latents are correctly shaped and adjusted
50
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
51
- latents = latents.view(1, -1, height // 8, width // 8)
52
 
 
 
 
 
 
 
 
 
 
 
 
53
  image = pipe(
54
  prompt=prompt,
55
  height=height,
@@ -73,7 +87,6 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
73
  return image, seed
74
 
75
 
76
-
77
  # Define example prompts
78
  examples = [
79
  "a tiny astronaut hatching from an egg on the moon",
 
46
  init_image = init_image.convert("RGB")
47
  init_image = preprocess_image(init_image, vae_image_size)
48
  latents = encode_image(init_image, pipe.vae)
49
+
50
+ # Debug: Print the shape of the latents after encoding
51
+ print(f"Latents shape after encoding: {latents.shape}")
52
+
53
  # Ensure latents are correctly shaped and adjusted
54
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
 
55
 
56
+ # Convert latent channels to 64 as expected by the transformer
57
+ latent_channels = pipe.vae.config.latent_channels
58
+ if latent_channels != 64:
59
+ latents = torch.nn.Conv2d(latent_channels, 64, kernel_size=1).to(device)(latents)
60
+
61
+ # Reshape latents to match the transformer's input expectations
62
+ latents = latents.view(1, 64, height // 8, width // 8)
63
+
64
+ # Debug: Print the shape of the latents after reshaping
65
+ print(f"Latents shape after reshaping: {latents.shape}")
66
+
67
  image = pipe(
68
  prompt=prompt,
69
  height=height,
 
87
  return image, seed
88
 
89
 
 
90
  # Define example prompts
91
  examples = [
92
  "a tiny astronaut hatching from an egg on the moon",