inie2003 commited on
Commit
6b1bbaf
·
verified ·
1 Parent(s): 12e44e3

Updated helper with new loading function

Browse files
Files changed (1) hide show
  1. helper.py +19 -43
helper.py CHANGED
@@ -46,49 +46,24 @@ def encode_query(query: Union[str, Image.Image]) -> torch.Tensor:
46
 
47
  return query_embedding
48
 
49
- def load_hf_datasets(key,dataset):
50
  """
51
- Load Datasets from Hugging Face as DF
52
- ---------------------------------------
53
- dataset_name: str - name of dataset on Hugging Face
54
- ---------------------------------------
55
- RETURNS: dataset as pandas dataframe
56
  """
57
- df = dataset[key].to_pandas()
58
-
59
- return df
60
-
61
- def parallel_load_and_combine(dataset_keys, dataset):
62
- """
63
- Load datasets in parallel and combine Main and Split keys
64
- ----------------------------------------------------------
65
- dataset_keys: list - keys of the dataset (e.g., ['Main_1', 'Split_1', ...])
66
- dataset: DatasetDict - the loaded Hugging Face dataset
67
- ----------------------------------------------------------
68
- RETURNS: combined DataFrame from both Main and Split keys
69
- """
70
- # Separate keys into Main and Split lists
71
- main_keys = [key for key in dataset_keys if key.startswith('Main')]
72
- split_keys = [key for key in dataset_keys if key.startswith('Split')]
73
-
74
- def process_key(key, key_type):
75
- df = load_hf_datasets(key, dataset)
76
- return df
77
-
78
- # Parallel loading of Main keys
79
- with concurrent.futures.ThreadPoolExecutor() as executor:
80
- main_dfs = list(executor.map(lambda key: process_key(key, 'Main'), main_keys))
81
-
82
- # Parallel loading of Split keys
83
- with concurrent.futures.ThreadPoolExecutor() as executor:
84
- split_dfs = list(executor.map(lambda key: process_key(key, 'Split'), split_keys))
85
-
86
- # Combine Main DataFrames and Split DataFrames
87
- main_combined_df = pd.concat(main_dfs, ignore_index=True) if main_dfs else pd.DataFrame()
88
- split_combined_df = pd.concat(split_dfs, ignore_index=True) if split_dfs else pd.DataFrame()
89
-
90
-
91
- return main_combined_df, split_combined_df
92
 
93
  def get_image_vectors(df):
94
  # Get the image vectors from the dataframe
@@ -96,7 +71,7 @@ def get_image_vectors(df):
96
  return torch.tensor(image_vectors, dtype=torch.float32)
97
 
98
 
99
- def search(query, df, limit, offset, scoring_func, search_in_images):
100
  if search_in_images:
101
  # Encode the image query
102
  query_vector = encode_query(query)
@@ -266,7 +241,8 @@ def main():
266
  dataset = load_dataset(f"quasara-io/{dataset_name}")
267
  print('loaded dataset')
268
  dataset_keys = dataset.keys()
269
- main_df, split_df = parallel_load_and_combine(dataset_keys, dataset)
 
270
  #Now we get the coordinates and the stuff
271
  print('processed datasets')
272
  if search_in_small_objects:
 
46
 
47
  return query_embedding
48
 
49
+ def load_dataset_with_limit(dataset_name, dataset_subset, search_in_small_objects,limit=1000):
50
  """
51
+ Load a dataset from Hugging Face and limit the number of rows.
 
 
 
 
52
  """
53
+ if search_in_small_objects:
54
+ split = f'Splits_{dataset_subset}'
55
+ else:
56
+ split = f'Main_{dataset_subset}'
57
+ dataset_name = f"quasara-io/{dataset_name}"
58
+ dataset = load_dataset(dataset_name, split=split)
59
+ total_rows = dataset.num_rows
60
+ # Convert to DataFrame and sample if limit is provided
61
+ if limit is not None:
62
+ df = dataset.to_pandas().sample(n=limit, random_state=42)
63
+ else:
64
+ df = dataset.to_pandas()
65
+
66
+ return df,total_rows
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def get_image_vectors(df):
69
  # Get the image vectors from the dataframe
 
71
  return torch.tensor(image_vectors, dtype=torch.float32)
72
 
73
 
74
+ def search(query, df, limit, search_in_images = True):
75
  if search_in_images:
76
  # Encode the image query
77
  query_vector = encode_query(query)
 
241
  dataset = load_dataset(f"quasara-io/{dataset_name}")
242
  print('loaded dataset')
243
  dataset_keys = dataset.keys()
244
+ random_sample_size = 1000
245
+ main_df, split_df = parallel_load_and_combine(dataset_keys, dataset, n_rows=random_sample_size)
246
  #Now we get the coordinates and the stuff
247
  print('processed datasets')
248
  if search_in_small_objects: