gokaygokay commited on
Commit
9470add
·
1 Parent(s): 692642f
Files changed (1) hide show
  1. app.py +29 -23
app.py CHANGED
@@ -19,6 +19,8 @@ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
19
  import gradio as gr
20
  import shutil
21
  import tempfile
 
 
22
  from flux_8bit_lora import FluxPipeline
23
 
24
  from src.utils.train_util import instantiate_from_config
@@ -72,21 +74,26 @@ else:
72
 
73
  device = torch.device('cuda')
74
 
75
- # Load Flux pipeline
76
- flux_pipe = FluxPipeline.from_pretrained(
77
- "Freepik/flux.1-lite-8B-alpha",
78
- torch_dtype=torch.bfloat16
79
- )
80
- flux_pipe.load_lora_weights(hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors"))
81
- flux_pipe.fuse_lora(lora_scale=1)
82
- flux_pipe.to(device="cuda", dtype=torch.bfloat16)
83
 
 
84
  print('Loading and fusing lora, please wait...')
85
- flux_pipe.load_lora_weights(hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors"))
86
- # We need this scaling because SimpleTuner fixes the alpha to 16, might be fixed later in diffusers
87
- # See https://github.com/huggingface/diffusers/issues/9134
88
- flux_pipe.fuse_lora(lora_scale=1.)
89
- flux_pipe.unload_lora_weights()
 
 
 
 
 
 
 
 
 
 
90
 
91
  # Load 3D generation models
92
  config_path = 'configs/instant-mesh-large.yaml'
@@ -150,16 +157,15 @@ ts_cutoff = 2
150
 
151
  @spaces.GPU
152
  def generate_flux_image(prompt, height, width, steps, scales, seed):
153
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("Flux inference"):
154
- return flux_pipe(
155
- prompt=[prompt],
156
- generator=torch.Generator().manual_seed(int(seed)),
157
- num_inference_steps=int(steps),
158
- guidance_scale=float(scales),
159
- height=int(height),
160
- width=int(width),
161
- max_sequence_length=256
162
- ).images[0]
163
 
164
 
165
  @spaces.GPU
 
19
  import gradio as gr
20
  import shutil
21
  import tempfile
22
+ from functools import partial
23
+ from optimum.quanto import quantize, qfloat8, freeze
24
  from flux_8bit_lora import FluxPipeline
25
 
26
  from src.utils.train_util import instantiate_from_config
 
74
 
75
  device = torch.device('cuda')
76
 
77
+ base_model = "black-forest-labs/FLUX.1-dev"
78
+ pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token)
 
 
 
 
 
 
79
 
80
+ # Load and fuse LoRA BEFORE quantizing
81
  print('Loading and fusing lora, please wait...')
82
+ lora_path = hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors")
83
+ pipe.load_lora_weights(lora_path)
84
+ pipe.fuse_lora(lora_scale=1.0)
85
+ pipe.unload_lora_weights()
86
+ pipe.transformer.to(device, dtype=torch.bfloat16)
87
+
88
+ # Now quantize after LoRA is fused
89
+ print('Quantizing, please wait...')
90
+ # Try qint8 if qfloat8 produces invalid values
91
+ quantize(pipe.transformer, qfloat8) # Consider changing to qint8 if you get invalid values
92
+ freeze(pipe.transformer)
93
+ pipe.transformer.to(device, dtype=torch.bfloat16)
94
+ print('Model quantized!')
95
+ pipe.enable_model_cpu_offload()
96
+
97
 
98
  # Load 3D generation models
99
  config_path = 'configs/instant-mesh-large.yaml'
 
157
 
158
  @spaces.GPU
159
  def generate_flux_image(prompt, height, width, steps, scales, seed):
160
+ return pipe(
161
+ prompt=prompt,
162
+ width=int(height),
163
+ height=int(width),
164
+ num_inference_steps=int(steps),
165
+ generator=torch.Generator().manual_seed(int(seed)),
166
+ guidance_scale=float(scales),
167
+ timestep_to_start_cfg=ts_cutoff,
168
+ ).images[0]
 
169
 
170
 
171
  @spaces.GPU