File size: 4,400 Bytes
51fe9d2
 
a134869
 
 
51fe9d2
 
a134869
 
 
 
51fe9d2
 
a134869
51fe9d2
 
 
 
a134869
 
 
 
 
 
 
 
 
51fe9d2
a134869
 
 
 
 
 
 
 
 
 
31eab61
a134869
 
 
 
51fe9d2
a134869
 
 
 
 
 
4ba1608
 
51fe9d2
a134869
 
 
 
 
 
51fe9d2
a134869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51fe9d2
 
a134869
 
 
 
 
31eab61
a134869
 
31eab61
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
from transformers import pipeline
from qa.qa import file_to_doc
from transformers import AutoTokenizer
from typing import Text, Union

@st.cache_resource
def summarization_model(
    model_name:str="facebook/bart-large-cnn", 
    custom_tokenizer:Union[AutoTokenizer, bool]=False
    ):
    summarizer = pipeline(
        model=model_name, 
        tokenizer=model_name if custom_tokenizer==False else custom_tokenizer, 
        task="summarization"
    )
    return summarizer

@st.cache_data
def split_string_into_token_chunks(s:Text, _tokenizer:AutoTokenizer, chunk_size:int):
    # Tokenize the entire string
    token_ids = _tokenizer.encode(s)
    # Split the token ids into chunks of the desired size
    chunks = [token_ids[i:i+chunk_size] for i in range(0, len(token_ids), chunk_size)]
    # Decode each chunk back into a string
    return [_tokenizer.decode(chunk) for chunk in chunks]

def summarization_main():
    st.markdown("<h2 style='text-align: center'>Text Summarization</h2>", unsafe_allow_html=True)
    st.markdown("<h3 style='text-align: left'><b>What is text summarization about?<b></h3>", unsafe_allow_html=True)
    
    st.write("""
        Text summarization is common NLP task concerned with producing a shorter version of a given text while preserving the important information
        contained in such text
        """)
    
    OPTION_1 = "I want to input some text"
    OPTION_2 = "I want to upload a file"
    option = st.radio("How would you like to start? Choose an option below", [OPTION_1, OPTION_2]) 
    
    # greenlight to summarize
    text_is_given = False
    if option == OPTION_1:
        sample_text = ""
        text = st.text_area(
            "Input a text in English (10,000 characters max)", 
            value=sample_text,
            max_chars=10_000, 
            height=330)
        # toggle text is given greenlight
        if text != sample_text:
            text_is_given = not text_is_given
        
    elif option == OPTION_2:
        uploaded_file = st.file_uploader(
                "Upload a pdf, docx, or txt file (scanned documents not supported)",
                type=["pdf", "docx", "txt"],
                help="Scanned documents are not supported yet 🥲"
            )
        if uploaded_file is not None:
            # parse the file using custom parsers and build a concatenation for the summarizer
            text = " ".join(file_to_doc(uploaded_file))
            # toggle text is given greenlight
            text_is_given = not text_is_given
    
    if text_is_given:
        # minimal number of words in the summary
        min_length, max_length = 30, 200
        user_max_length = max_length
        # user_max_lenght = st.slider(
        #     label="Maximal number of tokens in the summary", 
        #     min_value=min_length, 
        #     max_value=max_length, 
        #     value=150, 
        #     step=10, 
        # )

        summarizer_downloaded = False
        # loading the tokenizer to split the input document into feasible chunks
        model_name = "facebook/bart-large-cnn"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # the maximum number of tokens the model can handle depends on the model - accounting for tokens added by tokenizer
        chunk_size = int(0.88*tokenizer.model_max_length)

        # loading the summarization model considered
        with st.spinner(text="Loading summarization model..."):
            summarizer = summarization_model(model_name=model_name)
            summarizer_downloaded = True
        
        if summarizer_downloaded:
            button = st.button("Summarize!")
            if button:
                with st.spinner(text="Summarizing text..."):
                    # summarizing each chunk of the input text to avoid exceeding the maximum number of tokens
                    summary = ""
                    chunks = split_string_into_token_chunks(text, tokenizer, chunk_size)
                    for chunk in chunks:
                        chunk_summary = summarizer(chunk, max_length=user_max_length, min_length=min_length)
                        summary += "\n" + chunk_summary[0]["summary_text"]

                    st.markdown("<h3 style='text-align: left'><b>Summary<b></h3>", unsafe_allow_html=True)
                    print(summary)
                    st.write(summary)