Question about lm_head weights in Gemma-2-9b-it model

#34
by mjkmain - opened

I've noticed some inconsistencies regarding the lm_head component in the google/gemma-2-9b-it model:

  • The model.safetensors.index.json file does not contain an lm_head.
    image.png

  • When I load the model directly and save it using model.save_pretrained(), the resulting safetensors file also lacks an lm_head.

  • However, when I print the model structure, the lm_head is present, and the inference results are good.

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 3584, padding_idx=0)
    (layers): ModuleList(
      (0-41): 42 x Gemma2DecoderLayer(
        (self_attn): Gemma2SdpaAttention(
          (q_proj): Linear(in_features=3584, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (v_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3584, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=3584, out_features=14336, bias=False)
          (up_proj): Linear(in_features=3584, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=3584, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm()
        (post_attention_layernorm): Gemma2RMSNorm()
        (pre_feedforward_layernorm): Gemma2RMSNorm()
        (post_feedforward_layernorm): Gemma2RMSNorm()
      )
    )
    (norm): Gemma2RMSNorm()
  )
  (lm_head): Linear(in_features=3584, out_features=256000, bias=False)
)

This suggests that the lm_head might not be using initial values. I'm curious about the source of the lm_head weights in this case.

Questions:

  • Where are the lm_head weights coming from?
  • Why don't they appear in the safetensors files?
  • Is this behavior intended?

Any clarification on this matter would be greatly appreciated. Thank you!

@mjkmain The default tie_word_embeddings of Gemma2Config (defined in transformers/models/gemma2/configuration_gemma2.py) is True, so the output embeddings are tied with the input ones, meaning that they are the same. You can refer to https://paperswithcode.com/method/weight-tying for more on weighting tying.

@JaronTHU Thank you for the clear explanation!

It's helpful to know that the default setting is True, which means the input and output embeddings are indeed tied. The link you provided to learn more about weight tying is also very useful.
Thank you for sharing this knowledge.

image.png

mjkmain changed discussion status to closed

Sign up or log in to comment