Camille commited on
Commit
8384a73
·
1 Parent(s): 8d252dd

fix: laod model

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -4,6 +4,7 @@ from typing import List
4
 
5
  import streamlit as st
6
  from transformers import BertTokenizer, TFAutoModelForMaskedLM
 
7
 
8
  from rhyme_with_ai.utils import color_new_words, sanitize
9
  from rhyme_with_ai.rhyme import query_rhyme_words
@@ -70,7 +71,7 @@ def start_rhyming(query, rhyme_words_options):
70
 
71
  rhyme_words = rhyme_words_options[:N_RHYMES]
72
 
73
- model, tokenizer = load_model(MODEL_PATH)
74
  sentence_generator = RhymeGenerator(model, tokenizer)
75
  sentence_generator.start(query, rhyme_words)
76
 
@@ -84,13 +85,18 @@ def start_rhyming(query, rhyme_words_options):
84
 
85
 
86
  @st.cache(allow_output_mutation=True)
87
- def load_model(model_path):
88
- return (
89
- TFAutoModelForMaskedLM.from_pretrained(model_path),
90
- BertTokenizer.from_pretrained(model_path),
 
 
 
 
 
 
91
  )
92
 
93
-
94
  def display_output(status_text, query, current_sentences, previous_sentences):
95
  print_sentences = []
96
  for new, old in zip(current_sentences, previous_sentences):
 
4
 
5
  import streamlit as st
6
  from transformers import BertTokenizer, TFAutoModelForMaskedLM
7
+ from transformers import CamembertModel, CamembertTokenizer
8
 
9
  from rhyme_with_ai.utils import color_new_words, sanitize
10
  from rhyme_with_ai.rhyme import query_rhyme_words
 
71
 
72
  rhyme_words = rhyme_words_options[:N_RHYMES]
73
 
74
+ model, tokenizer = load_model(MODEL_PATH, LANGUAGE)
75
  sentence_generator = RhymeGenerator(model, tokenizer)
76
  sentence_generator.start(query, rhyme_words)
77
 
 
85
 
86
 
87
  @st.cache(allow_output_mutation=True)
88
+ def load_model(model_path, language):
89
+ if language != "french":
90
+ return (
91
+ TFAutoModelForMaskedLM.from_pretrained(model_path),
92
+ BertTokenizer.from_pretrained(model_path),
93
+ )
94
+ else :
95
+ return (
96
+ CamembertModel.from_pretrained(model_path),
97
+ CamembertTokenizer.from_pretrained(model_path),
98
  )
99
 
 
100
  def display_output(status_text, query, current_sentences, previous_sentences):
101
  print_sentences = []
102
  for new, old in zip(current_sentences, previous_sentences):