File size: 2,079 Bytes
63d858e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from sklearn.metrics.pairwise import cosine_similarity

from dataframe  import *

def get_model_info(model_ID, device):
    # Save the model to device
	model = CLIPModel.from_pretrained(model_ID).to(device)
        
 	# Get the processor
	processor = CLIPProcessor.from_pretrained(model_ID)
        
    # Get the tokenizer
	tokenizer = CLIPTokenizer.from_pretrained(model_ID)

    # Return model, processor & tokenizer
	return model, processor, tokenizer


def get_single_text_embedding(text, model, tokenizer, device):
    inputs = tokenizer(text, return_tensors = "pt", max_length=77, truncation=True).to(device)
    text_embeddings = model.get_text_features(**inputs)
    # convert the embeddings to numpy array
    embedding_as_np = text_embeddings.cpu().detach().numpy()

    return embedding_as_np

def df_to_array(result_df) :
    return [str(result_df['image_name'][i]) for i in range(len(result_df))]

def get_top_N_images(query,
                     data,
                     model, tokenizer,
                     device,
                     top_K=4, 
                     search_criterion="text"):
   # Text to image Search
   if (search_criterion.lower() == "text"):
     query_vect = get_single_text_embedding(query, model, tokenizer, device)
#    # Image to image Search
#    else:
#      query_vect = get_single_image_embedding(query)

   # Relevant columns
   revevant_cols = ["comment", "image_name", "cos_sim"]

   # Run similarity Search
   data["cos_sim"] = data["text_embeddings"].apply(lambda x: cosine_similarity(query_vect, x))# line 17
   data["cos_sim"] = data["cos_sim"].apply(lambda x: x[0][0])

   data_sorted = data.sort_values(by='cos_sim', ascending=False)
   non_repeated_images = ~data_sorted["image_name"].duplicated()
   most_similar_articles = data_sorted[non_repeated_images].head(top_K)

   """
   Retrieve top_K (4 is default value) articles similar to the query
   """

   result_df = most_similar_articles[revevant_cols].reset_index()
   return df_to_array(result_df)