sudy-super commited on
Commit
8959561
·
verified ·
1 Parent(s): b7c4310

Update modeling_c_cubed.py

Browse files
Files changed (1) hide show
  1. modeling_c_cubed.py +27 -1
modeling_c_cubed.py CHANGED
@@ -707,4 +707,30 @@ class CcubedForConditionalGeneration(CcubedPreTrainedModel):
707
  hidden_states=outputs.hidden_states,
708
  attentions=outputs.attentions,
709
  context_hidden_states=context_features,
710
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  hidden_states=outputs.hidden_states,
708
  attentions=outputs.attentions,
709
  context_hidden_states=context_features,
710
+ )
711
+
712
+ def prepare_inputs_for_generation(
713
+ self,
714
+ input_ids,
715
+ inputs_embeds=None,
716
+ past_key_values=None,
717
+ attention_mask=None,
718
+ context_attention_mask=None,
719
+ **kwargs
720
+ ):
721
+ if past_key_values:
722
+ input_ids = input_ids[:, -1:]
723
+
724
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
725
+ if inputs_embeds is not None and past_key_values is None:
726
+ model_inputs = {"inputs_embeds": inputs_embeds}
727
+ else:
728
+ model_inputs = {"input_ids": input_ids}
729
+
730
+ model_inputs.update({
731
+ "past_key_values": past_key_values,
732
+ "use_cache": kwargs.get("use_cache"),
733
+ "attention_mask": attention_mask,
734
+ "context_attention_mask": context_attention_mask
735
+ })
736
+ return model_inputs