|
import spaces |
|
import torch |
|
from transformers import PegasusForConditionalGeneration |
|
from tokenizers_pegasus import PegasusTokenizer |
|
|
|
class PegasusSummarizer: |
|
_instance = None |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if cls._instance is None: |
|
cls._instance = super().__new__(cls) |
|
cls._instance._init_model() |
|
return cls._instance |
|
|
|
def _init_model(self): |
|
|
|
model_name = "IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1" |
|
self.model = PegasusForConditionalGeneration.from_pretrained(model_name) |
|
self.tokenizer = PegasusTokenizer.from_pretrained(model_name) |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
def generate_summary(self, text, max_length=180, min_length=64): |
|
|
|
inputs = self.tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(self.device) |
|
|
|
|
|
summary_ids = self.model.generate( |
|
inputs["input_ids"], |
|
max_length=max_length, |
|
min_length=min_length, |
|
num_beams=4, |
|
early_stopping=True, |
|
temperature=0.7, |
|
top_k=50, |
|
top_p=0.9, |
|
repetition_penalty=2.0, |
|
length_penalty=1.0, |
|
no_repeat_ngram_size=3, |
|
num_return_sequences=1, |
|
do_sample=True |
|
) |
|
|
|
|
|
clean_summary = self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] |
|
|
|
|
|
special_tokens = ['<pad>', '<unk>', '</s>'] |
|
for token in special_tokens: |
|
clean_summary = clean_summary.replace(token, '') |
|
|
|
return clean_summary |
|
|
|
@spaces.GPU |
|
def generate_summary(text, max_length=180, min_length=64): |
|
if(len(text)< max_length): |
|
return text |
|
else: |
|
summarizer = PegasusSummarizer() |
|
return summarizer.generate_summary(text, max_length, min_length) |
|
|
|
if __name__ == "__main__": |
|
text = ( |
|
"东四路西侧之前有划分免费停车位,为什么后面被撤销,而道路东侧有划分免费停车车位,附近小区车位紧张导致很难找到停车位。" |
|
) |
|
summary = generate_summary(text, max_length=128, min_length=64) |
|
print(summary) |
|
|