Spaces:
Paused
Paused
tianzhuotao
commited on
Commit
·
c39e06d
1
Parent(s):
0ecb81f
Fix bug in loading weights for visual_model and text_hidden_fcs when using cached directory
Browse filesFormer-commit-id: efb481d6aa21b540025d29f424b42b3ff26fabea
README.md
CHANGED
@@ -196,6 +196,7 @@ deepspeed --master_port=24999 train_ds.py \
|
|
196 |
To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) or [LISA-13B-llama2-v0-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explanatory): (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.)
|
197 |
```
|
198 |
CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0'
|
|
|
199 |
```
|
200 |
To use `bf16` or `fp16` data type for inference:
|
201 |
```
|
|
|
196 |
To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) or [LISA-13B-llama2-v0-explanatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explanatory): (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.)
|
197 |
```
|
198 |
CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0'
|
199 |
+
CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0-explanatory'
|
200 |
```
|
201 |
To use `bf16` or `fp16` data type for inference:
|
202 |
```
|
chat.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
import sys
|
4 |
|
5 |
import cv2
|
|
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
@@ -82,12 +83,35 @@ def main(args):
|
|
82 |
load_in_4bit=args.load_in_4bit,
|
83 |
)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
weight = {}
|
86 |
visual_model_weight = torch.load(
|
87 |
-
os.path.join(
|
88 |
)
|
89 |
text_hidden_fcs_weight = torch.load(
|
90 |
-
os.path.join(
|
91 |
)
|
92 |
weight.update(visual_model_weight)
|
93 |
weight.update(text_hidden_fcs_weight)
|
|
|
3 |
import sys
|
4 |
|
5 |
import cv2
|
6 |
+
import glob
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
|
|
83 |
load_in_4bit=args.load_in_4bit,
|
84 |
)
|
85 |
|
86 |
+
if os.path.exists(args.version):
|
87 |
+
model_dir = args.version
|
88 |
+
else: # hack for cached pre-trained weights
|
89 |
+
user_name, model_name = args.version.split("/")
|
90 |
+
cache_dir = "{}/.cache/huggingface/hub/models--{}--{}".format(os.environ['HOME'], user_name, model_name)
|
91 |
+
if os.path.exists(cache_dir):
|
92 |
+
model1_dir = glob.glob("{}/snapshots/*/pytorch_model-visual_model.bin".format(cache_dir))
|
93 |
+
model2_dir = glob.glob("{}/snapshots/*/pytorch_model-text_hidden_fcs.bin".format(cache_dir))
|
94 |
+
if len(model1_dir) == 0 or len(model2_dir) == 0:
|
95 |
+
raise ValueError("Pre-trained weights for visual_model or text_hidden_fcs do not exist in {}.".format(
|
96 |
+
cache_dir
|
97 |
+
))
|
98 |
+
model1_dir = ["/".join(x.split("/")[:-1]) for x in model1_dir]
|
99 |
+
model2_dir = ["/".join(x.split("/")[:-1]) for x in model2_dir]
|
100 |
+
model_dir = list(set(model1_dir).intersection(set(model2_dir)))
|
101 |
+
if len(model_dir) == 0:
|
102 |
+
raise ValueError("Pre-trained weights for visual_model or text_hidden_fcs do not exist in {}.".format(
|
103 |
+
cache_dir
|
104 |
+
))
|
105 |
+
model_dir = model_dir[0]
|
106 |
+
else:
|
107 |
+
raise ValueError("The path {} does not exists.".format(cache_dir))
|
108 |
+
|
109 |
weight = {}
|
110 |
visual_model_weight = torch.load(
|
111 |
+
os.path.join(model_dir, "pytorch_model-visual_model.bin")
|
112 |
)
|
113 |
text_hidden_fcs_weight = torch.load(
|
114 |
+
os.path.join(model_dir, "pytorch_model-text_hidden_fcs.bin")
|
115 |
)
|
116 |
weight.update(visual_model_weight)
|
117 |
weight.update(text_hidden_fcs_weight)
|