Safetensors
Japanese
bert
hotchpotch commited on
Commit
fa0ad02
1 Parent(s): dfba912

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +170 -0
README.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - hpprc/emb
5
+ - hotchpotch/japanese-splade-v1-hard-negatives
6
+ - hpprc/msmarco-ja
7
+ language:
8
+ - ja
9
+ base_model:
10
+ - hotchpotch/japanese-splade-base-v1_5
11
+ ---
12
+
13
+ 高性能な日本語 [SPLADE](https://github.com/naver/splade) (Sparse Lexical and Expansion Model) モデルです。[テキストからスパースベクトルへの変換デモ](https://huggingface.co/spaces/hotchpotch/japanese-splade-demo-streamlit)で、どのようなスパースベクトルに変換できるか、WebUI から気軽にお試しいただけます。
14
+
15
+ - ⭐️URLの差し替えを行う
16
+
17
+ また、モデルの学習には[YAST - Yet Another SPLADE or Sparse Trainer](https://github.com/hotchpotch/yast)を使っています。
18
+
19
+
20
+ # 利用方法
21
+
22
+ ## [YASEM (Yet Another Splade|Sparse Embedder)](https://github.com/hotchpotch/yasem)
23
+
24
+ ```bash
25
+ pip install yasem
26
+ ```
27
+
28
+ ```python
29
+ from yasem import SpladeEmbedder
30
+
31
+ model_name = "hotchpotch/japanese-splade-base-v2"
32
+ embedder = SpladeEmbedder(model_name)
33
+
34
+ query = "車の燃費を向上させる方法は?"
35
+ docs = [
36
+ "急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。",
37
+ "車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。",
38
+ "車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。",
39
+ ]
40
+
41
+ print(embedder.rank(query, docs, return_documents=True))
42
+ ```
43
+ ```
44
+
45
+ ```
46
+ ```json
47
+ [
48
+ { 'corpus_id': 0
49
+ , 'score': 4.28
50
+ , 'text': '急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。' }
51
+ ,
52
+ { 'corpus_id': 2
53
+ , 'score': 2.47
54
+ , 'text': '車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。' }
55
+ ,
56
+ { 'corpus_id': 1
57
+ , 'score': 2.34
58
+ , 'text': '車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。' }
59
+ ]
60
+ ```
61
+
62
+ ```python
63
+ sentences = [query] + docs
64
+
65
+ embeddings = embedder.encode(sentences)
66
+ similarity = embedder.similarity(embeddings, embeddings)
67
+
68
+ print(similarity)
69
+ ```
70
+
71
+ ```json
72
+ [[5.19151189, 4.28027662, 2.34164901, 2.47221905],
73
+ [4.28027662, 11.64426784, 5.00328318, 2.15031016],
74
+ [2.34164901, 5.00328318, 6.05594296, 1.33752085],
75
+ [2.47221905, 2.15031016, 1.33752085, 9.39414744]]
76
+ ```
77
+
78
+
79
+ ```python
80
+ token_values = embedder.get_token_values(embeddings[0])
81
+ print(token_values)
82
+ ```
83
+
84
+ ```json
85
+ {
86
+ '燃費': 1.13,
87
+ '方法': 1.07,
88
+ '車': 1.05,
89
+ '高める': 0.67,
90
+ '向上': 0.56,
91
+ '増加': 0.52,
92
+ '都市': 0.44,
93
+ 'ガソリン': 0.32,
94
+ '改善': 0.30,
95
+ ...
96
+ ```
97
+
98
+ ## transformers からの利用
99
+
100
+ ```python
101
+
102
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
103
+ import torch
104
+
105
+ model = AutoModelForMaskedLM.from_pretrained(model_name)
106
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
107
+
108
+ def splade_max_pooling(logits, attention_mask):
109
+ relu_log = torch.log(1 + torch.relu(logits))
110
+ weighted_log = relu_log * attention_mask.unsqueeze(-1)
111
+ max_val, _ = torch.max(weighted_log, dim=1)
112
+ return max_val
113
+
114
+ tokens = tokenizer(
115
+ sentences, return_tensors="pt", padding=True, truncation=True, max_length=512
116
+ )
117
+ tokens = {k: v.to(model.device) for k, v in tokens.items()}
118
+
119
+ with torch.no_grad():
120
+ outputs = model(**tokens)
121
+ embeddings = splade_max_pooling(outputs.logits, tokens["attention_mask"])
122
+
123
+ similarity = torch.matmul(embeddings.unsqueeze(0), embeddings.T).squeeze(0)
124
+ print(similarity)
125
+ ```
126
+
127
+ ```python
128
+ tensor([
129
+ [5.1872, 4.2792, 2.3440, 2.4680],
130
+ [4.2792, 11.6327, 4.9983, 2.1470],
131
+ [2.3440, 4.9983, 6.0517, 1.3377],
132
+ [2.4680, 2.1470, 1.3377, 9.3801]
133
+ ])
134
+ ```
135
+
136
+ # ベンチマークスコア
137
+
138
+ ## retrieval (JMTEB)
139
+
140
+ [JMTEB](https://github.com/sbintuitions/JMTEB) の評価結果です。japanese-splade-base-v1 は [JMTEB をスパースベクトルで評価できるように変更したコード](https://github.com/hotchpotch/JMTEB/tree/add_splade)での評価となっています。
141
+ なお、japanese-splade-base-v2 は jaqket, mrtydi, jagovfaqs nlp_jornal のドメインを**学習していません**。
142
+
143
+
144
+ | モデル名 | jagovfaqs | jaqket | mrtydi | nlp_journal <br>title_abs | nlp_journal <br>abs_intro | nlp_journal <br>title_intro | Avg <br><512 | Avg <br>ALL |
145
+ | ----------------------------- | --------- | ------ | ---------- | ------------------------- | ------------------------- | --------------------------- | ------------- | ------------ |
146
+ | japanese-splade-base-v2 | 0.7313 | 0.6986 | **0.5106** | **0.9831** | 0.9067 | 0.8026 | **0.7309** | **0.7722** |
147
+ | GLuCoSE-base-ja-v2 | 0.6979 | 0.6729 | 0.4186 | 0.9511 | 0.9029 | 0.7580 | 0.6851 | 0.7336 |
148
+ | multilingual-e5-large | 0.7030 | 0.5878 | 0.4363 | 0.9470 | 0.8600 | 0.7248 | 0.6685 | 0.7098 |
149
+ | ruri-large | 0.7668 | 0.6174 | 0.3803 | 0.9658 | 0.8712 | 0.7797 | 0.6826 | 0.7302 |
150
+ | jinaai/jina-embeddings-v3 | 0.7150 | 0.4648 | 0.4545 | 0.9562 | 0.9843 | 0.9385 | 0.6476 | 0.7522 |
151
+ | sarashina-embedding-v1-1b | 0.7168 | 0.7279 | 0.4195 | 0.9696 | 0.9394 | 0.8833 | 0.7085 | 0.7761 |
152
+ | OpenAI/text-embedding-3-large | 0.7241 | 0.4821 | 0.3488 | 0.9655 | 0.9933 | 0.9547 | 0.6301 | 0.7448 |
153
+
154
+ ## 学習元データセット
155
+
156
+ - [hpprc/emb](https://huggingface.co/datasets/hpprc/emb)
157
+ - auto-wiki-qa
158
+ - jsquad
159
+ - jaquad
160
+ - auto-wiki-qa-nemotron
161
+ - quiz-works
162
+ - quiz-no-mori
163
+ - baobab-wiki-retrieval
164
+ - mkqa
165
+ - [hotchpotch/japanese-splade-v1-hard-negatives](https://huggingface.co/datasets/hotchpotch/japanese-splade-v1-hard-negatives)
166
+ - mmarco
167
+ - mqa
168
+ - msmarco-ja
169
+ -
170
+ また英語データセットとして、MS Marcoを利用しています。