Camille commited on
Commit
619dca5
1 Parent(s): 6c85b9e

vocab dict.txt added

Browse files
app.py CHANGED
@@ -16,7 +16,7 @@ DEFAULT_QUERY = "Machines will take over the world soon"
16
  N_RHYMES = 10
17
 
18
 
19
- LANGUAGE = st.sidebar.radio("Language", ["english", "dutch", "french"],0)
20
  if LANGUAGE == "english":
21
  MODEL_PATH = "bert-large-cased-whole-word-masking"
22
  ITER_FACTOR = 5
@@ -24,10 +24,14 @@ elif LANGUAGE == "dutch":
24
  MODEL_PATH = "GroNLP/bert-base-dutch-cased"
25
  ITER_FACTOR = 10 # Faster model
26
  elif LANGUAGE == "french":
27
- MODEL_PATH = "camembert-base"
28
- ITER_FACTOR = 5
29
  else:
30
  raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english','dutch' or 'french.")
 
 
 
 
 
31
 
32
  def main():
33
  st.markdown(
@@ -93,6 +97,7 @@ def load_model(model_path, language):
93
  BertTokenizer.from_pretrained(model_path),
94
  )
95
  else :
 
96
  return (
97
  CamembertModel.from_pretrained(model_path),
98
  CamembertTokenizer.from_pretrained(model_path),
 
16
  N_RHYMES = 10
17
 
18
 
19
+ """LANGUAGE = st.sidebar.radio("Language", ["english", "dutch", "french"],0)
20
  if LANGUAGE == "english":
21
  MODEL_PATH = "bert-large-cased-whole-word-masking"
22
  ITER_FACTOR = 5
 
24
  MODEL_PATH = "GroNLP/bert-base-dutch-cased"
25
  ITER_FACTOR = 10 # Faster model
26
  elif LANGUAGE == "french":
27
+
 
28
  else:
29
  raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english','dutch' or 'french.")
30
+ """
31
+
32
+ LANGUAGE = "french"
33
+ MODEL_PATH = "camembert-base"
34
+ ITER_FACTOR = 5
35
 
36
  def main():
37
  st.markdown(
 
97
  BertTokenizer.from_pretrained(model_path),
98
  )
99
  else :
100
+ tokenizer = CamembertTokenizer(vocab_file='rhyme_with_ai/dict.txt')
101
  return (
102
  CamembertModel.from_pretrained(model_path),
103
  CamembertTokenizer.from_pretrained(model_path),
rhyme_with_ai/dict.txt ADDED
The diff for this file is too large to render. See raw diff
 
rhyme_with_ai/token_weighter.py CHANGED
@@ -7,7 +7,7 @@ class TokenWeighter:
7
  self.proba = self.get_token_proba()
8
 
9
  def get_token_proba(self):
10
- valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
11
  return valid_token_mask
12
 
13
  def _filter_short_partial(self, vocab):
 
7
  self.proba = self.get_token_proba()
8
 
9
  def get_token_proba(self):
10
+ valid_token_mask = self._filter_short_partial(self.tokenizer_.vocabulaire)
11
  return valid_token_mask
12
 
13
  def _filter_short_partial(self, vocab):