FantasticGNU commited on
Commit
f2de29b
·
1 Parent(s): ff98170

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +5 -5
model/openllama.py CHANGED
@@ -172,16 +172,16 @@ class OpenLLAMAPEFTModel(nn.Module):
172
  print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
173
 
174
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
175
- self.visual_encoder.to(self.device)
176
  imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu'))
177
  self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
178
 
179
 
180
  self.iter = 0
181
 
182
- self.image_decoder = LinearLayer(1280, 1024, 4).to(self.device)
183
 
184
- self.prompt_learner = PromptLearner(1, 4096).to(self.device)
185
 
186
  self.loss_focal = FocalLoss()
187
  self.loss_dice = BinaryDiceLoss()
@@ -215,7 +215,7 @@ class OpenLLAMAPEFTModel(nn.Module):
215
  # # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
216
  # # self.llama_model.to(torch.float16)
217
  # # try:
218
- self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.bfloat16, device_map='auto', load_in_4bit=True, offload_folder="offload1", offload_state_dict = True)
219
  # # except:
220
  # pass
221
  # finally:
@@ -225,7 +225,7 @@ class OpenLLAMAPEFTModel(nn.Module):
225
  self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
  self.llama_model.print_trainable_parameters()
227
 
228
- self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload2", offload_state_dict = True)
229
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
230
  self.llama_tokenizer.padding_side = "right"
231
  print ('Language decoder initialized.')
 
172
  print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
173
 
174
  self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
175
+ self.visual_encoder.to(torch.bfloat16).to(self.device)
176
  imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu'))
177
  self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
178
 
179
 
180
  self.iter = 0
181
 
182
+ self.image_decoder = LinearLayer(1280, 1024, 4).to(torch.bfloat16).to(self.device)
183
 
184
+ self.prompt_learner = PromptLearner(1, 4096).to(torch.bfloat16).to(self.device)
185
 
186
  self.loss_focal = FocalLoss()
187
  self.loss_dice = BinaryDiceLoss()
 
215
  # # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
216
  # # self.llama_model.to(torch.float16)
217
  # # try:
218
+ self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.bfloat16, device_map='auto', load_in_8bit=True, offload_folder="offload1")
219
  # # except:
220
  # pass
221
  # finally:
 
225
  self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
  self.llama_model.print_trainable_parameters()
227
 
228
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.bfloat16, device_map='auto', offload_folder="offload2")
229
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
230
  self.llama_tokenizer.padding_side = "right"
231
  print ('Language decoder initialized.')