Jordan Legg commited on
Commit
12f1d71
Β·
1 Parent(s): 6cbacd9

mapped projection

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -13,6 +13,9 @@ dtype = torch.bfloat16
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
 
 
 
16
 
17
  # Load FLUX model
18
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
@@ -20,8 +23,8 @@ pipe.enable_model_cpu_offload()
20
  pipe.vae.enable_slicing()
21
  pipe.vae.enable_tiling()
22
 
23
- # Add a projection layer to match x_embedder input
24
- projection = nn.Linear(16, 64).to(device).to(dtype)
25
 
26
  def preprocess_image(image, image_size):
27
  preprocess = transforms.Compose([
@@ -41,7 +44,7 @@ def process_latents(latents, height, width):
41
  print(f"Latent shape after potential interpolation: {latents.shape}")
42
 
43
  # Reshape latents to [batch_size, seq_len, channels]
44
- latents = latents.permute(0, 2, 3, 1).reshape(1, -1, latents.shape[1])
45
  print(f"Reshaped latent shape: {latents.shape}")
46
 
47
  # Project latents from 16 to 64 dimensions
@@ -73,10 +76,10 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
73
  init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
74
 
75
  # Encode the image using FLUX VAE
76
- latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215
77
  print(f"Initial latent shape from VAE: {latents.shape}")
78
 
79
- # Process latents to match x_embedder input
80
  latents = process_latents(latents, height, width)
81
 
82
  print(f"x_embedder weight shape: {pipe.transformer.x_embedder.weight.shape}")
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
+ LATENT_CHANNELS = 16
17
+ TRANSFORMER_IN_CHANNELS = 64
18
+ SCALING_FACTOR = 0.3611
19
 
20
  # Load FLUX model
21
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
23
  pipe.vae.enable_slicing()
24
  pipe.vae.enable_tiling()
25
 
26
+ # Add a projection layer to match transformer input
27
+ projection = nn.Linear(LATENT_CHANNELS, TRANSFORMER_IN_CHANNELS).to(device).to(dtype)
28
 
29
  def preprocess_image(image, image_size):
30
  preprocess = transforms.Compose([
 
44
  print(f"Latent shape after potential interpolation: {latents.shape}")
45
 
46
  # Reshape latents to [batch_size, seq_len, channels]
47
+ latents = latents.permute(0, 2, 3, 1).reshape(1, -1, LATENT_CHANNELS)
48
  print(f"Reshaped latent shape: {latents.shape}")
49
 
50
  # Project latents from 16 to 64 dimensions
 
76
  init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
77
 
78
  # Encode the image using FLUX VAE
79
+ latents = pipe.vae.encode(init_image).latent_dist.sample() * SCALING_FACTOR
80
  print(f"Initial latent shape from VAE: {latents.shape}")
81
 
82
+ # Process latents to match transformer input
83
  latents = process_latents(latents, height, width)
84
 
85
  print(f"x_embedder weight shape: {pipe.transformer.x_embedder.weight.shape}")