keisuke-kiryu
commited on
Commit
·
7665c5f
1
Parent(s):
ba37e17
Update README.md
Browse files
README.md
CHANGED
@@ -22,27 +22,57 @@ widget:
|
|
22 |
|
23 |
|
24 |
# モデルの使い方
|
|
|
25 |
```python
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
```
|
47 |
|
48 |
# 学習データ
|
|
|
22 |
|
23 |
|
24 |
# モデルの使い方
|
25 |
+
## サンプルコード
|
26 |
```python
|
27 |
+
from transformers import AutoTokenizer,AutoModelForTokenClassification
|
28 |
+
import torch
|
29 |
+
import numpy as np
|
30 |
+
|
31 |
+
model_name = 'recruit-jp/japanese-typo-detector-roberta-base'
|
32 |
+
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
34 |
+
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
35 |
+
|
36 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
37 |
+
model = model.to(device)
|
38 |
+
|
39 |
+
in_text = "これは日本語の誤植を検出する真相学習モデルです。"
|
40 |
+
|
41 |
+
test_inputs = tokenizer(in_text, return_tensors='pt').get('input_ids')
|
42 |
+
test_outputs = model(test_inputs.to(torch.device(device)))
|
43 |
+
|
44 |
+
for chara, logit in zip(list(in_text), test_outputs.logits.squeeze().tolist()[1:-1]):
|
45 |
+
err_type_ind = np.argmax(logit)
|
46 |
+
err_name = model.config.id2label[err_type_ind]
|
47 |
+
err_desc = f"Detected!(err_index={err_type_ind}, err_name={err_name})" if err_type_ind > 0 else f""
|
48 |
+
print(f"{chara} : {err_desc}")
|
49 |
+
```
|
50 |
+
## サンプルコードの出力例
|
51 |
+
```
|
52 |
+
こ :
|
53 |
+
れ :
|
54 |
+
は :
|
55 |
+
日 :
|
56 |
+
本 :
|
57 |
+
語 :
|
58 |
+
の :
|
59 |
+
誤 :
|
60 |
+
植 :
|
61 |
+
を :
|
62 |
+
検 :
|
63 |
+
出 :
|
64 |
+
す :
|
65 |
+
る :
|
66 |
+
真 : Detected!(err_index=4, err_name=kanji-conversion_a)
|
67 |
+
相 : Detected!(err_index=4, err_name=kanji-conversion_a)
|
68 |
+
学 :
|
69 |
+
習 :
|
70 |
+
モ :
|
71 |
+
デ :
|
72 |
+
ル :
|
73 |
+
で :
|
74 |
+
す :
|
75 |
+
。 :
|
76 |
```
|
77 |
|
78 |
# 学習データ
|