Spaces:
Runtime error
Runtime error
Update opensora/serve/gradio_web_server.py
Browse files
opensora/serve/gradio_web_server.py
CHANGED
@@ -73,23 +73,7 @@ if __name__ == '__main__':
|
|
73 |
vae.latent_size = latent_size
|
74 |
transformer_model.force_images = args.force_images
|
75 |
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
|
76 |
-
|
77 |
-
load_8bit, load_4bit = True, False
|
78 |
-
kwargs = {"device_map": "auto"}
|
79 |
-
if load_8bit:
|
80 |
-
kwargs['load_in_8bit'] = True
|
81 |
-
elif load_4bit:
|
82 |
-
from transformers import BitsAndBytesConfig
|
83 |
-
kwargs['load_in_4bit'] = True
|
84 |
-
kwargs['quantization_config'] = BitsAndBytesConfig(
|
85 |
-
load_in_4bit=True,
|
86 |
-
bnb_4bit_compute_dtype=torch.float16,
|
87 |
-
bnb_4bit_use_double_quant=True,
|
88 |
-
bnb_4bit_quant_type='nf4'
|
89 |
-
)
|
90 |
-
else:
|
91 |
-
kwargs['torch_dtype'] = torch.float16
|
92 |
-
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", **kwargs)
|
93 |
|
94 |
# set eval mode
|
95 |
transformer_model.eval()
|
|
|
73 |
vae.latent_size = latent_size
|
74 |
transformer_model.force_images = args.force_images
|
75 |
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
|
76 |
+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, torch_dtype=torch.float16).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# set eval mode
|
79 |
transformer_model.eval()
|