Jordan Legg commited on
Commit
bf5cb46
Β·
1 Parent(s): 3ae9c83

trying to fix mat1 and mat2

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -12,7 +12,7 @@ MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = 2048
13
  MIN_IMAGE_SIZE = 256
14
  DEFAULT_IMAGE_SIZE = 1024
15
- MAX_PROMPT_LENGTH = 500
16
 
17
  # Check for GPU availability
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -23,7 +23,10 @@ dtype = torch.float16 if device == "cuda" else torch.float32
23
 
24
  def load_model():
25
  try:
26
- return DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
 
 
27
  except Exception as e:
28
  raise RuntimeError(f"Failed to load the model: {str(e)}")
29
 
@@ -65,19 +68,23 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=DEFAULT_
65
  seed = random.randint(0, MAX_SEED)
66
  generator = torch.Generator(device=device).manual_seed(seed)
67
 
 
 
 
68
  if init_image is not None:
69
  init_image = init_image.convert("RGB")
70
  init_image = preprocess_image(init_image, (height, width))
71
  latents = encode_image(init_image, pipe.vae)
72
- latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
73
  image = pipe(
74
  prompt=prompt,
 
75
  height=height,
76
  width=width,
77
  num_inference_steps=num_inference_steps,
78
  generator=generator,
79
  guidance_scale=0.0,
80
- latents=latents
81
  ).images[0]
82
  else:
83
  image = pipe(
@@ -86,7 +93,8 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=DEFAULT_
86
  width=width,
87
  num_inference_steps=num_inference_steps,
88
  generator=generator,
89
- guidance_scale=0.0
 
90
  ).images[0]
91
 
92
  return image, seed
 
12
  MAX_IMAGE_SIZE = 2048
13
  MIN_IMAGE_SIZE = 256
14
  DEFAULT_IMAGE_SIZE = 1024
15
+ MAX_PROMPT_LENGTH = 256 # Changed to 256 as per FLUX.1 schnell requirements
16
 
17
  # Check for GPU availability
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
 
24
  def load_model():
25
  try:
26
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
27
+ pipe.enable_model_cpu_offload()
28
+ pipe.enable_attention_slicing()
29
+ return pipe
30
  except Exception as e:
31
  raise RuntimeError(f"Failed to load the model: {str(e)}")
32
 
 
68
  seed = random.randint(0, MAX_SEED)
69
  generator = torch.Generator(device=device).manual_seed(seed)
70
 
71
+ # Ensure max_sequence_length is not more than 256
72
+ max_sequence_length = min(MAX_PROMPT_LENGTH, len(prompt))
73
+
74
  if init_image is not None:
75
  init_image = init_image.convert("RGB")
76
  init_image = preprocess_image(init_image, (height, width))
77
  latents = encode_image(init_image, pipe.vae)
78
+ latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear', align_corners=False)
79
  image = pipe(
80
  prompt=prompt,
81
+ image=latents, # Changed from latents=latents to image=latents
82
  height=height,
83
  width=width,
84
  num_inference_steps=num_inference_steps,
85
  generator=generator,
86
  guidance_scale=0.0,
87
+ max_sequence_length=max_sequence_length
88
  ).images[0]
89
  else:
90
  image = pipe(
 
93
  width=width,
94
  num_inference_steps=num_inference_steps,
95
  generator=generator,
96
+ guidance_scale=0.0,
97
+ max_sequence_length=max_sequence_length
98
  ).images[0]
99
 
100
  return image, seed