Jordan Legg commited on
Commit
3be64a5
Β·
1 Parent(s): 383a90d

target the text encoder, merge latent space before the pipeline

Browse files
Files changed (1) hide show
  1. app.py +23 -8
app.py CHANGED
@@ -14,7 +14,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
  LATENT_CHANNELS = 16
17
- TRANSFORMER_IN_CHANNELS = 64
 
18
  SCALING_FACTOR = 0.3611
19
 
20
  # Load FLUX model
@@ -23,8 +24,8 @@ pipe.enable_model_cpu_offload()
23
  pipe.vae.enable_slicing()
24
  pipe.vae.enable_tiling()
25
 
26
- # Add a projection layer to match transformer input
27
- projection = nn.Linear(LATENT_CHANNELS, TRANSFORMER_IN_CHANNELS).to(device).to(dtype)
28
 
29
  def preprocess_image(image, image_size):
30
  preprocess = transforms.Compose([
@@ -47,10 +48,19 @@ def process_latents(latents, height, width):
47
  latents = latents.permute(0, 2, 3, 1).reshape(1, -1, LATENT_CHANNELS)
48
  print(f"Reshaped latent shape: {latents.shape}")
49
 
50
- # Project latents from 16 to 64 dimensions
51
  latents = projection(latents)
52
  print(f"Projected latent shape: {latents.shape}")
53
 
 
 
 
 
 
 
 
 
 
54
  return latents
55
 
56
  @spaces.GPU()
@@ -79,11 +89,16 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
79
  latents = pipe.vae.encode(init_image).latent_dist.sample() * SCALING_FACTOR
80
  print(f"Initial latent shape from VAE: {latents.shape}")
81
 
82
- # Process latents to match transformer input
83
  latents = process_latents(latents, height, width)
84
 
85
- print(f"x_embedder weight shape: {pipe.transformer.x_embedder.weight.shape}")
86
- print(f"First transformer block input shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
 
 
 
 
 
87
 
88
  image = pipe(
89
  prompt=prompt,
@@ -92,7 +107,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
92
  num_inference_steps=num_inference_steps,
93
  generator=generator,
94
  guidance_scale=0.0,
95
- latents=latents
96
  ).images[0]
97
 
98
  return image, seed
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
  LATENT_CHANNELS = 16
17
+ TEXT_EMBED_DIM = 768
18
+ MAX_TEXT_EMBEDDINGS = 77
19
  SCALING_FACTOR = 0.3611
20
 
21
  # Load FLUX model
 
24
  pipe.vae.enable_slicing()
25
  pipe.vae.enable_tiling()
26
 
27
+ # Add a projection layer to match text embedding dimension
28
+ projection = nn.Linear(LATENT_CHANNELS, TEXT_EMBED_DIM).to(device).to(dtype)
29
 
30
  def preprocess_image(image, image_size):
31
  preprocess = transforms.Compose([
 
48
  latents = latents.permute(0, 2, 3, 1).reshape(1, -1, LATENT_CHANNELS)
49
  print(f"Reshaped latent shape: {latents.shape}")
50
 
51
+ # Project latents to match text embedding dimension
52
  latents = projection(latents)
53
  print(f"Projected latent shape: {latents.shape}")
54
 
55
+ # Adjust sequence length to match text embeddings
56
+ seq_len = latents.shape[1]
57
+ if seq_len > MAX_TEXT_EMBEDDINGS:
58
+ latents = latents[:, :MAX_TEXT_EMBEDDINGS, :]
59
+ elif seq_len < MAX_TEXT_EMBEDDINGS:
60
+ pad_len = MAX_TEXT_EMBEDDINGS - seq_len
61
+ latents = torch.nn.functional.pad(latents, (0, 0, 0, pad_len, 0, 0))
62
+ print(f"Final latent shape: {latents.shape}")
63
+
64
  return latents
65
 
66
  @spaces.GPU()
 
89
  latents = pipe.vae.encode(init_image).latent_dist.sample() * SCALING_FACTOR
90
  print(f"Initial latent shape from VAE: {latents.shape}")
91
 
92
+ # Process latents to match text embedding format
93
  latents = process_latents(latents, height, width)
94
 
95
+ # Get text embeddings
96
+ text_embeddings = pipe.transformer.text_encoder([prompt])
97
+ print(f"Text embedding shape: {text_embeddings.shape}")
98
+
99
+ # Combine image latents and text embeddings
100
+ combined_embeddings = torch.cat([latents, text_embeddings], dim=1)
101
+ print(f"Combined embedding shape: {combined_embeddings.shape}")
102
 
103
  image = pipe(
104
  prompt=prompt,
 
107
  num_inference_steps=num_inference_steps,
108
  generator=generator,
109
  guidance_scale=0.0,
110
+ latents=combined_embeddings
111
  ).images[0]
112
 
113
  return image, seed