Jordan Legg commited on
Commit
bc9da49
Β·
1 Parent(s): d53ee34

trying to match the latent shapes

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -34,22 +34,27 @@ def check_shapes(latents):
34
  print(f"Latent shape: {latent_shape}")
35
 
36
  # Get the expected shape for the transformer input
37
- expected_shape = (1, latent_shape[1] * latent_shape[2], latent_shape[3])
38
  print(f"Expected transformer input shape: {expected_shape}")
39
 
40
- # Get the shape of the transformer's weight matrix
41
- if hasattr(pipe.transformer, 'text_model'):
42
- weight_shape = pipe.transformer.text_model.encoder.layers[0].self_attn.q_proj.weight.shape
43
- else:
44
- weight_shape = pipe.transformer.encoder.layers[0].self_attn.q_proj.weight.shape
45
- print(f"Transformer weight shape: {weight_shape}")
 
 
 
46
 
47
- # Check if the shapes are compatible for matrix multiplication
48
- if expected_shape[1] == weight_shape[1]:
49
- print("Shapes are compatible for matrix multiplication.")
50
- else:
51
- print("Warning: Shapes are not compatible for matrix multiplication.")
52
- print(f"Expected: {expected_shape[1]}, Got: {weight_shape[1]}")
 
 
53
 
54
  @spaces.GPU()
55
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
@@ -83,7 +88,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
83
  check_shapes(latents)
84
 
85
  # Reshape latents to match the expected input shape of the transformer
86
- latents = latents.permute(0, 2, 3, 1).contiguous().view(1, -1, pipe.vae.config.latent_channels)
87
 
88
  # Check shapes after reshaping
89
  check_shapes(latents)
@@ -103,7 +108,6 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
103
  print(f"Error during inference: {e}")
104
  return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
105
 
106
-
107
  # Gradio interface setup
108
  with gr.Blocks() as demo:
109
  with gr.Row():
 
34
  print(f"Latent shape: {latent_shape}")
35
 
36
  # Get the expected shape for the transformer input
37
+ expected_shape = (1, latent_shape[1] * latent_shape[2] * latent_shape[3])
38
  print(f"Expected transformer input shape: {expected_shape}")
39
 
40
+ # Try to get the shape of the transformer's weight matrix
41
+ try:
42
+ # Assuming the first layer of the transformer has a linear projection
43
+ if hasattr(pipe.transformer, 'blocks'):
44
+ weight_shape = pipe.transformer.blocks[0].attn.to_q.weight.shape
45
+ else:
46
+ print("Unable to determine transformer weight shape.")
47
+ return
48
+ print(f"Transformer weight shape: {weight_shape}")
49
 
50
+ # Check if the shapes are compatible for matrix multiplication
51
+ if expected_shape[1] == weight_shape[1]:
52
+ print("Shapes are compatible for matrix multiplication.")
53
+ else:
54
+ print("Warning: Shapes are not compatible for matrix multiplication.")
55
+ print(f"Expected: {expected_shape[1]}, Got: {weight_shape[1]}")
56
+ except AttributeError as e:
57
+ print(f"Unable to access transformer weights: {e}")
58
 
59
  @spaces.GPU()
60
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
 
88
  check_shapes(latents)
89
 
90
  # Reshape latents to match the expected input shape of the transformer
91
+ latents = latents.reshape(1, -1)
92
 
93
  # Check shapes after reshaping
94
  check_shapes(latents)
 
108
  print(f"Error during inference: {e}")
109
  return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
110
 
 
111
  # Gradio interface setup
112
  with gr.Blocks() as demo:
113
  with gr.Row():