File size: 3,009 Bytes
6ed0d96
f80579e
762ed00
832fef5
bf0ac5d
f80579e
832fef5
f80579e
b09ffa7
 
 
 
f80579e
 
 
7521548
bf0ac5d
f80579e
 
b09ffa7
f80579e
 
 
bf0ac5d
f80579e
bf0ac5d
 
 
762ed00
 
 
 
bf0ac5d
b09ffa7
 
 
bf0ac5d
b09ffa7
762ed00
 
 
 
bf0ac5d
974b378
 
 
 
762ed00
 
 
b09ffa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b4800e
b09ffa7
 
832fef5
 
bf0ac5d
b09ffa7
762ed00
8b4800e
832fef5
 
bf0ac5d
832fef5
 
 
 
 
 
 
 
 
 
 
6ed0d96
 
832fef5
6ed0d96
832fef5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from huggingface_hub import HfApi, HfFolder
import datasets
import pandas as pd
import logging
import os


from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F


@st.cache_data
def login():
    if 'logged' not in st.session_state:
        logging.info("Trying to log in to HF")
        st.session_state['logged'] = True
        HF_TOKEN = os.environ.get("HF_TOKEN")
        api = HfApi()
        api.set_access_token(HF_TOKEN)
        folder = HfFolder()
        folder.save_token(HF_TOKEN)
        logging.info("Succesfully logged")
        return True
    else:
        logging.info("Already logged in")
        return False


@st.cache_resource
def load_model():
    logging.info("Trying to load model")
    tokenizer = AutoTokenizer.from_pretrained(
        'sentence-transformers/all-MiniLM-L6-v2')
    model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
    logging.info("Model loaded")
    return model, tokenizer


@st.cache_resource
def load_index():
    logging.info("Trying to load index")
    index = datasets.Dataset.load_from_disk("Data/articles.hf")
    logging.info("Articles dataset loaded")
    index.load_faiss_index("embedding", "Data/articles.index")
    logging.info("FAISS index loaded")
    return index


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1) \
        .expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) \
        / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def get_embedding(query, model, tokenizer):
    encoded_input = tokenizer(
        query, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        embeds = model(**encoded_input)
    embeds = mean_pooling(embeds, encoded_input['attention_mask'])
    embeds = F.normalize(embeds, p=2, dim=1)
    return embeds.numpy()


@st.cache_data
def get_answers(query, num_answers):
    logging.info("Getting answers for {}".format(query))
    model, tokenizer = load_model()
    index = load_index()
    query_embedding = get_embedding(query, model, tokenizer).reshape(-1)
    _, answers = index.get_nearest_examples('embedding', query_embedding, num_answers)
    answers = pd.DataFrame.from_dict(answers)
    logging.info("Succesfully got answers for {}".format(query))
    return answers.to_dict('records')


def display_article(article):
    with st.container():
        href = "https://arxiv.org/abs/{}".format(article['id'])
        title = "<h3><a href=\"{}\">{}</a></h3>".format(
            href, article['title'])
        st.write(title, unsafe_allow_html=True)
        st.markdown(article['abstract'])
    st.write("---")


def display_answers(query, max_answers=100):
    st.write("---")
    articles = get_answers(query, max_answers)
    for article in articles[:st.session_state['num_articles_to_show']]:
        display_article(article)