Make sure hidden state and wte weights are on same device when in parallel model.

Based on this:
https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1093

This should fix the following crash when running qlora:

Traceback (most recent call last):
  File "/code/qlora/qlora.py", line 758, in <module>
    train()
  File "/code/qlora/qlora.py", line 720, in train
    train_result = trainer.train(resume_from_checkpoint=checkpoint_dir)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1696, in train
    return inner_training_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1973, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2787, in training_step
    loss = self.compute_loss(model, inputs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2819, in compute_loss
    outputs = model(**inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/peft/peft_model.py", line 686, in forward
    return self.base_model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/mpt-7b/modeling_mpt.py", line 294, in forward
    logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
muelletm changed pull request status to open

Hi, I've made the requested changes, try it now. Will also update README. 👍

Thanks! (I don't know if you saw the PR attached to this, I guess we can close it now?)

muelletm changed pull request status to closed

Sign up or log in to comment