Adding _set_gradient_checkpointing for compatibility

#22
by vriveras - opened
Files changed (1) hide show
  1. modeling_mixformer_sequential.py +4 -0
modeling_mixformer_sequential.py CHANGED
@@ -711,6 +711,10 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
711
  "past_key_values": past_key_values,
712
  "attention_mask": attention_mask,
713
  }
 
 
 
 
714
 
715
 
716
  class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
 
711
  "past_key_values": past_key_values,
712
  "attention_mask": attention_mask,
713
  }
714
+
715
+ def _set_gradient_checkpointing(self, module, value=False):
716
+ if isinstance(module, MixFormerSequentialPreTrainedModel):
717
+ module.gradient_checkpointing = value
718
 
719
 
720
  class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):