terapyon commited on
Commit
e88a707
·
1 Parent(s): 3227cde

modify inference for model file

Browse files
Files changed (1) hide show
  1. inference.py +18 -15
inference.py CHANGED
@@ -1,4 +1,3 @@
1
- import copy
2
  import re
3
  from pathlib import Path
4
  from typing import Generator
@@ -10,6 +9,7 @@ import tomotopy as tp # type: ignore
10
  import torch
11
  import torch.nn as nn
12
  import transformers as T # type: ignore
 
13
  from scipy import stats # type: ignore
14
  from sudachipy import dictionary, tokenizer # type: ignore
15
 
@@ -29,7 +29,7 @@ else:
29
  gpu = -1 # gpu = -1 # GPUが使用できなければ(CPUで処理)-1を指定
30
 
31
 
32
- cls_num = 3
33
  max_length = 512
34
  k_folds = 10
35
  bert_model_name = "cl-tohoku/bert-base-japanese-v3"
@@ -37,11 +37,10 @@ device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu")
37
 
38
 
39
  #BERTモデルの定義
40
- class BertClassifier(nn.Module):
41
- def __init__(self, model_name, cls_num=3):
42
- super(BertClassifier, self).__init__()
43
- #model_name = "cl-tohoku/bert-base-japanese"
44
- self.bert = T.BertModel.from_pretrained(model_name, output_attentions=True)
45
  self.fc = nn.Linear(768, cls_num, bias=True)
46
 
47
  nn.init.normal_(self.fc.weight, std=0.02)
@@ -115,27 +114,31 @@ class SudachiTokenizer:
115
  return token_list
116
 
117
 
118
- def make_traind_model(bert_model):
119
  trained_models = []
120
  for k in range(k_folds):
121
  k = k + 1
122
- model_path = model_base_path / f"trained_model{k}.pt"
123
- trained_model = copy.deepcopy(bert_model)
124
- trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
 
 
 
125
  trained_models.append(trained_model)
126
  return trained_models
127
 
128
 
129
  @st.cache_resource
130
  def init_models():
131
- bert_model = BertClassifier(bert_model_name, cls_num=1) #出力ノードを1に設定
132
- bert_model.eval()
133
- bert_model.to(device)
134
 
135
  tokenizer_sudachi = SudachiTokenizer(split_mode="C")
136
  #Tokenizerの設定(ここではtokenizerをtokenizer_c2にしている)
137
  tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name)
138
- trained_models = make_traind_model(bert_model)
 
139
  return tokenizer_sudachi, tokenizer_c2, trained_models
140
 
141
 
 
 
1
  import re
2
  from pathlib import Path
3
  from typing import Generator
 
9
  import torch
10
  import torch.nn as nn
11
  import transformers as T # type: ignore
12
+ from huggingface_hub import PyTorchModelHubMixin # type: ignore
13
  from scipy import stats # type: ignore
14
  from sudachipy import dictionary, tokenizer # type: ignore
15
 
 
29
  gpu = -1 # gpu = -1 # GPUが使用できなければ(CPUで処理)-1を指定
30
 
31
 
32
+ # cls_num = 3
33
  max_length = 512
34
  k_folds = 10
35
  bert_model_name = "cl-tohoku/bert-base-japanese-v3"
 
37
 
38
 
39
  #BERTモデルの定義
40
+ class BertClassifier(nn.Module, PyTorchModelHubMixin):
41
+ def __init__(self, cls_num: int):
42
+ super().__init__()
43
+ self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True)
 
44
  self.fc = nn.Linear(768, cls_num, bias=True)
45
 
46
  nn.init.normal_(self.fc.weight, std=0.02)
 
114
  return token_list
115
 
116
 
117
+ def make_traind_model():
118
  trained_models = []
119
  for k in range(k_folds):
120
  k = k + 1
121
+ # model_path = model_base_path / f"trained_model{k}.pt"
122
+ # trained_model = copy.deepcopy(bert_model)
123
+ # trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
124
+ # trained_models.append(trained_model)
125
+ model_name = MODEL_BASE + str(k)
126
+ trained_model = BertClassifier.from_pretrained(model_name).to(device)
127
  trained_models.append(trained_model)
128
  return trained_models
129
 
130
 
131
  @st.cache_resource
132
  def init_models():
133
+ # bert_model = BertClassifier(cls_num=1) #出力ノードを1に設定
134
+ # bert_model.eval()
135
+ # bert_model.to(device)
136
 
137
  tokenizer_sudachi = SudachiTokenizer(split_mode="C")
138
  #Tokenizerの設定(ここではtokenizerをtokenizer_c2にしている)
139
  tokenizer_c2 = T.BertJapaneseTokenizer.from_pretrained(bert_model_name)
140
+ # trained_models = make_traind_model(bert_model)
141
+ trained_models = make_traind_model()
142
  return tokenizer_sudachi, tokenizer_c2, trained_models
143
 
144