AgaMiko commited on
Commit
1cf06d2
1 Parent(s): cf7e264

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -18
app.py CHANGED
@@ -14,8 +14,11 @@ def load_model_cache():
14
  model_pl = T5ForConditionalGeneration.from_pretrained(
15
  "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
16
  )
 
 
 
17
 
18
- return tokenizer_pl, model_pl
19
 
20
 
21
  img_full = Image.open("images/vl-logo-nlp-blue.png")
@@ -25,24 +28,35 @@ max_length: int = 5000
25
  cache_size: int = 100
26
 
27
  st.set_page_config(
28
- page_title="DEMO - Reason for Contact detection",
29
  page_icon=img_favicon,
30
  initial_sidebar_state="expanded",
31
  )
32
 
33
- tokenizer_pl, model_pl = load_model_cache()
34
 
35
 
36
- def get_predictions(text):
37
  input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
38
- output = model_pl.generate(
39
- input_ids,
40
- no_repeat_ngram_size=1,
41
- num_beams=3,
42
- num_beam_groups=3,
43
- min_length=10,
44
- max_length=100,
45
- )
 
 
 
 
 
 
 
 
 
 
 
46
  predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
47
  return predicted_rfc
48
 
@@ -55,7 +69,7 @@ def trim_length():
55
  if __name__ == "__main__":
56
  st.sidebar.image(img_short)
57
  st.image(img_full)
58
- st.title("VLT5 - RfC generation")
59
 
60
  generated_keywords = ""
61
  user_input = st.text_area(
@@ -66,16 +80,29 @@ if __name__ == "__main__":
66
  key="input",
67
  )
68
 
69
- language = st.sidebar.title("Model settings")
70
- language = st.sidebar.radio(
71
  "Select model to test",
72
  [
73
- "Polish",
 
74
  ],
75
  )
76
 
77
  result = st.button("Find reason for contact")
 
 
 
 
 
 
 
 
 
 
 
 
78
  if result:
79
- generated_rfc = get_predictions(text=user_input)
80
- st.text_area("Reason", generated_rfc)
81
  print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}")
 
14
  model_pl = T5ForConditionalGeneration.from_pretrained(
15
  "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
16
  )
17
+ model_det_pl = T5ForConditionalGeneration.from_pretrained(
18
+ "Voicelab/vlt5-base-rfc-detector-1.0", use_auth_token=auth_token
19
+ )
20
 
21
+ return tokenizer_pl, model_pl, model_det_pl
22
 
23
 
24
  img_full = Image.open("images/vl-logo-nlp-blue.png")
 
28
  cache_size: int = 100
29
 
30
  st.set_page_config(
31
+ page_title="DEMO - Reason for Contact generation",
32
  page_icon=img_favicon,
33
  initial_sidebar_state="expanded",
34
  )
35
 
36
+ tokenizer_pl, model_pl, model_det_pl = load_model_cache()
37
 
38
 
39
+ def get_predictions(text, mode):
40
  input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
41
+ if mode == "Polish - RfC Generation":
42
+ output = model_pl.generate(
43
+ input_ids,
44
+ no_repeat_ngram_size=1,
45
+ num_beams=3,
46
+ num_beam_groups=3,
47
+ min_length=10,
48
+ max_length=100,
49
+ )
50
+ elif mode == "Polish - RfC Detection":
51
+ output = model.generate(
52
+ input_ids,
53
+ no_repeat_ngram_size=2,
54
+ num_beams=3,
55
+ num_beam_groups=3,
56
+ repetition_penalty=1.5,
57
+ diversity_penalty=2.0,
58
+ length_penalty=2.0,
59
+ )
60
  predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
61
  return predicted_rfc
62
 
 
69
  if __name__ == "__main__":
70
  st.sidebar.image(img_short)
71
  st.image(img_full)
72
+ st.title("VLT5 - Reason for Contact generator")
73
 
74
  generated_keywords = ""
75
  user_input = st.text_area(
 
80
  key="input",
81
  )
82
 
83
+ mode = st.sidebar.title("Model settings")
84
+ mode = st.sidebar.radio(
85
  "Select model to test",
86
  [
87
+ "Polish - RfC Generation",
88
+ "Polish - RfC Detection",
89
  ],
90
  )
91
 
92
  result = st.button("Find reason for contact")
93
+ if mode == "Polish - RfC Generation (accepts whole conversation)":
94
+ print("You selected RfC Generation model.")
95
+ print("-- Input: Whole conversation. Should specify roles (e.g. AGENT: Hello, how can I help you? CLIENT: Hi, I would like to report a stolen card.")
96
+ print("-- Output: Reason for calling for the whole conversation.")
97
+ text_area = "Put a whole conversation or full e-mail here."
98
+
99
+ elif mode == "Polish - RfC Detection (accepts one turn)":
100
+ print("You selected RfC Detection model.")
101
+ print("-- Input: A single turn from the conversation e.g. 'Hello, how can I help you?' or 'Hi, I would like to report a stolen card.'")
102
+ print("-- Output: Model will return an empty string if a turn possibly does not includes Reason for Calling, or a sentence if the RfC is detected.")
103
+ text_area = "Put a single turn or a few sentences here."
104
+
105
  if result:
106
+ generated_rfc = get_predictions(text=user_input, mode=mode)
107
+ st.text_area(text_area, generated_rfc)
108
  print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}")