momoyukki commited on
Commit
2d93cf1
·
verified ·
1 Parent(s): 51daf4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -107
app.py CHANGED
@@ -1,107 +1,3 @@
1
- # import streamlit as st
2
- # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- # import nltk
4
- # import math
5
- # import torch
6
-
7
- # # model_name = "fabiochiu/t5-base-medium-title-generation"
8
- # model_name = "momoyukki/t5-base-news-title-generation_model"
9
- # max_input_length = 512
10
-
11
- # st.header("Generate candidate titles for news")
12
-
13
- # st_model_load = st.text('Loading title generator model...')
14
-
15
- # @st.cache(allow_output_mutation=True)
16
- # def load_model():
17
- # print("Loading model...")
18
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- # model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
20
- # nltk.download('punkt')
21
- # print("Model loaded!")
22
- # return tokenizer, model
23
-
24
- # tokenizer, model = load_model()
25
- # st.success('Model loaded!')
26
- # st_model_load.text("")
27
-
28
- # with st.sidebar:
29
- # st.header("Model parameters")
30
- # if 'num_titles' not in st.session_state:
31
- # st.session_state.num_titles = 5
32
- # def on_change_num_titles():
33
- # st.session_state.num_titles = num_titles
34
- # num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
35
- # if 'temperature' not in st.session_state:
36
- # st.session_state.temperature = 0.7
37
- # def on_change_temperatures():
38
- # st.session_state.temperature = temperature
39
- # temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
40
- # st.markdown("_High temperature means that results are more random_")
41
-
42
- # if 'text' not in st.session_state:
43
- # st.session_state.text = ""
44
- # st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)
45
-
46
- # def generate_title():
47
- # st.session_state.text = st_text_area
48
-
49
- # # tokenize text
50
- # inputs = ["summarize: " + st_text_area]
51
- # inputs = tokenizer(inputs, return_tensors="pt")
52
-
53
- # # compute span boundaries
54
- # num_tokens = len(inputs["input_ids"][0])
55
- # print(f"Input has {num_tokens} tokens")
56
- # max_input_length = 500
57
- # num_spans = math.ceil(num_tokens / max_input_length)
58
- # print(f"Input has {num_spans} spans")
59
- # overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
60
- # spans_boundaries = []
61
- # start = 0
62
- # for i in range(num_spans):
63
- # spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
64
- # start -= overlap
65
- # print(f"Span boundaries are {spans_boundaries}")
66
- # spans_boundaries_selected = []
67
- # j = 0
68
- # for _ in range(num_titles):
69
- # spans_boundaries_selected.append(spans_boundaries[j])
70
- # j += 1
71
- # if j == len(spans_boundaries):
72
- # j = 0
73
- # print(f"Selected span boundaries are {spans_boundaries_selected}")
74
-
75
- # # transform input with spans
76
- # tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
77
- # tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
78
-
79
- # inputs = {
80
- # "input_ids": torch.stack(tensor_ids),
81
- # "attention_mask": torch.stack(tensor_masks)
82
- # }
83
-
84
- # # compute predictions
85
- # outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
86
- # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
87
- # predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
88
-
89
- # st.session_state.titles = predicted_titles
90
-
91
- # # generate title button
92
- # st_generate_button = st.button('Generate title', on_click=generate_title)
93
-
94
- # # title generation labels
95
- # if 'titles' not in st.session_state:
96
- # st.session_state.titles = []
97
-
98
- # if len(st.session_state.titles) > 0:
99
- # with st.container():
100
- # st.subheader("Generated titles")
101
- # for title in st.session_state.titles:
102
- # st.markdown("__" + title + "__")
103
-
104
-
105
  import streamlit as st
106
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
107
  import nltk
@@ -151,11 +47,43 @@ def generate_title():
151
  st.session_state.text = st_text_area
152
 
153
  # tokenize text
154
- inputs = tokenizer.encode("summarize: " + st_text_area, return_tensors="pt", max_length=max_input_length, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # compute predictions
157
- outputs = model.generate(inputs, num_return_sequences=num_titles, temperature=temperature, max_length=150)
158
- decoded_outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
159
  predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
160
 
161
  st.session_state.titles = predicted_titles
@@ -172,3 +100,4 @@ if len(st.session_state.titles) > 0:
172
  st.subheader("Generated titles")
173
  for title in st.session_state.titles:
174
  st.markdown("__" + title + "__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import nltk
 
47
  st.session_state.text = st_text_area
48
 
49
  # tokenize text
50
+ inputs = ["summarize: " + st_text_area]
51
+ inputs = tokenizer(inputs, return_tensors="pt")
52
+
53
+ # compute span boundaries
54
+ num_tokens = len(inputs["input_ids"][0])
55
+ print(f"Input has {num_tokens} tokens")
56
+ max_input_length = 500
57
+ num_spans = math.ceil(num_tokens / max_input_length)
58
+ print(f"Input has {num_spans} spans")
59
+ overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
60
+ spans_boundaries = []
61
+ start = 0
62
+ for i in range(num_spans):
63
+ spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
64
+ start -= overlap
65
+ print(f"Span boundaries are {spans_boundaries}")
66
+ spans_boundaries_selected = []
67
+ j = 0
68
+ for _ in range(num_titles):
69
+ spans_boundaries_selected.append(spans_boundaries[j])
70
+ j += 1
71
+ if j == len(spans_boundaries):
72
+ j = 0
73
+ print(f"Selected span boundaries are {spans_boundaries_selected}")
74
+
75
+ # transform input with spans
76
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
77
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
78
+
79
+ inputs = {
80
+ "input_ids": torch.stack(tensor_ids),
81
+ "attention_mask": torch.stack(tensor_masks)
82
+ }
83
 
84
  # compute predictions
85
+ outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
86
+ decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
87
  predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
88
 
89
  st.session_state.titles = predicted_titles
 
100
  st.subheader("Generated titles")
101
  for title in st.session_state.titles:
102
  st.markdown("__" + title + "__")
103
+