hotchpotch
commited on
Commit
•
fa0ad02
1
Parent(s):
dfba912
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- hpprc/emb
|
5 |
+
- hotchpotch/japanese-splade-v1-hard-negatives
|
6 |
+
- hpprc/msmarco-ja
|
7 |
+
language:
|
8 |
+
- ja
|
9 |
+
base_model:
|
10 |
+
- hotchpotch/japanese-splade-base-v1_5
|
11 |
+
---
|
12 |
+
|
13 |
+
高性能な日本語 [SPLADE](https://github.com/naver/splade) (Sparse Lexical and Expansion Model) モデルです。[テキストからスパースベクトルへの変換デモ](https://huggingface.co/spaces/hotchpotch/japanese-splade-demo-streamlit)で、どのようなスパースベクトルに変換できるか、WebUI から気軽にお試しいただけます。
|
14 |
+
|
15 |
+
- ⭐️URLの差し替えを行う
|
16 |
+
|
17 |
+
また、モデルの学習には[YAST - Yet Another SPLADE or Sparse Trainer](https://github.com/hotchpotch/yast)を使っています。
|
18 |
+
|
19 |
+
|
20 |
+
# 利用方法
|
21 |
+
|
22 |
+
## [YASEM (Yet Another Splade|Sparse Embedder)](https://github.com/hotchpotch/yasem)
|
23 |
+
|
24 |
+
```bash
|
25 |
+
pip install yasem
|
26 |
+
```
|
27 |
+
|
28 |
+
```python
|
29 |
+
from yasem import SpladeEmbedder
|
30 |
+
|
31 |
+
model_name = "hotchpotch/japanese-splade-base-v2"
|
32 |
+
embedder = SpladeEmbedder(model_name)
|
33 |
+
|
34 |
+
query = "車の燃費を向上させる方法は?"
|
35 |
+
docs = [
|
36 |
+
"急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。",
|
37 |
+
"車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。",
|
38 |
+
"車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。",
|
39 |
+
]
|
40 |
+
|
41 |
+
print(embedder.rank(query, docs, return_documents=True))
|
42 |
+
```
|
43 |
+
```
|
44 |
+
|
45 |
+
```
|
46 |
+
```json
|
47 |
+
[
|
48 |
+
{ 'corpus_id': 0
|
49 |
+
, 'score': 4.28
|
50 |
+
, 'text': '急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。' }
|
51 |
+
,
|
52 |
+
{ 'corpus_id': 2
|
53 |
+
, 'score': 2.47
|
54 |
+
, 'text': '車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。' }
|
55 |
+
,
|
56 |
+
{ 'corpus_id': 1
|
57 |
+
, 'score': 2.34
|
58 |
+
, 'text': '車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。' }
|
59 |
+
]
|
60 |
+
```
|
61 |
+
|
62 |
+
```python
|
63 |
+
sentences = [query] + docs
|
64 |
+
|
65 |
+
embeddings = embedder.encode(sentences)
|
66 |
+
similarity = embedder.similarity(embeddings, embeddings)
|
67 |
+
|
68 |
+
print(similarity)
|
69 |
+
```
|
70 |
+
|
71 |
+
```json
|
72 |
+
[[5.19151189, 4.28027662, 2.34164901, 2.47221905],
|
73 |
+
[4.28027662, 11.64426784, 5.00328318, 2.15031016],
|
74 |
+
[2.34164901, 5.00328318, 6.05594296, 1.33752085],
|
75 |
+
[2.47221905, 2.15031016, 1.33752085, 9.39414744]]
|
76 |
+
```
|
77 |
+
|
78 |
+
|
79 |
+
```python
|
80 |
+
token_values = embedder.get_token_values(embeddings[0])
|
81 |
+
print(token_values)
|
82 |
+
```
|
83 |
+
|
84 |
+
```json
|
85 |
+
{
|
86 |
+
'燃費': 1.13,
|
87 |
+
'方法': 1.07,
|
88 |
+
'車': 1.05,
|
89 |
+
'高める': 0.67,
|
90 |
+
'向上': 0.56,
|
91 |
+
'増加': 0.52,
|
92 |
+
'都市': 0.44,
|
93 |
+
'ガソリン': 0.32,
|
94 |
+
'改善': 0.30,
|
95 |
+
...
|
96 |
+
```
|
97 |
+
|
98 |
+
## transformers からの利用
|
99 |
+
|
100 |
+
```python
|
101 |
+
|
102 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
103 |
+
import torch
|
104 |
+
|
105 |
+
model = AutoModelForMaskedLM.from_pretrained(model_name)
|
106 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
107 |
+
|
108 |
+
def splade_max_pooling(logits, attention_mask):
|
109 |
+
relu_log = torch.log(1 + torch.relu(logits))
|
110 |
+
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
111 |
+
max_val, _ = torch.max(weighted_log, dim=1)
|
112 |
+
return max_val
|
113 |
+
|
114 |
+
tokens = tokenizer(
|
115 |
+
sentences, return_tensors="pt", padding=True, truncation=True, max_length=512
|
116 |
+
)
|
117 |
+
tokens = {k: v.to(model.device) for k, v in tokens.items()}
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
outputs = model(**tokens)
|
121 |
+
embeddings = splade_max_pooling(outputs.logits, tokens["attention_mask"])
|
122 |
+
|
123 |
+
similarity = torch.matmul(embeddings.unsqueeze(0), embeddings.T).squeeze(0)
|
124 |
+
print(similarity)
|
125 |
+
```
|
126 |
+
|
127 |
+
```python
|
128 |
+
tensor([
|
129 |
+
[5.1872, 4.2792, 2.3440, 2.4680],
|
130 |
+
[4.2792, 11.6327, 4.9983, 2.1470],
|
131 |
+
[2.3440, 4.9983, 6.0517, 1.3377],
|
132 |
+
[2.4680, 2.1470, 1.3377, 9.3801]
|
133 |
+
])
|
134 |
+
```
|
135 |
+
|
136 |
+
# ベンチマークスコア
|
137 |
+
|
138 |
+
## retrieval (JMTEB)
|
139 |
+
|
140 |
+
[JMTEB](https://github.com/sbintuitions/JMTEB) の評価結果です。japanese-splade-base-v1 は [JMTEB をスパースベクトルで評価できるように変更したコード](https://github.com/hotchpotch/JMTEB/tree/add_splade)での評価となっています。
|
141 |
+
なお、japanese-splade-base-v2 は jaqket, mrtydi, jagovfaqs nlp_jornal のドメインを**学習していません**。
|
142 |
+
|
143 |
+
|
144 |
+
| モデル名 | jagovfaqs | jaqket | mrtydi | nlp_journal <br>title_abs | nlp_journal <br>abs_intro | nlp_journal <br>title_intro | Avg <br><512 | Avg <br>ALL |
|
145 |
+
| ----------------------------- | --------- | ------ | ---------- | ------------------------- | ------------------------- | --------------------------- | ------------- | ------------ |
|
146 |
+
| japanese-splade-base-v2 | 0.7313 | 0.6986 | **0.5106** | **0.9831** | 0.9067 | 0.8026 | **0.7309** | **0.7722** |
|
147 |
+
| GLuCoSE-base-ja-v2 | 0.6979 | 0.6729 | 0.4186 | 0.9511 | 0.9029 | 0.7580 | 0.6851 | 0.7336 |
|
148 |
+
| multilingual-e5-large | 0.7030 | 0.5878 | 0.4363 | 0.9470 | 0.8600 | 0.7248 | 0.6685 | 0.7098 |
|
149 |
+
| ruri-large | 0.7668 | 0.6174 | 0.3803 | 0.9658 | 0.8712 | 0.7797 | 0.6826 | 0.7302 |
|
150 |
+
| jinaai/jina-embeddings-v3 | 0.7150 | 0.4648 | 0.4545 | 0.9562 | 0.9843 | 0.9385 | 0.6476 | 0.7522 |
|
151 |
+
| sarashina-embedding-v1-1b | 0.7168 | 0.7279 | 0.4195 | 0.9696 | 0.9394 | 0.8833 | 0.7085 | 0.7761 |
|
152 |
+
| OpenAI/text-embedding-3-large | 0.7241 | 0.4821 | 0.3488 | 0.9655 | 0.9933 | 0.9547 | 0.6301 | 0.7448 |
|
153 |
+
|
154 |
+
## 学習元データセット
|
155 |
+
|
156 |
+
- [hpprc/emb](https://huggingface.co/datasets/hpprc/emb)
|
157 |
+
- auto-wiki-qa
|
158 |
+
- jsquad
|
159 |
+
- jaquad
|
160 |
+
- auto-wiki-qa-nemotron
|
161 |
+
- quiz-works
|
162 |
+
- quiz-no-mori
|
163 |
+
- baobab-wiki-retrieval
|
164 |
+
- mkqa
|
165 |
+
- [hotchpotch/japanese-splade-v1-hard-negatives](https://huggingface.co/datasets/hotchpotch/japanese-splade-v1-hard-negatives)
|
166 |
+
- mmarco
|
167 |
+
- mqa
|
168 |
+
- msmarco-ja
|
169 |
+
-
|
170 |
+
また英語データセットとして、MS Marcoを利用しています。
|