Jordan Legg commited on
Commit
69e75b1
Β·
1 Parent(s): da39f41

align image resize with VAE sample size

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -16,12 +16,12 @@ MAX_IMAGE_SIZE = 2048
16
  # Load the diffusion pipeline
17
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
18
 
19
- def preprocess_image(image):
20
  # Preprocess the image for the VAE
21
  preprocess = transforms.Compose([
22
- transforms.Resize((512, 512)), # Adjust the size as needed
23
  transforms.ToTensor(),
24
- transforms.Normalize([0.5], [0.5])
25
  ])
26
  image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
27
  return image
@@ -37,17 +37,17 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
37
  if randomize_seed:
38
  seed = random.randint(0, MAX_SEED)
39
  generator = torch.Generator().manual_seed(seed)
 
 
 
40
 
41
  if init_image is not None:
42
  # Process img2img
43
  init_image = init_image.convert("RGB")
44
- init_image = preprocess_image(init_image)
45
  latents = encode_image(init_image, pipe.vae)
46
  # Ensure latents are correctly shaped and adjusted
47
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
48
- latents = latents * 0.18215 # Adjust latent scaling factor if necessary
49
-
50
- # Ensure latents are reshaped to match the expected input dimensions of the model
51
  latents = latents.view(1, -1, height // 8, width // 8)
52
 
53
  image = pipe(
@@ -72,6 +72,8 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
72
 
73
  return image, seed
74
 
 
 
75
  # Define example prompts
76
  examples = [
77
  "a tiny astronaut hatching from an egg on the moon",
 
16
  # Load the diffusion pipeline
17
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
18
 
19
+ def preprocess_image(image, image_size):
20
  # Preprocess the image for the VAE
21
  preprocess = transforms.Compose([
22
+ transforms.Resize((image_size, image_size)), # Use model-specific size
23
  transforms.ToTensor(),
24
+ transforms.Normalize([0.5], [0.5]) # Ensure this matches the VAE's training normalization
25
  ])
26
  image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
27
  return image
 
37
  if randomize_seed:
38
  seed = random.randint(0, MAX_SEED)
39
  generator = torch.Generator().manual_seed(seed)
40
+
41
+ # Get the expected image size for the VAE
42
+ vae_image_size = pipe.vae.config.sample_size
43
 
44
  if init_image is not None:
45
  # Process img2img
46
  init_image = init_image.convert("RGB")
47
+ init_image = preprocess_image(init_image, vae_image_size)
48
  latents = encode_image(init_image, pipe.vae)
49
  # Ensure latents are correctly shaped and adjusted
50
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
 
 
 
51
  latents = latents.view(1, -1, height // 8, width // 8)
52
 
53
  image = pipe(
 
72
 
73
  return image, seed
74
 
75
+
76
+
77
  # Define example prompts
78
  examples = [
79
  "a tiny astronaut hatching from an egg on the moon",