The current checkpoint doesn't use group query attention.

#3
by yaya-sy - opened

When I tried to load the model using:

llm = AutoModelForCausalLM.from_pretrained("lelapa/InkubaLM-0.4B",
                                       torch_dtype=torch.float16)

I encountered the following error:

RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
    size mismatch for model.layers.0.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.0.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.1.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.1.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.2.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.2.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.3.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.3.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.4.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.4.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.5.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.5.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.6.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.6.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.7.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.7.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

This error suggests that the current checkpoint uses a standard Multi-Head Attention instead of Group Query Attention, as the k and v matrices are square. To fix this issue, I modified the config.json file by setting num_key_value_heads = num_attention_heads = 32.
This is the purpose of this pull request.

yaya-sy changed pull request title from The actual checkpoint doesn't use group query attention. to The current checkpoint doesn't use group query attention.

Hello, when loading the model, add trust_remote_code=True

e.g


llm = AutoModelForCausalLM.from_pretrained("lelapa/InkubaLM-0.4B", torch_dtype=torch.float16, trust_remote_code=True)
Atnafu changed pull request status to merged

Sign up or log in to comment