Vocab size vs. LM head size mismatch

#46
by harshil-shah - opened

Hi,

It seems there is a mismatch between the vocab size in the MllamaProcessor and the size of the lm_head weight matrix. Trying to call resize_token_embeddings doesn't fix this. This means that it is not possible to do training. Minimal example:

import requests
from PIL import Image
from transformers import MllamaForConditionalGeneration, MllamaProcessor

MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"

processor = MllamaProcessor.from_pretrained(MODEL_NAME)
model = MllamaForConditionalGeneration.from_pretrained(MODEL_NAME)

print(f"{len(processor.tokenizer) = }")
print(f"Before resize: {model.language_model.lm_head.weight.shape = }")

model.resize_token_embeddings(len(processor.tokenizer))

print(f"After resize: {model.language_model.lm_head.weight.shape = }")

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
image = Image.open(requests.get(url, stream=True).raw)

messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "If I had to write a haiku for this one, it would be: "}
    ]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
    image,
    input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to(model.device)

output = model(**inputs, labels=inputs.input_ids)

This outputs:

len(processor.tokenizer) = 128257
Before resize: model.language_model.lm_head.weight.shape = torch.Size([128256, 4096])
After resize: model.language_model.lm_head.weight.shape = torch.Size([128256, 4096])

And then errors with:

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.venv/lib/python3.11/site-packages/transformers/models/mllama/modeling_mllama.py:2188, in MllamaForConditionalGeneration.forward(self, input_ids, pixel_values, aspect_ratio_mask, aspect_ratio_ids, attention_mask, cross_attention_mask, cross_attention_states, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
   2185     cross_attention_mask = cross_attention_mask[:, :, cache_position]
   2186     full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
-> 2188 outputs = self.language_model(
   2189     input_ids=input_ids,
   2190     attention_mask=attention_mask,
   2191     position_ids=position_ids,
   2192     cross_attention_states=cross_attention_states,
   2193     cross_attention_mask=cross_attention_mask,
   2194     full_text_row_masked_out_mask=full_text_row_masked_out_mask,
   2195     past_key_values=past_key_values,
   2196     use_cache=use_cache,
   2197     inputs_embeds=inputs_embeds,
   2198     labels=labels,
   2199     output_hidden_states=output_hidden_states,
   2200     output_attentions=output_attentions,
   2201     return_dict=return_dict,
   2202     cache_position=cache_position,
   2203     num_logits_to_keep=num_logits_to_keep,
   2204 )
   2206 return outputs

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.venv/lib/python3.11/site-packages/transformers/models/mllama/modeling_mllama.py:1961, in MllamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, cross_attention_states, cross_attention_mask, full_text_row_masked_out_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
   1959     # Enable model parallelism
   1960     shift_labels = shift_labels.to(shift_logits.device)
-> 1961     loss = loss_fct(shift_logits, shift_labels)
   1963 if not return_dict:
   1964     output = (logits,) + outputs[1:]

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/loss.py:1179, in CrossEntropyLoss.forward(self, input, target)
   1178 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1179     return F.cross_entropy(input, target, weight=self.weight,
   1180                            ignore_index=self.ignore_index, reduction=self.reduction,
   1181                            label_smoothing=self.label_smoothing)

File ~/.venv/lib/python3.11/site-packages/torch/nn/functional.py:3059, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3057 if size_average is not None or reduce is not None:
   3058     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3059 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

IndexError: Target 128256 is out of bounds.

The <|image|> token, which is token ID 128256, is not intended to be trained on. You should take care of replacing / masking out that token for the forward pass in training. Another thread has contributed some training code here: https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct/discussions/31

Sign up or log in to comment