Spaces:
Running
Running
import streamlit as st | |
import torch as t | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer, util | |
from time import perf_counter as timer | |
def load_data(database_file): | |
df = pd.read_parquet(database_file) | |
chunk_embeddings = t.zeros((df.__len__(), 768)) | |
for idx in range(len(chunk_embeddings)): | |
chunk_embeddings[idx] = t.tensor(df.loc[df.index[idx], "chunk_embeddings"]) | |
return df, chunk_embeddings | |
def main(): | |
st.title("Semantic Text Retrieval App") | |
# Select device | |
device = "cuda" if t.cuda.is_available() else "cpu" | |
st.write(f"Using device: {device}") | |
# Load embedding model | |
embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device=device) | |
# File upload for the database | |
database_file = st.file_uploader("Upload the Parquet database file", type=["parquet"]) | |
if database_file is not None: | |
df, chunk_embeddings = load_data(database_file) | |
st.success("Database loaded successfully!") | |
query = st.text_area("Enter your query:") | |
if st.button("Search") and query: | |
query_embedding = embedding_model.encode(query) | |
# Compute dot product scores | |
start_time = timer() | |
dot_scores = util.dot_score(query_embedding, chunk_embeddings)[0] | |
end_time = timer() | |
st.write(f"Time taken to compute scores: {end_time - start_time:.5f} seconds") | |
# Get top results | |
top_k = st.slider("Select number of top results to display", min_value=1, max_value=10, value=5) | |
top_results_dot_product = t.topk(dot_scores, k=top_k) | |
st.subheader("Query Results") | |
st.write(f"Query: {query}") | |
for score, idx in zip(top_results_dot_product[0], top_results_dot_product[1]): | |
st.write(f"### Score: {score:.4f}") | |
st.write(f"**Text:** {df.iloc[int(idx)]['ext']}") | |
st.write(f"**Number of tokens:** {df.iloc[int(idx)]['tokens']}") | |
st.write("---") | |
if __name__ == "__main__": | |
main() | |