Jordan Legg commited on
Commit
d2c614b
Β·
1 Parent(s): 12f1d71

simplified pipeline

Browse files
Files changed (1) hide show
  1. app.py +36 -69
app.py CHANGED
@@ -1,108 +1,75 @@
1
  import spaces
2
  import gradio as gr
3
- import numpy as np
4
- import random
5
  import torch
6
- import torch.nn as nn
7
  from PIL import Image
8
- from torchvision import transforms
9
  from diffusers import DiffusionPipeline
10
 
 
11
  # Constants
12
- dtype = torch.bfloat16
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
- LATENT_CHANNELS = 16
17
- TRANSFORMER_IN_CHANNELS = 64
18
- SCALING_FACTOR = 0.3611
19
 
20
  # Load FLUX model
21
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
22
  pipe.enable_model_cpu_offload()
23
- pipe.vae.enable_slicing()
24
- pipe.vae.enable_tiling()
25
-
26
- # Add a projection layer to match transformer input
27
- projection = nn.Linear(LATENT_CHANNELS, TRANSFORMER_IN_CHANNELS).to(device).to(dtype)
28
 
29
- def preprocess_image(image, image_size):
30
- preprocess = transforms.Compose([
31
- transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
32
- transforms.ToTensor(),
33
- transforms.Normalize([0.5], [0.5])
34
- ])
35
- image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
36
- return image
37
 
38
- def process_latents(latents, height, width):
39
- print(f"Input latent shape: {latents.shape}")
40
-
41
- # Ensure latents are the correct shape
42
- if latents.shape[2:] != (height // 8, width // 8):
43
- latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
44
- print(f"Latent shape after potential interpolation: {latents.shape}")
45
-
46
- # Reshape latents to [batch_size, seq_len, channels]
47
- latents = latents.permute(0, 2, 3, 1).reshape(1, -1, LATENT_CHANNELS)
48
- print(f"Reshaped latent shape: {latents.shape}")
49
-
50
- # Project latents from 16 to 64 dimensions
51
- latents = projection(latents)
52
- print(f"Projected latent shape: {latents.shape}")
53
-
54
- return latents
55
 
56
  @spaces.GPU()
57
- def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
58
- if randomize_seed:
59
- seed = random.randint(0, MAX_SEED)
60
- generator = torch.Generator(device=device).manual_seed(seed)
61
 
62
  try:
63
  if init_image is None:
64
  # text2img case
 
65
  image = pipe(
66
  prompt=prompt,
67
  height=height,
68
  width=width,
69
  num_inference_steps=num_inference_steps,
70
  generator=generator,
71
- guidance_scale=0.0
72
  ).images[0]
73
  else:
74
  # img2img case
75
- init_image = init_image.convert("RGB")
76
- init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
77
-
78
- # Encode the image using FLUX VAE
79
- latents = pipe.vae.encode(init_image).latent_dist.sample() * SCALING_FACTOR
80
- print(f"Initial latent shape from VAE: {latents.shape}")
81
-
82
- # Process latents to match transformer input
83
- latents = process_latents(latents, height, width)
84
-
85
- print(f"x_embedder weight shape: {pipe.transformer.x_embedder.weight.shape}")
86
- print(f"First transformer block input shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
87
-
88
  image = pipe(
89
  prompt=prompt,
90
- height=height,
91
- width=width,
92
  num_inference_steps=num_inference_steps,
93
  generator=generator,
94
- guidance_scale=0.0,
95
- latents=latents
96
  ).images[0]
97
 
98
  return image, seed
 
 
 
 
 
 
 
 
 
 
99
  except Exception as e:
100
- print(f"Error during inference: {e}")
101
  import traceback
102
  traceback.print_exc()
103
- return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
104
 
105
- # Gradio interface setup
106
  with gr.Blocks() as demo:
107
  with gr.Row():
108
  prompt = gr.Textbox(label="Prompt")
@@ -116,15 +83,15 @@ with gr.Blocks() as demo:
116
  seed_output = gr.Number(label="Seed")
117
 
118
  with gr.Accordion("Advanced Settings", open=False):
119
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
120
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
121
  width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
122
  height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
123
  num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
 
124
 
125
  generate.click(
126
  infer,
127
- inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
128
  outputs=[result, seed_output]
129
  )
130
 
 
1
  import spaces
2
  import gradio as gr
 
 
3
  import torch
 
4
  from PIL import Image
 
5
  from diffusers import DiffusionPipeline
6
 
7
+
8
  # Constants
9
+ MAX_SEED = 2**32 - 1
 
 
10
  MAX_IMAGE_SIZE = 2048
 
 
 
11
 
12
  # Load FLUX model
13
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
14
+ pipe.to("cuda")
15
  pipe.enable_model_cpu_offload()
16
+ pipe.enable_vae_slicing()
 
 
 
 
17
 
18
+ def print_model_shapes(pipe):
19
+ print("Model component shapes:")
20
+ print(f"VAE Encoder: {pipe.vae.encoder}")
21
+ print(f"VAE Decoder: {pipe.vae.decoder}")
22
+ print(f"x_embedder shape: {pipe.transformer.x_embedder.weight.shape}")
23
+ print(f"First transformer block shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
 
 
24
 
25
+ print_model_shapes(pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  @spaces.GPU()
28
+ def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0):
29
+ generator = torch.Generator(device="cuda").manual_seed(seed) if seed is not None else None
 
 
30
 
31
  try:
32
  if init_image is None:
33
  # text2img case
34
+ print("Running text-to-image generation")
35
  image = pipe(
36
  prompt=prompt,
37
  height=height,
38
  width=width,
39
  num_inference_steps=num_inference_steps,
40
  generator=generator,
41
+ guidance_scale=guidance_scale
42
  ).images[0]
43
  else:
44
  # img2img case
45
+ print("Running image-to-image generation")
46
+ init_image = init_image.convert("RGB").resize((width, height))
 
 
 
 
 
 
 
 
 
 
 
47
  image = pipe(
48
  prompt=prompt,
49
+ image=init_image,
 
50
  num_inference_steps=num_inference_steps,
51
  generator=generator,
52
+ guidance_scale=guidance_scale
 
53
  ).images[0]
54
 
55
  return image, seed
56
+ except RuntimeError as e:
57
+ if "mat1 and mat2 shapes cannot be multiplied" in str(e):
58
+ print("Matrix multiplication error detected. Tensor shapes:")
59
+ print(e)
60
+ # Here you could add code to print shapes of specific tensors if needed
61
+ else:
62
+ print(f"RuntimeError during inference: {e}")
63
+ import traceback
64
+ traceback.print_exc()
65
+ return Image.new("RGB", (width, height), (255, 0, 0)), seed
66
  except Exception as e:
67
+ print(f"Unexpected error during inference: {e}")
68
  import traceback
69
  traceback.print_exc()
70
+ return Image.new("RGB", (width, height), (255, 0, 0)), seed
71
 
72
+ # Gradio interface
73
  with gr.Blocks() as demo:
74
  with gr.Row():
75
  prompt = gr.Textbox(label="Prompt")
 
83
  seed_output = gr.Number(label="Seed")
84
 
85
  with gr.Accordion("Advanced Settings", open=False):
86
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=None)
 
87
  width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
88
  height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
89
  num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
90
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=0.0)
91
 
92
  generate.click(
93
  infer,
94
+ inputs=[prompt, init_image, seed, width, height, num_inference_steps, guidance_scale],
95
  outputs=[result, seed_output]
96
  )
97