Spaces:
Sleeping
Sleeping
Update pages/📖History_Mystery.py
Browse files
pages/📖History_Mystery.py
CHANGED
@@ -15,31 +15,31 @@ max_len = st.slider(" Длина сгенерированного отрывк
|
|
15 |
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
|
16 |
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
|
17 |
# #Задаем класс модели (уже в streamlit/tg_bot)
|
18 |
-
|
19 |
'sberbank-ai/rugpt3small_based_on_gpt2',
|
20 |
output_attentions = False,
|
21 |
output_hidden_states = False,
|
22 |
)
|
23 |
-
|
24 |
'sberbank-ai/rugpt3small_based_on_gpt2',
|
25 |
output_attentions = False,
|
26 |
output_hidden_states = False,
|
27 |
)
|
28 |
|
29 |
# # Вешаем сохраненные веса на нашу модель
|
30 |
-
|
31 |
# Функция для генерации текста
|
32 |
def generate_text(prompt):
|
33 |
# Преобразование входной строки в токены
|
34 |
-
input_ids =
|
35 |
|
36 |
# Генерация текста
|
37 |
-
output =
|
38 |
temperature=5, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
|
39 |
num_return_sequences=3)
|
40 |
|
41 |
# Декодирование сгенерированного текста
|
42 |
-
generated_text =
|
43 |
|
44 |
return generated_text
|
45 |
|
|
|
15 |
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
|
16 |
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
|
17 |
# #Задаем класс модели (уже в streamlit/tg_bot)
|
18 |
+
model_GPT = GPT2LMHeadModel.from_pretrained(
|
19 |
'sberbank-ai/rugpt3small_based_on_gpt2',
|
20 |
output_attentions = False,
|
21 |
output_hidden_states = False,
|
22 |
)
|
23 |
+
tokenizer_GPT = GPT2Tokenizer.from_pretrained(
|
24 |
'sberbank-ai/rugpt3small_based_on_gpt2',
|
25 |
output_attentions = False,
|
26 |
output_hidden_states = False,
|
27 |
)
|
28 |
|
29 |
# # Вешаем сохраненные веса на нашу модель
|
30 |
+
model_GPT.load_state_dict(torch.load('model_history_new.pt',map_location=torch.device('cpu')))
|
31 |
# Функция для генерации текста
|
32 |
def generate_text(prompt):
|
33 |
# Преобразование входной строки в токены
|
34 |
+
input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
|
35 |
|
36 |
# Генерация текста
|
37 |
+
output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
|
38 |
temperature=5, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
|
39 |
num_return_sequences=3)
|
40 |
|
41 |
# Декодирование сгенерированного текста
|
42 |
+
generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
|
43 |
|
44 |
return generated_text
|
45 |
|