import gradio as gr from datasets import load_dataset import torch.nn.functional as F import numpy as np import torch from transformers import AutoTokenizer, AutoModel import pickle with open('book_embeddings.pkl', 'rb') as file: book_embeddings = pickle.load(file) model_checkpoint = 'intfloat/multilingual-e5-large' tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModel.from_pretrained(model_checkpoint) books_data = load_dataset('vojtam/czech_books_descriptions', split="train+test") books_data.set_format('pandas') def average_pool(last_hidden_states, attention_mask): last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] def create_embeddings(tokenizer, model, input_texts, batch_size=32): embeddings_list = [] for i in range(0, len(input_texts), batch_size): batch_texts = input_texts[i:i + batch_size] batch_dict = tokenizer(batch_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') # Get embeddings for batch with torch.no_grad(): outputs = model(**batch_dict) batch_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) embeddings_list.append(batch_embeddings) if (i + batch_size) % (batch_size * 10) == 0: print(f"Processed {i + batch_size}/{len(input_texts)} texts") return torch.cat(embeddings_list, dim=0) def find_similar_books(query: str, n = 5): input_query = "query: " + query query_embedding = create_embeddings(tokenizer, model, input_query) scores = ((query_embedding @ book_embeddings.T) * 100).detach().numpy()[0] top_indices = np.argsort(scores)[-n:][::-1] return books_data[top_indices] css = """ .full-height-gallery { height: calc(100vh - 250px); overflow-y: auto; } #submit-btn { background-color: #ff5b00; color: #ffffff; } """ with gr.Blocks(css=css) as intf: with gr.Row(): text_input = gr.Textbox(label="Popis knihy", info = "Zadejte popis knihy, kterou byste si chtěli přečíst a aplikace najde nejpodobněší knihy dle vašeho popisu", placeholder='Zadejte popis, například "drama z prostředí nemocnice"') n_books = gr.Number(value = 5, label = "Počet knih", info="Počet nejpodobnějších knih, které si přejete zobrazit", minimum = 1, step = 1) with gr.Row(): submit_btn = gr.Button("Vyhledat knihy", elem_id="submit-btn") clear_btn = gr.Button("Smazat") with gr.Row(): dataframe = gr.Dataframe(label="Podobné knihy", show_label=False, elem_classes = ["full-height-gallery"]) submit_btn.click(fn=find_similar_books, inputs=[text_input, n_books], outputs=dataframe) clear_btn.click(fn=lambda: [None, []], inputs=None, outputs=[text_input, dataframe]) intf.launch(share=True)