change model to randeng
Browse files
ebart.py
CHANGED
@@ -1,30 +1,33 @@
|
|
1 |
import spaces
|
2 |
import torch
|
3 |
-
from transformers import PegasusForConditionalGeneration
|
4 |
-
#
|
5 |
-
from tokenizers_pegasus import PegasusTokenizer
|
6 |
|
7 |
@spaces.GPU
|
8 |
def generate_summary(text, max_length=180, min_length=64):
|
9 |
# 加载标记器和模型
|
10 |
-
model = PegasusForConditionalGeneration.from_pretrained("IDEA-CCNL/Randeng-Pegasus-
|
11 |
-
tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-
|
12 |
-
|
13 |
|
14 |
# 将模型移动到GPU
|
15 |
-
|
16 |
-
|
17 |
|
18 |
-
#
|
19 |
-
inputs = tokenizer(text, max_length=1024, return_tensors="pt")
|
20 |
|
21 |
# 生成摘要
|
22 |
summary_ids = model.generate(
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
if __name__ == "__main__":
|
30 |
text = (
|
@@ -39,4 +42,5 @@ if __name__ == "__main__":
|
|
39 |
"从97年坑害到23年大量的客户没有开局相应的发票,存在重大偷税漏税嫌疑,请湖南税务机关对其进行查处!"
|
40 |
"还有其出口的设备,渠道是否正规,是白关,灰关,还是黑关,请湖南海关相关部门对其进行查处。"
|
41 |
)
|
42 |
-
generate_summary(text, max_length=128, min_length=64)
|
|
|
|
1 |
import spaces
|
2 |
import torch
|
3 |
+
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
4 |
+
#from tokenizers_pegasus import PegasusTokenizer
|
|
|
5 |
|
6 |
@spaces.GPU
|
7 |
def generate_summary(text, max_length=180, min_length=64):
|
8 |
# 加载标记器和模型
|
9 |
+
model = PegasusForConditionalGeneration.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1")
|
10 |
+
tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1")
|
|
|
11 |
|
12 |
# 将模型移动到GPU
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
model.to(device)
|
15 |
|
16 |
+
# 进行标记化并将输入数据移动到GPU
|
17 |
+
inputs = tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device)
|
18 |
|
19 |
# 生成摘要
|
20 |
summary_ids = model.generate(
|
21 |
+
inputs["input_ids"],
|
22 |
+
max_length=max_length,
|
23 |
+
min_length=min_length,
|
24 |
+
num_beams=4,
|
25 |
+
early_stopping=True
|
26 |
+
)
|
27 |
+
|
28 |
+
# 解码并返回摘要
|
29 |
+
clean_summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
30 |
+
return clean_summary
|
31 |
|
32 |
if __name__ == "__main__":
|
33 |
text = (
|
|
|
42 |
"从97年坑害到23年大量的客户没有开局相应的发票,存在重大偷税漏税嫌疑,请湖南税务机关对其进行查处!"
|
43 |
"还有其出口的设备,渠道是否正规,是白关,灰关,还是黑关,请湖南海关相关部门对其进行查处。"
|
44 |
)
|
45 |
+
summary = generate_summary(text, max_length=128, min_length=64)
|
46 |
+
print(summary)
|