LanguageBind commited on
Commit
171d4cc
1 Parent(s): 2ba42e0

Update opensora/serve/gradio_web_server.py

Browse files
Files changed (1) hide show
  1. opensora/serve/gradio_web_server.py +17 -2
opensora/serve/gradio_web_server.py CHANGED
@@ -72,8 +72,23 @@ if __name__ == '__main__':
72
  vae.latent_size = latent_size
73
  transformer_model.force_images = args.force_images
74
  tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
75
- text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name,
76
- torch_dtype=torch.float16).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # set eval mode
79
  transformer_model.eval()
 
72
  vae.latent_size = latent_size
73
  transformer_model.force_images = args.force_images
74
  tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
75
+
76
+ load_8bit, load_4bit = True, False
77
+ kwargs = {"device_map": "auto"}
78
+ if load_8bit:
79
+ kwargs['load_in_8bit'] = True
80
+ elif load_4bit:
81
+ from transformers import BitsAndBytesConfig
82
+ kwargs['load_in_4bit'] = True
83
+ kwargs['quantization_config'] = BitsAndBytesConfig(
84
+ load_in_4bit=True,
85
+ bnb_4bit_compute_dtype=torch.float16,
86
+ bnb_4bit_use_double_quant=True,
87
+ bnb_4bit_quant_type='nf4'
88
+ )
89
+ else:
90
+ kwargs['torch_dtype'] = torch.float16
91
+ text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", **kwargs)
92
 
93
  # set eval mode
94
  transformer_model.eval()