saily commited on
Commit
fb15db5
·
1 Parent(s): 599a980

change model to randeng

Browse files
Files changed (1) hide show
  1. ebart.py +20 -16
ebart.py CHANGED
@@ -1,30 +1,33 @@
1
  import spaces
2
  import torch
3
- from transformers import PegasusForConditionalGeneration
4
- # 从 Fengshenbang-LM 下载 tokenizers_pegasus.py 和其他 Python 脚本
5
- from tokenizers_pegasus import PegasusTokenizer
6
 
7
  @spaces.GPU
8
  def generate_summary(text, max_length=180, min_length=64):
9
  # 加载标记器和模型
10
- model = PegasusForConditionalGeneration.from_pretrained("IDEA-CCNL/Randeng-Pegasus-238M-Summary-Chinese")
11
- tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-238M-Summary-Chinese")
12
-
13
 
14
  # 将模型移动到GPU
15
- #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- #model.to(device)
17
 
18
- # 进行标记化
19
- inputs = tokenizer(text, max_length=1024, return_tensors="pt")#.to(device)
20
 
21
  # 生成摘要
22
  summary_ids = model.generate(
23
- inputs["input_ids"]
24
- )
25
- clean_summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
26
- print(clean_summary)
27
-
 
 
 
 
 
28
 
29
  if __name__ == "__main__":
30
  text = (
@@ -39,4 +42,5 @@ if __name__ == "__main__":
39
  "从97年坑害到23年大量的客户没有开局相应的发票,存在重大偷税漏税嫌疑,请湖南税务机关对其进行查处!"
40
  "还有其出口的设备,渠道是否正规,是白关,灰关,还是黑关,请湖南海关相关部门对其进行查处。"
41
  )
42
- generate_summary(text, max_length=128, min_length=64)
 
 
1
  import spaces
2
  import torch
3
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
4
+ #from tokenizers_pegasus import PegasusTokenizer
 
5
 
6
  @spaces.GPU
7
  def generate_summary(text, max_length=180, min_length=64):
8
  # 加载标记器和模型
9
+ model = PegasusForConditionalGeneration.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1")
10
+ tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1")
 
11
 
12
  # 将模型移动到GPU
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
 
16
+ # 进行标记化并将输入数据移动到GPU
17
+ inputs = tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device)
18
 
19
  # 生成摘要
20
  summary_ids = model.generate(
21
+ inputs["input_ids"],
22
+ max_length=max_length,
23
+ min_length=min_length,
24
+ num_beams=4,
25
+ early_stopping=True
26
+ )
27
+
28
+ # 解码并返回摘要
29
+ clean_summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
30
+ return clean_summary
31
 
32
  if __name__ == "__main__":
33
  text = (
 
42
  "从97年坑害到23年大量的客户没有开局相应的发票,存在重大偷税漏税嫌疑,请湖南税务机关对其进行查处!"
43
  "还有其出口的设备,渠道是否正规,是白关,灰关,还是黑关,请湖南海关相关部门对其进行查处。"
44
  )
45
+ summary = generate_summary(text, max_length=128, min_length=64)
46
+ print(summary)