Spaces:
Running
Running
chore: support embeddings caching
Browse files
lightweight_embeddings/analytics.py
CHANGED
@@ -19,8 +19,14 @@ class Analytics:
|
|
19 |
- redis_url: Redis connection URL (e.g., 'redis://localhost:6379/0')
|
20 |
- sync_interval: Interval in seconds for syncing with Redis.
|
21 |
"""
|
22 |
-
self.
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
self.local_buffer = {
|
25 |
"access": defaultdict(
|
26 |
lambda: defaultdict(int)
|
@@ -122,5 +128,4 @@ class Analytics:
|
|
122 |
await self._sync_to_redis()
|
123 |
except redis.exceptions.ConnectionError as e:
|
124 |
logger.error("Redis connection error: %s", e)
|
125 |
-
self.pool.disconnect() # force reconnect on next request
|
126 |
await asyncio.sleep(5)
|
|
|
19 |
- redis_url: Redis connection URL (e.g., 'redis://localhost:6379/0')
|
20 |
- sync_interval: Interval in seconds for syncing with Redis.
|
21 |
"""
|
22 |
+
self.redis_client = redis.from_url(
|
23 |
+
redis_url,
|
24 |
+
decode_responses=True,
|
25 |
+
health_check_interval=10,
|
26 |
+
socket_connect_timeout=5,
|
27 |
+
retry_on_timeout=True,
|
28 |
+
socket_keepalive=True,
|
29 |
+
)
|
30 |
self.local_buffer = {
|
31 |
"access": defaultdict(
|
32 |
lambda: defaultdict(int)
|
|
|
128 |
await self._sync_to_redis()
|
129 |
except redis.exceptions.ConnectionError as e:
|
130 |
logger.error("Redis connection error: %s", e)
|
|
|
131 |
await asyncio.sleep(5)
|
lightweight_embeddings/service.py
CHANGED
@@ -32,6 +32,8 @@ from typing import List, Union, Literal, Dict, Optional, NamedTuple, Any
|
|
32 |
from dataclasses import dataclass
|
33 |
from pathlib import Path
|
34 |
from io import BytesIO
|
|
|
|
|
35 |
|
36 |
import requests
|
37 |
import numpy as np
|
@@ -153,6 +155,8 @@ class EmbeddingsService:
|
|
153 |
"""
|
154 |
|
155 |
def __init__(self, config: Optional[ModelConfig] = None):
|
|
|
|
|
156 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
157 |
self.config = config or ModelConfig()
|
158 |
|
@@ -262,11 +266,19 @@ class EmbeddingsService:
|
|
262 |
|
263 |
def _generate_text_embeddings(self, texts: List[str]) -> np.ndarray:
|
264 |
"""
|
265 |
-
Generate text embeddings using the currently configured text model
|
|
|
266 |
"""
|
267 |
try:
|
|
|
|
|
|
|
|
|
268 |
model = self.text_models[self.config.text_model_type]
|
269 |
-
embeddings = model.encode(texts)
|
|
|
|
|
|
|
270 |
return embeddings
|
271 |
except Exception as e:
|
272 |
raise RuntimeError(
|
|
|
32 |
from dataclasses import dataclass
|
33 |
from pathlib import Path
|
34 |
from io import BytesIO
|
35 |
+
from hashlib import md5
|
36 |
+
from cachetools import LRUCache
|
37 |
|
38 |
import requests
|
39 |
import numpy as np
|
|
|
155 |
"""
|
156 |
|
157 |
def __init__(self, config: Optional[ModelConfig] = None):
|
158 |
+
self.lru_cache = LRUCache(maxsize=50_000) # Approximate for ~500MB usage
|
159 |
+
|
160 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
161 |
self.config = config or ModelConfig()
|
162 |
|
|
|
266 |
|
267 |
def _generate_text_embeddings(self, texts: List[str]) -> np.ndarray:
|
268 |
"""
|
269 |
+
Generate text embeddings using the currently configured text model
|
270 |
+
with an LRU cache for single-text requests.
|
271 |
"""
|
272 |
try:
|
273 |
+
if len(texts) == 1:
|
274 |
+
key = md5(texts[0].encode("utf-8")).hexdigest()
|
275 |
+
if key in self.lru_cache:
|
276 |
+
return self.lru_cache[key]
|
277 |
model = self.text_models[self.config.text_model_type]
|
278 |
+
embeddings = model.encode(texts)
|
279 |
+
|
280 |
+
if len(texts) == 1:
|
281 |
+
self.lru_cache[key] = embeddings
|
282 |
return embeddings
|
283 |
except Exception as e:
|
284 |
raise RuntimeError(
|