|
import time |
|
import re |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoModel |
|
from tokenizers import Tokenizer, AddedToken |
|
import streamlit as st |
|
from st_click_detector import click_detector |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_params(): |
|
try: |
|
print("update1") |
|
|
|
except ValueError: |
|
pass |
|
|
|
|
|
|
|
options = ["artificial intelligence", "robot", "VR", "medicine", "genomics", "cure", "heal", "brain", "support", "friendship", "memory", "aging", "pharma", "virus", "nurse", "doctor", "therapist", "nutrition", "technology", "computer", "software", "neuroscience", "birth", "death", "soul", "space", "sci-fi"] |
|
|
|
query_params = st.experimental_get_query_params() |
|
ix = 0 |
|
if query_params: |
|
try: |
|
q0 = query_params['query'][0] |
|
ix = options.index(q0) |
|
except ValueError: |
|
pass |
|
selected_option = st.radio( |
|
"Param", options, index=ix, key="query", on_change=update_params |
|
) |
|
st.write("<style>div.row-widget.stRadio > div{flex-direction:row;}</style>", unsafe_allow_html=True) |
|
|
|
|
|
st.experimental_set_query_params(option=selected_option) |
|
|
|
try: |
|
st.session_state.query = query |
|
except: |
|
print("Error cant set after init") |
|
|
|
|
|
|
|
|
|
if 'query' not in st.session_state: |
|
|
|
query = st.text_input("", value="artificial intelligence", key="query") |
|
|
|
|
|
else: |
|
query = st.text_input("", value=st.session_state["query"], key="query") |
|
try: |
|
query_params = st.experimental_get_query_params() |
|
query_option = query_params['query'][0] |
|
option_selected = st.sidebar.selectbox('Pick option', options, index=options.index(query_option)) |
|
except: |
|
st.experimental_set_query_params(query="health") |
|
query_params = st.experimental_get_query_params() |
|
query_option = query_params['query'][0] |
|
query_option = "ai" |
|
|
|
DEVICE = "cpu" |
|
MODEL_OPTIONS = ["msmarco-distilbert-base-tas-b", "all-mpnet-base-v2"] |
|
DESCRIPTION = """ |
|
# Semantic search |
|
**Enter your query and hit enter** |
|
Built with π€ Hugging Face's [transformers](https://huggingface.co/transformers/) library, [SentenceBert](https://www.sbert.net/) models, [Streamlit](https://streamlit.io/) and 44k movie descriptions from the Kaggle [Movies Dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) |
|
""" |
|
|
|
|
|
if 'key' not in st.session_state: |
|
st.session_state['key'] = 'value' |
|
if 'key' not in st.session_state: |
|
st.session_state.key = 'value' |
|
st.write(st.session_state.key) |
|
st.write(st.session_state) |
|
|
|
|
|
for key in st.session_state.keys(): |
|
del st.session_state[key] |
|
|
|
|
|
|
|
@st.cache( |
|
show_spinner=False, |
|
hash_funcs={ |
|
AutoModel: lambda _: None, |
|
AutoTokenizer: lambda _: None, |
|
dict: lambda _: None, |
|
}, |
|
) |
|
def load(): |
|
models, tokenizers, embeddings = [], [], [] |
|
for model_option in MODEL_OPTIONS: |
|
tokenizers.append( |
|
AutoTokenizer.from_pretrained(f"sentence-transformers/{model_option}") |
|
) |
|
models.append( |
|
AutoModel.from_pretrained(f"sentence-transformers/{model_option}").to( |
|
DEVICE |
|
) |
|
) |
|
embeddings.append(np.load("embeddings.npy")) |
|
embeddings.append(np.load("embeddings2.npy")) |
|
df = pd.read_csv("movies.csv") |
|
return tokenizers, models, embeddings, df |
|
|
|
tokenizers, models, embeddings, df = load() |
|
def pooling(model_output): |
|
return model_output.last_hidden_state[:, 0] |
|
|
|
def compute_embeddings(texts): |
|
encoded_input = tokenizers[0]( |
|
texts, padding=True, truncation=True, return_tensors="pt" |
|
).to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
model_output = models[0](**encoded_input, return_dict=True) |
|
|
|
embeddings = pooling(model_output) |
|
return embeddings.cpu().numpy() |
|
|
|
def pooling2(model_output, attention_mask): |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
|
input_mask_expanded.sum(1), min=1e-9 |
|
) |
|
|
|
def compute_embeddings2(list_of_strings): |
|
encoded_input = tokenizers[1]( |
|
list_of_strings, padding=True, truncation=True, return_tensors="pt" |
|
).to(DEVICE) |
|
with torch.no_grad(): |
|
model_output = models[1](**encoded_input) |
|
sentence_embeddings = pooling2(model_output, encoded_input["attention_mask"]) |
|
return F.normalize(sentence_embeddings, p=2, dim=1).cpu().numpy() |
|
|
|
@st.cache( |
|
show_spinner=False, |
|
hash_funcs={Tokenizer: lambda _: None, AddedToken: lambda _: None}, |
|
) |
|
def semantic_search(query, model_id): |
|
start = time.time() |
|
if len(query.strip()) == 0: |
|
return "" |
|
if "[Similar:" not in query: |
|
if model_id == 0: |
|
query_embedding = compute_embeddings([query]) |
|
else: |
|
query_embedding = compute_embeddings2([query]) |
|
else: |
|
match = re.match(r"\[Similar:(\d{1,5}).*", query) |
|
if match: |
|
idx = int(match.groups()[0]) |
|
query_embedding = embeddings[model_id][idx : idx + 1, :] |
|
if query_embedding.shape[0] == 0: |
|
return "" |
|
else: |
|
return "" |
|
indices = np.argsort(embeddings[model_id] @ np.transpose(query_embedding)[:, 0])[ |
|
-1:-11:-1 |
|
] |
|
if len(indices) == 0: |
|
return "" |
|
result = "<ol>" |
|
for i in indices: |
|
result += f"<li style='padding-top: 10px'><b>{df.iloc[i].title}</b> ({df.iloc[i].release_date}). {df.iloc[i].overview} " |
|
|
|
|
|
delay = "%.3f" % (time.time() - start) |
|
return f"<p><i>Computation time: {delay} seconds</i></p>{result}</ol>" |
|
|
|
st.sidebar.markdown(DESCRIPTION) |
|
|
|
model_choice = st.sidebar.selectbox("Similarity model", options=MODEL_OPTIONS) |
|
model_id = 0 if model_choice == MODEL_OPTIONS[0] else 1 |
|
|
|
clicked = click_detector(semantic_search(query, model_id)) |
|
|
|
if clicked != "": |
|
st.markdown(clicked) |
|
change_query = False |
|
if "last_clicked" not in st.session_state: |
|
st.session_state["last_clicked"] = clicked |
|
change_query = True |
|
else: |
|
if clicked != st.session_state["last_clicked"]: |
|
st.session_state["last_clicked"] = clicked |
|
change_query = True |
|
if change_query: |
|
st.session_state["query"] = f"[Similar:{clicked}] {df.iloc[int(clicked)].title}" |
|
st.experimental_rerun() |
|
|