smenon8 commited on
Commit
600d466
·
1 Parent(s): e592678

add nearest neighbor calculations

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +11 -2
  3. components/query_neighbor.py +33 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .venv/
2
  __pycache__/
 
 
1
  .venv/
2
  __pycache__/
3
+ .gradio/
app.py CHANGED
@@ -14,6 +14,7 @@ from torchvision import transforms
14
 
15
  from templates import openai_imagenet_template
16
  from components.query import get_sample
 
17
 
18
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
19
  logging.basicConfig(level=logging.INFO, format=log_format)
@@ -90,6 +91,10 @@ zero_shot_examples = [
90
  ],
91
  ]
92
 
 
 
 
 
93
 
94
  def indexed(lst, indices):
95
  return [lst[i] for i in indices]
@@ -146,6 +151,10 @@ def open_domain_classification(img, rank: int, return_all=False):
146
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
147
  probs = F.softmax(logits, dim=0)
148
 
 
 
 
 
149
  if rank + 1 == len(ranks):
150
  topk = probs.topk(k)
151
  prediction_dict = {
@@ -154,9 +163,9 @@ def open_domain_classification(img, rank: int, return_all=False):
154
  logger.info(f"Top K predictions: {prediction_dict}")
155
  top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
156
  logger.info(f"Top prediction name: {top_prediction_name}")
157
- sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
158
  if return_all:
159
- return prediction_dict, sample_img, taxon_url
160
  return prediction_dict
161
 
162
  output = collections.defaultdict(float)
 
14
 
15
  from templates import openai_imagenet_template
16
  from components.query import get_sample
17
+ from components.query_neighbor import QueryNeighbor
18
 
19
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
20
  logging.basicConfig(level=logging.INFO, format=log_format)
 
91
  ],
92
  ]
93
 
94
+ VECTOR_DB_PATH = "/Users/sreejithnoopur/codebase/bioclip-vector-db/vector_db"
95
+ query_neighbor = QueryNeighbor(vector_db = VECTOR_DB_PATH,
96
+ dataset_name = "BIRD")
97
+
98
 
99
  def indexed(lst, indices):
100
  return [lst[i] for i in indices]
 
151
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
152
  probs = F.softmax(logits, dim=0)
153
 
154
+ neighbor = str(query_neighbor.get_nearest_neighbor(img_features))
155
+ neighbor_image = query_neighbor.get_image(neighbor)
156
+ logger.info(f"Nearest neighbor: {neighbor}")
157
+
158
  if rank + 1 == len(ranks):
159
  topk = probs.topk(k)
160
  prediction_dict = {
 
163
  logger.info(f"Top K predictions: {prediction_dict}")
164
  top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
165
  logger.info(f"Top prediction name: {top_prediction_name}")
166
+ _, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
167
  if return_all:
168
+ return prediction_dict, neighbor_image, taxon_url
169
  return prediction_dict
170
 
171
  output = collections.defaultdict(float)
components/query_neighbor.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import chromadb
3
+ import boto3
4
+ import requests
5
+
6
+ from PIL import Image
7
+
8
+ S3_BUCKET = "tol-bird-dataset-test"
9
+
10
+ class QueryNeighbor:
11
+ def __init__(self, vector_db: str, dataset_name: str):
12
+ self._client = chromadb.PersistentClient(path=vector_db)
13
+ self._collection = self._client.get_collection(
14
+ name=dataset_name
15
+ )
16
+ self._s3_client = boto3.client("s3")
17
+
18
+
19
+ def get_nearest_neighbor(self, img_features) -> int:
20
+ ''' Returns the nearest neighbors for the given image features. '''
21
+ neighbors = self._collection.query(query_embeddings=[img_features[0].tolist()],
22
+ n_results = 2)
23
+ return neighbors["ids"][0][0]
24
+
25
+ def get_image(self, image_key: str):
26
+ ''' Returns the image for the given key. '''
27
+ img_src = self._s3_client.generate_presigned_url('get_object',
28
+ Params={'Bucket': S3_BUCKET,
29
+ 'Key': image_key}
30
+ )
31
+ img_resp = requests.get(img_src)
32
+ img = Image.open(io.BytesIO(img_resp.content))
33
+ return img