osanchik commited on
Commit
63d858e
·
1 Parent(s): 6e7628d

added search for release 2

Browse files
Files changed (5) hide show
  1. app.py +10 -4
  2. dataframe.py +12 -0
  3. main.py +25 -2
  4. model.py +60 -0
  5. setup.py +2 -2
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import streamlit as st
 
2
  from main import *
3
  from setup import *
4
 
@@ -26,7 +27,12 @@ downlad_images()
26
  st.title('Find my pic!')
27
 
28
  search_request = st.text_input('', 'Search ...')
29
-
30
- if st.button('Find!'):
31
- search_result = search(search_request)
32
- display(search_request, search_result)
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
  from main import *
4
  from setup import *
5
 
 
27
  st.title('Find my pic!')
28
 
29
  search_request = st.text_input('', 'Search ...')
30
+
31
+ # if st.button('Find Relsease 1!'):
32
+ # search_result = search1(search_request)
33
+ # display(search_request, search_result)
34
+
35
+ if st.button('Find Relsease 2!'):
36
+ search_result = search2(search_request)
37
+ for item in search_result :
38
+ st.write(item)
dataframe.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ def get_image_data() :
5
+
6
+ # flickr = pd.read_csv('data/results.csv', sep='|')
7
+ image_data_df = pd.read_csv ('data/output2.csv')
8
+
9
+ image_data_df['text_embeddings'] = image_data_df['text_embeddings'].apply(lambda x: np.fromstring(x[2:-2], sep=' ')).values
10
+ image_data_df['text_embeddings'] = image_data_df['text_embeddings'].apply(lambda x: np.reshape(x, (1, -1)))
11
+
12
+ return image_data_df
main.py CHANGED
@@ -1,5 +1,9 @@
1
 
2
  import random
 
 
 
 
3
 
4
  images = ["Girl.jpg",
5
  "Cat In Hat.jpg",
@@ -17,9 +21,28 @@ images = ["Girl.jpg",
17
 
18
 
19
 
20
- def search(search_prompt : str):
21
  """
22
  Given a search_prompt, return an array of pictures to display
23
  """
24
 
25
- return [ (images[i], images[i].split('.')[0]) for i in random.sample(range(len(images)), 4) ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import random
3
+ import torch
4
+
5
+ from dataframe import *
6
+ from model import *
7
 
8
  images = ["Girl.jpg",
9
  "Cat In Hat.jpg",
 
21
 
22
 
23
 
24
+ def search1(search_prompt : str):
25
  """
26
  Given a search_prompt, return an array of pictures to display
27
  """
28
 
29
+ return [ (images[i], images[i].split('.')[0]) for i in random.sample(range(len(images)), 4) ]
30
+
31
+ def search2(search_prompt : str) :
32
+
33
+ # Set the device
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ # Define the model ID
37
+ model_ID = "openai/clip-vit-base-patch32"
38
+
39
+ # Get model, processor & tokenizer
40
+ model, processor, tokenizer = get_model_info(model_ID, device)
41
+
42
+ image_data_df = get_image_data()
43
+
44
+ return get_top_N_images(search_prompt,
45
+ data = image_data_df,
46
+ model=model, tokenizer=tokenizer,
47
+ device = device,
48
+ top_K=4)
model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+
4
+ from dataframe import *
5
+
6
+ def get_model_info(model_ID, device):
7
+ # Save the model to device
8
+ model = CLIPModel.from_pretrained(model_ID).to(device)
9
+
10
+ # Get the processor
11
+ processor = CLIPProcessor.from_pretrained(model_ID)
12
+
13
+ # Get the tokenizer
14
+ tokenizer = CLIPTokenizer.from_pretrained(model_ID)
15
+
16
+ # Return model, processor & tokenizer
17
+ return model, processor, tokenizer
18
+
19
+
20
+ def get_single_text_embedding(text, model, tokenizer, device):
21
+ inputs = tokenizer(text, return_tensors = "pt", max_length=77, truncation=True).to(device)
22
+ text_embeddings = model.get_text_features(**inputs)
23
+ # convert the embeddings to numpy array
24
+ embedding_as_np = text_embeddings.cpu().detach().numpy()
25
+
26
+ return embedding_as_np
27
+
28
+ def df_to_array(result_df) :
29
+ return [str(result_df['image_name'][i]) for i in range(len(result_df))]
30
+
31
+ def get_top_N_images(query,
32
+ data,
33
+ model, tokenizer,
34
+ device,
35
+ top_K=4,
36
+ search_criterion="text"):
37
+ # Text to image Search
38
+ if (search_criterion.lower() == "text"):
39
+ query_vect = get_single_text_embedding(query, model, tokenizer, device)
40
+ # # Image to image Search
41
+ # else:
42
+ # query_vect = get_single_image_embedding(query)
43
+
44
+ # Relevant columns
45
+ revevant_cols = ["comment", "image_name", "cos_sim"]
46
+
47
+ # Run similarity Search
48
+ data["cos_sim"] = data["text_embeddings"].apply(lambda x: cosine_similarity(query_vect, x))# line 17
49
+ data["cos_sim"] = data["cos_sim"].apply(lambda x: x[0][0])
50
+
51
+ data_sorted = data.sort_values(by='cos_sim', ascending=False)
52
+ non_repeated_images = ~data_sorted["image_name"].duplicated()
53
+ most_similar_articles = data_sorted[non_repeated_images].head(top_K)
54
+
55
+ """
56
+ Retrieve top_K (4 is default value) articles similar to the query
57
+ """
58
+
59
+ result_df = most_similar_articles[revevant_cols].reset_index()
60
+ return df_to_array(result_df)
setup.py CHANGED
@@ -2,8 +2,8 @@
2
  import os
3
  import streamlit as st
4
 
5
- from huggingface_hub import hf_hub_url, cached_download
6
- from huggingface_hub.archive import unpack_archive
7
 
8
 
9
  def downlad_images() :
 
2
  import os
3
  import streamlit as st
4
 
5
+ # from huggingface_hub import hf_hub_url, cached_download
6
+ # from huggingface_hub.archive import unpack_archive
7
 
8
 
9
  def downlad_images() :