lamhieu commited on
Commit
422dfa9
·
1 Parent(s): 716eebd

fix: truncate texts that are longer than allowed

Browse files
Files changed (1) hide show
  1. lightweight_embeddings/service.py +92 -14
lightweight_embeddings/service.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from __future__ import annotations
2
 
3
  import asyncio
@@ -44,6 +46,21 @@ class ImageModelType(str, Enum):
44
  SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class ModelInfo(NamedTuple):
48
  """
49
  Container mapping a model type to its model identifier and optional ONNX file.
@@ -200,6 +217,14 @@ class EmbeddingsService:
200
  device=self.device,
201
  trust_remote_code=True,
202
  )
 
 
 
 
 
 
 
 
203
 
204
  # Preload image models.
205
  for i_model_type in ImageModelType:
@@ -265,6 +290,37 @@ class EmbeddingsService:
265
 
266
  return input_images
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  async def _fetch_image(self, path_or_url: str) -> Image.Image:
269
  """
270
  Asynchronously fetch an image from a URL or load from a local path.
@@ -312,14 +368,16 @@ class EmbeddingsService:
312
  return processed_data
313
 
314
  def _generate_text_embeddings(
315
- self,
316
- model_id: TextModelType,
317
- texts: List[str],
318
  ) -> np.ndarray:
319
  """
320
  Generate text embeddings using the SentenceTransformer model.
321
  Single-text requests are cached using an LRU cache.
322
 
 
 
 
 
323
  Returns:
324
  A NumPy array of text embeddings.
325
 
@@ -345,9 +403,7 @@ class EmbeddingsService:
345
  ) from e
346
 
347
  async def _async_generate_image_embeddings(
348
- self,
349
- model_id: ImageModelType,
350
- images: List[str],
351
  ) -> np.ndarray:
352
  """
353
  Asynchronously generate image embeddings.
@@ -355,6 +411,10 @@ class EmbeddingsService:
355
  This method concurrently processes multiple images and offloads
356
  the blocking model inference to a separate thread.
357
 
 
 
 
 
358
  Returns:
359
  A NumPy array of image embeddings.
360
 
@@ -386,9 +446,7 @@ class EmbeddingsService:
386
  ) from e
387
 
388
  async def generate_embeddings(
389
- self,
390
- model: str,
391
- inputs: Union[str, List[str]],
392
  ) -> np.ndarray:
393
  """
394
  Asynchronously generate embeddings for text or image inputs based on model type.
@@ -402,16 +460,21 @@ class EmbeddingsService:
402
  """
403
  modality = detect_model_kind(model)
404
  if modality == ModelKind.TEXT:
405
- text_model_id = TextModelType(model)
406
  text_list = self._validate_text_list(inputs)
 
 
 
 
 
407
  return await asyncio.to_thread(
408
- self._generate_text_embeddings, text_model_id, text_list
409
  )
410
  elif modality == ModelKind.IMAGE:
411
- image_model_id = ImageModelType(model)
412
  image_list = self._validate_image_list(inputs)
413
  return await self._async_generate_image_embeddings(
414
- image_model_id, image_list
415
  )
416
 
417
  async def rank(
@@ -424,6 +487,11 @@ class EmbeddingsService:
424
  Asynchronously rank candidate texts/images against the provided queries.
425
  Embeddings for queries and candidates are generated concurrently.
426
 
 
 
 
 
 
427
  Returns:
428
  A dictionary containing probabilities, cosine similarities, and usage statistics.
429
  """
@@ -469,6 +537,9 @@ class EmbeddingsService:
469
  """
470
  Estimate the token count for the given text input using the SentenceTransformer tokenizer.
471
 
 
 
 
472
  Returns:
473
  The total number of tokens.
474
  """
@@ -482,8 +553,11 @@ class EmbeddingsService:
482
  """
483
  Compute the softmax over the last dimension of the input array.
484
 
 
 
 
485
  Returns:
486
- The softmax probabilities.
487
  """
488
  exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
489
  return exps / np.sum(exps, axis=-1, keepdims=True)
@@ -493,6 +567,10 @@ class EmbeddingsService:
493
  """
494
  Compute the pairwise cosine similarity between all rows of arrays a and b.
495
 
 
 
 
 
496
  Returns:
497
  A (N x M) matrix of cosine similarities.
498
  """
 
1
+ # filename: service.py
2
+
3
  from __future__ import annotations
4
 
5
  import asyncio
 
46
  SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
47
 
48
 
49
+ class MaxModelLength(str, Enum):
50
+ """
51
+ Enumeration of maximum token lengths for supported text models.
52
+ """
53
+
54
+ MULTILINGUAL_E5_SMALL = 512
55
+ MULTILINGUAL_E5_BASE = 512
56
+ MULTILINGUAL_E5_LARGE = 512
57
+ SNOWFLAKE_ARCTIC_EMBED_L_V2 = 8192
58
+ PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = 128
59
+ PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = 128
60
+ BGE_M3 = 8192
61
+ GTE_MULTILINGUAL_BASE = 8192
62
+
63
+
64
  class ModelInfo(NamedTuple):
65
  """
66
  Container mapping a model type to its model identifier and optional ONNX file.
 
217
  device=self.device,
218
  trust_remote_code=True,
219
  )
220
+ # Set maximum sequence length based on configuration.
221
+ max_length = int(MaxModelLength[t_model_type.name].value)
222
+ self.text_models[t_model_type].max_seq_length = max_length
223
+ logger.info(
224
+ "Set max_seq_length=%d for text model: %s",
225
+ max_length,
226
+ info.model_id,
227
+ )
228
 
229
  # Preload image models.
230
  for i_model_type in ImageModelType:
 
290
 
291
  return input_images
292
 
293
+ def _truncate_text(self, text: str, model: SentenceTransformer) -> str:
294
+ """
295
+ Truncate the input text to the maximum allowed tokens for the given model.
296
+
297
+ Args:
298
+ text: The input text.
299
+ model: The SentenceTransformer model used for tokenization.
300
+
301
+ Returns:
302
+ The truncated text if token length exceeds the maximum allowed length,
303
+ otherwise the original text.
304
+ """
305
+ try:
306
+ # Attempt to get the tokenizer from the first module of the SentenceTransformer.
307
+ module = model._first_module()
308
+ if not hasattr(module, "tokenizer"):
309
+ return text
310
+ tokenizer = module.tokenizer
311
+ # Tokenize without truncation.
312
+ encoded = tokenizer(text, add_special_tokens=True, truncation=False)
313
+ max_length = model.max_seq_length
314
+ if len(encoded["input_ids"]) > max_length:
315
+ truncated_ids = encoded["input_ids"][:max_length]
316
+ truncated_text = tokenizer.decode(
317
+ truncated_ids, skip_special_tokens=True
318
+ )
319
+ return truncated_text
320
+ except Exception as e:
321
+ logger.warning("Error during text truncation: %s", str(e))
322
+ return text
323
+
324
  async def _fetch_image(self, path_or_url: str) -> Image.Image:
325
  """
326
  Asynchronously fetch an image from a URL or load from a local path.
 
368
  return processed_data
369
 
370
  def _generate_text_embeddings(
371
+ self, model_id: TextModelType, texts: List[str]
 
 
372
  ) -> np.ndarray:
373
  """
374
  Generate text embeddings using the SentenceTransformer model.
375
  Single-text requests are cached using an LRU cache.
376
 
377
+ Args:
378
+ model_id: The text model type.
379
+ texts: A list of input texts.
380
+
381
  Returns:
382
  A NumPy array of text embeddings.
383
 
 
403
  ) from e
404
 
405
  async def _async_generate_image_embeddings(
406
+ self, model_id: ImageModelType, images: List[str]
 
 
407
  ) -> np.ndarray:
408
  """
409
  Asynchronously generate image embeddings.
 
411
  This method concurrently processes multiple images and offloads
412
  the blocking model inference to a separate thread.
413
 
414
+ Args:
415
+ model_id: The image model type.
416
+ images: A list of image URLs or file paths.
417
+
418
  Returns:
419
  A NumPy array of image embeddings.
420
 
 
446
  ) from e
447
 
448
  async def generate_embeddings(
449
+ self, model: str, inputs: Union[str, List[str]]
 
 
450
  ) -> np.ndarray:
451
  """
452
  Asynchronously generate embeddings for text or image inputs based on model type.
 
460
  """
461
  modality = detect_model_kind(model)
462
  if modality == ModelKind.TEXT:
463
+ text_model_enum = TextModelType(model)
464
  text_list = self._validate_text_list(inputs)
465
+ model_instance = self.text_models[text_model_enum]
466
+ # Truncate each text if it exceeds the maximum allowed token length.
467
+ truncated_texts = [
468
+ self._truncate_text(text, model_instance) for text in text_list
469
+ ]
470
  return await asyncio.to_thread(
471
+ self._generate_text_embeddings, text_model_enum, truncated_texts
472
  )
473
  elif modality == ModelKind.IMAGE:
474
+ image_model_enum = ImageModelType(model)
475
  image_list = self._validate_image_list(inputs)
476
  return await self._async_generate_image_embeddings(
477
+ image_model_enum, image_list
478
  )
479
 
480
  async def rank(
 
487
  Asynchronously rank candidate texts/images against the provided queries.
488
  Embeddings for queries and candidates are generated concurrently.
489
 
490
+ Args:
491
+ model: The model identifier.
492
+ queries: The query input(s).
493
+ candidates: The candidate input(s).
494
+
495
  Returns:
496
  A dictionary containing probabilities, cosine similarities, and usage statistics.
497
  """
 
537
  """
538
  Estimate the token count for the given text input using the SentenceTransformer tokenizer.
539
 
540
+ Args:
541
+ input_data: The text input(s).
542
+
543
  Returns:
544
  The total number of tokens.
545
  """
 
553
  """
554
  Compute the softmax over the last dimension of the input array.
555
 
556
+ Args:
557
+ scores: A NumPy array of scores.
558
+
559
  Returns:
560
+ A NumPy array of softmax probabilities.
561
  """
562
  exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
563
  return exps / np.sum(exps, axis=-1, keepdims=True)
 
567
  """
568
  Compute the pairwise cosine similarity between all rows of arrays a and b.
569
 
570
+ Args:
571
+ a: A NumPy array.
572
+ b: A NumPy array.
573
+
574
  Returns:
575
  A (N x M) matrix of cosine similarities.
576
  """