SaviAnna commited on
Commit
425069f
·
1 Parent(s): a7f32c3

Update pages/📖History_Mystery.py

Browse files
Files changed (1) hide show
  1. pages/📖History_Mystery.py +72 -16
pages/📖History_Mystery.py CHANGED
@@ -9,31 +9,63 @@ st.title("""
9
  History Mystery
10
  """)
11
  # Добавление слайдера
12
- temperature = st.slider("Градус дичи", 1, 20, 1)
13
  max_length = st.slider(" Длина сгенерированного отрывка", 60, 120, 2)
14
  # Загрузка модели и токенизатора
15
  # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
16
  # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
17
  # #Задаем класс модели (уже в streamlit/tg_bot)
18
 
19
- @st.cache
20
- def load_gpt():
21
- model_GPT = GPT2LMHeadModel.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  'sberbank-ai/rugpt3small_based_on_gpt2',
23
  output_attentions = False,
24
  output_hidden_states = False,
25
- )
26
- tokenizer_GPT = GPT2Tokenizer.from_pretrained(
27
- 'sberbank-ai/rugpt3small_based_on_gpt2',
28
- output_attentions = False,
29
- output_hidden_states = False,
30
- )
31
- model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu')))
32
- return model_GPT, tokenizer_GPT
33
 
34
  # # Вешаем сохраненные веса на нашу модель
35
-
36
  # Функция для генерации текста
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def generate_text(model_GPT, tokenizer_GPT, prompt):
38
  # Преобразование входной строки в токены
39
  input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
@@ -48,9 +80,34 @@ def generate_text(model_GPT, tokenizer_GPT, prompt):
48
 
49
  return generated_text
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Streamlit приложение
52
  def main():
53
- model_GPT, tokenizer_GPT = load_gpt()
54
  st.write("""
55
  # GPT-3 генерация текста
56
  """)
@@ -65,7 +122,7 @@ def main():
65
  # Обработка события нажатия кнопки
66
  if generate_button:
67
  # Вывод сгенерированного текста
68
- generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
69
  st.subheader("Продолжени��:")
70
  st.write(generated_text)
71
 
@@ -73,4 +130,3 @@ def main():
73
 
74
  if __name__ == "__main__":
75
  main()
76
-
 
9
  History Mystery
10
  """)
11
  # Добавление слайдера
12
+ temperature = st.slider("Градус дичи", 1.0, 20.0, 1.0)
13
  max_length = st.slider(" Длина сгенерированного отрывка", 60, 120, 2)
14
  # Загрузка модели и токенизатора
15
  # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
16
  # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
17
  # #Задаем класс модели (уже в streamlit/tg_bot)
18
 
19
+ # @st.cache
20
+ # def load_gpt():
21
+ # model_GPT = GPT2LMHeadModel.from_pretrained(
22
+ # 'sberbank-ai/rugpt3small_based_on_gpt2',
23
+ # output_attentions = False,
24
+ # output_hidden_states = False,
25
+ # )
26
+ # tokenizer_GPT = GPT2Tokenizer.from_pretrained(
27
+ # 'sberbank-ai/rugpt3small_based_on_gpt2',
28
+ # output_attentions = False,
29
+ # output_hidden_states = False,
30
+ # )
31
+ # model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu')))
32
+ # return model_GPT, tokenizer_GPT
33
+
34
+ # # # Вешаем сохраненные веса на нашу модель
35
+
36
+ # # Функция для генерации текста
37
+ mperature = st.slider("Градус дичи", 1.0, 20.0, 1.0)
38
+ # Загрузка модели и токенизатора
39
+ # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
40
+ # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
41
+ # #Задаем класс модели (уже в streamlit/tg_bot)
42
+ model = GPT2LMHeadModel.from_pretrained(
43
  'sberbank-ai/rugpt3small_based_on_gpt2',
44
  output_attentions = False,
45
  output_hidden_states = False,
46
+ )
47
+ tokenizer = GPT2Tokenizer.from_pretrained(
48
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
49
+ output_attentions = False,
50
+ output_hidden_states = False,
51
+ )
 
 
52
 
53
  # # Вешаем сохраненные веса на нашу модель
54
+ model.load_state_dict(torch.load('model_history.pt',map_location=torch.device('cpu')))
55
  # Функция для генерации текста
56
+ def generate_text(prompt):
57
+ # Преобразование входной строки в токены
58
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
59
+
60
+ # Генерация текста
61
+ output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
62
+ temperature=1.0, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
63
+ num_return_sequences=3)
64
+
65
+ # Декодирование сгенерированного текста
66
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
67
+
68
+ return generated_text
69
  def generate_text(model_GPT, tokenizer_GPT, prompt):
70
  # Преобразование входной строки в токены
71
  input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
 
80
 
81
  return generated_text
82
 
83
+ # Streamlit приложение
84
+ # def main():
85
+ # model_GPT, tokenizer_GPT = load_gpt()
86
+ # st.write("""
87
+ # # GPT-3 генерация текста
88
+ # """)
89
+
90
+ # # Ввод строки пользователем
91
+ # prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")
92
+
93
+ # # # Генерация текста по введенной строке
94
+ # # generated_text = generate_text(prompt)
95
+ # # Создание кнопки "Сгенерировать"
96
+ # generate_button = st.button("За работу!")
97
+ # # Обработка события нажатия кнопки
98
+ # if generate_button:
99
+ # # Вывод сгенерированного текста
100
+ # generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
101
+ # st.subheader("Продолжение:")
102
+ # st.write(generated_text)
103
+
104
+
105
+
106
+ # if __name__ == "__main__":
107
+ # main()
108
+
109
  # Streamlit приложение
110
  def main():
 
111
  st.write("""
112
  # GPT-3 генерация текста
113
  """)
 
122
  # Обработка события нажатия кнопки
123
  if generate_button:
124
  # Вывод сгенерированного текста
125
+ generated_text = generate_text(prompt)
126
  st.subheader("Продолжени��:")
127
  st.write(generated_text)
128
 
 
130
 
131
  if __name__ == "__main__":
132
  main()