Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
import spacy | |
import pytextrank | |
from sumy.parsers.plaintext import PlaintextParser | |
from sumy.nlp.tokenizers import Tokenizer | |
from sumy.summarizers.luhn import LuhnSummarizer | |
from sumy.summarizers.lex_rank import LexRankSummarizer | |
import nltk | |
nlp = spacy.load('pt_core_news_sm') | |
nltk.download('punkt') | |
#WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip())) | |
model_name="GiordanoB/mT5_multilingual_XLSum-sumarizacao-PTBR" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
def summarize_HUB(input, method, max_length, min_length, no_repeat_ngram_size, num_beams): | |
if method == "Pure mT5": | |
return sumarize_mT5(input, max_length, min_length, no_repeat_ngram_size, num_beams) | |
if method == "Luhn": | |
return sumarize_Luhn(input) | |
if method == "LexRank": | |
return sumarize_LexRank(input) | |
return "tchau" | |
def sumarize_Luhn(input): | |
summ = '' | |
summarizer = LuhnSummarizer() | |
parser = PlaintextParser.from_string(input, Tokenizer("portuguese")) | |
summary_1 = summarizer(parser.document, 3) | |
for sentence in summary_1: | |
summ = summ + ' ' + str(sentence) | |
summ2 = '' | |
summ2 = summ.replace('\n', ' ').replace('\r', '') | |
return summ2 | |
def sumarize_LexRank(input): | |
summ = '' | |
summarizer = LexRankSummarizer() | |
parser = PlaintextParser.from_string(input, Tokenizer("portuguese")) | |
summary_1 = summarizer(parser.document, 3) | |
for sentence in summary_1: | |
summ = summ + ' ' + str(sentence) | |
summ2 = '' | |
summ2 = summ.replace('\n', ' ').replace('\r', '') | |
return summ2 | |
def sumarize_mT5(input, max_length, min_length, no_repeat_ngram_size, num_beams): | |
for i in range(0,14): | |
input_ids = tokenizer( | |
input, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=512 | |
)["input_ids"] | |
output_ids = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
min_length=min_length, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
num_beams=num_beams | |
)[0] | |
response = tokenizer.decode( | |
output_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
) | |
return response | |
gr.Interface(fn = summarize_HUB, | |
inputs=["textbox",gr.Radio(["Pure mT5","Luhn","LexRank"]), gr.Slider(50, 200, step=1, value=200),gr.Slider(25, 100, step=1, value=75), gr.Slider(1, 10, step=1, value=1), gr.Slider(1, 10, step=1, value=1)], outputs=["text"]).launch() | |