vvv-knyazeva commited on
Commit
9e9df39
·
1 Parent(s): 2c52d39

Update stri.py

Browse files
Files changed (1) hide show
  1. stri.py +28 -29
stri.py CHANGED
@@ -53,36 +53,35 @@ for inputs, attention_masks in zip(input_ids, attention_mask):
53
  book_embedding = model(inputs.unsqueeze(0), attention_mask=attention_masks.unsqueeze(0))
54
  book_embedding = book_embedding[0][:, 0, :] #.detach().cpu().numpy()
55
  book_embeddings.append(np.squeeze(book_embedding))
 
56
 
57
  # Определение запроса пользователя
58
  query = st.text_input("Введите запрос")
59
- query_tokens = tokenizer.encode(query, add_special_tokens=True,
60
- truncation=True, max_length=max_len)
61
 
62
- query_padded = np.array(query_tokens + [0] * (max_len - len(query_tokens)))
63
- query_mask = np.where(query_padded != 0, 1, 0)
64
-
65
- # Переведем numpy массивы в тензоры PyTorch
66
- query_padded = torch.tensor(query_padded, dtype=torch.long)
67
- query_mask = torch.tensor(query_mask, dtype=torch.long)
68
-
69
- with torch.no_grad():
70
- query_embedding = model(query_padded.unsqueeze(0), query_mask.unsqueeze(0))
71
- query_embedding = query_embedding[0][:, 0, :] #.detach().cpu().numpy()
72
-
73
- # Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
74
- cosine_similarities = torch.nn.functional.cosine_similarity(
75
- query_embedding.squeeze(0),
76
- torch.stack(book_embeddings)
77
- )
78
- #cosine_similarities = torch.nn.functional.cosine_similarity(
79
- # torch.tensor(query_embedding.squeeze(0)),
80
- # torch.stack([torch.tensor(embedding) for embedding in book_embeddings])
81
- #)
82
-
83
- cosine_similarities = cosine_similarities.numpy()
84
-
85
- indices = np.argsort(cosine_similarities)[::-1] # Сортировка по убыванию
86
-
87
- for i in indices[:10]:
88
- st.write(books['title'][i])
 
53
  book_embedding = model(inputs.unsqueeze(0), attention_mask=attention_masks.unsqueeze(0))
54
  book_embedding = book_embedding[0][:, 0, :] #.detach().cpu().numpy()
55
  book_embeddings.append(np.squeeze(book_embedding))
56
+
57
 
58
  # Определение запроса пользователя
59
  query = st.text_input("Введите запрос")
 
 
60
 
61
+ if st.button('**Generate text**'):
62
+ query_tokens = tokenizer.encode(query, add_special_tokens=True,
63
+ truncation=True, max_length=max_len)
64
+
65
+ query_padded = np.array(query_tokens + [0] * (max_len - len(query_tokens)))
66
+ query_mask = np.where(query_padded != 0, 1, 0)
67
+
68
+ # Переведем numpy массивы в тензоры PyTorch
69
+ query_padded = torch.tensor(query_padded, dtype=torch.long)
70
+ query_mask = torch.tensor(query_mask, dtype=torch.long)
71
+
72
+ with torch.no_grad():
73
+ query_embedding = model(query_padded.unsqueeze(0), query_mask.unsqueeze(0))
74
+ query_embedding = query_embedding[0][:, 0, :] #.detach().cpu().numpy()
75
+
76
+ # Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
77
+ cosine_similarities = torch.nn.functional.cosine_similarity(
78
+ query_embedding.squeeze(0),
79
+ torch.stack(book_embeddings)
80
+ )
81
+
82
+ cosine_similarities = cosine_similarities.numpy()
83
+
84
+ indices = np.argsort(cosine_similarities)[::-1] # Сортировка по убыванию
85
+
86
+ for i in indices[:10]:
87
+ st.write(books['title'][i])