A1ex1 commited on
Commit
f81d65c
·
1 Parent(s): a5ffd76

edit application file 2

Browse files
Files changed (1) hide show
  1. app.py +32 -35
app.py CHANGED
@@ -20,16 +20,17 @@ model_init = GPT2LMHeadModel.from_pretrained(
20
  output_hidden_states = False,
21
  )
22
  model_init.to(device);
23
- # # Это обученная модель, в нее загружаем веса
24
- # model = GPT2LMHeadModel.from_pretrained(
25
- # 'sberbank-ai/rugpt3small_based_on_gpt2',
26
- # output_attentions = False,
27
- # output_hidden_states = False,
28
- # )
29
 
30
- # m = torch.load('model.pt')
31
- # model.load_state_dict(m)
32
- # model.to(device);
 
33
 
34
  str = st.text_input('Введите 1-4 слова начала текста, и подождите минутку', 'Мужик спрашивает у официанта')
35
 
@@ -73,30 +74,26 @@ for out_ in out1:
73
  # print(tokenizer.decode(out_))
74
 
75
 
76
- # # дообученная модель
77
- # with torch.inference_mode():
78
- # # prompt = 'Мужик спрашивает официанта'
79
- # # prompt = tokenizer.encode(str, return_tensors='pt')
80
- # out2 = model.generate(
81
- # input_ids=prompt,
82
- # max_length=150,
83
- # num_beams=1,
84
- # do_sample=True,
85
- # temperature=1.,
86
- # top_k=5,
87
- # top_p=0.6,
88
- # no_repeat_ngram_size=2,
89
- # num_return_sequences=3,
90
- # ).numpy() #).cpu().numpy()
91
 
92
- # st.subheader('Тексты на модели, обученной документами всех тематик и дообученной анекдотами:')
93
- # n = 0
94
- # for out_ in out2:
95
- # n += 1
96
- # st.write(tokenizer.decode(out_).rpartition('.')[0],'.')
97
- # # print(textwrap.fill(tokenizer.decode(out_), 100), end='\n------------------\n')
98
- <<<<<<< HEAD
99
- # st.write('\n------------------\n')
100
- =======
101
- # st.write('\n------------------\n')
102
- >>>>>>> da65de15227afe7841c21d51b9e43521b1a62c1b
 
20
  output_hidden_states = False,
21
  )
22
  model_init.to(device);
23
+ # Это обученная модель, в нее загружаем веса
24
+ model = GPT2LMHeadModel.from_pretrained(
25
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
26
+ output_attentions = False,
27
+ output_hidden_states = False,
28
+ )
29
 
30
+ # Подгружаем сохраненные веса модели и загружаем их в модель
31
+ m = torch.load('model.pt')
32
+ model.load_state_dict(m)
33
+ model.to(device);
34
 
35
  str = st.text_input('Введите 1-4 слова начала текста, и подождите минутку', 'Мужик спрашивает у официанта')
36
 
 
74
  # print(tokenizer.decode(out_))
75
 
76
 
77
+ # дообученная модель
78
+ with torch.inference_mode():
79
+ # prompt = 'Мужик спрашивает официанта'
80
+ # prompt = tokenizer.encode(str, return_tensors='pt')
81
+ out2 = model.generate(
82
+ input_ids=prompt,
83
+ max_length=150,
84
+ num_beams=1,
85
+ do_sample=True,
86
+ temperature=1.,
87
+ top_k=5,
88
+ top_p=0.6,
89
+ no_repeat_ngram_size=2,
90
+ num_return_sequences=3,
91
+ ).cpu().numpy() #).cpu().numpy()
92
 
93
+ st.subheader('Тексты на модели, обученной документами всех тематик и дообученной анекдотами:')
94
+ n = 0
95
+ for out_ in out2:
96
+ n += 1
97
+ st.write(tokenizer.decode(out_).rpartition('.')[0],'.')
98
+ # print(textwrap.fill(tokenizer.decode(out_), 100), end='\n------------------\n')
99
+ st.write('\n------------------\n')