grimulkan commited on
Commit
e409e32
1 Parent(s): 0d87dfc

Support 8B

Browse files
Files changed (1) hide show
  1. merge_vision_example.py +12 -5
merge_vision_example.py CHANGED
@@ -4,9 +4,8 @@ from transformers import MllamaForConditionalGeneration, MllamaProcessor, AutoMo
4
 
5
  # NOTE: You need sufficient DRAM to load both models at once (otherwise, need to process layer by layer which is not shown here)
6
 
7
- multimodal_model_path = "models/meta-llama-Llama-3.2-90B-Vision-Instruct" # Original Llama vision model (90B)
8
- text_model_path = "models/path_to_Llama3.1_70B" # Model to be merged (70B)
9
-
10
  save_path = "models/merged_model"
11
 
12
  multimodal_model = MllamaForConditionalGeneration.from_pretrained(multimodal_model_path, device_map="cpu", torch_dtype=torch.bfloat16)
@@ -19,8 +18,16 @@ state_dict_text = text_model.state_dict()
19
  num_decoder_layers_text = text_model.config.num_hidden_layers
20
  num_decoder_layers_vision = multimodal_model.config.text_config.num_hidden_layers
21
 
22
- # Hard-coded list of inserted layers in multimodal Llama
23
- inserted_layers = {3, 8, 13, 18, 23, 28, 33, 38, 43, 48, 53, 58, 63, 68, 73, 78, 83, 88, 93, 98}
 
 
 
 
 
 
 
 
24
  assert len(inserted_layers) == num_decoder_layers_vision-num_decoder_layers_text, "# of added layers do not match"
25
 
26
  # Build decoder layer map from multimodal layer# to text layer#, skipping layers listed in inserted_layers
 
4
 
5
  # NOTE: You need sufficient DRAM to load both models at once (otherwise, need to process layer by layer which is not shown here)
6
 
7
+ multimodal_model_path = "models/meta-llama-Llama-3.2-90B-Vision-Instruct" # Original Llama vision model (11B or 90B)
8
+ text_model_path = "models/path_to_Llama3.1_70B" # Model to be merged (8B or 70B)
 
9
  save_path = "models/merged_model"
10
 
11
  multimodal_model = MllamaForConditionalGeneration.from_pretrained(multimodal_model_path, device_map="cpu", torch_dtype=torch.bfloat16)
 
18
  num_decoder_layers_text = text_model.config.num_hidden_layers
19
  num_decoder_layers_vision = multimodal_model.config.text_config.num_hidden_layers
20
 
21
+ # Find the list of inserted layers in multimodal Llama
22
+ inserted_layers = set()
23
+ for key_multimodal in state_dict_multimodal.keys():
24
+ if "language_model" in key_multimodal and "cross_attn" in key_multimodal and ".layers." in key_multimodal:
25
+ layer_num_multimodal = int(key_multimodal.split(".layers.")[1].split(".")[0]) if ".layers." in key_multimodal else None
26
+ if layer_num_multimodal is not None: inserted_layers.add(layer_num_multimodal)
27
+ # Here are the hard-coded list of layers added:
28
+ # inserted_layers = {3, 8, 13, 18, 23, 28, 33, 38, 43, 48, 53, 58, 63, 68, 73, 78, 83, 88, 93, 98} $ For 90B
29
+ # inserted_layers = {3, 8, 13, 18, 23, 28, 33, 38} $ For 11B
30
+
31
  assert len(inserted_layers) == num_decoder_layers_vision-num_decoder_layers_text, "# of added layers do not match"
32
 
33
  # Build decoder layer map from multimodal layer# to text layer#, skipping layers listed in inserted_layers