Spaces:
Runtime error
Runtime error
Jordan Legg
commited on
Commit
Β·
409e82d
1
Parent(s):
b54a3db
remove latent flattening
Browse files
app.py
CHANGED
@@ -18,7 +18,6 @@ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", tor
|
|
18 |
|
19 |
def preprocess_image(image, image_size):
|
20 |
print(f"Preprocessing image to size: {image_size}x{image_size}")
|
21 |
-
# Preprocess the image for the VAE
|
22 |
preprocess = transforms.Compose([
|
23 |
transforms.Resize((image_size, image_size)), # Use model-specific size
|
24 |
transforms.ToTensor(),
|
@@ -30,7 +29,6 @@ def preprocess_image(image, image_size):
|
|
30 |
|
31 |
def encode_image(image, vae):
|
32 |
print("Encoding image using the VAE")
|
33 |
-
# Encode the image using the VAE
|
34 |
with torch.no_grad():
|
35 |
latents = vae.encode(image).latent_dist.sample() * 0.18215
|
36 |
print(f"Latents shape after encoding: {latents.shape}")
|
@@ -72,9 +70,16 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
|
|
72 |
latents = latents.view(1, 64, height // 8, width // 8)
|
73 |
print(f"Latents shape after reshaping: {latents.shape}")
|
74 |
|
75 |
-
#
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
print("Calling the diffusion pipeline with latents")
|
80 |
image = pipe(
|
@@ -103,6 +108,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
|
|
103 |
|
104 |
|
105 |
|
|
|
106 |
# Define example prompts
|
107 |
examples = [
|
108 |
"a tiny astronaut hatching from an egg on the moon",
|
|
|
18 |
|
19 |
def preprocess_image(image, image_size):
|
20 |
print(f"Preprocessing image to size: {image_size}x{image_size}")
|
|
|
21 |
preprocess = transforms.Compose([
|
22 |
transforms.Resize((image_size, image_size)), # Use model-specific size
|
23 |
transforms.ToTensor(),
|
|
|
29 |
|
30 |
def encode_image(image, vae):
|
31 |
print("Encoding image using the VAE")
|
|
|
32 |
with torch.no_grad():
|
33 |
latents = vae.encode(image).latent_dist.sample() * 0.18215
|
34 |
print(f"Latents shape after encoding: {latents.shape}")
|
|
|
70 |
latents = latents.view(1, 64, height // 8, width // 8)
|
71 |
print(f"Latents shape after reshaping: {latents.shape}")
|
72 |
|
73 |
+
# Avoid flattening, ensure latents are in the expected shape for the transformer
|
74 |
+
# Adding extra debug to understand what transformer expects
|
75 |
+
try:
|
76 |
+
print("Calling the transformer with latents")
|
77 |
+
# Dummy call to transformer to understand the shape requirement
|
78 |
+
_ = pipe.transformer(latents)
|
79 |
+
print("Transformer call succeeded")
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Transformer call failed with error: {e}")
|
82 |
+
raise
|
83 |
|
84 |
print("Calling the diffusion pipeline with latents")
|
85 |
image = pipe(
|
|
|
108 |
|
109 |
|
110 |
|
111 |
+
|
112 |
# Define example prompts
|
113 |
examples = [
|
114 |
"a tiny astronaut hatching from an egg on the moon",
|