|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
@st.cache_resource() |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("google/pegasus-xsum") |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
|
|
st.title("Text Summarization with Pegasus-XSum") |
|
st.write("Enter text below and get a summarized version using the Pegasus model.") |
|
|
|
|
|
text_input = st.text_area("Enter text to summarize:", "") |
|
|
|
if st.button("Summarize"): |
|
if text_input: |
|
|
|
inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
summary_ids = model.generate(**inputs, max_length=60, min_length=10, length_penalty=2.0, num_beams=4) |
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
st.subheader("Summary:") |
|
st.write(summary) |
|
else: |
|
st.warning("Please enter some text to summarize.") |
|
|