Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,19 +12,29 @@ import random
|
|
| 12 |
import time
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
from diffusers import FluxTransformer2DModel, FluxPipeline
|
|
|
|
| 15 |
import safetensors.torch
|
| 16 |
from safetensors.torch import load_file
|
| 17 |
import gc
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
|
| 20 |
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
| 21 |
os.environ["HF_HUB_CACHE"] = cache_path
|
| 22 |
os.environ["HF_HOME"] = cache_path
|
| 23 |
|
| 24 |
-
|
| 25 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
pipe.to(device="cuda", dtype=torch.bfloat16)
|
| 29 |
|
| 30 |
# Load LoRAs from JSON file
|
|
|
|
| 12 |
import time
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
from diffusers import FluxTransformer2DModel, FluxPipeline
|
| 15 |
+
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
|
| 16 |
import safetensors.torch
|
| 17 |
from safetensors.torch import load_file
|
| 18 |
import gc
|
| 19 |
+
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
|
| 20 |
+
from tea_model import TeaDecoder
|
| 21 |
+
from text_encoder import t5_config, T5EncoderModel, PretrainedTextEncoder
|
| 22 |
|
| 23 |
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
|
| 24 |
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
| 25 |
os.environ["HF_HUB_CACHE"] = cache_path
|
| 26 |
os.environ["HF_HOME"] = cache_path
|
| 27 |
|
|
|
|
| 28 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 29 |
|
| 30 |
+
class Flux2DModel(QuantizedDiffusersModel):
|
| 31 |
+
base_class = FluxTransformer2DModel
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
t5 = PretrainedTextEncoder(t5_config, T5EncoderModel(t5_config)).to(dtype=torch.float16)
|
| 35 |
+
t5.load_model('text_encoder_2.safetensors')
|
| 36 |
+
|
| 37 |
+
pipe = FluxPipeline.from_pretrained("John6666/fastflux-unchained-t5f16-fp8-flux", torch_dtype=torch.bfloat16, text_encoder_2=t5)
|
| 38 |
pipe.to(device="cuda", dtype=torch.bfloat16)
|
| 39 |
|
| 40 |
# Load LoRAs from JSON file
|