Jordan Legg commited on
Commit
b54a3db
Β·
1 Parent(s): 2811e7f

added console logging

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -17,6 +17,7 @@ MAX_IMAGE_SIZE = 2048
17
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
18
 
19
  def preprocess_image(image, image_size):
 
20
  # Preprocess the image for the VAE
21
  preprocess = transforms.Compose([
22
  transforms.Resize((image_size, image_size)), # Use model-specific size
@@ -24,47 +25,58 @@ def preprocess_image(image, image_size):
24
  transforms.Normalize([0.5], [0.5]) # Ensure this matches the VAE's training normalization
25
  ])
26
  image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
 
27
  return image
28
 
29
  def encode_image(image, vae):
 
30
  # Encode the image using the VAE
31
  with torch.no_grad():
32
  latents = vae.encode(image).latent_dist.sample() * 0.18215
 
33
  return latents
34
 
35
  @spaces.GPU()
36
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
 
37
  if randomize_seed:
38
  seed = random.randint(0, MAX_SEED)
 
39
  generator = torch.Generator().manual_seed(seed)
40
 
41
  # Get the expected image size for the VAE
42
  vae_image_size = pipe.vae.config.sample_size
 
43
 
44
  if init_image is not None:
45
- # Process img2img
46
  init_image = init_image.convert("RGB")
47
  init_image = preprocess_image(init_image, vae_image_size)
48
  latents = encode_image(init_image, pipe.vae)
49
 
50
- # Debug: Print the shape of the latents after encoding
51
- print(f"Latents shape after encoding: {latents.shape}")
52
-
53
- # Ensure latents are correctly shaped and adjusted
54
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
 
55
 
56
  # Convert latent channels to 64 as expected by the transformer
57
  latent_channels = pipe.vae.config.latent_channels
 
58
  if latent_channels != 64:
 
59
  conv = torch.nn.Conv2d(latent_channels, 64, kernel_size=1).to(device, dtype=dtype)
60
  latents = conv(latents)
 
61
 
62
  # Reshape latents to match the transformer's input expectations
63
  latents = latents.view(1, 64, height // 8, width // 8)
64
-
65
- # Debug: Print the shape of the latents after reshaping
66
  print(f"Latents shape after reshaping: {latents.shape}")
67
 
 
 
 
 
 
68
  image = pipe(
69
  prompt=prompt,
70
  height=height,
@@ -75,7 +87,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
75
  latents=latents
76
  ).images[0]
77
  else:
78
- # Process text2img
79
  image = pipe(
80
  prompt=prompt,
81
  height=height,
@@ -85,10 +97,12 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
85
  guidance_scale=0.0
86
  ).images[0]
87
 
 
88
  return image, seed
89
 
90
 
91
 
 
92
  # Define example prompts
93
  examples = [
94
  "a tiny astronaut hatching from an egg on the moon",
 
17
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
18
 
19
  def preprocess_image(image, image_size):
20
+ print(f"Preprocessing image to size: {image_size}x{image_size}")
21
  # Preprocess the image for the VAE
22
  preprocess = transforms.Compose([
23
  transforms.Resize((image_size, image_size)), # Use model-specific size
 
25
  transforms.Normalize([0.5], [0.5]) # Ensure this matches the VAE's training normalization
26
  ])
27
  image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
28
+ print(f"Image shape after preprocessing: {image.shape}")
29
  return image
30
 
31
  def encode_image(image, vae):
32
+ print("Encoding image using the VAE")
33
  # Encode the image using the VAE
34
  with torch.no_grad():
35
  latents = vae.encode(image).latent_dist.sample() * 0.18215
36
+ print(f"Latents shape after encoding: {latents.shape}")
37
  return latents
38
 
39
  @spaces.GPU()
40
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
41
+ print(f"Inference started with prompt: {prompt}")
42
  if randomize_seed:
43
  seed = random.randint(0, MAX_SEED)
44
+ print(f"Using seed: {seed}")
45
  generator = torch.Generator().manual_seed(seed)
46
 
47
  # Get the expected image size for the VAE
48
  vae_image_size = pipe.vae.config.sample_size
49
+ print(f"Expected VAE image size: {vae_image_size}")
50
 
51
  if init_image is not None:
52
+ print("Initial image provided, processing img2img")
53
  init_image = init_image.convert("RGB")
54
  init_image = preprocess_image(init_image, vae_image_size)
55
  latents = encode_image(init_image, pipe.vae)
56
 
57
+ # Interpolating latents
58
+ print(f"Interpolating latents to size: {(height // 8, width // 8)}")
 
 
59
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
60
+ print(f"Latents shape after interpolation: {latents.shape}")
61
 
62
  # Convert latent channels to 64 as expected by the transformer
63
  latent_channels = pipe.vae.config.latent_channels
64
+ print(f"Expected latent channels: 64, current latent channels: {latent_channels}")
65
  if latent_channels != 64:
66
+ print(f"Converting latent channels from {latent_channels} to 64")
67
  conv = torch.nn.Conv2d(latent_channels, 64, kernel_size=1).to(device, dtype=dtype)
68
  latents = conv(latents)
69
+ print(f"Latents shape after channel conversion: {latents.shape}")
70
 
71
  # Reshape latents to match the transformer's input expectations
72
  latents = latents.view(1, 64, height // 8, width // 8)
 
 
73
  print(f"Latents shape after reshaping: {latents.shape}")
74
 
75
+ # Flatten the latents if required by the transformer
76
+ latents = latents.flatten(start_dim=1)
77
+ print(f"Latents shape after flattening: {latents.shape}")
78
+
79
+ print("Calling the diffusion pipeline with latents")
80
  image = pipe(
81
  prompt=prompt,
82
  height=height,
 
87
  latents=latents
88
  ).images[0]
89
  else:
90
+ print("No initial image provided, processing text2img")
91
  image = pipe(
92
  prompt=prompt,
93
  height=height,
 
97
  guidance_scale=0.0
98
  ).images[0]
99
 
100
+ print("Inference complete")
101
  return image, seed
102
 
103
 
104
 
105
+
106
  # Define example prompts
107
  examples = [
108
  "a tiny astronaut hatching from an egg on the moon",