modified faiss
Browse files
model.py
CHANGED
@@ -75,16 +75,12 @@ def faiss_add_index_cos(df, column):
|
|
75 |
|
76 |
# Create an index
|
77 |
index = faiss.IndexFlatIP(embeddings.shape[1])
|
78 |
-
|
79 |
-
faiss.normalize_L2(embeddings)
|
80 |
-
print("<<<<faiss_ after normalize")
|
81 |
|
82 |
index.train(embeddings)
|
83 |
-
print("<<<<faiss_ after index.train")
|
84 |
|
85 |
# Add the embeddings to the index
|
86 |
index.add(embeddings)
|
87 |
-
print("<<<<faiss_add")
|
88 |
|
89 |
# Return the index
|
90 |
return index
|
@@ -100,7 +96,7 @@ def faiss_get_top_N_images(query,
|
|
100 |
model, tokenizer,
|
101 |
device)
|
102 |
# Relevant columns
|
103 |
-
relevant_cols = ["comment", "image_name"
|
104 |
|
105 |
#faiss search with cos similarity
|
106 |
index = faiss_add_index_cos(data, column="text_embeddings")
|
@@ -113,5 +109,7 @@ def faiss_get_top_N_images(query,
|
|
113 |
non_repeated_images = ~data_sorted["image_name"].duplicated()
|
114 |
most_similar_articles = data_sorted[non_repeated_images].head(top_K)
|
115 |
|
116 |
-
result_df = most_similar_articles[relevant_cols].reset_index()
|
|
|
|
|
117 |
return [get_item_data(result_df, i, 'similarity') for i in range(len(result_df))]
|
|
|
75 |
|
76 |
# Create an index
|
77 |
index = faiss.IndexFlatIP(embeddings.shape[1])
|
78 |
+
faiss.normalize_L2(embeddings)
|
|
|
|
|
79 |
|
80 |
index.train(embeddings)
|
|
|
81 |
|
82 |
# Add the embeddings to the index
|
83 |
index.add(embeddings)
|
|
|
84 |
|
85 |
# Return the index
|
86 |
return index
|
|
|
96 |
model, tokenizer,
|
97 |
device)
|
98 |
# Relevant columns
|
99 |
+
relevant_cols = ["comment", "image_name"]
|
100 |
|
101 |
#faiss search with cos similarity
|
102 |
index = faiss_add_index_cos(data, column="text_embeddings")
|
|
|
109 |
non_repeated_images = ~data_sorted["image_name"].duplicated()
|
110 |
most_similar_articles = data_sorted[non_repeated_images].head(top_K)
|
111 |
|
112 |
+
result_df = most_similar_articles[relevant_cols].reset_index()
|
113 |
+
D = D.reshape(-1,1)[:top_K]
|
114 |
+
result_df = pd.concat([result_df, pd.DataFrame(D, columns=['similarity'])], axis=1)
|
115 |
return [get_item_data(result_df, i, 'similarity') for i in range(len(result_df))]
|