Support 8B
Browse files- merge_vision_example.py +12 -5
merge_vision_example.py
CHANGED
@@ -4,9 +4,8 @@ from transformers import MllamaForConditionalGeneration, MllamaProcessor, AutoMo
|
|
4 |
|
5 |
# NOTE: You need sufficient DRAM to load both models at once (otherwise, need to process layer by layer which is not shown here)
|
6 |
|
7 |
-
multimodal_model_path = "models/meta-llama-Llama-3.2-90B-Vision-Instruct" # Original Llama vision model (90B)
|
8 |
-
text_model_path = "models/path_to_Llama3.1_70B" # Model to be merged (70B)
|
9 |
-
|
10 |
save_path = "models/merged_model"
|
11 |
|
12 |
multimodal_model = MllamaForConditionalGeneration.from_pretrained(multimodal_model_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
@@ -19,8 +18,16 @@ state_dict_text = text_model.state_dict()
|
|
19 |
num_decoder_layers_text = text_model.config.num_hidden_layers
|
20 |
num_decoder_layers_vision = multimodal_model.config.text_config.num_hidden_layers
|
21 |
|
22 |
-
#
|
23 |
-
inserted_layers =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
assert len(inserted_layers) == num_decoder_layers_vision-num_decoder_layers_text, "# of added layers do not match"
|
25 |
|
26 |
# Build decoder layer map from multimodal layer# to text layer#, skipping layers listed in inserted_layers
|
|
|
4 |
|
5 |
# NOTE: You need sufficient DRAM to load both models at once (otherwise, need to process layer by layer which is not shown here)
|
6 |
|
7 |
+
multimodal_model_path = "models/meta-llama-Llama-3.2-90B-Vision-Instruct" # Original Llama vision model (11B or 90B)
|
8 |
+
text_model_path = "models/path_to_Llama3.1_70B" # Model to be merged (8B or 70B)
|
|
|
9 |
save_path = "models/merged_model"
|
10 |
|
11 |
multimodal_model = MllamaForConditionalGeneration.from_pretrained(multimodal_model_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
|
|
18 |
num_decoder_layers_text = text_model.config.num_hidden_layers
|
19 |
num_decoder_layers_vision = multimodal_model.config.text_config.num_hidden_layers
|
20 |
|
21 |
+
# Find the list of inserted layers in multimodal Llama
|
22 |
+
inserted_layers = set()
|
23 |
+
for key_multimodal in state_dict_multimodal.keys():
|
24 |
+
if "language_model" in key_multimodal and "cross_attn" in key_multimodal and ".layers." in key_multimodal:
|
25 |
+
layer_num_multimodal = int(key_multimodal.split(".layers.")[1].split(".")[0]) if ".layers." in key_multimodal else None
|
26 |
+
if layer_num_multimodal is not None: inserted_layers.add(layer_num_multimodal)
|
27 |
+
# Here are the hard-coded list of layers added:
|
28 |
+
# inserted_layers = {3, 8, 13, 18, 23, 28, 33, 38, 43, 48, 53, 58, 63, 68, 73, 78, 83, 88, 93, 98} $ For 90B
|
29 |
+
# inserted_layers = {3, 8, 13, 18, 23, 28, 33, 38} $ For 11B
|
30 |
+
|
31 |
assert len(inserted_layers) == num_decoder_layers_vision-num_decoder_layers_text, "# of added layers do not match"
|
32 |
|
33 |
# Build decoder layer map from multimodal layer# to text layer#, skipping layers listed in inserted_layers
|