added search for release 2
Browse files- app.py +10 -4
- dataframe.py +12 -0
- main.py +25 -2
- model.py +60 -0
- setup.py +2 -2
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import streamlit as st
|
|
|
2 |
from main import *
|
3 |
from setup import *
|
4 |
|
@@ -26,7 +27,12 @@ downlad_images()
|
|
26 |
st.title('Find my pic!')
|
27 |
|
28 |
search_request = st.text_input('', 'Search ...')
|
29 |
-
|
30 |
-
if st.button('Find!'):
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
from main import *
|
4 |
from setup import *
|
5 |
|
|
|
27 |
st.title('Find my pic!')
|
28 |
|
29 |
search_request = st.text_input('', 'Search ...')
|
30 |
+
|
31 |
+
# if st.button('Find Relsease 1!'):
|
32 |
+
# search_result = search1(search_request)
|
33 |
+
# display(search_request, search_result)
|
34 |
+
|
35 |
+
if st.button('Find Relsease 2!'):
|
36 |
+
search_result = search2(search_request)
|
37 |
+
for item in search_result :
|
38 |
+
st.write(item)
|
dataframe.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def get_image_data() :
|
5 |
+
|
6 |
+
# flickr = pd.read_csv('data/results.csv', sep='|')
|
7 |
+
image_data_df = pd.read_csv ('data/output2.csv')
|
8 |
+
|
9 |
+
image_data_df['text_embeddings'] = image_data_df['text_embeddings'].apply(lambda x: np.fromstring(x[2:-2], sep=' ')).values
|
10 |
+
image_data_df['text_embeddings'] = image_data_df['text_embeddings'].apply(lambda x: np.reshape(x, (1, -1)))
|
11 |
+
|
12 |
+
return image_data_df
|
main.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
|
2 |
import random
|
|
|
|
|
|
|
|
|
3 |
|
4 |
images = ["Girl.jpg",
|
5 |
"Cat In Hat.jpg",
|
@@ -17,9 +21,28 @@ images = ["Girl.jpg",
|
|
17 |
|
18 |
|
19 |
|
20 |
-
def
|
21 |
"""
|
22 |
Given a search_prompt, return an array of pictures to display
|
23 |
"""
|
24 |
|
25 |
-
return [ (images[i], images[i].split('.')[0]) for i in random.sample(range(len(images)), 4) ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
import random
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from dataframe import *
|
6 |
+
from model import *
|
7 |
|
8 |
images = ["Girl.jpg",
|
9 |
"Cat In Hat.jpg",
|
|
|
21 |
|
22 |
|
23 |
|
24 |
+
def search1(search_prompt : str):
|
25 |
"""
|
26 |
Given a search_prompt, return an array of pictures to display
|
27 |
"""
|
28 |
|
29 |
+
return [ (images[i], images[i].split('.')[0]) for i in random.sample(range(len(images)), 4) ]
|
30 |
+
|
31 |
+
def search2(search_prompt : str) :
|
32 |
+
|
33 |
+
# Set the device
|
34 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
35 |
+
|
36 |
+
# Define the model ID
|
37 |
+
model_ID = "openai/clip-vit-base-patch32"
|
38 |
+
|
39 |
+
# Get model, processor & tokenizer
|
40 |
+
model, processor, tokenizer = get_model_info(model_ID, device)
|
41 |
+
|
42 |
+
image_data_df = get_image_data()
|
43 |
+
|
44 |
+
return get_top_N_images(search_prompt,
|
45 |
+
data = image_data_df,
|
46 |
+
model=model, tokenizer=tokenizer,
|
47 |
+
device = device,
|
48 |
+
top_K=4)
|
model.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
|
2 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
3 |
+
|
4 |
+
from dataframe import *
|
5 |
+
|
6 |
+
def get_model_info(model_ID, device):
|
7 |
+
# Save the model to device
|
8 |
+
model = CLIPModel.from_pretrained(model_ID).to(device)
|
9 |
+
|
10 |
+
# Get the processor
|
11 |
+
processor = CLIPProcessor.from_pretrained(model_ID)
|
12 |
+
|
13 |
+
# Get the tokenizer
|
14 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_ID)
|
15 |
+
|
16 |
+
# Return model, processor & tokenizer
|
17 |
+
return model, processor, tokenizer
|
18 |
+
|
19 |
+
|
20 |
+
def get_single_text_embedding(text, model, tokenizer, device):
|
21 |
+
inputs = tokenizer(text, return_tensors = "pt", max_length=77, truncation=True).to(device)
|
22 |
+
text_embeddings = model.get_text_features(**inputs)
|
23 |
+
# convert the embeddings to numpy array
|
24 |
+
embedding_as_np = text_embeddings.cpu().detach().numpy()
|
25 |
+
|
26 |
+
return embedding_as_np
|
27 |
+
|
28 |
+
def df_to_array(result_df) :
|
29 |
+
return [str(result_df['image_name'][i]) for i in range(len(result_df))]
|
30 |
+
|
31 |
+
def get_top_N_images(query,
|
32 |
+
data,
|
33 |
+
model, tokenizer,
|
34 |
+
device,
|
35 |
+
top_K=4,
|
36 |
+
search_criterion="text"):
|
37 |
+
# Text to image Search
|
38 |
+
if (search_criterion.lower() == "text"):
|
39 |
+
query_vect = get_single_text_embedding(query, model, tokenizer, device)
|
40 |
+
# # Image to image Search
|
41 |
+
# else:
|
42 |
+
# query_vect = get_single_image_embedding(query)
|
43 |
+
|
44 |
+
# Relevant columns
|
45 |
+
revevant_cols = ["comment", "image_name", "cos_sim"]
|
46 |
+
|
47 |
+
# Run similarity Search
|
48 |
+
data["cos_sim"] = data["text_embeddings"].apply(lambda x: cosine_similarity(query_vect, x))# line 17
|
49 |
+
data["cos_sim"] = data["cos_sim"].apply(lambda x: x[0][0])
|
50 |
+
|
51 |
+
data_sorted = data.sort_values(by='cos_sim', ascending=False)
|
52 |
+
non_repeated_images = ~data_sorted["image_name"].duplicated()
|
53 |
+
most_similar_articles = data_sorted[non_repeated_images].head(top_K)
|
54 |
+
|
55 |
+
"""
|
56 |
+
Retrieve top_K (4 is default value) articles similar to the query
|
57 |
+
"""
|
58 |
+
|
59 |
+
result_df = most_similar_articles[revevant_cols].reset_index()
|
60 |
+
return df_to_array(result_df)
|
setup.py
CHANGED
@@ -2,8 +2,8 @@
|
|
2 |
import os
|
3 |
import streamlit as st
|
4 |
|
5 |
-
from huggingface_hub import hf_hub_url, cached_download
|
6 |
-
from huggingface_hub.archive import unpack_archive
|
7 |
|
8 |
|
9 |
def downlad_images() :
|
|
|
2 |
import os
|
3 |
import streamlit as st
|
4 |
|
5 |
+
# from huggingface_hub import hf_hub_url, cached_download
|
6 |
+
# from huggingface_hub.archive import unpack_archive
|
7 |
|
8 |
|
9 |
def downlad_images() :
|