Jordan Legg commited on
Commit
044186b
Β·
1 Parent(s): b11c213

fix: using the VAE directly

Browse files
Files changed (1) hide show
  1. app.py +25 -5
app.py CHANGED
@@ -3,6 +3,8 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
 
 
6
  from diffusers import DiffusionPipeline
7
 
8
  # Define constants
@@ -14,6 +16,22 @@ MAX_IMAGE_SIZE = 2048
14
  # Load the diffusion pipeline
15
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @spaces.GPU()
18
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
19
  if randomize_seed:
@@ -23,22 +41,23 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
23
  if init_image is not None:
24
  # Process img2img
25
  init_image = init_image.convert("RGB")
26
- init_image = pipe.preprocess(init_image).unsqueeze(0).to(device, dtype)
 
27
  image = pipe(
28
  prompt=prompt,
29
- init_image=init_image,
30
- width=width,
31
  height=height,
 
32
  num_inference_steps=num_inference_steps,
33
  generator=generator,
34
- guidance_scale=0.0
 
35
  ).images[0]
36
  else:
37
  # Process text2img
38
  image = pipe(
39
  prompt=prompt,
40
- width=width,
41
  height=height,
 
42
  num_inference_steps=num_inference_steps,
43
  generator=generator,
44
  guidance_scale=0.0
@@ -164,3 +183,4 @@ with gr.Blocks(css=css) as demo:
164
 
165
  demo.launch()
166
 
 
 
3
  import random
4
  import spaces
5
  import torch
6
+ from PIL import Image
7
+ from torchvision import transforms
8
  from diffusers import DiffusionPipeline
9
 
10
  # Define constants
 
16
  # Load the diffusion pipeline
17
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
18
 
19
+ def preprocess_image(image):
20
+ # Preprocess the image for the VAE
21
+ preprocess = transforms.Compose([
22
+ transforms.Resize((512, 512)), # Adjust the size as needed
23
+ transforms.ToTensor(),
24
+ transforms.Normalize([0.5], [0.5])
25
+ ])
26
+ image = preprocess(image).unsqueeze(0).to(device)
27
+ return image
28
+
29
+ def encode_image(image, vae):
30
+ # Encode the image using the VAE
31
+ with torch.no_grad():
32
+ latents = vae.encode(image).latent_dist.sample() * 0.18215
33
+ return latents
34
+
35
  @spaces.GPU()
36
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
37
  if randomize_seed:
 
41
  if init_image is not None:
42
  # Process img2img
43
  init_image = init_image.convert("RGB")
44
+ init_image = preprocess_image(init_image)
45
+ latents = encode_image(init_image, pipe.vae)
46
  image = pipe(
47
  prompt=prompt,
 
 
48
  height=height,
49
+ width=width,
50
  num_inference_steps=num_inference_steps,
51
  generator=generator,
52
+ guidance_scale=0.0,
53
+ latents=latents
54
  ).images[0]
55
  else:
56
  # Process text2img
57
  image = pipe(
58
  prompt=prompt,
 
59
  height=height,
60
+ width=width,
61
  num_inference_steps=num_inference_steps,
62
  generator=generator,
63
  guidance_scale=0.0
 
183
 
184
  demo.launch()
185
 
186
+