lamhieu commited on
Commit
9604bdd
·
1 Parent(s): 422dfa9

chore: update something

Browse files
Files changed (1) hide show
  1. lightweight_embeddings/service.py +40 -74
lightweight_embeddings/service.py CHANGED
@@ -4,6 +4,7 @@ from __future__ import annotations
4
 
5
  import asyncio
6
  import logging
 
7
  from enum import Enum
8
  from typing import List, Union, Dict, Optional, NamedTuple, Any
9
  from dataclasses import dataclass
@@ -27,7 +28,6 @@ class TextModelType(str, Enum):
27
  """
28
  Enumeration of supported text models.
29
  """
30
-
31
  MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
32
  MULTILINGUAL_E5_BASE = "multilingual-e5-base"
33
  MULTILINGUAL_E5_LARGE = "multilingual-e5-large"
@@ -42,7 +42,6 @@ class ImageModelType(str, Enum):
42
  """
43
  Enumeration of supported image models.
44
  """
45
-
46
  SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
47
 
48
 
@@ -50,7 +49,6 @@ class MaxModelLength(str, Enum):
50
  """
51
  Enumeration of maximum token lengths for supported text models.
52
  """
53
-
54
  MULTILINGUAL_E5_SMALL = 512
55
  MULTILINGUAL_E5_BASE = 512
56
  MULTILINGUAL_E5_LARGE = 512
@@ -65,7 +63,6 @@ class ModelInfo(NamedTuple):
65
  """
66
  Container mapping a model type to its model identifier and optional ONNX file.
67
  """
68
-
69
  model_id: str
70
  onnx_file: Optional[str] = None
71
 
@@ -75,11 +72,8 @@ class ModelConfig:
75
  """
76
  Configuration for text and image models.
77
  """
78
-
79
  text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL
80
- image_model_type: ImageModelType = (
81
- ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL
82
- )
83
  logit_scale: float = 4.60517 # Example scale used in cross-modal similarity
84
 
85
  @property
@@ -140,7 +134,6 @@ class ModelKind(str, Enum):
140
  """
141
  Indicates the type of model: text or image.
142
  """
143
-
144
  TEXT = "text"
145
  IMAGE = "image"
146
 
@@ -184,6 +177,11 @@ class EmbeddingsService:
184
  self.image_models: Dict[ImageModelType, AutoModel] = {}
185
  self.image_processors: Dict[ImageModelType, AutoProcessor] = {}
186
 
 
 
 
 
 
187
  # Create a persistent asynchronous HTTP client.
188
  self.async_http_client = httpx.AsyncClient(timeout=10)
189
 
@@ -220,17 +218,11 @@ class EmbeddingsService:
220
  # Set maximum sequence length based on configuration.
221
  max_length = int(MaxModelLength[t_model_type.name].value)
222
  self.text_models[t_model_type].max_seq_length = max_length
223
- logger.info(
224
- "Set max_seq_length=%d for text model: %s",
225
- max_length,
226
- info.model_id,
227
- )
228
 
229
  # Preload image models.
230
  for i_model_type in ImageModelType:
231
- model_id = ModelConfig(
232
- image_model_type=i_model_type
233
- ).image_model_info.model_id
234
  logger.info("Loading image model: %s", model_id)
235
  model = AutoModel.from_pretrained(model_id).to(self.device)
236
  model.eval() # Set the model to evaluation mode.
@@ -257,9 +249,7 @@ class EmbeddingsService:
257
  raise ValueError("Text input cannot be empty.")
258
  return [input_text]
259
 
260
- if not isinstance(input_text, list) or not all(
261
- isinstance(x, str) for x in input_text
262
- ):
263
  raise ValueError("Text input must be a string or a list of strings.")
264
 
265
  if len(input_text) == 0:
@@ -280,9 +270,7 @@ class EmbeddingsService:
280
  raise ValueError("Image input cannot be empty.")
281
  return [input_images]
282
 
283
- if not isinstance(input_images, list) or not all(
284
- isinstance(x, str) for x in input_images
285
- ):
286
  raise ValueError("Image input must be a string or a list of strings.")
287
 
288
  if len(input_images) == 0:
@@ -305,17 +293,15 @@ class EmbeddingsService:
305
  try:
306
  # Attempt to get the tokenizer from the first module of the SentenceTransformer.
307
  module = model._first_module()
308
- if not hasattr(module, "tokenizer"):
309
  return text
310
  tokenizer = module.tokenizer
311
  # Tokenize without truncation.
312
  encoded = tokenizer(text, add_special_tokens=True, truncation=False)
313
  max_length = model.max_seq_length
314
- if len(encoded["input_ids"]) > max_length:
315
- truncated_ids = encoded["input_ids"][:max_length]
316
- truncated_text = tokenizer.decode(
317
- truncated_ids, skip_special_tokens=True
318
- )
319
  return truncated_text
320
  except Exception as e:
321
  logger.warning("Error during text truncation: %s", str(e))
@@ -367,9 +353,7 @@ class EmbeddingsService:
367
  processed_data = processor(images=img, return_tensors="pt").to(self.device)
368
  return processed_data
369
 
370
- def _generate_text_embeddings(
371
- self, model_id: TextModelType, texts: List[str]
372
- ) -> np.ndarray:
373
  """
374
  Generate text embeddings using the SentenceTransformer model.
375
  Single-text requests are cached using an LRU cache.
@@ -385,26 +369,25 @@ class EmbeddingsService:
385
  RuntimeError: If text embedding generation fails.
386
  """
387
  try:
388
- if len(texts) == 1:
389
- single_text = texts[0]
390
- key = md5(f"{model_id}:{single_text}".encode("utf-8")).hexdigest()[:8]
391
- if key in self.lru_cache:
392
- return self.lru_cache[key]
393
- model = self.text_models[model_id]
394
- emb = model.encode([single_text])
395
- self.lru_cache[key] = emb
396
- return emb
397
-
398
  model = self.text_models[model_id]
399
- return model.encode(texts)
 
 
 
 
 
 
 
 
 
 
 
400
  except Exception as e:
401
  raise RuntimeError(
402
  f"Error generating text embeddings with model '{model_id}': {e}"
403
  ) from e
404
 
405
- async def _async_generate_image_embeddings(
406
- self, model_id: ImageModelType, images: List[str]
407
- ) -> np.ndarray:
408
  """
409
  Asynchronously generate image embeddings.
410
 
@@ -428,15 +411,11 @@ class EmbeddingsService:
428
  )
429
  # Assume all processed outputs have the same keys.
430
  keys = processed_tensors[0].keys()
431
- combined = {
432
- k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys
433
- }
434
 
435
  def infer():
436
  with torch.no_grad():
437
- embeddings = self.image_models[model_id].get_image_features(
438
- **combined
439
- )
440
  return embeddings.cpu().numpy()
441
 
442
  return await asyncio.to_thread(infer)
@@ -445,9 +424,7 @@ class EmbeddingsService:
445
  f"Error generating image embeddings with model '{model_id}': {e}"
446
  ) from e
447
 
448
- async def generate_embeddings(
449
- self, model: str, inputs: Union[str, List[str]]
450
- ) -> np.ndarray:
451
  """
452
  Asynchronously generate embeddings for text or image inputs based on model type.
453
 
@@ -463,26 +440,19 @@ class EmbeddingsService:
463
  text_model_enum = TextModelType(model)
464
  text_list = self._validate_text_list(inputs)
465
  model_instance = self.text_models[text_model_enum]
466
- # Truncate each text if it exceeds the maximum allowed token length.
467
- truncated_texts = [
468
- self._truncate_text(text, model_instance) for text in text_list
469
- ]
470
  return await asyncio.to_thread(
471
  self._generate_text_embeddings, text_model_enum, truncated_texts
472
  )
473
  elif modality == ModelKind.IMAGE:
474
  image_model_enum = ImageModelType(model)
475
  image_list = self._validate_image_list(inputs)
476
- return await self._async_generate_image_embeddings(
477
- image_model_enum, image_list
478
- )
479
 
480
- async def rank(
481
- self,
482
- model: str,
483
- queries: Union[str, List[str]],
484
- candidates: Union[str, List[str]],
485
- ) -> Dict[str, Any]:
486
  """
487
  Asynchronously rank candidate texts/images against the provided queries.
488
  Embeddings for queries and candidates are generated concurrently.
@@ -503,12 +473,8 @@ class EmbeddingsService:
503
 
504
  # Concurrently generate embeddings.
505
  query_task = asyncio.create_task(self.generate_embeddings(model, queries))
506
- candidate_task = asyncio.create_task(
507
- self.generate_embeddings(model, candidates)
508
- )
509
- query_embeds, candidate_embeds = await asyncio.gather(
510
- query_task, candidate_task
511
- )
512
 
513
  # Compute cosine similarity.
514
  sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
 
4
 
5
  import asyncio
6
  import logging
7
+ import threading
8
  from enum import Enum
9
  from typing import List, Union, Dict, Optional, NamedTuple, Any
10
  from dataclasses import dataclass
 
28
  """
29
  Enumeration of supported text models.
30
  """
 
31
  MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
32
  MULTILINGUAL_E5_BASE = "multilingual-e5-base"
33
  MULTILINGUAL_E5_LARGE = "multilingual-e5-large"
 
42
  """
43
  Enumeration of supported image models.
44
  """
 
45
  SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
46
 
47
 
 
49
  """
50
  Enumeration of maximum token lengths for supported text models.
51
  """
 
52
  MULTILINGUAL_E5_SMALL = 512
53
  MULTILINGUAL_E5_BASE = 512
54
  MULTILINGUAL_E5_LARGE = 512
 
63
  """
64
  Container mapping a model type to its model identifier and optional ONNX file.
65
  """
 
66
  model_id: str
67
  onnx_file: Optional[str] = None
68
 
 
72
  """
73
  Configuration for text and image models.
74
  """
 
75
  text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL
76
+ image_model_type: ImageModelType = ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL
 
 
77
  logit_scale: float = 4.60517 # Example scale used in cross-modal similarity
78
 
79
  @property
 
134
  """
135
  Indicates the type of model: text or image.
136
  """
 
137
  TEXT = "text"
138
  IMAGE = "image"
139
 
 
177
  self.image_models: Dict[ImageModelType, AutoModel] = {}
178
  self.image_processors: Dict[ImageModelType, AutoProcessor] = {}
179
 
180
+ # Create reentrant locks for each text model to ensure thread safety.
181
+ self.text_model_locks: Dict[TextModelType, threading.RLock] = {
182
+ t: threading.RLock() for t in TextModelType
183
+ }
184
+
185
  # Create a persistent asynchronous HTTP client.
186
  self.async_http_client = httpx.AsyncClient(timeout=10)
187
 
 
218
  # Set maximum sequence length based on configuration.
219
  max_length = int(MaxModelLength[t_model_type.name].value)
220
  self.text_models[t_model_type].max_seq_length = max_length
221
+ logger.info("Set max_seq_length=%d for text model: %s", max_length, info.model_id)
 
 
 
 
222
 
223
  # Preload image models.
224
  for i_model_type in ImageModelType:
225
+ model_id = ModelConfig(image_model_type=i_model_type).image_model_info.model_id
 
 
226
  logger.info("Loading image model: %s", model_id)
227
  model = AutoModel.from_pretrained(model_id).to(self.device)
228
  model.eval() # Set the model to evaluation mode.
 
249
  raise ValueError("Text input cannot be empty.")
250
  return [input_text]
251
 
252
+ if not isinstance(input_text, list) or not all(isinstance(x, str) for x in input_text):
 
 
253
  raise ValueError("Text input must be a string or a list of strings.")
254
 
255
  if len(input_text) == 0:
 
270
  raise ValueError("Image input cannot be empty.")
271
  return [input_images]
272
 
273
+ if not isinstance(input_images, list) or not all(isinstance(x, str) for x in input_images):
 
 
274
  raise ValueError("Image input must be a string or a list of strings.")
275
 
276
  if len(input_images) == 0:
 
293
  try:
294
  # Attempt to get the tokenizer from the first module of the SentenceTransformer.
295
  module = model._first_module()
296
+ if not hasattr(module, 'tokenizer'):
297
  return text
298
  tokenizer = module.tokenizer
299
  # Tokenize without truncation.
300
  encoded = tokenizer(text, add_special_tokens=True, truncation=False)
301
  max_length = model.max_seq_length
302
+ if len(encoded['input_ids']) > max_length:
303
+ truncated_ids = encoded['input_ids'][:max_length]
304
+ truncated_text = tokenizer.decode(truncated_ids, skip_special_tokens=True)
 
 
305
  return truncated_text
306
  except Exception as e:
307
  logger.warning("Error during text truncation: %s", str(e))
 
353
  processed_data = processor(images=img, return_tensors="pt").to(self.device)
354
  return processed_data
355
 
356
+ def _generate_text_embeddings(self, model_id: TextModelType, texts: List[str]) -> np.ndarray:
 
 
357
  """
358
  Generate text embeddings using the SentenceTransformer model.
359
  Single-text requests are cached using an LRU cache.
 
369
  RuntimeError: If text embedding generation fails.
370
  """
371
  try:
 
 
 
 
 
 
 
 
 
 
372
  model = self.text_models[model_id]
373
+ lock = self.text_model_locks[model_id]
374
+ with lock:
375
+ if len(texts) == 1:
376
+ single_text = texts[0]
377
+ key = md5(f"{model_id}:{single_text}".encode("utf-8")).hexdigest()[:8]
378
+ if key in self.lru_cache:
379
+ return self.lru_cache[key]
380
+ emb = model.encode([single_text])
381
+ self.lru_cache[key] = emb
382
+ return emb
383
+
384
+ return model.encode(texts)
385
  except Exception as e:
386
  raise RuntimeError(
387
  f"Error generating text embeddings with model '{model_id}': {e}"
388
  ) from e
389
 
390
+ async def _async_generate_image_embeddings(self, model_id: ImageModelType, images: List[str]) -> np.ndarray:
 
 
391
  """
392
  Asynchronously generate image embeddings.
393
 
 
411
  )
412
  # Assume all processed outputs have the same keys.
413
  keys = processed_tensors[0].keys()
414
+ combined = {k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys}
 
 
415
 
416
  def infer():
417
  with torch.no_grad():
418
+ embeddings = self.image_models[model_id].get_image_features(**combined)
 
 
419
  return embeddings.cpu().numpy()
420
 
421
  return await asyncio.to_thread(infer)
 
424
  f"Error generating image embeddings with model '{model_id}': {e}"
425
  ) from e
426
 
427
+ async def generate_embeddings(self, model: str, inputs: Union[str, List[str]]) -> np.ndarray:
 
 
428
  """
429
  Asynchronously generate embeddings for text or image inputs based on model type.
430
 
 
440
  text_model_enum = TextModelType(model)
441
  text_list = self._validate_text_list(inputs)
442
  model_instance = self.text_models[text_model_enum]
443
+ lock = self.text_model_locks[text_model_enum]
444
+ with lock:
445
+ # Truncate each text if it exceeds the maximum allowed token length.
446
+ truncated_texts = [self._truncate_text(text, model_instance) for text in text_list]
447
  return await asyncio.to_thread(
448
  self._generate_text_embeddings, text_model_enum, truncated_texts
449
  )
450
  elif modality == ModelKind.IMAGE:
451
  image_model_enum = ImageModelType(model)
452
  image_list = self._validate_image_list(inputs)
453
+ return await self._async_generate_image_embeddings(image_model_enum, image_list)
 
 
454
 
455
+ async def rank(self, model: str, queries: Union[str, List[str]], candidates: Union[str, List[str]]) -> Dict[str, Any]:
 
 
 
 
 
456
  """
457
  Asynchronously rank candidate texts/images against the provided queries.
458
  Embeddings for queries and candidates are generated concurrently.
 
473
 
474
  # Concurrently generate embeddings.
475
  query_task = asyncio.create_task(self.generate_embeddings(model, queries))
476
+ candidate_task = asyncio.create_task(self.generate_embeddings(model, candidates))
477
+ query_embeds, candidate_embeds = await asyncio.gather(query_task, candidate_task)
 
 
 
 
478
 
479
  # Compute cosine similarity.
480
  sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)