gokaygokay commited on
Commit
824bcc3
·
1 Parent(s): fc6a8b1
Files changed (1) hide show
  1. app.py +7 -23
app.py CHANGED
@@ -2,13 +2,13 @@ import spaces
2
  import torch
3
  import gradio as gr
4
  import os
5
- from diffusers import FluxPipeline, FluxTransformer2DModel, BitsAndBytesConfig
6
  from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF
7
 
8
  # Initialize model outside the function
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  dtype = torch.bfloat16
11
- file_url = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v31.safetensors"
12
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
13
  single_file_base_model = "camenduru/FLUX.1-dev-diffusers"
14
 
@@ -23,30 +23,14 @@ text_encoder_2 = T5EncoderModel.from_pretrained(
23
  token=huggingface_token
24
  )
25
 
26
- # Initialize transformer
27
- quantization_config = BitsAndBytesConfig(
28
- load_in_4bit=True,
29
- bnb_4bit_quant_type="nf4",
30
- bnb_4bit_use_double_quant=True,
31
- bnb_4bit_compute_dtype=torch.bfloat16,
32
- token=huggingface_token
33
- )
34
- transformer = FluxTransformer2DModel.from_single_file(
35
- file_url,
36
- subfolder="transformer",
37
- torch_dtype=dtype,
38
- config=single_file_base_model,
39
- quantization_config=quantization_config,
40
- token=huggingface_token
41
- )
42
-
43
  # Load the pipeline with proper configuration
44
- flux_pipeline = FluxPipeline.from_pretrained(
45
- single_file_base_model,
46
- transformer=transformer,
47
  text_encoder_2=text_encoder_2,
48
  torch_dtype=dtype,
49
- token=huggingface_token
 
 
50
  )
51
  flux_pipeline.to(device)
52
 
 
2
  import torch
3
  import gradio as gr
4
  import os
5
+ from diffusers import FluxPipeline
6
  from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF
7
 
8
  # Initialize model outside the function
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  dtype = torch.bfloat16
11
+ file_url = "https://huggingface.co/lodestones/Chroma/resolve/main/chroma-unlocked-v31.safetensors"
12
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
13
  single_file_base_model = "camenduru/FLUX.1-dev-diffusers"
14
 
 
23
  token=huggingface_token
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Load the pipeline with proper configuration
27
+ flux_pipeline = FluxPipeline.from_single_file(
28
+ file_url,
 
29
  text_encoder_2=text_encoder_2,
30
  torch_dtype=dtype,
31
+ token=huggingface_token,
32
+ use_safetensors=True,
33
+ variant="fp16"
34
  )
35
  flux_pipeline.to(device)
36