Spaces:
Runtime error
Runtime error
Jordan Legg
commited on
Commit
Β·
bc9da49
1
Parent(s):
d53ee34
trying to match the latent shapes
Browse files
app.py
CHANGED
@@ -34,22 +34,27 @@ def check_shapes(latents):
|
|
34 |
print(f"Latent shape: {latent_shape}")
|
35 |
|
36 |
# Get the expected shape for the transformer input
|
37 |
-
expected_shape = (1, latent_shape[1] * latent_shape[2]
|
38 |
print(f"Expected transformer input shape: {expected_shape}")
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
|
54 |
@spaces.GPU()
|
55 |
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
|
@@ -83,7 +88,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
|
|
83 |
check_shapes(latents)
|
84 |
|
85 |
# Reshape latents to match the expected input shape of the transformer
|
86 |
-
latents = latents.
|
87 |
|
88 |
# Check shapes after reshaping
|
89 |
check_shapes(latents)
|
@@ -103,7 +108,6 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
|
|
103 |
print(f"Error during inference: {e}")
|
104 |
return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
|
105 |
|
106 |
-
|
107 |
# Gradio interface setup
|
108 |
with gr.Blocks() as demo:
|
109 |
with gr.Row():
|
|
|
34 |
print(f"Latent shape: {latent_shape}")
|
35 |
|
36 |
# Get the expected shape for the transformer input
|
37 |
+
expected_shape = (1, latent_shape[1] * latent_shape[2] * latent_shape[3])
|
38 |
print(f"Expected transformer input shape: {expected_shape}")
|
39 |
|
40 |
+
# Try to get the shape of the transformer's weight matrix
|
41 |
+
try:
|
42 |
+
# Assuming the first layer of the transformer has a linear projection
|
43 |
+
if hasattr(pipe.transformer, 'blocks'):
|
44 |
+
weight_shape = pipe.transformer.blocks[0].attn.to_q.weight.shape
|
45 |
+
else:
|
46 |
+
print("Unable to determine transformer weight shape.")
|
47 |
+
return
|
48 |
+
print(f"Transformer weight shape: {weight_shape}")
|
49 |
|
50 |
+
# Check if the shapes are compatible for matrix multiplication
|
51 |
+
if expected_shape[1] == weight_shape[1]:
|
52 |
+
print("Shapes are compatible for matrix multiplication.")
|
53 |
+
else:
|
54 |
+
print("Warning: Shapes are not compatible for matrix multiplication.")
|
55 |
+
print(f"Expected: {expected_shape[1]}, Got: {weight_shape[1]}")
|
56 |
+
except AttributeError as e:
|
57 |
+
print(f"Unable to access transformer weights: {e}")
|
58 |
|
59 |
@spaces.GPU()
|
60 |
def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
|
|
|
88 |
check_shapes(latents)
|
89 |
|
90 |
# Reshape latents to match the expected input shape of the transformer
|
91 |
+
latents = latents.reshape(1, -1)
|
92 |
|
93 |
# Check shapes after reshaping
|
94 |
check_shapes(latents)
|
|
|
108 |
print(f"Error during inference: {e}")
|
109 |
return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
|
110 |
|
|
|
111 |
# Gradio interface setup
|
112 |
with gr.Blocks() as demo:
|
113 |
with gr.Row():
|