Jordan Legg commited on
Commit
13ab5d1
Β·
1 Parent(s): f0decf0

twin model

Browse files
Files changed (1) hide show
  1. app.py +32 -13
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- import spaces
5
  import torch
6
  from PIL import Image
7
  from torchvision import transforms
8
- from diffusers import DiffusionPipeline
9
 
10
  # Define constants
11
  dtype = torch.bfloat16
@@ -13,7 +12,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
15
 
16
- # Load the diffusion pipeline with optimizations
 
 
 
 
17
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
18
  pipe.enable_model_cpu_offload()
19
  pipe.vae.enable_slicing()
@@ -27,11 +30,13 @@ def preprocess_image(image, image_size):
27
  transforms.Normalize([0.5], [0.5])
28
  ])
29
  image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
 
30
  return image
31
 
32
  def encode_image(image, vae):
33
  with torch.no_grad():
34
  latents = vae.encode(image).latent_dist.sample() * 0.18215
 
35
  return latents
36
 
37
  @spaces.GPU()
@@ -61,20 +66,33 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
61
  return fallback_image, seed
62
  else:
63
  # img2img case
64
- vae_image_size = pipe.vae.config.sample_size # Ensure this is correct
 
65
  init_image = init_image.convert("RGB")
66
  init_image = preprocess_image(init_image, vae_image_size)
67
- latents = encode_image(init_image, pipe.vae)
68
-
 
 
 
 
 
 
69
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
70
- latent_channels = pipe.vae.config.latent_channels # Ensure this is correct
71
- if latent_channels != 64:
72
- conv = torch.nn.Conv2d(latent_channels, 64, kernel_size=1).to(device, dtype=dtype)
 
 
 
 
73
  latents = conv(latents)
74
 
75
- latents = latents.permute(0, 2, 3, 1).contiguous().view(-1, 64)
 
76
 
77
  try:
 
78
  # Determine if 'timesteps' is required for the transformer
79
  if hasattr(pipe.transformer, 'forward') and hasattr(pipe.transformer.forward, '__code__') and 'timesteps' in pipe.transformer.forward.__code__.co_varnames:
80
  timestep = torch.tensor([num_inference_steps], device=device, dtype=dtype)
@@ -86,6 +104,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
86
  return fallback_image, seed
87
 
88
  try:
 
89
  image = pipe(
90
  prompt=prompt,
91
  height=height,
@@ -95,6 +114,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
95
  guidance_scale=0.0,
96
  latents=latents
97
  ).images[0]
 
98
  except Exception as e:
99
  print(f"Pipeline call with latents failed with error: {e}")
100
  return fallback_image, seed
@@ -210,9 +230,8 @@ with gr.Blocks(css=css) as demo:
210
  cache_examples="lazy"
211
  )
212
 
213
- gr.on(
214
- triggers=[run_button.click, prompt.submit],
215
- fn=infer,
216
  inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
217
  outputs=[result, seed]
218
  )
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
  import torch
5
  from PIL import Image
6
  from torchvision import transforms
7
+ from diffusers import DiffusionPipeline, AutoencoderKL
8
 
9
  # Define constants
10
  dtype = torch.bfloat16
 
12
  MAX_SEED = np.iinfo(np.int32).max
13
  MAX_IMAGE_SIZE = 2048
14
 
15
+ # Load the initial VAE model for preprocessing
16
+ vae_model_name = "CompVis/stable-diffusion-v1-4" # Example VAE model
17
+ vae = AutoencoderKL.from_pretrained(vae_model_name).to(device)
18
+
19
+ # Load the FLUX diffusion pipeline with optimizations
20
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
21
  pipe.enable_model_cpu_offload()
22
  pipe.vae.enable_slicing()
 
30
  transforms.Normalize([0.5], [0.5])
31
  ])
32
  image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
33
+ print("Image processed successfully.")
34
  return image
35
 
36
  def encode_image(image, vae):
37
  with torch.no_grad():
38
  latents = vae.encode(image).latent_dist.sample() * 0.18215
39
+ print("Image encoded successfully.")
40
  return latents
41
 
42
  @spaces.GPU()
 
66
  return fallback_image, seed
67
  else:
68
  # img2img case
69
+ print("Initial image provided, starting preprocessing...")
70
+ vae_image_size = 1024 # Using FLUX VAE sample size for preprocessing
71
  init_image = init_image.convert("RGB")
72
  init_image = preprocess_image(init_image, vae_image_size)
73
+
74
+ print("Starting encoding of the image...")
75
+ latents = encode_image(init_image, vae)
76
+
77
+ print(f"Latents shape after encoding: {latents.shape}")
78
+
79
+ # Ensure the latents size matches the expected input size for the FLUX model
80
+ print("Interpolating latents to match model's input size...")
81
  latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
82
+
83
+ latent_channels = 16 # Using FLUX VAE latent channels
84
+ print(f"Latent channels from VAE: {latent_channels}, expected by FLUX model: {pipe.vae.config.latent_channels}")
85
+
86
+ if latent_channels != pipe.vae.config.latent_channels:
87
+ print(f"Adjusting latent channels from {latent_channels} to {pipe.vae.config.latent_channels}")
88
+ conv = torch.nn.Conv2d(latent_channels, pipe.vae.config.latent_channels, kernel_size=1).to(device, dtype=dtype)
89
  latents = conv(latents)
90
 
91
+ latents = latents.permute(0, 2, 3, 1).contiguous().view(-1, pipe.vae.config.latent_channels)
92
+ print(f"Latents shape after permutation: {latents.shape}")
93
 
94
  try:
95
+ print("Sending latents to the FLUX transformer...")
96
  # Determine if 'timesteps' is required for the transformer
97
  if hasattr(pipe.transformer, 'forward') and hasattr(pipe.transformer.forward, '__code__') and 'timesteps' in pipe.transformer.forward.__code__.co_varnames:
98
  timestep = torch.tensor([num_inference_steps], device=device, dtype=dtype)
 
104
  return fallback_image, seed
105
 
106
  try:
107
+ print("Generating final image with the FLUX pipeline...")
108
  image = pipe(
109
  prompt=prompt,
110
  height=height,
 
114
  guidance_scale=0.0,
115
  latents=latents
116
  ).images[0]
117
+ print("Image generation completed.")
118
  except Exception as e:
119
  print(f"Pipeline call with latents failed with error: {e}")
120
  return fallback_image, seed
 
230
  cache_examples="lazy"
231
  )
232
 
233
+ run_button.click(
234
+ infer,
 
235
  inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
236
  outputs=[result, seed]
237
  )