SmerkyG commited on
Commit
b972c5e
1 Parent(s): 468db1e

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +3 -0
modeling_rwkv5.py CHANGED
@@ -789,6 +789,9 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
789
  # only last token for inputs_ids if the state is passed along.
790
  if state is not None:
791
  input_ids = input_ids[:, -1].unsqueeze(-1)
 
 
 
792
 
793
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
794
  if inputs_embeds is not None and state is None:
 
789
  # only last token for inputs_ids if the state is passed along.
790
  if state is not None:
791
  input_ids = input_ids[:, -1].unsqueeze(-1)
792
+ else:
793
+ # add in \n at the beginning
794
+ input_ids = torch.cat([torch.full([1,1],11,device=input_ids.device,dtype=input_ids.dtype), input_ids])
795
 
796
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
797
  if inputs_embeds is not None and state is None: