File size: 2,485 Bytes
599a980 495e47a 599a980 61b70f9 599a980 495e47a 599a980 fb15db5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import spaces
import torch
from transformers import PegasusForConditionalGeneration
from tokenizers_pegasus import PegasusTokenizer
class PegasusSummarizer:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._init_model()
return cls._instance
def _init_model(self):
# 加载标记器和模型
model_name = "IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1"
self.model = PegasusForConditionalGeneration.from_pretrained(model_name)
self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
# 将模型移动到GPU
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def generate_summary(self, text, max_length=180, min_length=64):
# 进行标记化并将输入数据移动到GPU
inputs = self.tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(self.device)
# 生成摘要
summary_ids = self.model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
num_beams=4,
early_stopping=True,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=2.0,
length_penalty=1.0,
no_repeat_ngram_size=3,
num_return_sequences=1,
do_sample=True
)
# 解码并返回摘要
clean_summary = self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
# 处理并过滤掉不需要的特殊标记
special_tokens = ['<pad>', '<unk>', '</s>']
for token in special_tokens:
clean_summary = clean_summary.replace(token, '')
return clean_summary
@spaces.GPU
def generate_summary(text, max_length=180, min_length=64):
if(len(text)< max_length):
return text
else:
summarizer = PegasusSummarizer()
return summarizer.generate_summary(text, max_length, min_length)
if __name__ == "__main__":
text = (
"东四路西侧之前有划分免费停车位,为什么后面被撤销,而道路东侧有划分免费停车车位,附近小区车位紧张导致很难找到停车位。"
)
summary = generate_summary(text, max_length=128, min_length=64)
print(summary)
|