File size: 6,628 Bytes
e4ebacd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83023b5
 
 
 
 
 
3360601
 
83023b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4ebacd
83023b5
 
e4ebacd
 
83023b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4ebacd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# import streamlit as st
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# import nltk
# import math
# import torch

# # model_name = "fabiochiu/t5-base-medium-title-generation"
# model_name = "momoyukki/t5-base-news-title-generation_model"
# max_input_length = 512

# st.header("Generate candidate titles for news")

# st_model_load = st.text('Loading title generator model...')

# @st.cache(allow_output_mutation=True)
# def load_model():
#     print("Loading model...")
#     tokenizer = AutoTokenizer.from_pretrained(model_name)
#     model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
#     nltk.download('punkt')
#     print("Model loaded!")
#     return tokenizer, model

# tokenizer, model = load_model()
# st.success('Model loaded!')
# st_model_load.text("")

# with st.sidebar:
#     st.header("Model parameters")
#     if 'num_titles' not in st.session_state:
#         st.session_state.num_titles = 5
#     def on_change_num_titles():
#         st.session_state.num_titles = num_titles
#     num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
#     if 'temperature' not in st.session_state:
#         st.session_state.temperature = 0.7
#     def on_change_temperatures():
#         st.session_state.temperature = temperature
#     temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
#     st.markdown("_High temperature means that results are more random_")

# if 'text' not in st.session_state:
#     st.session_state.text = ""
# st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)

# def generate_title():
#     st.session_state.text = st_text_area

#     # tokenize text
#     inputs = ["summarize: " + st_text_area]
#     inputs = tokenizer(inputs, return_tensors="pt")

#     # compute span boundaries
#     num_tokens = len(inputs["input_ids"][0])
#     print(f"Input has {num_tokens} tokens")
#     max_input_length = 500
#     num_spans = math.ceil(num_tokens / max_input_length)
#     print(f"Input has {num_spans} spans")
#     overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
#     spans_boundaries = []
#     start = 0
#     for i in range(num_spans):
#         spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
#         start -= overlap
#     print(f"Span boundaries are {spans_boundaries}")
#     spans_boundaries_selected = []
#     j = 0
#     for _ in range(num_titles):
#         spans_boundaries_selected.append(spans_boundaries[j])
#         j += 1
#         if j == len(spans_boundaries):
#             j = 0
#     print(f"Selected span boundaries are {spans_boundaries_selected}")

#     # transform input with spans
#     tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
#     tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]

#     inputs = {
#         "input_ids": torch.stack(tensor_ids),
#         "attention_mask": torch.stack(tensor_masks)
#     }

#     # compute predictions
#     outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
#     decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
#     predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]

#     st.session_state.titles = predicted_titles

# # generate title button
# st_generate_button = st.button('Generate title', on_click=generate_title)

# # title generation labels
# if 'titles' not in st.session_state:
#     st.session_state.titles = []

# if len(st.session_state.titles) > 0:
#     with st.container():
#         st.subheader("Generated titles")
#         for title in st.session_state.titles:
#             st.markdown("__" + title + "__")


import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
import math
import torch

# model_name = "fabiochiu/t5-base-medium-title-generation"
model_name = "momoyukki/t5-base-news-title-generation_model"
max_input_length = 512

st.header("Generate candidate titles for news")

st_model_load = st.text('Loading title generator model...')

@st.cache(allow_output_mutation=True)
def load_model():
    print("Loading model...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    nltk.download('punkt')
    print("Model loaded!")
    return tokenizer, model

tokenizer, model = load_model()
st.success('Model loaded!')
st_model_load.text("")

with st.sidebar:
    st.header("Model parameters")
    if 'num_titles' not in st.session_state:
        st.session_state.num_titles = 5
    def on_change_num_titles():
        st.session_state.num_titles = num_titles
    num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
    if 'temperature' not in st.session_state:
        st.session_state.temperature = 0.7
    def on_change_temperatures():
        st.session_state.temperature = temperature
    temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
    st.markdown("_High temperature means that results are more random_")

if 'text' not in st.session_state:
    st.session_state.text = ""
st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)

def generate_title():
    st.session_state.text = st_text_area

    # tokenize text
    inputs = tokenizer.encode("summarize: " + st_text_area, return_tensors="pt", max_length=max_input_length, truncation=True)

    # compute predictions
    outputs = model.generate(inputs, num_return_sequences=num_titles, temperature=temperature, max_length=150)
    decoded_outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]

    st.session_state.titles = predicted_titles

# generate title button
st_generate_button = st.button('Generate title', on_click=generate_title)

# title generation labels
if 'titles' not in st.session_state:
    st.session_state.titles = []

if len(st.session_state.titles) > 0:
    with st.container():
        st.subheader("Generated titles")
        for title in st.session_state.titles:
            st.markdown("__" + title + "__")