smenon8 commited on
Commit
50f1a2f
·
1 Parent(s): 600d466

Add ability to fetch a dataset of vectordb from hf hub

Browse files
Files changed (2) hide show
  1. app.py +1 -3
  2. components/query_neighbor.py +44 -2
app.py CHANGED
@@ -91,9 +91,7 @@ zero_shot_examples = [
91
  ],
92
  ]
93
 
94
- VECTOR_DB_PATH = "/Users/sreejithnoopur/codebase/bioclip-vector-db/vector_db"
95
- query_neighbor = QueryNeighbor(vector_db = VECTOR_DB_PATH,
96
- dataset_name = "BIRD")
97
 
98
 
99
  def indexed(lst, indices):
 
91
  ],
92
  ]
93
 
94
+ query_neighbor = QueryNeighbor(dataset_name = "BIRD")
 
 
95
 
96
 
97
  def indexed(lst, indices):
components/query_neighbor.py CHANGED
@@ -1,15 +1,57 @@
1
  import io
 
2
  import chromadb
3
  import boto3
4
  import requests
 
5
 
6
  from PIL import Image
 
 
 
 
 
 
7
 
8
  S3_BUCKET = "tol-bird-dataset-test"
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class QueryNeighbor:
11
- def __init__(self, vector_db: str, dataset_name: str):
12
- self._client = chromadb.PersistentClient(path=vector_db)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  self._collection = self._client.get_collection(
14
  name=dataset_name
15
  )
 
1
  import io
2
+ import os
3
  import chromadb
4
  import boto3
5
  import requests
6
+ import logging
7
 
8
  from PIL import Image
9
+ from huggingface_hub import snapshot_download
10
+ from dataclasses import dataclass
11
+
12
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
13
+ logging.basicConfig(level=logging.INFO, format=log_format)
14
+ logger = logging.getLogger()
15
 
16
  S3_BUCKET = "tol-bird-dataset-test"
17
 
18
+ @dataclass
19
+ class VectorDataset:
20
+ dataset_name: str
21
+ hf_dataset_path: str
22
+ relative_vector_db_path: str
23
+
24
+ _SUPPORTED_DATASETS = {
25
+ "BIRD": VectorDataset(
26
+ dataset_name="BIRD",
27
+ hf_dataset_path="imageomics/bird-dataset-vector",
28
+ relative_vector_db_path="bird_vector_db"
29
+ ),
30
+ }
31
+
32
+
33
  class QueryNeighbor:
34
+ """
35
+ Class to query the nearest neighbor for a given image feature vector.
36
+ It uses a vector database to find the nearest neighbor and retrieves the image from S3.
37
+ The class is initialized with the vector database path and the dataset name.
38
+ The vector database is downloaded from Hugging Face Hub and stored in a local cache.
39
+ The class uses the chromadb library to interact with the vector database and boto3 to interact with S3.
40
+ """
41
+ def __init__(self, dataset_name: str):
42
+ logger.info("Initializing QueryNeighbor")
43
+ vector_dataset = _SUPPORTED_DATASETS.get(dataset_name)
44
+ if vector_dataset is None:
45
+ raise ValueError(f"Unsupported dataset: {dataset_name}")
46
+
47
+ vector_db_path = snapshot_download(
48
+ repo_id=vector_dataset.hf_dataset_path,
49
+ repo_type="dataset"
50
+ )
51
+ logger.info(f"Vector DB cache: {vector_db_path}")
52
+ self._client = chromadb.PersistentClient(
53
+ path=os.path.join(vector_db_path,
54
+ vector_dataset.relative_vector_db_path))
55
  self._collection = self._client.get_collection(
56
  name=dataset_name
57
  )