gokaygokay commited on
Commit
6f75c6d
·
1 Parent(s): 154712f
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -7,16 +7,15 @@ from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfi
7
  # Initialize model outside the function
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  dtype = torch.bfloat16
10
- base_model = "lodestones/Chroma"
11
  file_url = "https://huggingface.co/lodestones/Chroma/resolve/main/chroma-unlocked-v31.safetensors"
12
 
13
  quantization_config_tf = BitsAndBytesConfigTF(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
14
- text_encoder_2 = T5EncoderModel.from_pretrained(base_model, subfolder="text_encoder_2", torch_dtype=dtype, config=base_model, quantization_config=quantization_config_tf)
15
 
16
  quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
17
- transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", torch_dtype=dtype, config=base_model, quantization_config=quantization_config)
18
 
19
- flux_pipeline = FluxPipeline.from_pretrained(base_model, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype)
20
  flux_pipeline.to(device)
21
 
22
  @spaces.GPU()
 
7
  # Initialize model outside the function
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  dtype = torch.bfloat16
 
10
  file_url = "https://huggingface.co/lodestones/Chroma/resolve/main/chroma-unlocked-v31.safetensors"
11
 
12
  quantization_config_tf = BitsAndBytesConfigTF(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
13
+ text_encoder_2 = T5EncoderModel.from_single_file(file_url, subfolder="text_encoder_2", torch_dtype=dtype, quantization_config=quantization_config_tf)
14
 
15
  quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
16
+ transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", torch_dtype=dtype, quantization_config=quantization_config)
17
 
18
+ flux_pipeline = FluxPipeline.from_single_file(file_url, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype)
19
  flux_pipeline.to(device)
20
 
21
  @spaces.GPU()