GiordanoB's picture
Update app.py
59c709a
raw
history blame
2.65 kB
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import spacy
import pytextrank
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.luhn import LuhnSummarizer
from sumy.summarizers.lex_rank import LexRankSummarizer
import nltk
nlp = spacy.load('pt_core_news_sm')
nltk.download('punkt')
#WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
model_name="GiordanoB/mT5_multilingual_XLSum-sumarizacao-PTBR"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def summarize_HUB(input, method, max_length, min_length, no_repeat_ngram_size, num_beams):
if method == "Pure mT5":
return sumarize_mT5(input, max_length, min_length, no_repeat_ngram_size, num_beams)
if method == "Luhn":
return sumarize_Luhn(input)
if method == "LexRank":
return sumarize_LexRank(input)
return "tchau"
def sumarize_Luhn(input):
summ = ''
summarizer = LuhnSummarizer()
parser = PlaintextParser.from_string(input, Tokenizer("portuguese"))
summary_1 = summarizer(parser.document, 3)
for sentence in summary_1:
summ = summ + ' ' + str(sentence)
summ2 = ''
summ2 = summ.replace('\n', ' ').replace('\r', '')
return summ2
def sumarize_LexRank(input):
summ = ''
summarizer = LexRankSummarizer()
parser = PlaintextParser.from_string(input, Tokenizer("portuguese"))
summary_1 = summarizer(parser.document, 3)
for sentence in summary_1:
summ = summ + ' ' + str(sentence)
summ2 = ''
summ2 = summ.replace('\n', ' ').replace('\r', '')
return summ2
def sumarize_mT5(input, max_length, min_length, no_repeat_ngram_size, num_beams):
for i in range(0,14):
input_ids = tokenizer(
input,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)["input_ids"]
output_ids = model.generate(
input_ids=input_ids,
max_length=max_length,
min_length=min_length,
no_repeat_ngram_size=no_repeat_ngram_size,
num_beams=num_beams
)[0]
response = tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return response
gr.Interface(fn = summarize_HUB,
inputs=["textbox",gr.Radio(["Pure mT5","Luhn","LexRank"]), gr.Slider(50, 200, step=1, value=200),gr.Slider(25, 100, step=1, value=75), gr.Slider(1, 10, step=1, value=1), gr.Slider(1, 10, step=1, value=1)], outputs=["text"]).launch()