Trent commited on
Commit
a1fc7fb
1 Parent(s): 98e7562

Index locking

Browse files
Files changed (1) hide show
  1. utils.py +16 -1
utils.py CHANGED
@@ -8,8 +8,22 @@ from koclip import FlaxHybridCLIP
8
  from global_session import GlobalState
9
  from threading import Lock
10
 
11
- @st.cache(allow_output_mutation=True)
12
  def load_index(img_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  filenames, embeddings = [], []
14
  with open(img_file, "r") as f:
15
  for line in f:
@@ -37,6 +51,7 @@ def load_model(model_name="koclip/koclip-base"):
37
  print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.")
38
  return cached_model
39
 
 
40
  @st.cache(allow_output_mutation=True)
41
  def load_model_cached(model_name):
42
  assert model_name in {f"koclip/{model}" for model in MODEL_LIST}
 
8
  from global_session import GlobalState
9
  from threading import Lock
10
 
11
+
12
  def load_index(img_file):
13
+ state = GlobalState(img_file)
14
+ if not hasattr(state, '_lock'):
15
+ state._lock = Lock()
16
+ print(f"Locking loading of features : {img_file} to avoid concurrent caching.")
17
+
18
+ with state._lock:
19
+ cached_index = load_index_cached(img_file)
20
+
21
+ print(f"Unlocking loading of model : {img_file} to avoid concurrent caching.")
22
+ return cached_index
23
+
24
+
25
+ @st.cache(allow_output_mutation=True)
26
+ def load_index_cached(img_file):
27
  filenames, embeddings = [], []
28
  with open(img_file, "r") as f:
29
  for line in f:
 
51
  print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.")
52
  return cached_model
53
 
54
+
55
  @st.cache(allow_output_mutation=True)
56
  def load_model_cached(model_name):
57
  assert model_name in {f"koclip/{model}" for model in MODEL_LIST}