mipatov commited on
Commit
27bdf54
·
1 Parent(s): 0fe7317

t5 model get func

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -8,12 +8,19 @@ from PIL import Image
8
 
9
 
10
  @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
11
- def get_model(model_name, model_path):
12
  tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
13
  model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
14
  model.eval()
15
  return model, tokenizer
16
 
 
 
 
 
 
 
 
17
 
18
  def predict_gpt(text, model, tokenizer, temperature):
19
  input_ids = tokenizer.encode(text, return_tensors="pt")
@@ -54,10 +61,8 @@ def predict_t5(text, model, tokenizer, temperature):
54
  generated_text = list(map(decode, out['sequences']))[0]
55
  return generated_text
56
 
57
- gpt_model, gpt_tokenizer = get_model('mipatov/rugpt3_nb_descr', 'mipatov/rugpt3_nb_descr')
58
- t5_model, t5_tokenizer = get_model('mipatov/rut5_nb_descr', 'mipatov/rut5_nb_descr')
59
-
60
- # st.title("NeuroKorzh")
61
 
62
 
63
  option = st.selectbox('Выберите модель', ('GPT', 'T5'))
 
8
 
9
 
10
  @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
11
+ def get_model_gpt(model_name):
12
  tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
13
  model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
14
  model.eval()
15
  return model, tokenizer
16
 
17
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
18
+ def get_model_t5(model_name):
19
+ tokenizer = transformers.T5Tokenizer.from_pretrained(model_name)
20
+ model = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
21
+ model.eval()
22
+ return model, tokenizer
23
+
24
 
25
  def predict_gpt(text, model, tokenizer, temperature):
26
  input_ids = tokenizer.encode(text, return_tensors="pt")
 
61
  generated_text = list(map(decode, out['sequences']))[0]
62
  return generated_text
63
 
64
+ gpt_model, gpt_tokenizer = get_model_gpt('mipatov/rugpt3_nb_descr', 'mipatov/rugpt3_nb_descr')
65
+ t5_model, t5_tokenizer = get_model_t5('mipatov/rut5_nb_descr', 'mipatov/rut5_nb_descr')
 
 
66
 
67
 
68
  option = st.selectbox('Выберите модель', ('GPT', 'T5'))