Upload COCOM
Browse files- modelling_pisco.py +4 -4
modelling_pisco.py
CHANGED
@@ -985,7 +985,7 @@ class COCOM(PreTrainedModel):
|
|
985 |
# Creating encoder inputs:
|
986 |
if query_dependent:
|
987 |
# We provide the question for compression:
|
988 |
-
input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128, q_texts=questions)
|
989 |
else:
|
990 |
input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
|
991 |
|
@@ -1036,11 +1036,11 @@ class COCOM(PreTrainedModel):
|
|
1036 |
"""
|
1037 |
Compress a list of documents
|
1038 |
if questions is not None, assumes compression is done query-dependently !
|
1039 |
-
"""
|
1040 |
if questions is None:
|
1041 |
input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
|
1042 |
-
else:
|
1043 |
-
input_encoder = self.prepare_encoder_inputs(documents, max_length=128, q_texts=questions)
|
1044 |
enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
|
1045 |
attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
|
1046 |
return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
|
|
|
985 |
# Creating encoder inputs:
|
986 |
if query_dependent:
|
987 |
# We provide the question for compression:
|
988 |
+
input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128, q_texts=[question for question, docs in zip(questions, documents) for _ in docs])
|
989 |
else:
|
990 |
input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128)
|
991 |
|
|
|
1036 |
"""
|
1037 |
Compress a list of documents
|
1038 |
if questions is not None, assumes compression is done query-dependently !
|
1039 |
+
"""
|
1040 |
if questions is None:
|
1041 |
input_encoder = self.prepare_encoder_inputs(documents, max_length=128)
|
1042 |
+
else: # we assume query-dependent here:
|
1043 |
+
input_encoder = self.prepare_encoder_inputs(documents, max_length=128, q_texts=[question for question, docs in zip(questions, documents) for _ in docs])
|
1044 |
enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
|
1045 |
attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
|
1046 |
return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
|