Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
import nltk
|
@@ -47,43 +151,11 @@ def generate_title():
|
|
47 |
st.session_state.text = st_text_area
|
48 |
|
49 |
# tokenize text
|
50 |
-
inputs =
|
51 |
-
inputs = tokenizer(inputs, return_tensors="pt")
|
52 |
-
|
53 |
-
# compute span boundaries
|
54 |
-
num_tokens = len(inputs["input_ids"][0])
|
55 |
-
print(f"Input has {num_tokens} tokens")
|
56 |
-
max_input_length = 500
|
57 |
-
num_spans = math.ceil(num_tokens / max_input_length)
|
58 |
-
print(f"Input has {num_spans} spans")
|
59 |
-
overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
|
60 |
-
spans_boundaries = []
|
61 |
-
start = 0
|
62 |
-
for i in range(num_spans):
|
63 |
-
spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
|
64 |
-
start -= overlap
|
65 |
-
print(f"Span boundaries are {spans_boundaries}")
|
66 |
-
spans_boundaries_selected = []
|
67 |
-
j = 0
|
68 |
-
for _ in range(num_titles):
|
69 |
-
spans_boundaries_selected.append(spans_boundaries[j])
|
70 |
-
j += 1
|
71 |
-
if j == len(spans_boundaries):
|
72 |
-
j = 0
|
73 |
-
print(f"Selected span boundaries are {spans_boundaries_selected}")
|
74 |
-
|
75 |
-
# transform input with spans
|
76 |
-
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
|
77 |
-
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
|
78 |
-
|
79 |
-
inputs = {
|
80 |
-
"input_ids": torch.stack(tensor_ids),
|
81 |
-
"attention_mask": torch.stack(tensor_masks)
|
82 |
-
}
|
83 |
|
84 |
# compute predictions
|
85 |
-
outputs = model.generate(
|
86 |
-
decoded_outputs = tokenizer.
|
87 |
predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
|
88 |
|
89 |
st.session_state.titles = predicted_titles
|
@@ -99,4 +171,4 @@ if len(st.session_state.titles) > 0:
|
|
99 |
with st.container():
|
100 |
st.subheader("Generated titles")
|
101 |
for title in st.session_state.titles:
|
102 |
-
st.markdown("__" + title + "__")
|
|
|
1 |
+
# import streamlit as st
|
2 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
+
# import nltk
|
4 |
+
# import math
|
5 |
+
# import torch
|
6 |
+
|
7 |
+
# # model_name = "fabiochiu/t5-base-medium-title-generation"
|
8 |
+
# model_name = "momoyukki/t5-base-news-title-generation_model"
|
9 |
+
# max_input_length = 512
|
10 |
+
|
11 |
+
# st.header("Generate candidate titles for news")
|
12 |
+
|
13 |
+
# st_model_load = st.text('Loading title generator model...')
|
14 |
+
|
15 |
+
# @st.cache(allow_output_mutation=True)
|
16 |
+
# def load_model():
|
17 |
+
# print("Loading model...")
|
18 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
20 |
+
# nltk.download('punkt')
|
21 |
+
# print("Model loaded!")
|
22 |
+
# return tokenizer, model
|
23 |
+
|
24 |
+
# tokenizer, model = load_model()
|
25 |
+
# st.success('Model loaded!')
|
26 |
+
# st_model_load.text("")
|
27 |
+
|
28 |
+
# with st.sidebar:
|
29 |
+
# st.header("Model parameters")
|
30 |
+
# if 'num_titles' not in st.session_state:
|
31 |
+
# st.session_state.num_titles = 5
|
32 |
+
# def on_change_num_titles():
|
33 |
+
# st.session_state.num_titles = num_titles
|
34 |
+
# num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
|
35 |
+
# if 'temperature' not in st.session_state:
|
36 |
+
# st.session_state.temperature = 0.7
|
37 |
+
# def on_change_temperatures():
|
38 |
+
# st.session_state.temperature = temperature
|
39 |
+
# temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
|
40 |
+
# st.markdown("_High temperature means that results are more random_")
|
41 |
+
|
42 |
+
# if 'text' not in st.session_state:
|
43 |
+
# st.session_state.text = ""
|
44 |
+
# st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)
|
45 |
+
|
46 |
+
# def generate_title():
|
47 |
+
# st.session_state.text = st_text_area
|
48 |
+
|
49 |
+
# # tokenize text
|
50 |
+
# inputs = ["summarize: " + st_text_area]
|
51 |
+
# inputs = tokenizer(inputs, return_tensors="pt")
|
52 |
+
|
53 |
+
# # compute span boundaries
|
54 |
+
# num_tokens = len(inputs["input_ids"][0])
|
55 |
+
# print(f"Input has {num_tokens} tokens")
|
56 |
+
# max_input_length = 500
|
57 |
+
# num_spans = math.ceil(num_tokens / max_input_length)
|
58 |
+
# print(f"Input has {num_spans} spans")
|
59 |
+
# overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
|
60 |
+
# spans_boundaries = []
|
61 |
+
# start = 0
|
62 |
+
# for i in range(num_spans):
|
63 |
+
# spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
|
64 |
+
# start -= overlap
|
65 |
+
# print(f"Span boundaries are {spans_boundaries}")
|
66 |
+
# spans_boundaries_selected = []
|
67 |
+
# j = 0
|
68 |
+
# for _ in range(num_titles):
|
69 |
+
# spans_boundaries_selected.append(spans_boundaries[j])
|
70 |
+
# j += 1
|
71 |
+
# if j == len(spans_boundaries):
|
72 |
+
# j = 0
|
73 |
+
# print(f"Selected span boundaries are {spans_boundaries_selected}")
|
74 |
+
|
75 |
+
# # transform input with spans
|
76 |
+
# tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
|
77 |
+
# tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
|
78 |
+
|
79 |
+
# inputs = {
|
80 |
+
# "input_ids": torch.stack(tensor_ids),
|
81 |
+
# "attention_mask": torch.stack(tensor_masks)
|
82 |
+
# }
|
83 |
+
|
84 |
+
# # compute predictions
|
85 |
+
# outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
|
86 |
+
# decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
87 |
+
# predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
|
88 |
+
|
89 |
+
# st.session_state.titles = predicted_titles
|
90 |
+
|
91 |
+
# # generate title button
|
92 |
+
# st_generate_button = st.button('Generate title', on_click=generate_title)
|
93 |
+
|
94 |
+
# # title generation labels
|
95 |
+
# if 'titles' not in st.session_state:
|
96 |
+
# st.session_state.titles = []
|
97 |
+
|
98 |
+
# if len(st.session_state.titles) > 0:
|
99 |
+
# with st.container():
|
100 |
+
# st.subheader("Generated titles")
|
101 |
+
# for title in st.session_state.titles:
|
102 |
+
# st.markdown("__" + title + "__")
|
103 |
+
|
104 |
+
|
105 |
import streamlit as st
|
106 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
107 |
import nltk
|
|
|
151 |
st.session_state.text = st_text_area
|
152 |
|
153 |
# tokenize text
|
154 |
+
inputs = tokenizer.encode("summarize: " + st_text_area, return_tensors="pt", max_length=max_input_length, truncation=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
# compute predictions
|
157 |
+
outputs = model.generate(inputs, num_return_sequences=num_titles, temperature=temperature, max_length=150)
|
158 |
+
decoded_outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
159 |
predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
|
160 |
|
161 |
st.session_state.titles = predicted_titles
|
|
|
171 |
with st.container():
|
172 |
st.subheader("Generated titles")
|
173 |
for title in st.session_state.titles:
|
174 |
+
st.markdown("__" + title + "__")
|