Ryukijano commited on
Commit
3d9f174
·
verified ·
1 Parent(s): eef4bc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -20,7 +20,9 @@ DEFAULT_INFERENCE_STEPS = 1
20
 
21
  # Device and model setup
22
  dtype = torch.float16
23
- # Download the LoRA weights using hf_hub_download
 
 
24
  lora_weights_path = hf_hub_download(
25
  repo_id="hugovntr/flux-schnell-realism",
26
  filename="schnell-realism_v2.3.safetensors",
@@ -30,17 +32,17 @@ pipe = FluxWithCFGPipeline.from_pretrained(
30
  "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
31
  )
32
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
33
- pipe.to("cuda")
34
 
35
- # Load the LoRA weights using the downloaded path
36
  pipe.load_lora_weights(lora_weights_path, adapter_name="better")
37
  pipe.set_adapters(["better"], adapter_weights=[1.0])
38
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
39
  pipe.unload_lora_weights()
40
 
41
  # Memory optimizations
42
- pipe.transformer.to(memory_format=torch.channels_last) # Channels last
43
- pipe.enable_xformers_memory_efficient_attention() # Flash Attention
44
 
45
  # CUDA Graph setup
46
  static_inputs = None
@@ -51,7 +53,6 @@ def setup_cuda_graph(prompt, height, width, num_inference_steps):
51
  global static_inputs, static_model, graph
52
 
53
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
54
- device = "cuda"
55
  num_images_per_prompt = 1
56
 
57
  prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
@@ -91,11 +92,11 @@ def setup_cuda_graph(prompt, height, width, num_inference_steps):
91
  guidance = torch.full([1], 3.5, device=device, dtype=torch.float16).expand(latents.shape[0]) if pipe.transformer.config.guidance_embeds else None
92
 
93
  static_inputs = {
94
- "hidden_states": latents,
95
- "timestep": timesteps,
96
- "guidance": guidance,
97
- "pooled_projections": pooled_prompt_embeds,
98
- "encoder_hidden_states": prompt_embeds,
99
  "txt_ids": text_ids,
100
  "img_ids": latent_image_ids,
101
  "joint_attention_kwargs": None,
@@ -108,7 +109,7 @@ def setup_cuda_graph(prompt, height, width, num_inference_steps):
108
  static_output = static_model(**static_inputs)
109
 
110
  # Inference function
111
- @spaces.GPU(duration=25)
112
  def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
113
  global static_inputs, graph
114
 
 
20
 
21
  # Device and model setup
22
  dtype = torch.float16
23
+ device = "cuda" # Explicitly set device to CUDA
24
+
25
+ # Download the LoRA weights
26
  lora_weights_path = hf_hub_download(
27
  repo_id="hugovntr/flux-schnell-realism",
28
  filename="schnell-realism_v2.3.safetensors",
 
32
  "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
33
  )
34
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
35
+ pipe.to(device) # Move the pipeline to CUDA
36
 
37
+ # Load the LoRA weights
38
  pipe.load_lora_weights(lora_weights_path, adapter_name="better")
39
  pipe.set_adapters(["better"], adapter_weights=[1.0])
40
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
41
  pipe.unload_lora_weights()
42
 
43
  # Memory optimizations
44
+ pipe.transformer.to(memory_format=torch.channels_last)
45
+ pipe.enable_xformers_memory_efficient_attention()
46
 
47
  # CUDA Graph setup
48
  static_inputs = None
 
53
  global static_inputs, static_model, graph
54
 
55
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
 
56
  num_images_per_prompt = 1
57
 
58
  prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
 
92
  guidance = torch.full([1], 3.5, device=device, dtype=torch.float16).expand(latents.shape[0]) if pipe.transformer.config.guidance_embeds else None
93
 
94
  static_inputs = {
95
+ "hidden_states": latents.to(device),
96
+ "timestep": timesteps.to(device),
97
+ "guidance": guidance.to(device) if guidance is not None else None,
98
+ "pooled_projections": pooled_prompt_embeds.to(device),
99
+ "encoder_hidden_states": prompt_embeds.to(device),
100
  "txt_ids": text_ids,
101
  "img_ids": latent_image_ids,
102
  "joint_attention_kwargs": None,
 
109
  static_output = static_model(**static_inputs)
110
 
111
  # Inference function
112
+ # @spaces.GPU(duration=25) # Remove decorator
113
  def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
114
  global static_inputs, graph
115