import spaces import torch from transformers import PegasusForConditionalGeneration from tokenizers_pegasus import PegasusTokenizer class PegasusSummarizer: _instance = None def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._init_model() return cls._instance def _init_model(self): # 加载标记器和模型 model_name = "IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1" self.model = PegasusForConditionalGeneration.from_pretrained(model_name) self.tokenizer = PegasusTokenizer.from_pretrained(model_name) # 将模型移动到GPU self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def generate_summary(self, text, max_length=180, min_length=64): # 进行标记化并将输入数据移动到GPU inputs = self.tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(self.device) # 生成摘要 summary_ids = self.model.generate( inputs["input_ids"], max_length=max_length, min_length=min_length, num_beams=4, early_stopping=True, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=2.0, length_penalty=1.0, no_repeat_ngram_size=3, num_return_sequences=1, do_sample=True ) # 解码并返回摘要 clean_summary = self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] # 处理并过滤掉不需要的特殊标记 special_tokens = ['', '', ''] for token in special_tokens: clean_summary = clean_summary.replace(token, '') return clean_summary @spaces.GPU def generate_summary(text, max_length=180, min_length=64): if(len(text)< max_length): return text else: summarizer = PegasusSummarizer() return summarizer.generate_summary(text, max_length, min_length) if __name__ == "__main__": text = ( "东四路西侧之前有划分免费停车位,为什么后面被撤销,而道路东侧有划分免费停车车位,附近小区车位紧张导致很难找到停车位。" ) summary = generate_summary(text, max_length=128, min_length=64) print(summary)