PicFinder / model.py
osanchik's picture
added search for release 2
63d858e
raw
history blame
2.08 kB
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from sklearn.metrics.pairwise import cosine_similarity
from dataframe import *
def get_model_info(model_ID, device):
# Save the model to device
model = CLIPModel.from_pretrained(model_ID).to(device)
# Get the processor
processor = CLIPProcessor.from_pretrained(model_ID)
# Get the tokenizer
tokenizer = CLIPTokenizer.from_pretrained(model_ID)
# Return model, processor & tokenizer
return model, processor, tokenizer
def get_single_text_embedding(text, model, tokenizer, device):
inputs = tokenizer(text, return_tensors = "pt", max_length=77, truncation=True).to(device)
text_embeddings = model.get_text_features(**inputs)
# convert the embeddings to numpy array
embedding_as_np = text_embeddings.cpu().detach().numpy()
return embedding_as_np
def df_to_array(result_df) :
return [str(result_df['image_name'][i]) for i in range(len(result_df))]
def get_top_N_images(query,
data,
model, tokenizer,
device,
top_K=4,
search_criterion="text"):
# Text to image Search
if (search_criterion.lower() == "text"):
query_vect = get_single_text_embedding(query, model, tokenizer, device)
# # Image to image Search
# else:
# query_vect = get_single_image_embedding(query)
# Relevant columns
revevant_cols = ["comment", "image_name", "cos_sim"]
# Run similarity Search
data["cos_sim"] = data["text_embeddings"].apply(lambda x: cosine_similarity(query_vect, x))# line 17
data["cos_sim"] = data["cos_sim"].apply(lambda x: x[0][0])
data_sorted = data.sort_values(by='cos_sim', ascending=False)
non_repeated_images = ~data_sorted["image_name"].duplicated()
most_similar_articles = data_sorted[non_repeated_images].head(top_K)
"""
Retrieve top_K (4 is default value) articles similar to the query
"""
result_df = most_similar_articles[revevant_cols].reset_index()
return df_to_array(result_df)