AgaMiko commited on
Commit
ca0f425
1 Parent(s): bea6893

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -35
app.py CHANGED
@@ -6,62 +6,42 @@ import os
6
  @st.cache(allow_output_mutation=True)
7
  def load_model_cache():
8
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
9
- tokenizer_en = T5Tokenizer.from_pretrained(
10
- "Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
11
- )
12
- model_en = T5ForConditionalGeneration.from_pretrained(
13
- "Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
14
- )
15
 
16
  tokenizer_pl = T5Tokenizer.from_pretrained(
17
- "Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
18
  )
19
  model_pl = T5ForConditionalGeneration.from_pretrained(
20
- "Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
21
  )
22
 
23
- return tokenizer_en, model_en, tokenizer_pl, model_pl
24
 
25
 
26
  img_full = Image.open("images/vl-logo-nlp-blue.png")
27
  img_short = Image.open("images/sVL-NLP-short.png")
28
  img_favicon = Image.open("images/favicon_vl.png")
29
- max_length: int = 1000
30
  cache_size: int = 100
31
 
32
  st.set_page_config(
33
- page_title="DEMO - keywords generation",
34
  page_icon=img_favicon,
35
  initial_sidebar_state="expanded",
36
  )
37
 
38
  tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()
39
 
40
- def get_predictions(text, language):
41
- if language == "Polish":
42
  input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
43
  output = model_pl.generate(
44
  input_ids,
45
- no_repeat_ngram_size=2,
46
  num_beams=3,
47
  num_beam_groups=3,
48
- repetition_penalty=1.5,
49
- diversity_penalty=2.0,
50
- length_penalty=2.0,
51
  )
52
  predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
53
- elif language == "English":
54
- input_ids = tokenizer_en(text, return_tensors="pt", truncation=True).input_ids
55
- output = model_en.generate(
56
- input_ids,
57
- no_repeat_ngram_size=2,
58
- num_beams=3,
59
- num_beam_groups=3,
60
- repetition_penalty=1.5,
61
- diversity_penalty=2.0,
62
- length_penalty=2.0,
63
- )
64
- predicted_kw = tokenizer_en.decode(output[0], skip_special_tokens=True)
65
  return predicted_kw
66
 
67
 
@@ -73,7 +53,7 @@ def trim_length():
73
  if __name__ == "__main__":
74
  st.sidebar.image(img_short)
75
  st.image(img_full)
76
- st.title("VLT5 - keywords generation")
77
 
78
  generated_keywords = ""
79
  user_input = st.text_area(
@@ -89,12 +69,11 @@ if __name__ == "__main__":
89
  "Select model to test",
90
  [
91
  "Polish",
92
- "English",
93
  ],
94
  )
95
 
96
- result = st.button("Generate keywords")
97
  if result:
98
- generated_keywords = get_predictions(text=user_input, language=language)
99
- st.text_area("Generated keywords", generated_keywords)
100
- print(f"Input: {user_input}---> Keywords: {generated_keywords}")
 
6
  @st.cache(allow_output_mutation=True)
7
  def load_model_cache():
8
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
 
 
 
 
 
 
9
 
10
  tokenizer_pl = T5Tokenizer.from_pretrained(
11
+ "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
12
  )
13
  model_pl = T5ForConditionalGeneration.from_pretrained(
14
+ "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
15
  )
16
 
17
+ return tokenizer_pl, model_pl
18
 
19
 
20
  img_full = Image.open("images/vl-logo-nlp-blue.png")
21
  img_short = Image.open("images/sVL-NLP-short.png")
22
  img_favicon = Image.open("images/favicon_vl.png")
23
+ max_length: int = 5000
24
  cache_size: int = 100
25
 
26
  st.set_page_config(
27
+ page_title="DEMO - Reason for Contact detection",
28
  page_icon=img_favicon,
29
  initial_sidebar_state="expanded",
30
  )
31
 
32
  tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()
33
 
34
+ def get_predictions(text):
 
35
  input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
36
  output = model_pl.generate(
37
  input_ids,
38
+ no_repeat_ngram_size=1,
39
  num_beams=3,
40
  num_beam_groups=3,
41
+ min_length=10,
42
+ max_length=100,
 
43
  )
44
  predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
45
  return predicted_kw
46
 
47
 
 
53
  if __name__ == "__main__":
54
  st.sidebar.image(img_short)
55
  st.image(img_full)
56
+ st.title("VLT5 - RfC generation")
57
 
58
  generated_keywords = ""
59
  user_input = st.text_area(
 
69
  "Select model to test",
70
  [
71
  "Polish",
 
72
  ],
73
  )
74
 
75
+ result = st.button("Find reason for contact")
76
  if result:
77
+ generated_rfc = get_predictions(text=user_input, language=language)
78
+ st.text_area("Reason", generated_rfc)
79
+ print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}")