jiaxin-wen commited on
Commit
230e52a
·
1 Parent(s): 7bfd512

initial commit

Browse files
Files changed (1) hide show
  1. README.md +32 -0
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This model has been trained on massive Chinese plain-text open-domain dialogues following the approach described in [Re$^3$Dial: Retrieve, Reorganize and Rescale Conversations for Long-Turn Open-Domain Dialogue Pre-training](https://arxiv.org/abs/2305.02606). The associated Github repository is available here https://github.com/thu-coai/Re3Dial.
2
+
3
+ ### Usage
4
+
5
+ ```python
6
+ from transformers import BertTokenizer, BertModel
7
+ import torch
8
+
9
+
10
+ def get_embedding(encoder, inputs):
11
+ outputs = encoder(**inputs)
12
+ pooled_output = outputs[0][:, 0, :]
13
+ return pooled_output
14
+
15
+ tokenizer = BertTokenizer.from_pretrained('xwwwww/bert-chinese-dialogue-retriever-query')
16
+ tokenizer.add_tokens(['<uttsep>'])
17
+ query_encoder = BertModel.from_pretrained('xwwwww/bert-chinese-dialogue-retriever-query')
18
+ context_encoder = BertModel.from_pretrained('xwwwww/bert-chinese-dialogue-retriever-context')
19
+
20
+ query = '你好<uttsep>好久不见,最近在干嘛'
21
+ context = '正在准备考试<uttsep>是什么考试呀,很辛苦吧'
22
+
23
+ query_inputs = tokenizer([query], return_tensors='pt')
24
+ context_inputs = tokenizer([context], return_tensors='pt')
25
+
26
+ query_embedding = get_embedding(query_encoder, query_inputs)
27
+ context_embedding = get_embedding(context_encoder, context_inputs)
28
+
29
+ score = torch.cosine_similarity(query_embedding, context_embedding, dim=1)
30
+
31
+ print('similarity score = ', score)
32
+ ```