numb3r3 commited on
Commit
4c9340f
·
verified ·
1 Parent(s): 1d1bb09

fix: add missing module

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. modeling_xlm_roberta.py +214 -0
README.md CHANGED
@@ -90,7 +90,7 @@ results = model.rank(query, documents, return_documents=True, top_k=3)
90
  from transformers import AutoModelForSequenceClassification
91
 
92
  model = AutoModelForSequenceClassification.from_pretrained(
93
- 'jinaai/jina-reranker-v2-base-multilingual', num_labels=1, trust_remote_code=True
94
  )
95
 
96
  # Example query and documents
 
90
  from transformers import AutoModelForSequenceClassification
91
 
92
  model = AutoModelForSequenceClassification.from_pretrained(
93
+ 'jinaai/jina-reranker-v2-base-multilingual', trust_remote_code=True
94
  )
95
 
96
  # Example query and documents
modeling_xlm_roberta.py CHANGED
@@ -902,3 +902,217 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
902
  hidden_states=outputs.hidden_states,
903
  attentions=outputs.attentions,
904
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
902
  hidden_states=outputs.hidden_states,
903
  attentions=outputs.attentions,
904
  )
905
+
906
+
907
+ @torch.inference_mode()
908
+ def compute_score(
909
+ self,
910
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
911
+ batch_size: int = 32,
912
+ max_length: Optional[int] = None,
913
+ ) -> List[float]:
914
+
915
+ if not hasattr(self, "_tokenizer"):
916
+ from transformers import AutoTokenizer
917
+
918
+ self._tokenizer = AutoTokenizer.from_pretrained(
919
+ self.name_or_path, trust_remote_code=True
920
+ )
921
+
922
+ assert isinstance(sentence_pairs, list)
923
+ if isinstance(sentence_pairs[0], str):
924
+ sentence_pairs = [sentence_pairs]
925
+
926
+ all_scores = []
927
+ for start_index in range(
928
+ 0, len(sentence_pairs), batch_size
929
+ ):
930
+ sentences_batch = sentence_pairs[
931
+ start_index : start_index + batch_size
932
+ ]
933
+ inputs = self._tokenizer(
934
+ sentences_batch,
935
+ padding=True,
936
+ truncation=True,
937
+ return_tensors='pt',
938
+ max_length=max_length,
939
+ ).to(self.device)
940
+ scores = (
941
+ self.forward(**inputs, return_dict=True)
942
+ .logits.view(
943
+ -1,
944
+ )
945
+ .float()
946
+ )
947
+ all_scores.extend(scores.cpu().numpy().tolist())
948
+
949
+ if len(all_scores) == 1:
950
+ return all_scores[0]
951
+ return all_scores
952
+
953
+ def predict(
954
+ self,
955
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
956
+ batch_size: int = 32,
957
+ max_length: Optional[int] = None,
958
+ ) -> List[float]:
959
+ # used for beir evaluation
960
+ return self.compute_score(sentence_pairs, batch_size=batch_size, max_length=max_length)
961
+
962
+ def rerank(
963
+ self,
964
+ query: str,
965
+ documents: List[str],
966
+ batch_size: int = 32,
967
+ max_length: int = 1024,
968
+ max_query_length: int = 512,
969
+ overlap_tokens: int = 80,
970
+ top_n: Optional[int] = None,
971
+ **kwargs,
972
+ ):
973
+ assert max_length >= max_query_length * 2, (
974
+ f'max_length ({max_length}) must be greater than or equal to '
975
+ f'max_query_length ({max_query_length}) * 2'
976
+ )
977
+
978
+ if not hasattr(self, "_tokenizer"):
979
+ from transformers import AutoTokenizer
980
+
981
+ self._tokenizer = AutoTokenizer.from_pretrained(
982
+ self.name_or_path, trust_remote_code=True
983
+ )
984
+
985
+ # preproc of tokenization
986
+ sentence_pairs, sentence_pairs_pids = reranker_tokenize_preproc(
987
+ query,
988
+ documents,
989
+ tokenizer=self._tokenizer,
990
+ max_length=max_length,
991
+ max_query_length=max_query_length,
992
+ overlap_tokens=overlap_tokens,
993
+ )
994
+
995
+ tot_scores = []
996
+ with torch.no_grad():
997
+ for k in range(0, len(sentence_pairs), batch_size):
998
+ batch = self._tokenizer.pad(
999
+ sentence_pairs[k : k + batch_size],
1000
+ padding=True,
1001
+ max_length=max_length,
1002
+ pad_to_multiple_of=None,
1003
+ return_tensors="pt",
1004
+ )
1005
+ batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
1006
+ scores = (
1007
+ self.forward(**batch_on_device, return_dict=True)
1008
+ .logits.view(
1009
+ -1,
1010
+ )
1011
+ .float()
1012
+ )
1013
+ scores = torch.sigmoid(scores)
1014
+ tot_scores.extend(scores.cpu().numpy().tolist())
1015
+
1016
+ # ranking
1017
+ merge_scores = [0 for _ in range(len(documents))]
1018
+ for pid, score in zip(sentence_pairs_pids, tot_scores):
1019
+ merge_scores[pid] = max(merge_scores[pid], score)
1020
+
1021
+ merge_scores_argsort = np.argsort(merge_scores)[::-1]
1022
+ sorted_documents = []
1023
+ sorted_scores = []
1024
+ for mid in merge_scores_argsort:
1025
+ sorted_scores.append(merge_scores[mid])
1026
+ sorted_documents.append(documents[mid])
1027
+
1028
+ top_n = min(top_n or len(sorted_documents), len(sorted_documents))
1029
+
1030
+ return [
1031
+ {
1032
+ 'document': sorted_documents[i],
1033
+ 'relevance_score': sorted_scores[i],
1034
+ 'index': merge_scores_argsort[i],
1035
+ }
1036
+ for i in range(top_n)
1037
+ ]
1038
+
1039
+
1040
+ def reranker_tokenize_preproc(
1041
+ query: str,
1042
+ passages: List[str],
1043
+ tokenizer=None,
1044
+ max_length: int = 1024,
1045
+ max_query_length: int = 512,
1046
+ overlap_tokens: int = 80,
1047
+ ):
1048
+ from copy import deepcopy
1049
+
1050
+ assert tokenizer is not None, "Please provide a valid tokenizer for tokenization!"
1051
+ sep_id = tokenizer.sep_token_id
1052
+
1053
+ def _merge_inputs(chunk1_raw, chunk2):
1054
+ chunk1 = deepcopy(chunk1_raw)
1055
+ chunk1['input_ids'].append(sep_id)
1056
+ chunk1['input_ids'].extend(chunk2['input_ids'])
1057
+ chunk1['input_ids'].append(sep_id)
1058
+ chunk1['attention_mask'].append(chunk2['attention_mask'][0])
1059
+ chunk1['attention_mask'].extend(chunk2['attention_mask'])
1060
+ chunk1['attention_mask'].append(chunk2['attention_mask'][-1])
1061
+ if 'token_type_ids' in chunk1:
1062
+ token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 2)]
1063
+ chunk1['token_type_ids'].extend(token_type_ids)
1064
+ return chunk1
1065
+
1066
+ # Note: the long query will be truncated to 256 tokens by default
1067
+ query_inputs = tokenizer.encode_plus(
1068
+ query, truncation=True, padding=False, max_length=max_query_length
1069
+ )
1070
+
1071
+ max_passage_inputs_length = max_length - len(query_inputs['input_ids']) - 2
1072
+ # assert (
1073
+ # max_passage_inputs_length > 100
1074
+ # ), "Your query is too long! Please make sure your query less than 500 tokens!"
1075
+
1076
+ overlap_tokens_implt = min(overlap_tokens, max_passage_inputs_length // 4)
1077
+
1078
+ res_merge_inputs = []
1079
+ res_merge_inputs_pids = []
1080
+ for pid, passage in enumerate(passages):
1081
+ passage_inputs = tokenizer.encode_plus(
1082
+ passage,
1083
+ truncation=False,
1084
+ padding=False,
1085
+ add_special_tokens=False,
1086
+ max_length=0,
1087
+ )
1088
+ passage_inputs_length = len(passage_inputs['input_ids'])
1089
+
1090
+ if passage_inputs_length <= max_passage_inputs_length:
1091
+ qp_merge_inputs = _merge_inputs(query_inputs, passage_inputs)
1092
+ res_merge_inputs.append(qp_merge_inputs)
1093
+ res_merge_inputs_pids.append(pid)
1094
+ else:
1095
+ start_id = 0
1096
+ while start_id < passage_inputs_length:
1097
+ end_id = start_id + max_passage_inputs_length
1098
+ # make sure the length of the last chunk is `max_passage_inputs_length`
1099
+ if end_id >= passage_inputs_length:
1100
+ sub_passage_inputs = {
1101
+ k: v[-max_passage_inputs_length:]
1102
+ for k, v in passage_inputs.items()
1103
+ }
1104
+ else:
1105
+ sub_passage_inputs = {
1106
+ k: v[start_id:end_id] for k, v in passage_inputs.items()
1107
+ }
1108
+ start_id = (
1109
+ end_id - overlap_tokens_implt
1110
+ if end_id < passage_inputs_length
1111
+ else end_id
1112
+ )
1113
+
1114
+ qp_merge_inputs = _merge_inputs(query_inputs, sub_passage_inputs)
1115
+ res_merge_inputs.append(qp_merge_inputs)
1116
+ res_merge_inputs_pids.append(pid)
1117
+
1118
+ return res_merge_inputs, res_merge_inputs_pids