SaviAnna commited on
Commit
3f4021d
·
1 Parent(s): 52f2101

Update pages/📖History_Mystery.py

Browse files
Files changed (1) hide show
  1. pages/📖History_Mystery.py +6 -6
pages/📖History_Mystery.py CHANGED
@@ -15,31 +15,31 @@ max_len = st.slider(" Длина сгенерированного отрывк
15
  # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
16
  # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
17
  # #Задаем класс модели (уже в streamlit/tg_bot)
18
- model = GPT2LMHeadModel.from_pretrained(
19
  'sberbank-ai/rugpt3small_based_on_gpt2',
20
  output_attentions = False,
21
  output_hidden_states = False,
22
  )
23
- tokenizer = GPT2Tokenizer.from_pretrained(
24
  'sberbank-ai/rugpt3small_based_on_gpt2',
25
  output_attentions = False,
26
  output_hidden_states = False,
27
  )
28
 
29
  # # Вешаем сохраненные веса на нашу модель
30
- model.load_state_dict(torch.load('model_history_new.pt',map_location=torch.device('cpu')))
31
  # Функция для генерации текста
32
  def generate_text(prompt):
33
  # Преобразование входной строки в токены
34
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
35
 
36
  # Генерация текста
37
- output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
38
  temperature=5, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
39
  num_return_sequences=3)
40
 
41
  # Декодирование сгенерированного текста
42
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
43
 
44
  return generated_text
45
 
 
15
  # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
16
  # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
17
  # #Задаем класс модели (уже в streamlit/tg_bot)
18
+ model_GPT = GPT2LMHeadModel.from_pretrained(
19
  'sberbank-ai/rugpt3small_based_on_gpt2',
20
  output_attentions = False,
21
  output_hidden_states = False,
22
  )
23
+ tokenizer_GPT = GPT2Tokenizer.from_pretrained(
24
  'sberbank-ai/rugpt3small_based_on_gpt2',
25
  output_attentions = False,
26
  output_hidden_states = False,
27
  )
28
 
29
  # # Вешаем сохраненные веса на нашу модель
30
+ model_GPT.load_state_dict(torch.load('model_history_new.pt',map_location=torch.device('cpu')))
31
  # Функция для генерации текста
32
  def generate_text(prompt):
33
  # Преобразование входной строки в токены
34
+ input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
35
 
36
  # Генерация текста
37
+ output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
38
  temperature=5, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
39
  num_return_sequences=3)
40
 
41
  # Декодирование сгенерированного текста
42
+ generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
43
 
44
  return generated_text
45