Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import load_dataset | |
import torch.nn.functional as F | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import pickle | |
with open('book_embeddings.pkl', 'rb') as file: | |
book_embeddings = pickle.load(file) | |
model_checkpoint = 'intfloat/multilingual-e5-large' | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
model = AutoModel.from_pretrained(model_checkpoint) | |
books_data = load_dataset('vojtam/czech_books_descriptions', split="train+test") | |
books_data.set_format('pandas') | |
def average_pool(last_hidden_states, attention_mask): | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def create_embeddings(tokenizer, model, input_texts, batch_size=32): | |
embeddings_list = [] | |
for i in range(0, len(input_texts), batch_size): | |
batch_texts = input_texts[i:i + batch_size] | |
batch_dict = tokenizer(batch_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') | |
# Get embeddings for batch | |
with torch.no_grad(): | |
outputs = model(**batch_dict) | |
batch_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) | |
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) | |
embeddings_list.append(batch_embeddings) | |
if (i + batch_size) % (batch_size * 10) == 0: | |
print(f"Processed {i + batch_size}/{len(input_texts)} texts") | |
return torch.cat(embeddings_list, dim=0) | |
def find_similar_books(query: str, n = 5): | |
input_query = "query: " + query | |
query_embedding = create_embeddings(tokenizer, model, input_query) | |
scores = ((query_embedding @ book_embeddings.T) * 100).detach().numpy()[0] | |
top_indices = np.argsort(scores)[-n:][::-1] | |
return books_data[top_indices] | |
css = """ | |
.full-height-gallery { | |
height: calc(100vh - 250px); | |
overflow-y: auto; | |
} | |
#submit-btn { | |
background-color: #ff5b00; | |
color: #ffffff; | |
} | |
""" | |
with gr.Blocks(css=css) as intf: | |
with gr.Row(): | |
text_input = gr.Textbox(label="Popis knihy", info = "Zadejte popis knihy, kterou byste si chtěli přečíst a aplikace najde nejpodobněší knihy dle vašeho popisu", placeholder='Zadejte popis, například "drama z prostředí nemocnice"') | |
n_books = gr.Number(value = 5, label = "Počet knih", info="Počet nejpodobnějších knih, které si přejete zobrazit", minimum = 1, step = 1) | |
with gr.Row(): | |
submit_btn = gr.Button("Vyhledat knihy", elem_id="submit-btn") | |
clear_btn = gr.Button("Smazat") | |
with gr.Row(): | |
dataframe = gr.Dataframe(label="Podobné knihy", show_label=False, elem_classes = ["full-height-gallery"]) | |
submit_btn.click(fn=find_similar_books, inputs=[text_input, n_books], outputs=dataframe) | |
clear_btn.click(fn=lambda: [None, []], inputs=None, outputs=[text_input, dataframe]) | |
intf.launch(share=True) |