LanguageBind commited on
Commit
5644f6f
·
verified ·
1 Parent(s): 98142d6

Update opensora/serve/gradio_web_server.py

Browse files
Files changed (1) hide show
  1. opensora/serve/gradio_web_server.py +1 -17
opensora/serve/gradio_web_server.py CHANGED
@@ -73,23 +73,7 @@ if __name__ == '__main__':
73
  vae.latent_size = latent_size
74
  transformer_model.force_images = args.force_images
75
  tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
76
-
77
- load_8bit, load_4bit = True, False
78
- kwargs = {"device_map": "auto"}
79
- if load_8bit:
80
- kwargs['load_in_8bit'] = True
81
- elif load_4bit:
82
- from transformers import BitsAndBytesConfig
83
- kwargs['load_in_4bit'] = True
84
- kwargs['quantization_config'] = BitsAndBytesConfig(
85
- load_in_4bit=True,
86
- bnb_4bit_compute_dtype=torch.float16,
87
- bnb_4bit_use_double_quant=True,
88
- bnb_4bit_quant_type='nf4'
89
- )
90
- else:
91
- kwargs['torch_dtype'] = torch.float16
92
- text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", **kwargs)
93
 
94
  # set eval mode
95
  transformer_model.eval()
 
73
  vae.latent_size = latent_size
74
  transformer_model.force_images = args.force_images
75
  tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
76
+ text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, torch_dtype=torch.float16).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # set eval mode
79
  transformer_model.eval()