import faiss import pickle import datasets import numpy as np import requests import streamlit as st from vector_engine.utils import vector_search from transformers import AutoModel, AutoTokenizer from datasets import load_dataset #@st.cache @st.cache_data def read_data(dataset_repo='dhmeltzer/asks_validation_embedded'): """Read the data from huggingface.""" return load_dataset(dataset_repo)['validation_asks'] @st.cache_data def load_faiss_index(path_to_faiss="./faiss_index_small.pickle"): """Load and deserialize the Faiss index.""" with open(path_to_faiss, "rb") as h: data = pickle.load(h) return faiss.deserialize_index(data) def main(): # Load data and models data = read_data() #model = load_bert_model() #tok = load_tokenizer() faiss_index = load_faiss_index() model_id="sentence-transformers/nli-distilbert-base" api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"} def query(texts): response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}}) return response.json() st.title("Vector-based searches with Sentence Transformers and Faiss") # User search user_input = st.text_area("Search box", "What is spacetime made out of?") # Filters st.sidebar.markdown("**Filters**") num_results = st.sidebar.slider("Number of search results", 1, 50, 1) vector = query([user_input]) # Fetch results if user_input: # Get paper IDs _, I = faiss_index.search(np.array(vector).astype("float32"), k=num_results) # Get individual results for id_ in I.flatten().tolist(): row = data[id_] answers=row['answers']['text'] answers_URLs = row['answers_urls']['url'] for k in range(len(answers_URLs)): answers = [answer.replace(f'_URL_{k}_',answers_URLs[k]) for answer in answers] st.write( f"**Title**: {row['title']}") st.write( f"**Score**: {row['answers']['score'][0]}") st.write( f"**Top Answer**: {answers[0]} ") st.write("-"*20) if __name__ == "__main__": main() #@st.cache(allow_output_mutation=True) #def load_bert_model(name="nli-distilbert-base"): # """Instantiate a sentence-level DistilBERT model.""" # return AutoModel.from_pretrained(f'sentence-transformers/{name}') # #@st.cache(allow_output_mutation=True) #def load_tokenizer(name="nli-distilbert-base"): # return AutoTokenizer.from_pretrained(f'sentence-transformers/{name}') #@st.cache(allow_output_mutation=True)