File size: 3,833 Bytes
bea6893
 
 
 
 
b9afdfb
bea6893
 
 
b9afdfb
bea6893
ca0f425
bea6893
 
ca0f425
bea6893
1cf06d2
 
 
bea6893
1cf06d2
bea6893
 
 
 
 
ca0f425
bea6893
 
 
1cf06d2
bea6893
 
 
 
1cf06d2
bea6893
b9afdfb
1cf06d2
b9afdfb
1cf06d2
 
 
 
 
 
 
 
 
 
0951988
1cf06d2
 
 
 
 
 
 
 
b9afdfb
 
bea6893
 
 
 
 
 
 
 
 
 
1cf06d2
bea6893
 
 
 
 
 
 
 
 
 
1cf06d2
 
bea6893
 
1cf06d2
 
bea6893
 
 
ca0f425
1cf06d2
0951988
 
 
1cf06d2
 
 
0951988
 
 
1cf06d2
 
bea6893
1cf06d2
 
52f8d54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
import os


@st.cache(allow_output_mutation=True)
def load_model_cache():
    auth_token = os.environ.get("TOKEN_FROM_SECRET") or True

    tokenizer_pl = T5Tokenizer.from_pretrained(
        "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
    )
    model_pl = T5ForConditionalGeneration.from_pretrained(
        "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
    )
    model_det_pl = T5ForConditionalGeneration.from_pretrained(
        "Voicelab/vlt5-base-rfc-detector-1.0", use_auth_token=auth_token
    )

    return tokenizer_pl, model_pl, model_det_pl


img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 5000
cache_size: int = 100

st.set_page_config(
    page_title="DEMO - Reason for Contact generation",
    page_icon=img_favicon,
    initial_sidebar_state="expanded",
)

tokenizer_pl, model_pl, model_det_pl = load_model_cache()


def get_predictions(text, mode):
    input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
    if mode == "Polish - RfC Generation":
        output = model_pl.generate(
            input_ids,
            no_repeat_ngram_size=1,
            num_beams=3,
            num_beam_groups=3,
            min_length=10,
            max_length=100,
        )
    elif mode == "Polish - RfC Detection":
        output = model_det_pl.generate(
            input_ids,
            no_repeat_ngram_size=2,
            num_beams=3,
            num_beam_groups=3,
            repetition_penalty=1.5,
            diversity_penalty=2.0,
            length_penalty=2.0,
        )
    predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
    return predicted_rfc


def trim_length():
    if len(st.session_state["input"]) > max_length:
        st.session_state["input"] = st.session_state["input"][:max_length]


if __name__ == "__main__":
    st.sidebar.image(img_short)
    st.image(img_full)
    st.title("VLT5 - Reason for Contact generator")

    generated_keywords = ""
    user_input = st.text_area(
        label=f"Input text (max {max_length} characters)",
        value="",
        height=300,
        on_change=trim_length,
        key="input",
    )

    mode = st.sidebar.title("Model settings")
    mode = st.sidebar.radio(
        "Select model to test",
        [
            "Polish - RfC Generation",
            "Polish - RfC Detection",
        ],
    )

    result = st.button("Find reason for contact")
    if mode == "Polish - RfC Generation (accepts whole conversation)":
        st.markdown("### You selected RfC Generation model.")
        st.markdown("-- *Input*: Whole conversation. Should specify roles (e.g. **AGENT: Hello, how can I help you? CLIENT: Hi, I would like to report a stolen card.**")
        st.markdown("-- *Output*: Reason for calling for the whole conversation.")
        text_area = "Put a whole conversation or full e-mail here."
        
    elif mode == "Polish - RfC Detection (accepts one turn)":
        st.markdown("### You selected RfC Detection model.")
        st.markdown("-- *Input*: A single turn from the conversation e.g. **'Hello, how can I help you?'** or **'Hi, I would like to report a stolen card.'**")
        st.markdown("-- *Output*: Model will return an empty string if a turn possibly does not includes Reason for Calling, or a sentence if the RfC is detected.")
        text_area = "Put a single turn or a few sentences here."
        
    if result:
        generated_rfc = get_predictions(text=user_input, mode=mode)
        st.text_area(text_area, generated_rfc)
        print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}")