keisuke-kiryu commited on
Commit
7665c5f
·
1 Parent(s): ba37e17

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +50 -20
README.md CHANGED
@@ -22,27 +22,57 @@ widget:
22
 
23
 
24
  # モデルの使い方
 
25
  ```python
26
- from transformers import AutoTokenizer,AutoModelForTokenClassification
27
-
28
- model_name('recruit-jp/japanese-typo-detector-roberta-base')
29
-
30
- tokenizer = AutoTokenizer.from_pretrained(model_name)
31
- model = AutoModelForTokenClassification.from_pretrained(model_name)
32
-
33
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
34
- model = model.to(device)
35
-
36
- in_text = "これは日本語の誤植を検出する真相学習モデルです。"
37
-
38
- test_inputs = tokenizer(in_text, return_tensors='pt').get('input_ids')
39
- test_outputs = model(test_inputs.to(torch.device(device)))
40
-
41
- for chara, logit in zip(["[CLS]"] + list(in_text) + ["[SEP]"], test_outputs.logits.squeeze().tolist()):
42
- err_type_ind = np.argmax(logit)
43
- err_name = model.config.id2label[err_type_ind]
44
- err_desc = f"★誤字(err_index={err_type_ind}, err_name={err_name})" if err_type_ind > 0 else f""
45
- print(f"{chara} : {err_desc}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  ```
47
 
48
  # 学習データ
 
22
 
23
 
24
  # モデルの使い方
25
+ ## サンプルコード
26
  ```python
27
+ from transformers import AutoTokenizer,AutoModelForTokenClassification
28
+ import torch
29
+ import numpy as np
30
+
31
+ model_name = 'recruit-jp/japanese-typo-detector-roberta-base'
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
35
+
36
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
37
+ model = model.to(device)
38
+
39
+ in_text = "これは日本語の誤植を検出する真相学習モデルです。"
40
+
41
+ test_inputs = tokenizer(in_text, return_tensors='pt').get('input_ids')
42
+ test_outputs = model(test_inputs.to(torch.device(device)))
43
+
44
+ for chara, logit in zip(list(in_text), test_outputs.logits.squeeze().tolist()[1:-1]):
45
+ err_type_ind = np.argmax(logit)
46
+ err_name = model.config.id2label[err_type_ind]
47
+ err_desc = f"Detected!(err_index={err_type_ind}, err_name={err_name})" if err_type_ind > 0 else f""
48
+ print(f"{chara} : {err_desc}")
49
+ ```
50
+ ## サンプルコードの出力例
51
+ ```
52
+ こ :
53
+ れ :
54
+ は :
55
+ 日 :
56
+ 本 :
57
+ 語 :
58
+ の :
59
+ 誤 :
60
+ 植 :
61
+ を :
62
+ 検 :
63
+ 出 :
64
+ す :
65
+ る :
66
+ 真 : Detected!(err_index=4, err_name=kanji-conversion_a)
67
+ 相 : Detected!(err_index=4, err_name=kanji-conversion_a)
68
+ 学 :
69
+ 習 :
70
+ モ :
71
+ デ :
72
+ ル :
73
+ で :
74
+ す :
75
+ 。 :
76
  ```
77
 
78
  # 学習データ