gokaygokay commited on
Commit
3880b98
·
1 Parent(s): 44d2484
Files changed (2) hide show
  1. app.py +26 -2
  2. stable_diffusion_model.py +0 -0
app.py CHANGED
@@ -3,10 +3,11 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
  from huggingface_hub import hf_hub_download
 
10
  import os
11
 
12
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
@@ -23,13 +24,36 @@ MAX_IMAGE_SIZE = 2048
23
 
24
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
25
 
 
 
 
 
 
 
 
 
26
  # Load and fuse LoRA BEFORE quantizing
27
  print('Loading and fusing lora, please wait...')
28
  lora_path = hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors")
29
  pipe.load_lora_weights(lora_path)
30
  pipe.fuse_lora(lora_scale=0.125)
31
  pipe.unload_lora_weights()
32
- pipe.transformer.to(device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  @spaces.GPU(duration=75)
35
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
 
3
  import random
4
  import spaces
5
  import torch
6
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
  from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
  from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
  from huggingface_hub import hf_hub_download
10
+ from optimum.quanto import freeze, qfloat8, quantize
11
  import os
12
 
13
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
24
 
25
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
26
 
27
+ # Load base model first (before quantization)
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ "black-forest-labs/FLUX.1-dev",
30
+ torch_dtype=dtype,
31
+ vae=taef1,
32
+ token=huggingface_token
33
+ )
34
+
35
  # Load and fuse LoRA BEFORE quantizing
36
  print('Loading and fusing lora, please wait...')
37
  lora_path = hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors")
38
  pipe.load_lora_weights(lora_path)
39
  pipe.fuse_lora(lora_scale=0.125)
40
  pipe.unload_lora_weights()
41
+
42
+ # Quantize the transformer
43
+ print("Quantizing transformer")
44
+ quantize(pipe.transformer, weights=qfloat8)
45
+ freeze(pipe.transformer)
46
+ pipe.transformer.to(device)
47
+
48
+ # Quantize T5 encoder
49
+ print("Quantizing T5")
50
+ quantize(pipe.text_encoder_2, weights=qfloat8)
51
+ freeze(pipe.text_encoder_2)
52
+ pipe.text_encoder_2.to(device)
53
+
54
+ # Move other components to device
55
+ pipe.text_encoder.to(device, dtype=dtype)
56
+ torch.cuda.empty_cache()
57
 
58
  @spaces.GPU(duration=75)
59
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
stable_diffusion_model.py ADDED
The diff for this file is too large to render. See raw diff