File size: 2,369 Bytes
5e92cc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import streamlit as st
from generate import generate


if __name__ == "__main__":
    # Batch modification
    with st.form(key='my_form'):
        print("Loading model...")
        

        print("Model is ready to serve...")
        desc = "Vietnamese News Generative Model - finetuned GPT2"

        st.title('Vietnamese News Generative Model!')
        st.write(desc)
        st.write("##")
        option = st.selectbox(
            'Choose a category',
            ('thời sự ', 'thế giới', 'tài chính kinh doanh', 
            'đời sống', 'văn hoá', 'giải trí', 'giới trẻ', 'giáo dục',
            'công nghệ', 'sức khoẻ'))

        st.write("##")
        category = str(option)
        headline = st.text_input('Headline (or part of the headline)')
        num_return_sequences = st.slider('Number of return sequences', min_value = 1, max_value = 5, value = 2)
        max_len = st.slider('Max Length', min_value = 80, max_value = 500, value = 300)
        with st.expander("Setting parameters"):
            min_len = st.slider('Min Length', min_value = 0, max_value = 50, value = 50)
            top_k = st.slider('Top k', min_value = 30, max_value = 200, value = 50)
            top_p = st.slider('Top p', min_value = 0.0, max_value = 1.0, value = 0.8)
            num_beams = st.slider('Num Beams', min_value = 1, max_value = 6, value = 2)
            

        submit_button = st.form_submit_button(label='Generate', )

    if submit_button:
        print("Generating output")
        with st.spinner('Wait for it...'):
            outputs = generate(category = str(category), headline = str(headline), min_len = min_len, max_len = max_len, num_return_sequences = num_return_sequences)

        for i, output in enumerate(outputs):
            # Cut start of text
            temp = output.split("<|startoftext|>")[1].strip()

            temp = temp.split("<|headline|>")
            category = temp[0]

            temp = temp[1].split("<|content|>")
            headline = temp[0].strip()
            content = temp[1].strip()

            st.header(f"Output: {i}")
            st.subheader("Category")
            st.write(category)
            st.subheader(f"Headline")
            st.write(headline)
            st.subheader(f"Content")
            st.write(content)
            st.write("##")

        st.balloons()