maxoul commited on
Commit
746b6af
·
verified ·
1 Parent(s): 9689211

Upload COCOM

Browse files
Files changed (1) hide show
  1. modelling_pisco.py +4 -4
modelling_pisco.py CHANGED
@@ -985,7 +985,7 @@ class COCOM(PreTrainedModel):
985
  # Creating encoder inputs:
986
  if query_dependent:
987
  # We provide the question for compression:
988
- input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128, q_texts=questions)
989
  else:
990
  input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
991
 
@@ -1036,11 +1036,11 @@ class COCOM(PreTrainedModel):
1036
  """
1037
  Compress a list of documents
1038
  if questions is not None, assumes compression is done query-dependently !
1039
- """
1040
  if questions is None:
1041
  input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
1042
- else:
1043
- input_encoder = self.prepare_encoder_inputs(documents, max_length=128, q_texts=questions)
1044
  enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
1045
  attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
1046
  return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
 
985
  # Creating encoder inputs:
986
  if query_dependent:
987
  # We provide the question for compression:
988
+ input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128, q_texts=[question for question, docs in zip(questions, documents) for _ in docs])
989
  else:
990
  input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
991
 
 
1036
  """
1037
  Compress a list of documents
1038
  if questions is not None, assumes compression is done query-dependently !
1039
+ """
1040
  if questions is None:
1041
  input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
1042
+ else: # we assume query-dependent here:
1043
+ input_encoder = self.prepare_encoder_inputs(documents, max_length=128, q_texts=[question for question, docs in zip(questions, documents) for _ in docs])
1044
  enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
1045
  attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
1046
  return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)