Jordan Legg commited on
Commit
817a141
Β·
1 Parent(s): bf5cb46

change image to latents

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -12,7 +12,7 @@ MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = 2048
13
  MIN_IMAGE_SIZE = 256
14
  DEFAULT_IMAGE_SIZE = 1024
15
- MAX_PROMPT_LENGTH = 256 # Changed to 256 as per FLUX.1 schnell requirements
16
 
17
  # Check for GPU availability
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -23,10 +23,7 @@ dtype = torch.float16 if device == "cuda" else torch.float32
23
 
24
  def load_model():
25
  try:
26
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
27
- pipe.enable_model_cpu_offload()
28
- pipe.enable_attention_slicing()
29
- return pipe
30
  except Exception as e:
31
  raise RuntimeError(f"Failed to load the model: {str(e)}")
32
 
@@ -72,21 +69,30 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=DEFAULT_
72
  max_sequence_length = min(MAX_PROMPT_LENGTH, len(prompt))
73
 
74
  if init_image is not None:
 
75
  init_image = init_image.convert("RGB")
76
  init_image = preprocess_image(init_image, (height, width))
77
- latents = encode_image(init_image, pipe.vae)
78
- latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear', align_corners=False)
 
 
 
 
 
 
 
79
  image = pipe(
80
  prompt=prompt,
81
- image=latents, # Changed from latents=latents to image=latents
82
  height=height,
83
  width=width,
84
  num_inference_steps=num_inference_steps,
85
  generator=generator,
86
  guidance_scale=0.0,
 
87
  max_sequence_length=max_sequence_length
88
  ).images[0]
89
  else:
 
90
  image = pipe(
91
  prompt=prompt,
92
  height=height,
 
12
  MAX_IMAGE_SIZE = 2048
13
  MIN_IMAGE_SIZE = 256
14
  DEFAULT_IMAGE_SIZE = 1024
15
+ MAX_PROMPT_LENGTH = 500
16
 
17
  # Check for GPU availability
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
 
24
  def load_model():
25
  try:
26
+ return DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
 
 
27
  except Exception as e:
28
  raise RuntimeError(f"Failed to load the model: {str(e)}")
29
 
 
69
  max_sequence_length = min(MAX_PROMPT_LENGTH, len(prompt))
70
 
71
  if init_image is not None:
72
+ # Process img2img
73
  init_image = init_image.convert("RGB")
74
  init_image = preprocess_image(init_image, (height, width))
75
+
76
+ # Encode the image using the VAE
77
+ with torch.no_grad():
78
+ init_latents = pipe.vae.encode(init_image).latent_dist.sample(generator=generator)
79
+ init_latents = 0.18215 * init_latents
80
+
81
+ # Ensure latents are correctly shaped
82
+ init_latents = torch.nn.functional.interpolate(init_latents, size=(height // 8, width // 8), mode='bilinear', align_corners=False)
83
+
84
  image = pipe(
85
  prompt=prompt,
 
86
  height=height,
87
  width=width,
88
  num_inference_steps=num_inference_steps,
89
  generator=generator,
90
  guidance_scale=0.0,
91
+ latents=init_latents, # Use latents instead of image
92
  max_sequence_length=max_sequence_length
93
  ).images[0]
94
  else:
95
+ # Process text2img
96
  image = pipe(
97
  prompt=prompt,
98
  height=height,