Jordan Legg commited on
Commit
f071803
Β·
1 Parent(s): 817a141
Files changed (1) hide show
  1. app.py +24 -13
app.py CHANGED
@@ -12,7 +12,7 @@ MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = 2048
13
  MIN_IMAGE_SIZE = 256
14
  DEFAULT_IMAGE_SIZE = 1024
15
- MAX_PROMPT_LENGTH = 500
16
 
17
  # Check for GPU availability
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -23,14 +23,19 @@ dtype = torch.float16 if device == "cuda" else torch.float32
23
 
24
  def load_model():
25
  try:
26
- return DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
 
 
 
 
27
  except Exception as e:
28
  raise RuntimeError(f"Failed to load the model: {str(e)}")
29
 
30
  # Load the diffusion pipeline
31
  pipe = load_model()
32
 
33
- def preprocess_image(image, target_size=(512, 512)):
34
  # Preprocess the image for the VAE
35
  preprocess = transforms.Compose([
36
  transforms.Resize(target_size, interpolation=transforms.InterpolationMode.LANCZOS),
@@ -57,7 +62,7 @@ def validate_inputs(prompt, width, height, num_inference_steps):
57
  raise ValueError("Number of inference steps must be between 1 and 50.")
58
 
59
  @spaces.GPU()
60
- def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=DEFAULT_IMAGE_SIZE, height=DEFAULT_IMAGE_SIZE, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
61
  try:
62
  validate_inputs(prompt, width, height, num_inference_steps)
63
 
@@ -74,13 +79,15 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=DEFAULT_
74
  init_image = preprocess_image(init_image, (height, width))
75
 
76
  # Encode the image using the VAE
77
- with torch.no_grad():
78
- init_latents = pipe.vae.encode(init_image).latent_dist.sample(generator=generator)
79
- init_latents = 0.18215 * init_latents
80
 
81
  # Ensure latents are correctly shaped
82
  init_latents = torch.nn.functional.interpolate(init_latents, size=(height // 8, width // 8), mode='bilinear', align_corners=False)
83
 
 
 
 
 
84
  image = pipe(
85
  prompt=prompt,
86
  height=height,
@@ -88,7 +95,7 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=DEFAULT_
88
  num_inference_steps=num_inference_steps,
89
  generator=generator,
90
  guidance_scale=0.0,
91
- latents=init_latents, # Use latents instead of image
92
  max_sequence_length=max_sequence_length
93
  ).images[0]
94
  else:
@@ -209,6 +216,13 @@ with gr.Blocks(css=css) as demo:
209
  step=1,
210
  value=4,
211
  )
 
 
 
 
 
 
 
212
 
213
  gr.Examples(
214
  examples=examples,
@@ -221,12 +235,9 @@ with gr.Blocks(css=css) as demo:
221
  gr.on(
222
  triggers=[run_button.click, prompt.submit],
223
  fn=infer,
224
- inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
225
  outputs=[result, seed]
226
  )
227
 
228
  if __name__ == "__main__":
229
- demo.launch()
230
-
231
-
232
-
 
12
  MAX_IMAGE_SIZE = 2048
13
  MIN_IMAGE_SIZE = 256
14
  DEFAULT_IMAGE_SIZE = 1024
15
+ MAX_PROMPT_LENGTH = 256 # Changed to 256 as per FLUX.1-schnell requirements
16
 
17
  # Check for GPU availability
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
 
24
  def load_model():
25
  try:
26
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
27
+ pipe.to(device)
28
+ pipe.enable_model_cpu_offload()
29
+ pipe.vae.enable_slicing()
30
+ pipe.vae.enable_tiling()
31
+ return pipe
32
  except Exception as e:
33
  raise RuntimeError(f"Failed to load the model: {str(e)}")
34
 
35
  # Load the diffusion pipeline
36
  pipe = load_model()
37
 
38
+ def preprocess_image(image, target_size):
39
  # Preprocess the image for the VAE
40
  preprocess = transforms.Compose([
41
  transforms.Resize(target_size, interpolation=transforms.InterpolationMode.LANCZOS),
 
62
  raise ValueError("Number of inference steps must be between 1 and 50.")
63
 
64
  @spaces.GPU()
65
+ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=DEFAULT_IMAGE_SIZE, height=DEFAULT_IMAGE_SIZE, num_inference_steps=4, strength=0.8, progress=gr.Progress(track_tqdm=True)):
66
  try:
67
  validate_inputs(prompt, width, height, num_inference_steps)
68
 
 
79
  init_image = preprocess_image(init_image, (height, width))
80
 
81
  # Encode the image using the VAE
82
+ init_latents = encode_image(init_image, pipe.vae)
 
 
83
 
84
  # Ensure latents are correctly shaped
85
  init_latents = torch.nn.functional.interpolate(init_latents, size=(height // 8, width // 8), mode='bilinear', align_corners=False)
86
 
87
+ # Add noise to latents
88
+ noise = torch.randn_like(init_latents)
89
+ latents = noise + strength * (init_latents - noise)
90
+
91
  image = pipe(
92
  prompt=prompt,
93
  height=height,
 
95
  num_inference_steps=num_inference_steps,
96
  generator=generator,
97
  guidance_scale=0.0,
98
+ latents=latents,
99
  max_sequence_length=max_sequence_length
100
  ).images[0]
101
  else:
 
216
  step=1,
217
  value=4,
218
  )
219
+ strength = gr.Slider(
220
+ label="Strength (for img2img)",
221
+ minimum=0.0,
222
+ maximum=1.0,
223
+ step=0.01,
224
+ value=0.8,
225
+ )
226
 
227
  gr.Examples(
228
  examples=examples,
 
235
  gr.on(
236
  triggers=[run_button.click, prompt.submit],
237
  fn=infer,
238
+ inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps, strength],
239
  outputs=[result, seed]
240
  )
241
 
242
  if __name__ == "__main__":
243
+ demo.launch()