lamhieu commited on
Commit
aaf7e4c
·
1 Parent(s): c54a701

chore: support embeddings caching

Browse files
lightweight_embeddings/analytics.py CHANGED
@@ -19,8 +19,14 @@ class Analytics:
19
  - redis_url: Redis connection URL (e.g., 'redis://localhost:6379/0')
20
  - sync_interval: Interval in seconds for syncing with Redis.
21
  """
22
- self.pool = redis.ConnectionPool.from_url(redis_url, decode_responses=True)
23
- self.redis_client = redis.Redis(connection_pool=self.pool)
 
 
 
 
 
 
24
  self.local_buffer = {
25
  "access": defaultdict(
26
  lambda: defaultdict(int)
@@ -122,5 +128,4 @@ class Analytics:
122
  await self._sync_to_redis()
123
  except redis.exceptions.ConnectionError as e:
124
  logger.error("Redis connection error: %s", e)
125
- self.pool.disconnect() # force reconnect on next request
126
  await asyncio.sleep(5)
 
19
  - redis_url: Redis connection URL (e.g., 'redis://localhost:6379/0')
20
  - sync_interval: Interval in seconds for syncing with Redis.
21
  """
22
+ self.redis_client = redis.from_url(
23
+ redis_url,
24
+ decode_responses=True,
25
+ health_check_interval=10,
26
+ socket_connect_timeout=5,
27
+ retry_on_timeout=True,
28
+ socket_keepalive=True,
29
+ )
30
  self.local_buffer = {
31
  "access": defaultdict(
32
  lambda: defaultdict(int)
 
128
  await self._sync_to_redis()
129
  except redis.exceptions.ConnectionError as e:
130
  logger.error("Redis connection error: %s", e)
 
131
  await asyncio.sleep(5)
lightweight_embeddings/service.py CHANGED
@@ -32,6 +32,8 @@ from typing import List, Union, Literal, Dict, Optional, NamedTuple, Any
32
  from dataclasses import dataclass
33
  from pathlib import Path
34
  from io import BytesIO
 
 
35
 
36
  import requests
37
  import numpy as np
@@ -153,6 +155,8 @@ class EmbeddingsService:
153
  """
154
 
155
  def __init__(self, config: Optional[ModelConfig] = None):
 
 
156
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
157
  self.config = config or ModelConfig()
158
 
@@ -262,11 +266,19 @@ class EmbeddingsService:
262
 
263
  def _generate_text_embeddings(self, texts: List[str]) -> np.ndarray:
264
  """
265
- Generate text embeddings using the currently configured text model.
 
266
  """
267
  try:
 
 
 
 
268
  model = self.text_models[self.config.text_model_type]
269
- embeddings = model.encode(texts) # shape: (num_items, emb_dim)
 
 
 
270
  return embeddings
271
  except Exception as e:
272
  raise RuntimeError(
 
32
  from dataclasses import dataclass
33
  from pathlib import Path
34
  from io import BytesIO
35
+ from hashlib import md5
36
+ from cachetools import LRUCache
37
 
38
  import requests
39
  import numpy as np
 
155
  """
156
 
157
  def __init__(self, config: Optional[ModelConfig] = None):
158
+ self.lru_cache = LRUCache(maxsize=50_000) # Approximate for ~500MB usage
159
+
160
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
161
  self.config = config or ModelConfig()
162
 
 
266
 
267
  def _generate_text_embeddings(self, texts: List[str]) -> np.ndarray:
268
  """
269
+ Generate text embeddings using the currently configured text model
270
+ with an LRU cache for single-text requests.
271
  """
272
  try:
273
+ if len(texts) == 1:
274
+ key = md5(texts[0].encode("utf-8")).hexdigest()
275
+ if key in self.lru_cache:
276
+ return self.lru_cache[key]
277
  model = self.text_models[self.config.text_model_type]
278
+ embeddings = model.encode(texts)
279
+
280
+ if len(texts) == 1:
281
+ self.lru_cache[key] = embeddings
282
  return embeddings
283
  except Exception as e:
284
  raise RuntimeError(