Spaces:
Runtime error
Runtime error
import streamlit as st | |
import faiss | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
import json | |
def load_index(): | |
index = faiss.read_index('cdc_search.index') | |
return index | |
def load_data(): | |
with open('./data.json') as f: | |
data = json.load(f) | |
return data | |
def load_embedder(): | |
embedder = SentenceTransformer("distilbert-base-nli-stsb-mean-tokens") | |
return embedder | |
def load_qa_pipeline(): | |
qa = pipeline("question-answering", model="ktrapeznikov/albert-xlarge-v2-squad-v2") | |
return qa | |
def load_questions(): | |
with open('./questions.json') as f: | |
data = json.load(f) | |
return (q for q in data) | |
index = load_index() | |
embedder = load_embedder() | |
qa = load_qa_pipeline() | |
data = load_data() | |
def search(query: str, k=1): | |
encoded_query = embedder.encode([query]) | |
top_k = index.search(encoded_query, k) | |
scores = top_k[0][0] | |
results = [data[_id] for _id in top_k[1][0]] | |
answers = [] | |
for result in results: | |
answer = qa(question=query, context=result['text']) | |
if 'answer' in answer: | |
answers.append((answer['answer'], answer['score'])) | |
return sorted(answers, key=lambda tup: tup[1], reverse=True) | |
questions = load_questions() | |
option = st.selectbox("Sample Questions", questions) | |
st.write('You selected: ', option) | |
st.markdown("\n".join([f"* {answer}" for (answer, _) in search(option)])) |