Jordan Legg commited on
Commit
409e82d
Β·
1 Parent(s): b54a3db

remove latent flattening

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -18,7 +18,6 @@ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", tor
18
 
19
  def preprocess_image(image, image_size):
20
  print(f"Preprocessing image to size: {image_size}x{image_size}")
21
- # Preprocess the image for the VAE
22
  preprocess = transforms.Compose([
23
  transforms.Resize((image_size, image_size)), # Use model-specific size
24
  transforms.ToTensor(),
@@ -30,7 +29,6 @@ def preprocess_image(image, image_size):
30
 
31
  def encode_image(image, vae):
32
  print("Encoding image using the VAE")
33
- # Encode the image using the VAE
34
  with torch.no_grad():
35
  latents = vae.encode(image).latent_dist.sample() * 0.18215
36
  print(f"Latents shape after encoding: {latents.shape}")
@@ -72,9 +70,16 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
72
  latents = latents.view(1, 64, height // 8, width // 8)
73
  print(f"Latents shape after reshaping: {latents.shape}")
74
 
75
- # Flatten the latents if required by the transformer
76
- latents = latents.flatten(start_dim=1)
77
- print(f"Latents shape after flattening: {latents.shape}")
 
 
 
 
 
 
 
78
 
79
  print("Calling the diffusion pipeline with latents")
80
  image = pipe(
@@ -103,6 +108,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
103
 
104
 
105
 
 
106
  # Define example prompts
107
  examples = [
108
  "a tiny astronaut hatching from an egg on the moon",
 
18
 
19
  def preprocess_image(image, image_size):
20
  print(f"Preprocessing image to size: {image_size}x{image_size}")
 
21
  preprocess = transforms.Compose([
22
  transforms.Resize((image_size, image_size)), # Use model-specific size
23
  transforms.ToTensor(),
 
29
 
30
  def encode_image(image, vae):
31
  print("Encoding image using the VAE")
 
32
  with torch.no_grad():
33
  latents = vae.encode(image).latent_dist.sample() * 0.18215
34
  print(f"Latents shape after encoding: {latents.shape}")
 
70
  latents = latents.view(1, 64, height // 8, width // 8)
71
  print(f"Latents shape after reshaping: {latents.shape}")
72
 
73
+ # Avoid flattening, ensure latents are in the expected shape for the transformer
74
+ # Adding extra debug to understand what transformer expects
75
+ try:
76
+ print("Calling the transformer with latents")
77
+ # Dummy call to transformer to understand the shape requirement
78
+ _ = pipe.transformer(latents)
79
+ print("Transformer call succeeded")
80
+ except Exception as e:
81
+ print(f"Transformer call failed with error: {e}")
82
+ raise
83
 
84
  print("Calling the diffusion pipeline with latents")
85
  image = pipe(
 
108
 
109
 
110
 
111
+
112
  # Define example prompts
113
  examples = [
114
  "a tiny astronaut hatching from an egg on the moon",