tianzhuotao commited on
Commit
c39e06d
·
1 Parent(s): 0ecb81f

Fix bug in loading weights for visual_model and text_hidden_fcs when using cached directory

Browse files

Former-commit-id: efb481d6aa21b540025d29f424b42b3ff26fabea

Files changed (2) hide show
  1. README.md +1 -0
  2. chat.py +26 -2
README.md CHANGED
@@ -196,6 +196,7 @@ deepspeed --master_port=24999 train_ds.py \
196
  To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) or [LISA-13B-llama2-v0-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explanatory): (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.)
197
  ```
198
  CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0'
 
199
  ```
200
  To use `bf16` or `fp16` data type for inference:
201
  ```
 
196
  To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) or [LISA-13B-llama2-v0-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explanatory): (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.)
197
  ```
198
  CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0'
199
+ CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0-explanatory'
200
  ```
201
  To use `bf16` or `fp16` data type for inference:
202
  ```
chat.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import sys
4
 
5
  import cv2
 
6
  import numpy as np
7
  import torch
8
  import torch.nn.functional as F
@@ -82,12 +83,35 @@ def main(args):
82
  load_in_4bit=args.load_in_4bit,
83
  )
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  weight = {}
86
  visual_model_weight = torch.load(
87
- os.path.join(args.version, "pytorch_model-visual_model.bin")
88
  )
89
  text_hidden_fcs_weight = torch.load(
90
- os.path.join(args.version, "pytorch_model-text_hidden_fcs.bin")
91
  )
92
  weight.update(visual_model_weight)
93
  weight.update(text_hidden_fcs_weight)
 
3
  import sys
4
 
5
  import cv2
6
+ import glob
7
  import numpy as np
8
  import torch
9
  import torch.nn.functional as F
 
83
  load_in_4bit=args.load_in_4bit,
84
  )
85
 
86
+ if os.path.exists(args.version):
87
+ model_dir = args.version
88
+ else: # hack for cached pre-trained weights
89
+ user_name, model_name = args.version.split("/")
90
+ cache_dir = "{}/.cache/huggingface/hub/models--{}--{}".format(os.environ['HOME'], user_name, model_name)
91
+ if os.path.exists(cache_dir):
92
+ model1_dir = glob.glob("{}/snapshots/*/pytorch_model-visual_model.bin".format(cache_dir))
93
+ model2_dir = glob.glob("{}/snapshots/*/pytorch_model-text_hidden_fcs.bin".format(cache_dir))
94
+ if len(model1_dir) == 0 or len(model2_dir) == 0:
95
+ raise ValueError("Pre-trained weights for visual_model or text_hidden_fcs do not exist in {}.".format(
96
+ cache_dir
97
+ ))
98
+ model1_dir = ["/".join(x.split("/")[:-1]) for x in model1_dir]
99
+ model2_dir = ["/".join(x.split("/")[:-1]) for x in model2_dir]
100
+ model_dir = list(set(model1_dir).intersection(set(model2_dir)))
101
+ if len(model_dir) == 0:
102
+ raise ValueError("Pre-trained weights for visual_model or text_hidden_fcs do not exist in {}.".format(
103
+ cache_dir
104
+ ))
105
+ model_dir = model_dir[0]
106
+ else:
107
+ raise ValueError("The path {} does not exists.".format(cache_dir))
108
+
109
  weight = {}
110
  visual_model_weight = torch.load(
111
+ os.path.join(model_dir, "pytorch_model-visual_model.bin")
112
  )
113
  text_hidden_fcs_weight = torch.load(
114
+ os.path.join(model_dir, "pytorch_model-text_hidden_fcs.bin")
115
  )
116
  weight.update(visual_model_weight)
117
  weight.update(text_hidden_fcs_weight)