Fix Error while `get_input_embeddings()`

#17

See code in transformers this will effect the trainer based on transformers
Thank for all the devs who review this issue and pr πŸ€—

By the way the code

# Line ~127 in modeling_internvl_chat.py
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]

may cause

RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

in torch>=2.4 Why not use

        try:
            input_embeds[selected] = vit_embeds.reshape(-1, C)
        except Exception as e:
            vit_embeds = vit_embeds.reshape(-1, C)
            print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
                  f'vit_embeds.shape={vit_embeds.shape}')
            n_token = selected.sum()
            input_embeds[selected] = vit_embeds[:n_token]
OpenGVLab org

Thanks for your feedback, I will add base_model_prefix to all InternVL models.

czczup changed pull request status to merged

Sign up or log in to comment