Spaces:
Running
Running
import pandas as pd | |
import numpy as np | |
import nmslib | |
from sentence_transformers import SentenceTransformer | |
# TODO: Use pipe, remove embeddings | |
class SentenceEncoder: | |
"""Encodes the querry and papers data set and finds elements with the lowest cosine similarity | |
This application uses Sentence-BERT embeddings. | |
Sentence Embedding is achieved here via Siamese BERT-Networks from https://arxiv.org/abs/1908.10084 | |
The implementation used is that of SBERT.net (https://www.sbert.net/) | |
""" | |
def load_and_encode(self): | |
"""prepare data before running search querry""" | |
# load | |
df = self._load() | |
# encode | |
df, model, embeddings = self._encode_papers(df) | |
return df, model, embeddings | |
def transform(self, df, querry, model, embeddings): | |
"""main querry pipeline""" | |
# create_index | |
emb_querry = self._econde_querry(querry, model) | |
# search | |
result = self._make_search(df, emb_querry, embeddings) | |
# add_relevant_columns | |
df = self._add_relevant_columns(df, result) | |
return df, result | |
def _load(self): | |
# Load data | |
df = pd.read_csv("data/arxiv.csv") | |
return df | |
def _encode_papers(self, df): | |
# Encode the papers title | |
checkpoint = "distilbert-base-uncased" | |
model = SentenceTransformer(checkpoint) | |
embeddings = model.encode(df["title"], convert_to_tensor=True) | |
# embeddings column | |
df["embeddings"] = np.array(embeddings).tolist() | |
return df, model, embeddings | |
def _econde_querry(self, querry, model): | |
# Encode the querry | |
emb_querry = model.encode([querry]) | |
return emb_querry | |
def _make_search(self, df, emb_querry, embeddings): | |
"""search for nearest K neighbours in the embedding space""" | |
# initialize a new index, using a HNSW index on Cosine Similarity | |
index = nmslib.init(method="hnsw", space="cosinesimil") | |
index.addDataPointBatch(embeddings) | |
index.createIndex({"post": 2}, print_progress=True) | |
# search | |
result = self._extract_search_result(index, emb_querry, df, k=10) | |
return result | |
def _extract_search_result(self, index, emb_querry, df, k): | |
data = [] | |
idx, distances = index.knnQuery(emb_querry, k=k) | |
for i, j in zip(idx, distances): | |
data.append( | |
{ | |
"index": i, | |
"title": df.title[i], | |
"abstract": df.abstract[i], | |
"similarity": 1.0 - j, | |
} | |
) | |
return pd.DataFrame(data) | |
def _add_relevant_columns(self, df, result): | |
"""post processing""" | |
# get categories | |
df["categories_parsed"] = ( | |
df.categories.str.split() | |
.apply(lambda x: x[0]) | |
.str.split(".") | |
.apply(lambda x: x[0]) | |
) | |
# create columns for plotting | |
df["index_papers"] = df.index | |
df["selected"] = df.index_papers.apply(lambda x: x in list(result["index"])) | |
return df | |