|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Convert pretrained Pixtral vision model weights to checkpoint and verify the checkpoint loading. |
|
|
|
Usage: |
|
|
|
PYTHONPATH=$(pwd) python cosmos1/scripts/convert_pixtral_ckpt.py |
|
""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
import shutil |
|
from glob import glob |
|
|
|
import torch |
|
from huggingface_hub import snapshot_download |
|
from safetensors.torch import load_file |
|
|
|
|
|
def convert_pixtral_checkpoint(checkpoint_dir: str, checkpoint_name: str, vit_type: str): |
|
""" |
|
Main function to convert Pixtral vision model weights to checkpoint and optionally verify and save the converted checkpoint. |
|
|
|
Args: |
|
checkpoint_dir (str): Path to the checkpoint directory |
|
checkpoint_name (str): Name of the checkpoint |
|
vit_type (str): Type of ViT used in the Pixtral model |
|
|
|
This function performs the following steps: |
|
0. Download the checkpoint from Hugging Face |
|
1. Loads the original Pixtral checkpoint |
|
2. Splits the checkpoint into vision encoder, projector, and LLM weights |
|
3. Reorganizes the weights to match the expected format |
|
4. Extracts and verifies the vision encoder configuration |
|
5. Optionally verifies the converted checkpoint by loading it into a VisionTransformer |
|
6. Optionally saves the converted checkpoint and configuration |
|
""" |
|
|
|
save_dir = os.path.join(checkpoint_dir, checkpoint_name) |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
save_path = os.path.join(save_dir, "model.pt") |
|
if os.path.exists(save_path) and os.path.getsize(save_path) > 0: |
|
print(f"Checkpoint {save_path} already exists and is not empty") |
|
return |
|
|
|
pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409") |
|
os.makedirs(pixtral_ckpt_dir, exist_ok=True) |
|
repo_id = "mistralai/Pixtral-12B-2409" |
|
print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...") |
|
snapshot_download( |
|
repo_id=repo_id, |
|
allow_patterns=["params.json", "consolidated.safetensors"], |
|
local_dir=pixtral_ckpt_dir, |
|
local_dir_use_symlinks=False, |
|
) |
|
orig_dtype = torch.get_default_dtype() |
|
dtype = torch.bfloat16 |
|
torch.set_default_dtype(dtype) |
|
|
|
|
|
ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors")) |
|
assert len(ckpt_files) == 1, "ckpt_dir should contain only one file" |
|
ckpt_path = ckpt_files[0] |
|
ckpt = load_file(ckpt_path) |
|
|
|
|
|
vit_key_prefix = "vision_encoder." |
|
vit_ckpt = {} |
|
for key, value in ckpt.items(): |
|
if key.startswith(vit_key_prefix): |
|
vit_ckpt[key.lstrip(vit_key_prefix)] = value |
|
|
|
projector_key_prefix = "vision_language_adapter." |
|
projector_ckpt = {} |
|
substring_replacement_map = { |
|
"w_in.": "projector.0.", |
|
"w_out.": "projector.2.", |
|
} |
|
for key, value in ckpt.items(): |
|
if key.startswith(projector_key_prefix): |
|
key = key.lstrip(projector_key_prefix) |
|
for old, new in substring_replacement_map.items(): |
|
key = key.replace(old, new) |
|
projector_ckpt[key] = value |
|
|
|
llm_ckpt = {} |
|
for key, value in ckpt.items(): |
|
if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix): |
|
continue |
|
llm_ckpt[key] = value |
|
|
|
vlm_ckpt = {} |
|
for key, value in llm_ckpt.items(): |
|
vlm_ckpt["model." + key] = value |
|
for key, value in projector_ckpt.items(): |
|
vlm_ckpt["mm_projector." + key] = value |
|
for key, value in vit_ckpt.items(): |
|
vlm_ckpt["vision_encoder." + key] = value |
|
|
|
|
|
config_path = os.path.join(pixtral_ckpt_dir, "params.json") |
|
with open(config_path, "r") as f: |
|
pixtral_config = json.load(f) |
|
|
|
|
|
vision_encoder_config = { |
|
"dim": pixtral_config["vision_encoder"]["hidden_size"], |
|
"num_channels": pixtral_config["vision_encoder"]["num_channels"], |
|
"image_size": pixtral_config["vision_encoder"]["image_size"], |
|
"patch_size": pixtral_config["vision_encoder"]["patch_size"], |
|
"rope_theta": pixtral_config["vision_encoder"]["rope_theta"], |
|
"ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"], |
|
"n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"], |
|
"n_heads": pixtral_config["vision_encoder"]["num_attention_heads"], |
|
"n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"], |
|
"norm_type": "rmsnorm", |
|
"norm_eps": pixtral_config["norm_eps"], |
|
"image_token_id": pixtral_config["vision_encoder"]["image_token_id"], |
|
} |
|
|
|
vit_config = dict( |
|
dim=1024, |
|
num_channels=3, |
|
image_size=1024, |
|
patch_size=16, |
|
rope_theta=10000, |
|
ffn_hidden_size=4096, |
|
n_layers=24, |
|
n_heads=16, |
|
n_kv_heads=16, |
|
norm_type="rmsnorm", |
|
norm_eps=1e-5, |
|
image_token_id=10, |
|
) |
|
|
|
for key, value in vit_config.items(): |
|
assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}" |
|
|
|
llm_config_keys = [ |
|
"dim", |
|
"n_layers", |
|
"head_dim", |
|
"hidden_dim", |
|
"n_heads", |
|
"n_kv_heads", |
|
"rope_theta", |
|
"norm_eps", |
|
"vocab_size", |
|
] |
|
assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch" |
|
replace_map = { |
|
"hidden_dim": "ffn_hidden_size", |
|
} |
|
llm_config = {} |
|
for k, v in pixtral_config.items(): |
|
if k in llm_config_keys: |
|
llm_config[replace_map.get(k, k)] = v |
|
elif k == "vision_encoder": |
|
llm_config["vision_encoder"] = vit_type |
|
else: |
|
raise ValueError(f"Unknown key: {k}") |
|
|
|
ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt} |
|
torch.save(ckpt_to_save, save_path) |
|
print(f"Model saved to {save_path}") |
|
|
|
|
|
config_path = os.path.join(save_dir, "config.json") |
|
with open(config_path, "w") as f: |
|
json.dump(llm_config, f) |
|
|
|
torch.set_default_dtype(orig_dtype) |
|
|
|
|
|
shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True) |
|
print(f"Removed {pixtral_ckpt_dir}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Convert pretrained Pixtral vision model weights to checkpoint and verify accuracy" |
|
) |
|
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Path to the checkpoint directory") |
|
parser.add_argument( |
|
"--checkpoint_name", |
|
type=str, |
|
default="Pixtral-12B", |
|
help="Name of the checkpoint", |
|
) |
|
parser.add_argument("--vit_type", default="pixtral-12b-vit", help="Type of ViT used in the Pixtral model") |
|
args = parser.parse_args() |
|
convert_pixtral_checkpoint( |
|
checkpoint_dir=args.checkpoint_dir, checkpoint_name=args.checkpoint_name, vit_type=args.vit_type |
|
) |
|
|