momoyukki commited on
Commit
83023b5
1 Parent(s): f977cdd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import nltk
4
+ import math
5
+ import torch
6
+
7
+ model_name = "momoyukki/t5-base-news-title-generation_model"
8
+ max_input_length = 512
9
+
10
+ st.header("Generate candidate titles for news")
11
+
12
+ st_model_load = st.text('Loading title generator model...')
13
+
14
+ @st.cache(allow_output_mutation=True)
15
+ def load_model():
16
+ print("Loading model...")
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
19
+ nltk.download('punkt')
20
+ print("Model loaded!")
21
+ return tokenizer, model
22
+
23
+ tokenizer, model = load_model()
24
+ st.success('Model loaded!')
25
+ st_model_load.text("")
26
+
27
+ with st.sidebar:
28
+ st.header("Model parameters")
29
+ if 'num_titles' not in st.session_state:
30
+ st.session_state.num_titles = 5
31
+ def on_change_num_titles():
32
+ st.session_state.num_titles = num_titles
33
+ num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
34
+ if 'temperature' not in st.session_state:
35
+ st.session_state.temperature = 0.7
36
+ def on_change_temperatures():
37
+ st.session_state.temperature = temperature
38
+ temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
39
+ st.markdown("_High temperature means that results are more random_")
40
+
41
+ if 'text' not in st.session_state:
42
+ st.session_state.text = ""
43
+ st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)
44
+
45
+ def generate_title():
46
+ st.session_state.text = st_text_area
47
+
48
+ # tokenize text
49
+ inputs = ["summarize: " + st_text_area]
50
+ inputs = tokenizer(inputs, return_tensors="pt")
51
+
52
+ # compute span boundaries
53
+ num_tokens = len(inputs["input_ids"][0])
54
+ print(f"Input has {num_tokens} tokens")
55
+ max_input_length = 500
56
+ num_spans = math.ceil(num_tokens / max_input_length)
57
+ print(f"Input has {num_spans} spans")
58
+ overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
59
+ spans_boundaries = []
60
+ start = 0
61
+ for i in range(num_spans):
62
+ spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
63
+ start -= overlap
64
+ print(f"Span boundaries are {spans_boundaries}")
65
+ spans_boundaries_selected = []
66
+ j = 0
67
+ for _ in range(num_titles):
68
+ spans_boundaries_selected.append(spans_boundaries[j])
69
+ j += 1
70
+ if j == len(spans_boundaries):
71
+ j = 0
72
+ print(f"Selected span boundaries are {spans_boundaries_selected}")
73
+
74
+ # transform input with spans
75
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
76
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
77
+
78
+ inputs = {
79
+ "input_ids": torch.stack(tensor_ids),
80
+ "attention_mask": torch.stack(tensor_masks)
81
+ }
82
+
83
+ # compute predictions
84
+ outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
85
+ decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
86
+ predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
87
+
88
+ st.session_state.titles = predicted_titles
89
+
90
+ # generate title button
91
+ st_generate_button = st.button('Generate title', on_click=generate_title)
92
+
93
+ # title generation labels
94
+ if 'titles' not in st.session_state:
95
+ st.session_state.titles = []
96
+
97
+ if len(st.session_state.titles) > 0:
98
+ with st.container():
99
+ st.subheader("Generated titles")
100
+ for title in st.session_state.titles:
101
+ st.markdown("__" + title + "__")