svjack commited on
Commit
50fee82
·
1 Parent(s): 9612197

Upload zh_mt5_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. zh_mt5_model.py +28 -0
zh_mt5_model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Tokenizer, MT5ForConditionalGeneration
2
+
3
+ class T5_B(object):
4
+ def __init__(self, model: str = "google/t5-large-ssm", device = 'cuda:0'):
5
+ self.device = device
6
+ self.tokenizer = T5Tokenizer.from_pretrained(model)
7
+ if device == 'multigpu':
8
+ self.model = MT5ForConditionalGeneration.from_pretrained(model).eval()
9
+ self.model.parallelize()
10
+ else:
11
+ self.model = MT5ForConditionalGeneration.from_pretrained(model).to(device).eval()
12
+
13
+ def predict(self, question: str):
14
+ device = 'cuda:0' if self.device == 'multigpu' else self.device
15
+ encode = self.tokenizer(question, return_tensors='pt').to(device)
16
+ answer = self.model.generate(encode.input_ids)[0]
17
+ decoded = self.tokenizer.decode(answer, skip_special_tokens=True)
18
+ return decoded
19
+
20
+ def predict_batch(self, question_list):
21
+ assert type(question_list) == type([])
22
+ device = 'cuda:0' if self.device == 'multigpu' else self.device
23
+ encode = self.tokenizer(question_list, return_tensors='pt', padding = True).to(device)
24
+ answer = self.model.generate(**encode)
25
+ #return answer
26
+ decoded = [self.tokenizer.decode(ans, skip_special_tokens=True) for ans in answer]
27
+ #decoded = self.tokenizer.decode(answer, skip_special_tokens=True)
28
+ return decoded