dusense / ebart.py
saily's picture
add reqments & max length
61b70f9
import spaces
import torch
from transformers import PegasusForConditionalGeneration
from tokenizers_pegasus import PegasusTokenizer
class PegasusSummarizer:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._init_model()
return cls._instance
def _init_model(self):
# 加载标记器和模型
model_name = "IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1"
self.model = PegasusForConditionalGeneration.from_pretrained(model_name)
self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
# 将模型移动到GPU
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def generate_summary(self, text, max_length=180, min_length=64):
# 进行标记化并将输入数据移动到GPU
inputs = self.tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(self.device)
# 生成摘要
summary_ids = self.model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
num_beams=4,
early_stopping=True,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=2.0,
length_penalty=1.0,
no_repeat_ngram_size=3,
num_return_sequences=1,
do_sample=True
)
# 解码并返回摘要
clean_summary = self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
# 处理并过滤掉不需要的特殊标记
special_tokens = ['<pad>', '<unk>', '</s>']
for token in special_tokens:
clean_summary = clean_summary.replace(token, '')
return clean_summary
@spaces.GPU
def generate_summary(text, max_length=180, min_length=64):
if(len(text)< max_length):
return text
else:
summarizer = PegasusSummarizer()
return summarizer.generate_summary(text, max_length, min_length)
if __name__ == "__main__":
text = (
"东四路西侧之前有划分免费停车位,为什么后面被撤销,而道路东侧有划分免费停车车位,附近小区车位紧张导致很难找到停车位。"
)
summary = generate_summary(text, max_length=128, min_length=64)
print(summary)