Spaces:
Runtime error
Runtime error
Update opensora/serve/gradio_web_server.py
Browse files
opensora/serve/gradio_web_server.py
CHANGED
@@ -72,8 +72,23 @@ if __name__ == '__main__':
|
|
72 |
vae.latent_size = latent_size
|
73 |
transformer_model.force_images = args.force_images
|
74 |
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# set eval mode
|
79 |
transformer_model.eval()
|
|
|
72 |
vae.latent_size = latent_size
|
73 |
transformer_model.force_images = args.force_images
|
74 |
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
|
75 |
+
|
76 |
+
load_8bit, load_4bit = True, False
|
77 |
+
kwargs = {"device_map": "auto"}
|
78 |
+
if load_8bit:
|
79 |
+
kwargs['load_in_8bit'] = True
|
80 |
+
elif load_4bit:
|
81 |
+
from transformers import BitsAndBytesConfig
|
82 |
+
kwargs['load_in_4bit'] = True
|
83 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
84 |
+
load_in_4bit=True,
|
85 |
+
bnb_4bit_compute_dtype=torch.float16,
|
86 |
+
bnb_4bit_use_double_quant=True,
|
87 |
+
bnb_4bit_quant_type='nf4'
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
kwargs['torch_dtype'] = torch.float16
|
91 |
+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", **kwargs)
|
92 |
|
93 |
# set eval mode
|
94 |
transformer_model.eval()
|