File size: 3,631 Bytes
bea6893
 
 
 
 
b9afdfb
bea6893
 
 
b9afdfb
bea6893
ca0f425
bea6893
 
ca0f425
bea6893
1cf06d2
 
 
bea6893
1cf06d2
bea6893
 
 
 
 
ca0f425
bea6893
 
 
1cf06d2
bea6893
 
 
 
1cf06d2
bea6893
b9afdfb
1cf06d2
b9afdfb
1cf06d2
 
 
 
 
 
 
 
64def62
1cf06d2
 
0951988
1cf06d2
 
 
 
 
 
 
 
b9afdfb
 
bea6893
 
 
 
 
 
 
 
 
 
1cf06d2
c80bd44
76cab36
c80bd44
 
 
 
64def62
 
bea6893
 
 
 
 
 
 
 
1cf06d2
 
bea6893
 
1cf06d2
 
bea6893
 
 
ca0f425
64def62
bea6893
1cf06d2
c80bd44
52f8d54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
import os


@st.cache(allow_output_mutation=True)
def load_model_cache():
    auth_token = os.environ.get("TOKEN_FROM_SECRET") or True

    tokenizer_pl = T5Tokenizer.from_pretrained(
        "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
    )
    model_pl = T5ForConditionalGeneration.from_pretrained(
        "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
    )
    model_det_pl = T5ForConditionalGeneration.from_pretrained(
        "Voicelab/vlt5-base-rfc-detector-1.0", use_auth_token=auth_token
    )

    return tokenizer_pl, model_pl, model_det_pl


img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 5000
cache_size: int = 100

st.set_page_config(
    page_title="DEMO - Reason for Contact generation",
    page_icon=img_favicon,
    initial_sidebar_state="expanded",
)

tokenizer_pl, model_pl, model_det_pl = load_model_cache()


def get_predictions(text, mode):
    input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
    if mode == "Polish - RfC Generation":
        output = model_pl.generate(
            input_ids,
            no_repeat_ngram_size=1,
            num_beams=3,
            num_beam_groups=3,
            min_length=10,
            max_length=100,
            diversity_penalty=1.0,
        )
    elif mode == "Polish - RfC Detection":
        output = model_det_pl.generate(
            input_ids,
            no_repeat_ngram_size=2,
            num_beams=3,
            num_beam_groups=3,
            repetition_penalty=1.5,
            diversity_penalty=2.0,
            length_penalty=2.0,
        )
    predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
    return predicted_rfc


def trim_length():
    if len(st.session_state["input"]) > max_length:
        st.session_state["input"] = st.session_state["input"][:max_length]


if __name__ == "__main__":
    st.sidebar.image(img_short)
    st.image(img_full)
    st.title("VLT5 - Reason for Contact generator")
    st.markdown("#### RfC Generation model.")
    st.markdown("**Input**: Whole conversation. Should specify roles e.g. *AGENT: Hello, how can I help you? CLIENT: Hi, I would like to report a stolen card.* Put a whole conversation or full e-mail here.")
    st.markdown("**Output**: Reason for calling for the whole conversation.")
    st.markdown("#### RfC Detection model.")
    st.markdown("**Input**: A single turn from the conversation e.g. *'Hello, how can I help you?'* or *'Hi, I would like to report a stolen card.'. Put a single turn or a few sentences here.*")
    st.markdown("**Output**: Model will return an empty string if a turn possibly does not includes Reason for Calling, or a sentence if the RfC is detected.")
    
    generated_rfc = ""
    user_input = st.text_area(
        label=f"Input text (max {max_length} characters)",
        value="",
        height=300,
        on_change=trim_length,
        key="input",
    )

    mode = st.sidebar.title("Model settings")
    mode = st.sidebar.radio(
        "Select model to test",
        [
            "Polish - RfC Generation",
            "Polish - RfC Detection",
        ],
    )

    result = st.button("Find reason for contact")

    if result:
        generated_rfc = get_predictions(text=user_input, mode=mode)
        st.text_area("Find reason for contact", generated_rfc)
        print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}")