llama2 forward pass seemingly not working with padded inputs, unless one element in batch is not padded
From this discussion thread [https://github.com/huggingface/transformers/issues/26601], moved to here. Basically this seems to be an issue with padding, only when trust_remote_code=True
, so maybe related to FlashAttention
?
Here's a script to reproduce,
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizerFast
tokenizer = AutoTokenizer.from_pretrained("togethercomputer/Llama-2-7B-32K-Instruct")
tokenizer = LlamaTokenizerFast.from_pretrained(
"togethercomputer/Llama-2-7B-32K-Instruct"
)
model = AutoModelForCausalLM.from_pretrained(
"togethercomputer/Llama-2-7B-32K-Instruct",
trust_remote_code=True, # this works when this is False
torch_dtype=torch.float16,
).cuda()
""" THIS works in both cases
model = MT5ForConditionalGeneration.from_pretrained(
'google/mt5-xl'
"""
encoded = tokenizer(
[
"[INST]\nWrite a poem about cats\n[/INST]\n\n",
"[INST]\nWrite " + "a poem about" * 400 + " cats\n[/INST]\n\n",
],
return_tensors="pt",
padding="longest",
).to(model.device)
encoded_firstelem = {
"input_ids": encoded["input_ids"][:1, :],
"attention_mask": encoded["attention_mask"][:1, :],
}
breakpoint()
print(encoded_firstelem)
# {'input_ids': tensor([[ 0, 0, 0, ..., 29962, 13, 13]], device='cuda:0'), 'attention_mask': tensor([[0, 0, 0, ..., 1, 1, 1]], device='cuda:0')}
# works
print(model(**encoded))
# breaks
print(model(**encoded_firstelem))
Hi @joehakim and thanks for reporting this!
I think the error you see when feeding only the first element comes from a mismatch between q_len
and max_seqlen_q
, because of the unnecessary padding of the first element.
For your specific example, this is caused by the following steps in `modelling_flash_llama.py:
bsz, q_len, h_size = hidden_states.size()
(L311) -- this reads the sequence length from the padded input which is 1215.unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
(L371) -- here the padding gets removed and yourmax_seqlen_q
becomes 18.attn_output = pad_input(attn_output, indices_q, bsz, max_seqlen_q).reshape(bsz, q_len, h_size)
(L380-382) -- this is were the error happens due to the mismatch betweenq_len
andmax_seqlen_q
So that means that you can't process a batch where the actual (unpadded) sequence length is smaller than the longest (padded) sequence in your batch.
I am encountering the same error, ie a mismatch between q_len
and max_seqlen_q
givesRuntimeError: shape '[4, 6400, 4096]' is invalid for input of size 14811136
Is there a solution to this issue?
Hi @mauriceweber - Is there support for batches containing different lengths of unpadded sequences?