Jordan Legg commited on
Commit
383a90d
Β·
1 Parent(s): d9f1205

move back to complex code

Browse files
Files changed (1) hide show
  1. app.py +65 -34
app.py CHANGED
@@ -1,15 +1,21 @@
1
  import spaces
2
  import gradio as gr
 
 
3
  import torch
 
4
  from PIL import Image
 
5
  from diffusers import DiffusionPipeline
6
 
7
-
8
  # Constants
9
- MAX_SEED = 2**32 - 1
10
- MAX_IMAGE_SIZE = 2048
11
  dtype = torch.bfloat16
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
13
 
14
  # Load FLUX model
15
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
@@ -17,61 +23,86 @@ pipe.enable_model_cpu_offload()
17
  pipe.vae.enable_slicing()
18
  pipe.vae.enable_tiling()
19
 
20
- def print_model_shapes(pipe):
21
- print("Model component shapes:")
22
- print(f"VAE Encoder: {pipe.vae.encoder}")
23
- print(f"VAE Decoder: {pipe.vae.decoder}")
24
- print(f"x_embedder shape: {pipe.transformer.x_embedder.weight.shape}")
25
- print(f"First transformer block shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
26
 
27
- print_model_shapes(pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  @spaces.GPU()
30
- def infer(prompt, init_image=None, seed=None, width=1024, height=1024, num_inference_steps=4, guidance_scale=0.0):
31
- generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
 
 
32
 
33
  try:
34
  if init_image is None:
35
  # text2img case
36
- print("Running text-to-image generation")
37
  image = pipe(
38
  prompt=prompt,
39
  height=height,
40
  width=width,
41
  num_inference_steps=num_inference_steps,
42
  generator=generator,
43
- guidance_scale=guidance_scale
44
  ).images[0]
45
  else:
46
  # img2img case
47
- print("Running image-to-image generation")
48
- init_image = init_image.convert("RGB").resize((width, height))
 
 
 
 
 
 
 
 
 
 
 
49
  image = pipe(
50
  prompt=prompt,
51
- image=init_image,
 
52
  num_inference_steps=num_inference_steps,
53
  generator=generator,
54
- guidance_scale=guidance_scale
 
55
  ).images[0]
56
 
57
  return image, seed
58
- except RuntimeError as e:
59
- if "mat1 and mat2 shapes cannot be multiplied" in str(e):
60
- print("Matrix multiplication error detected. Tensor shapes:")
61
- print(e)
62
- # Here you could add code to print shapes of specific tensors if needed
63
- else:
64
- print(f"RuntimeError during inference: {e}")
65
- import traceback
66
- traceback.print_exc()
67
- return Image.new("RGB", (width, height), (255, 0, 0)), seed
68
  except Exception as e:
69
- print(f"Unexpected error during inference: {e}")
70
  import traceback
71
  traceback.print_exc()
72
- return Image.new("RGB", (width, height), (255, 0, 0)), seed
73
 
74
- # Gradio interface
75
  with gr.Blocks() as demo:
76
  with gr.Row():
77
  prompt = gr.Textbox(label="Prompt")
@@ -85,15 +116,15 @@ with gr.Blocks() as demo:
85
  seed_output = gr.Number(label="Seed")
86
 
87
  with gr.Accordion("Advanced Settings", open=False):
88
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=None)
 
89
  width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
90
  height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
91
  num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
92
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=0.0)
93
 
94
  generate.click(
95
  infer,
96
- inputs=[prompt, init_image, seed, width, height, num_inference_steps, guidance_scale],
97
  outputs=[result, seed_output]
98
  )
99
 
 
1
  import spaces
2
  import gradio as gr
3
+ import numpy as np
4
+ import random
5
  import torch
6
+ import torch.nn as nn
7
  from PIL import Image
8
+ from torchvision import transforms
9
  from diffusers import DiffusionPipeline
10
 
 
11
  # Constants
 
 
12
  dtype = torch.bfloat16
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ MAX_SEED = np.iinfo(np.int32).max
15
+ MAX_IMAGE_SIZE = 2048
16
+ LATENT_CHANNELS = 16
17
+ TRANSFORMER_IN_CHANNELS = 64
18
+ SCALING_FACTOR = 0.3611
19
 
20
  # Load FLUX model
21
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
23
  pipe.vae.enable_slicing()
24
  pipe.vae.enable_tiling()
25
 
26
+ # Add a projection layer to match transformer input
27
+ projection = nn.Linear(LATENT_CHANNELS, TRANSFORMER_IN_CHANNELS).to(device).to(dtype)
 
 
 
 
28
 
29
+ def preprocess_image(image, image_size):
30
+ preprocess = transforms.Compose([
31
+ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize([0.5], [0.5])
34
+ ])
35
+ image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
36
+ return image
37
+
38
+ def process_latents(latents, height, width):
39
+ print(f"Input latent shape: {latents.shape}")
40
+
41
+ # Ensure latents are the correct shape
42
+ if latents.shape[2:] != (height // 8, width // 8):
43
+ latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
44
+ print(f"Latent shape after potential interpolation: {latents.shape}")
45
+
46
+ # Reshape latents to [batch_size, seq_len, channels]
47
+ latents = latents.permute(0, 2, 3, 1).reshape(1, -1, LATENT_CHANNELS)
48
+ print(f"Reshaped latent shape: {latents.shape}")
49
+
50
+ # Project latents from 16 to 64 dimensions
51
+ latents = projection(latents)
52
+ print(f"Projected latent shape: {latents.shape}")
53
+
54
+ return latents
55
 
56
  @spaces.GPU()
57
+ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
58
+ if randomize_seed:
59
+ seed = random.randint(0, MAX_SEED)
60
+ generator = torch.Generator(device=device).manual_seed(seed)
61
 
62
  try:
63
  if init_image is None:
64
  # text2img case
 
65
  image = pipe(
66
  prompt=prompt,
67
  height=height,
68
  width=width,
69
  num_inference_steps=num_inference_steps,
70
  generator=generator,
71
+ guidance_scale=0.0
72
  ).images[0]
73
  else:
74
  # img2img case
75
+ init_image = init_image.convert("RGB")
76
+ init_image = preprocess_image(init_image, 1024) # Using 1024 as FLUX VAE sample size
77
+
78
+ # Encode the image using FLUX VAE
79
+ latents = pipe.vae.encode(init_image).latent_dist.sample() * SCALING_FACTOR
80
+ print(f"Initial latent shape from VAE: {latents.shape}")
81
+
82
+ # Process latents to match transformer input
83
+ latents = process_latents(latents, height, width)
84
+
85
+ print(f"x_embedder weight shape: {pipe.transformer.x_embedder.weight.shape}")
86
+ print(f"First transformer block input shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
87
+
88
  image = pipe(
89
  prompt=prompt,
90
+ height=height,
91
+ width=width,
92
  num_inference_steps=num_inference_steps,
93
  generator=generator,
94
+ guidance_scale=0.0,
95
+ latents=latents
96
  ).images[0]
97
 
98
  return image, seed
 
 
 
 
 
 
 
 
 
 
99
  except Exception as e:
100
+ print(f"Error during inference: {e}")
101
  import traceback
102
  traceback.print_exc()
103
+ return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
104
 
105
+ # Gradio interface setup
106
  with gr.Blocks() as demo:
107
  with gr.Row():
108
  prompt = gr.Textbox(label="Prompt")
 
116
  seed_output = gr.Number(label="Seed")
117
 
118
  with gr.Accordion("Advanced Settings", open=False):
119
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
120
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
121
  width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
122
  height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
123
  num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
 
124
 
125
  generate.click(
126
  infer,
127
+ inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
128
  outputs=[result, seed_output]
129
  )
130