Model keeps cache of generation in Transformers (fixed using torch.no_grad())

#14
by Pietroferr - opened

Hi,
I have noticed a strange behaviour running on A10G 24GB.

I implemented the code you shared for transformers in my codebase, wrapping the call to the model in class' method to perform batch inference (my dataset was too big to perform all at one time).
Doing so, the used gpu increased linearly with the number of processed inputs. I thought it could be that the model was caching by default, so I set model.use_cache = False and put torch.cuda.empty_cache() where needed. Unfortunately, that did not fix the issue.

What worked was wrapping the generation in torch.no_grad(), so my final function is something like:

    def _generate_embedding_one_batch(self, model, tokenizer, batch_texts, max_length):
        with torch.no_grad():
            batch_dict = tokenizer(batch_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(self.device)
            outputs = model(**batch_dict)
            embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

            del batch_dict, outputs
            torch.cuda.empty_cache()
            gc.collect()
        return embeddings

I am wondering if this is an expected behaviour, probably not.

The same problem, and I had not solved it by now. I guess it is because the context is saved in gpu with the input increase, which means the output of current input contains the knowledge of previous input. I guess the previous information should be deleted before input curren text to the model. But I do not know how to do this. Any methods?

Sign up or log in to comment