MiniCPM-Reranker-Light / scripts /flagembedding_demo.py
Kaguya-19's picture
Update scripts/flagembedding_demo.py
03cf8e8 verified
from FlagEmbedding import FlagReranker
model_name = "OpenBMB/MiniCPM-Reranker-Light"
model = FlagReranker(model_name, use_fp16=True, query_instruction_for_rerank="Query: ", trust_remote_code=True)
# You can hack the __init__() method of the FlagEmbedding BaseReranker class to use flash_attention_2 for faster inference
# self.model = AutoModelForSequenceClassification.from_pretrained(
# model_name_or_path,
# trust_remote_code=trust_remote_code,
# cache_dir=cache_dir,
# # torch_dtype=torch.float16, # we need to add this line to use fp16
# # attn_implementation="flash_attention_2", # we need to add this line to use flash_attention_2
# )
model.tokenizer.padding_side = "right"
query = "中国的首都是哪里?" # "Where is the capital of China?"
passages = ["beijing", "shanghai"] # 北京,上海
sentence_pairs = [[query, doc] for doc in passages]
scores = model.compute_score(sentence_pairs,normalize=True)
print(scores) # [0.01791734476747132, 0.0002472934613244585]