gokaygokay commited on
Commit
fc6a8b1
·
1 Parent(s): 10929b7
Files changed (1) hide show
  1. app.py +36 -7
app.py CHANGED
@@ -2,22 +2,51 @@ import spaces
2
  import torch
3
  import gradio as gr
4
  import os
5
- from diffusers import FluxPipeline
 
6
 
7
  # Initialize model outside the function
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  dtype = torch.bfloat16
10
  file_url = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v31.safetensors"
11
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
12
 
13
- # Load the pipeline with proper configuration
14
- flux_pipeline = FluxPipeline.from_single_file(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  file_url,
 
 
 
 
 
 
 
 
 
 
 
 
16
  torch_dtype=dtype,
17
- token=huggingface_token,
18
- use_safetensors=True,
19
- local_files_only=False,
20
- config_file="model_index.json"
21
  )
22
  flux_pipeline.to(device)
23
 
 
2
  import torch
3
  import gradio as gr
4
  import os
5
+ from diffusers import FluxPipeline, FluxTransformer2DModel, BitsAndBytesConfig
6
+ from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF
7
 
8
  # Initialize model outside the function
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  dtype = torch.bfloat16
11
  file_url = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v31.safetensors"
12
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
13
+ single_file_base_model = "camenduru/FLUX.1-dev-diffusers"
14
 
15
+ # Initialize text encoder
16
+ quantization_config_tf = BitsAndBytesConfigTF(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
17
+ text_encoder_2 = T5EncoderModel.from_pretrained(
18
+ single_file_base_model,
19
+ subfolder="text_encoder_2",
20
+ torch_dtype=dtype,
21
+ config=single_file_base_model,
22
+ quantization_config=quantization_config_tf,
23
+ token=huggingface_token
24
+ )
25
+
26
+ # Initialize transformer
27
+ quantization_config = BitsAndBytesConfig(
28
+ load_in_4bit=True,
29
+ bnb_4bit_quant_type="nf4",
30
+ bnb_4bit_use_double_quant=True,
31
+ bnb_4bit_compute_dtype=torch.bfloat16,
32
+ token=huggingface_token
33
+ )
34
+ transformer = FluxTransformer2DModel.from_single_file(
35
  file_url,
36
+ subfolder="transformer",
37
+ torch_dtype=dtype,
38
+ config=single_file_base_model,
39
+ quantization_config=quantization_config,
40
+ token=huggingface_token
41
+ )
42
+
43
+ # Load the pipeline with proper configuration
44
+ flux_pipeline = FluxPipeline.from_pretrained(
45
+ single_file_base_model,
46
+ transformer=transformer,
47
+ text_encoder_2=text_encoder_2,
48
  torch_dtype=dtype,
49
+ token=huggingface_token
 
 
 
50
  )
51
  flux_pipeline.to(device)
52