Spaces:
Sleeping
Sleeping
File size: 3,833 Bytes
bea6893 b9afdfb bea6893 b9afdfb bea6893 ca0f425 bea6893 ca0f425 bea6893 1cf06d2 bea6893 1cf06d2 bea6893 ca0f425 bea6893 1cf06d2 bea6893 1cf06d2 bea6893 b9afdfb 1cf06d2 b9afdfb 1cf06d2 0951988 1cf06d2 b9afdfb bea6893 1cf06d2 bea6893 1cf06d2 bea6893 1cf06d2 bea6893 ca0f425 1cf06d2 0951988 1cf06d2 0951988 1cf06d2 bea6893 1cf06d2 52f8d54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
import os
@st.cache(allow_output_mutation=True)
def load_model_cache():
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
tokenizer_pl = T5Tokenizer.from_pretrained(
"Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
)
model_pl = T5ForConditionalGeneration.from_pretrained(
"Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
)
model_det_pl = T5ForConditionalGeneration.from_pretrained(
"Voicelab/vlt5-base-rfc-detector-1.0", use_auth_token=auth_token
)
return tokenizer_pl, model_pl, model_det_pl
img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 5000
cache_size: int = 100
st.set_page_config(
page_title="DEMO - Reason for Contact generation",
page_icon=img_favicon,
initial_sidebar_state="expanded",
)
tokenizer_pl, model_pl, model_det_pl = load_model_cache()
def get_predictions(text, mode):
input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
if mode == "Polish - RfC Generation":
output = model_pl.generate(
input_ids,
no_repeat_ngram_size=1,
num_beams=3,
num_beam_groups=3,
min_length=10,
max_length=100,
)
elif mode == "Polish - RfC Detection":
output = model_det_pl.generate(
input_ids,
no_repeat_ngram_size=2,
num_beams=3,
num_beam_groups=3,
repetition_penalty=1.5,
diversity_penalty=2.0,
length_penalty=2.0,
)
predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
return predicted_rfc
def trim_length():
if len(st.session_state["input"]) > max_length:
st.session_state["input"] = st.session_state["input"][:max_length]
if __name__ == "__main__":
st.sidebar.image(img_short)
st.image(img_full)
st.title("VLT5 - Reason for Contact generator")
generated_keywords = ""
user_input = st.text_area(
label=f"Input text (max {max_length} characters)",
value="",
height=300,
on_change=trim_length,
key="input",
)
mode = st.sidebar.title("Model settings")
mode = st.sidebar.radio(
"Select model to test",
[
"Polish - RfC Generation",
"Polish - RfC Detection",
],
)
result = st.button("Find reason for contact")
if mode == "Polish - RfC Generation (accepts whole conversation)":
st.markdown("### You selected RfC Generation model.")
st.markdown("-- *Input*: Whole conversation. Should specify roles (e.g. **AGENT: Hello, how can I help you? CLIENT: Hi, I would like to report a stolen card.**")
st.markdown("-- *Output*: Reason for calling for the whole conversation.")
text_area = "Put a whole conversation or full e-mail here."
elif mode == "Polish - RfC Detection (accepts one turn)":
st.markdown("### You selected RfC Detection model.")
st.markdown("-- *Input*: A single turn from the conversation e.g. **'Hello, how can I help you?'** or **'Hi, I would like to report a stolen card.'**")
st.markdown("-- *Output*: Model will return an empty string if a turn possibly does not includes Reason for Calling, or a sentence if the RfC is detected.")
text_area = "Put a single turn or a few sentences here."
if result:
generated_rfc = get_predictions(text=user_input, mode=mode)
st.text_area(text_area, generated_rfc)
print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}") |