FindMyBook / stri.py
MARI-posa's picture
Update stri.py
feb52ca
raw
history blame
3.97 kB
import streamlit as st
import torch
import numpy as np
import pandas as pd
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import re
import pickle
import requests
from io import BytesIO
st.title("Книжные рекомендации")
# Загрузка модели и токенизатора
model_name = "cointegrated/rubert-tiny2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
# Загрузка датасета и аннотаций к книгам
books = pd.read_csv('all+++.csv')
books['author'].fillna('other', inplace=True)
#books.dropna(inplace=True)
#books = books[books['annotation'].apply(lambda x: len(x.split()) >= 40)]
#books.drop_duplicates(subset='title', keep='first', inplace=True)
#books = books.reset_index(drop=True)
#def data_preprocessing(text: str) -> str:
#text = re.sub(r'http\S+', " ", text) # удаляем ссылки
#text = re.sub(r'@\w+', ' ', text) # удаляем упоминания пользователей
#text = re.sub(r'#\w+', ' ', text) # удаляем хэштеги
#text = re.sub(r'<.*?>', ' ', text) # html tags
# return text
#for i in ['author', 'title', 'annotation']:
#books[i] = books[i].apply(data_preprocessing)
annot = books['annotation']
# Получение эмбеддингов аннотаций каждой книги в датасете
length = 256
# Определение запроса пользователя
query = st.text_input("Введите запрос")
if st.button('Сгенерировать'):
with open("book_embeddings256xxx.pkl", "rb") as f:
book_embeddings = pickle.load(f)
#book_embeddings = torch.tensor(book_embeddings, device=torch.device('cpu'))
#if st.button('Сгенерировать'):
#with open("book_embeddingsN.pkl", "rb") as f:
#book_embeddings = torch.load("book_embeddingsN.pkl", map_location=torch.device('cpu'))
#
#book_embeddings = pickle.load(f)
query_tokens = tokenizer.encode_plus(
query,
add_special_tokens=True,
max_length=length, # Ограничение на максимальную длину входной последовательности
pad_to_max_length=True, # Дополним последовательность нулями до максимальной длины
return_tensors='pt' # Вернём тензоры PyTorch
)
with torch.no_grad():
query_outputs = model(**query_tokens)
query_hidden_states = query_outputs.hidden_states[-1][:,0,:]
query_hidden_states = torch.nn.functional.normalize(query_hidden_states)
# Вычисление косинусного расстояния между эмбеддингом запроса и каждой аннотацией
cosine_similarities = torch.nn.functional.cosine_similarity(
query_hidden_states.squeeze(0),
torch.stack(book_embeddings)
)
cosine_similarities = cosine_similarities.numpy()
indices = np.argsort(cosine_similarities)[::-1] # Сортировка по убыванию
num_books_per_page = st.selectbox("Количество книг на странице:", [3, 5, 10], index=0)
for i in indices[:num_books_per_page]:
cols = st.columns(2) # Создание двух столбцов для размещения информации и изображения
cols[1].write("## " + books['title'][i])
cols[1].markdown("**Автор:** " + books['author'][i])
cols[1].markdown("**Аннотация:** " + books['annotation'][i])
image_url = books['image_url'][i]
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
cols[0].image(image)
cols[0].write(cosine_similarities[i])
cols[1].write("---")