Spaces:
Runtime error
Runtime error
File size: 3,009 Bytes
6ed0d96 f80579e 762ed00 832fef5 bf0ac5d f80579e 832fef5 f80579e b09ffa7 f80579e 7521548 bf0ac5d f80579e b09ffa7 f80579e bf0ac5d f80579e bf0ac5d 762ed00 bf0ac5d b09ffa7 bf0ac5d b09ffa7 762ed00 bf0ac5d 974b378 762ed00 b09ffa7 8b4800e b09ffa7 832fef5 bf0ac5d b09ffa7 762ed00 8b4800e 832fef5 bf0ac5d 832fef5 6ed0d96 832fef5 6ed0d96 832fef5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import streamlit as st
from huggingface_hub import HfApi, HfFolder
import datasets
import pandas as pd
import logging
import os
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
@st.cache_data
def login():
if 'logged' not in st.session_state:
logging.info("Trying to log in to HF")
st.session_state['logged'] = True
HF_TOKEN = os.environ.get("HF_TOKEN")
api = HfApi()
api.set_access_token(HF_TOKEN)
folder = HfFolder()
folder.save_token(HF_TOKEN)
logging.info("Succesfully logged")
return True
else:
logging.info("Already logged in")
return False
@st.cache_resource
def load_model():
logging.info("Trying to load model")
tokenizer = AutoTokenizer.from_pretrained(
'sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
logging.info("Model loaded")
return model, tokenizer
@st.cache_resource
def load_index():
logging.info("Trying to load index")
index = datasets.Dataset.load_from_disk("Data/articles.hf")
logging.info("Articles dataset loaded")
index.load_faiss_index("embedding", "Data/articles.index")
logging.info("FAISS index loaded")
return index
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1) \
.expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) \
/ torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embedding(query, model, tokenizer):
encoded_input = tokenizer(
query, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
embeds = model(**encoded_input)
embeds = mean_pooling(embeds, encoded_input['attention_mask'])
embeds = F.normalize(embeds, p=2, dim=1)
return embeds.numpy()
@st.cache_data
def get_answers(query, num_answers):
logging.info("Getting answers for {}".format(query))
model, tokenizer = load_model()
index = load_index()
query_embedding = get_embedding(query, model, tokenizer).reshape(-1)
_, answers = index.get_nearest_examples('embedding', query_embedding, num_answers)
answers = pd.DataFrame.from_dict(answers)
logging.info("Succesfully got answers for {}".format(query))
return answers.to_dict('records')
def display_article(article):
with st.container():
href = "https://arxiv.org/abs/{}".format(article['id'])
title = "<h3><a href=\"{}\">{}</a></h3>".format(
href, article['title'])
st.write(title, unsafe_allow_html=True)
st.markdown(article['abstract'])
st.write("---")
def display_answers(query, max_answers=100):
st.write("---")
articles = get_answers(query, max_answers)
for article in articles[:st.session_state['num_articles_to_show']]:
display_article(article)
|