Jordan Legg commited on
Commit
22e5a11
Β·
1 Parent(s): 29a504c

log the tensore shape

Browse files
Files changed (1) hide show
  1. app.py +20 -19
app.py CHANGED
@@ -5,7 +5,7 @@ import spaces
5
  import torch
6
  from PIL import Image
7
  from torchvision import transforms
8
- from diffusers import DiffusionPipeline
9
 
10
  # Define constants
11
  dtype = torch.bfloat16
@@ -34,6 +34,10 @@ def encode_image(image, vae):
34
  print(f"Latents shape after encoding: {latents.shape}")
35
  return latents
36
 
 
 
 
 
37
  @spaces.GPU()
38
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
39
  print(f"Inference started with prompt: {prompt}")
@@ -44,7 +48,6 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
44
 
45
  if init_image is None:
46
  print("No initial image provided, processing text2img")
47
- # Process text2img
48
  try:
49
  print("Calling the diffusion pipeline for text2img")
50
  result = pipe(
@@ -58,12 +61,16 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
58
  image = result.images[0]
59
  print(f"Generated image shape: {image.size}")
60
 
61
- # Since the 'latents' attribute is not present, we need to inspect other attributes
62
- print(f"Result attributes: {dir(result)}")
 
 
 
 
 
63
  except Exception as e:
64
  print(f"Pipeline call failed with error: {e}")
65
  raise
66
-
67
  else:
68
  print("Initial image provided, processing img2img")
69
  vae_image_size = pipe.vae.config.sample_size
@@ -72,28 +79,21 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
72
  init_image = preprocess_image(init_image, vae_image_size)
73
  latents = encode_image(init_image, pipe.vae)
74
 
75
- # Interpolating latents
76
- print(f"Interpolating latents to size: {(height // 8, width // 8)}")
77
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
78
- print(f"Latents shape after interpolation: {latents.shape}")
79
-
80
- # Convert latent channels to 64 as expected by the transformer
81
  latent_channels = pipe.vae.config.latent_channels
82
  print(f"Expected latent channels: 64, current latent channels: {latent_channels}")
83
  if latent_channels != 64:
84
  print(f"Converting latent channels from {latent_channels} to 64")
85
  conv = torch.nn.Conv2d(latent_channels, 64, kernel_size=1).to(device, dtype=dtype)
86
  latents = conv(latents)
87
- print(f"Latents shape after channel conversion: {latents.shape}")
88
 
89
- # Debugging input shape before calling transformer
90
- print(f"Latents shape before reshaping for transformer: {latents.shape}")
91
 
92
- # Reshape latents to match the transformer's input expectations
93
- latents = latents.permute(0, 2, 3, 1).contiguous().view(-1, 64) # Assuming the transformer expects (batch, sequence, feature)
94
- print(f"Latents shape after reshaping for transformer: {latents.shape}")
95
-
96
- # Adding extra debug to understand what transformer expects
97
  try:
98
  print("Calling the transformer with latents")
99
  # Dummy call to transformer to understand the shape requirement
@@ -103,8 +103,8 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
103
  print(f"Transformer call failed with error: {e}")
104
  raise
105
 
106
- print("Calling the diffusion pipeline with latents")
107
  try:
 
108
  image = pipe(
109
  prompt=prompt,
110
  height=height,
@@ -121,6 +121,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
121
  print("Inference complete")
122
  return image, seed
123
 
 
124
  # Define example prompts
125
  examples = [
126
  "a tiny astronaut hatching from an egg on the moon",
 
5
  import torch
6
  from PIL import Image
7
  from torchvision import transforms
8
+ from diffusers import DiffusionPipeline, AutoencoderKL
9
 
10
  # Define constants
11
  dtype = torch.bfloat16
 
34
  print(f"Latents shape after encoding: {latents.shape}")
35
  return latents
36
 
37
+ # A utility function to log shapes and other relevant information
38
+ def log_tensor_info(tensor, name):
39
+ print(f"{name} shape: {tensor.shape} dtype: {tensor.dtype} device: {tensor.device}")
40
+
41
  @spaces.GPU()
42
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
43
  print(f"Inference started with prompt: {prompt}")
 
48
 
49
  if init_image is None:
50
  print("No initial image provided, processing text2img")
 
51
  try:
52
  print("Calling the diffusion pipeline for text2img")
53
  result = pipe(
 
61
  image = result.images[0]
62
  print(f"Generated image shape: {image.size}")
63
 
64
+ # Inspect the output and log relevant details
65
+ print("Logging detailed information for text2img:")
66
+ for name, param in pipe.named_parameters():
67
+ if 'weight' in name:
68
+ log_tensor_info(param, name)
69
+
70
+ print("Logging complete.")
71
  except Exception as e:
72
  print(f"Pipeline call failed with error: {e}")
73
  raise
 
74
  else:
75
  print("Initial image provided, processing img2img")
76
  vae_image_size = pipe.vae.config.sample_size
 
79
  init_image = preprocess_image(init_image, vae_image_size)
80
  latents = encode_image(init_image, pipe.vae)
81
 
82
+ print("Interpolating latents to match model's input size...")
 
83
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
84
+ log_tensor_info(latents, "Latents after interpolation")
85
+
 
86
  latent_channels = pipe.vae.config.latent_channels
87
  print(f"Expected latent channels: 64, current latent channels: {latent_channels}")
88
  if latent_channels != 64:
89
  print(f"Converting latent channels from {latent_channels} to 64")
90
  conv = torch.nn.Conv2d(latent_channels, 64, kernel_size=1).to(device, dtype=dtype)
91
  latents = conv(latents)
92
+ log_tensor_info(latents, "Latents after channel conversion")
93
 
94
+ latents = latents.permute(0, 2, 3, 1).contiguous().view(-1, 64)
95
+ log_tensor_info(latents, "Latents after reshaping for transformer")
96
 
 
 
 
 
 
97
  try:
98
  print("Calling the transformer with latents")
99
  # Dummy call to transformer to understand the shape requirement
 
103
  print(f"Transformer call failed with error: {e}")
104
  raise
105
 
 
106
  try:
107
+ print("Calling the diffusion pipeline with latents")
108
  image = pipe(
109
  prompt=prompt,
110
  height=height,
 
121
  print("Inference complete")
122
  return image, seed
123
 
124
+
125
  # Define example prompts
126
  examples = [
127
  "a tiny astronaut hatching from an egg on the moon",