Spaces:
Running
Running
fix: truncate texts that are longer than allowed
Browse files
lightweight_embeddings/service.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
@@ -44,6 +46,21 @@ class ImageModelType(str, Enum):
|
|
44 |
SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
class ModelInfo(NamedTuple):
|
48 |
"""
|
49 |
Container mapping a model type to its model identifier and optional ONNX file.
|
@@ -200,6 +217,14 @@ class EmbeddingsService:
|
|
200 |
device=self.device,
|
201 |
trust_remote_code=True,
|
202 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
# Preload image models.
|
205 |
for i_model_type in ImageModelType:
|
@@ -265,6 +290,37 @@ class EmbeddingsService:
|
|
265 |
|
266 |
return input_images
|
267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
async def _fetch_image(self, path_or_url: str) -> Image.Image:
|
269 |
"""
|
270 |
Asynchronously fetch an image from a URL or load from a local path.
|
@@ -312,14 +368,16 @@ class EmbeddingsService:
|
|
312 |
return processed_data
|
313 |
|
314 |
def _generate_text_embeddings(
|
315 |
-
self,
|
316 |
-
model_id: TextModelType,
|
317 |
-
texts: List[str],
|
318 |
) -> np.ndarray:
|
319 |
"""
|
320 |
Generate text embeddings using the SentenceTransformer model.
|
321 |
Single-text requests are cached using an LRU cache.
|
322 |
|
|
|
|
|
|
|
|
|
323 |
Returns:
|
324 |
A NumPy array of text embeddings.
|
325 |
|
@@ -345,9 +403,7 @@ class EmbeddingsService:
|
|
345 |
) from e
|
346 |
|
347 |
async def _async_generate_image_embeddings(
|
348 |
-
self,
|
349 |
-
model_id: ImageModelType,
|
350 |
-
images: List[str],
|
351 |
) -> np.ndarray:
|
352 |
"""
|
353 |
Asynchronously generate image embeddings.
|
@@ -355,6 +411,10 @@ class EmbeddingsService:
|
|
355 |
This method concurrently processes multiple images and offloads
|
356 |
the blocking model inference to a separate thread.
|
357 |
|
|
|
|
|
|
|
|
|
358 |
Returns:
|
359 |
A NumPy array of image embeddings.
|
360 |
|
@@ -386,9 +446,7 @@ class EmbeddingsService:
|
|
386 |
) from e
|
387 |
|
388 |
async def generate_embeddings(
|
389 |
-
self,
|
390 |
-
model: str,
|
391 |
-
inputs: Union[str, List[str]],
|
392 |
) -> np.ndarray:
|
393 |
"""
|
394 |
Asynchronously generate embeddings for text or image inputs based on model type.
|
@@ -402,16 +460,21 @@ class EmbeddingsService:
|
|
402 |
"""
|
403 |
modality = detect_model_kind(model)
|
404 |
if modality == ModelKind.TEXT:
|
405 |
-
|
406 |
text_list = self._validate_text_list(inputs)
|
|
|
|
|
|
|
|
|
|
|
407 |
return await asyncio.to_thread(
|
408 |
-
self._generate_text_embeddings,
|
409 |
)
|
410 |
elif modality == ModelKind.IMAGE:
|
411 |
-
|
412 |
image_list = self._validate_image_list(inputs)
|
413 |
return await self._async_generate_image_embeddings(
|
414 |
-
|
415 |
)
|
416 |
|
417 |
async def rank(
|
@@ -424,6 +487,11 @@ class EmbeddingsService:
|
|
424 |
Asynchronously rank candidate texts/images against the provided queries.
|
425 |
Embeddings for queries and candidates are generated concurrently.
|
426 |
|
|
|
|
|
|
|
|
|
|
|
427 |
Returns:
|
428 |
A dictionary containing probabilities, cosine similarities, and usage statistics.
|
429 |
"""
|
@@ -469,6 +537,9 @@ class EmbeddingsService:
|
|
469 |
"""
|
470 |
Estimate the token count for the given text input using the SentenceTransformer tokenizer.
|
471 |
|
|
|
|
|
|
|
472 |
Returns:
|
473 |
The total number of tokens.
|
474 |
"""
|
@@ -482,8 +553,11 @@ class EmbeddingsService:
|
|
482 |
"""
|
483 |
Compute the softmax over the last dimension of the input array.
|
484 |
|
|
|
|
|
|
|
485 |
Returns:
|
486 |
-
|
487 |
"""
|
488 |
exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
|
489 |
return exps / np.sum(exps, axis=-1, keepdims=True)
|
@@ -493,6 +567,10 @@ class EmbeddingsService:
|
|
493 |
"""
|
494 |
Compute the pairwise cosine similarity between all rows of arrays a and b.
|
495 |
|
|
|
|
|
|
|
|
|
496 |
Returns:
|
497 |
A (N x M) matrix of cosine similarities.
|
498 |
"""
|
|
|
1 |
+
# filename: service.py
|
2 |
+
|
3 |
from __future__ import annotations
|
4 |
|
5 |
import asyncio
|
|
|
46 |
SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
|
47 |
|
48 |
|
49 |
+
class MaxModelLength(str, Enum):
|
50 |
+
"""
|
51 |
+
Enumeration of maximum token lengths for supported text models.
|
52 |
+
"""
|
53 |
+
|
54 |
+
MULTILINGUAL_E5_SMALL = 512
|
55 |
+
MULTILINGUAL_E5_BASE = 512
|
56 |
+
MULTILINGUAL_E5_LARGE = 512
|
57 |
+
SNOWFLAKE_ARCTIC_EMBED_L_V2 = 8192
|
58 |
+
PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = 128
|
59 |
+
PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = 128
|
60 |
+
BGE_M3 = 8192
|
61 |
+
GTE_MULTILINGUAL_BASE = 8192
|
62 |
+
|
63 |
+
|
64 |
class ModelInfo(NamedTuple):
|
65 |
"""
|
66 |
Container mapping a model type to its model identifier and optional ONNX file.
|
|
|
217 |
device=self.device,
|
218 |
trust_remote_code=True,
|
219 |
)
|
220 |
+
# Set maximum sequence length based on configuration.
|
221 |
+
max_length = int(MaxModelLength[t_model_type.name].value)
|
222 |
+
self.text_models[t_model_type].max_seq_length = max_length
|
223 |
+
logger.info(
|
224 |
+
"Set max_seq_length=%d for text model: %s",
|
225 |
+
max_length,
|
226 |
+
info.model_id,
|
227 |
+
)
|
228 |
|
229 |
# Preload image models.
|
230 |
for i_model_type in ImageModelType:
|
|
|
290 |
|
291 |
return input_images
|
292 |
|
293 |
+
def _truncate_text(self, text: str, model: SentenceTransformer) -> str:
|
294 |
+
"""
|
295 |
+
Truncate the input text to the maximum allowed tokens for the given model.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
text: The input text.
|
299 |
+
model: The SentenceTransformer model used for tokenization.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
The truncated text if token length exceeds the maximum allowed length,
|
303 |
+
otherwise the original text.
|
304 |
+
"""
|
305 |
+
try:
|
306 |
+
# Attempt to get the tokenizer from the first module of the SentenceTransformer.
|
307 |
+
module = model._first_module()
|
308 |
+
if not hasattr(module, "tokenizer"):
|
309 |
+
return text
|
310 |
+
tokenizer = module.tokenizer
|
311 |
+
# Tokenize without truncation.
|
312 |
+
encoded = tokenizer(text, add_special_tokens=True, truncation=False)
|
313 |
+
max_length = model.max_seq_length
|
314 |
+
if len(encoded["input_ids"]) > max_length:
|
315 |
+
truncated_ids = encoded["input_ids"][:max_length]
|
316 |
+
truncated_text = tokenizer.decode(
|
317 |
+
truncated_ids, skip_special_tokens=True
|
318 |
+
)
|
319 |
+
return truncated_text
|
320 |
+
except Exception as e:
|
321 |
+
logger.warning("Error during text truncation: %s", str(e))
|
322 |
+
return text
|
323 |
+
|
324 |
async def _fetch_image(self, path_or_url: str) -> Image.Image:
|
325 |
"""
|
326 |
Asynchronously fetch an image from a URL or load from a local path.
|
|
|
368 |
return processed_data
|
369 |
|
370 |
def _generate_text_embeddings(
|
371 |
+
self, model_id: TextModelType, texts: List[str]
|
|
|
|
|
372 |
) -> np.ndarray:
|
373 |
"""
|
374 |
Generate text embeddings using the SentenceTransformer model.
|
375 |
Single-text requests are cached using an LRU cache.
|
376 |
|
377 |
+
Args:
|
378 |
+
model_id: The text model type.
|
379 |
+
texts: A list of input texts.
|
380 |
+
|
381 |
Returns:
|
382 |
A NumPy array of text embeddings.
|
383 |
|
|
|
403 |
) from e
|
404 |
|
405 |
async def _async_generate_image_embeddings(
|
406 |
+
self, model_id: ImageModelType, images: List[str]
|
|
|
|
|
407 |
) -> np.ndarray:
|
408 |
"""
|
409 |
Asynchronously generate image embeddings.
|
|
|
411 |
This method concurrently processes multiple images and offloads
|
412 |
the blocking model inference to a separate thread.
|
413 |
|
414 |
+
Args:
|
415 |
+
model_id: The image model type.
|
416 |
+
images: A list of image URLs or file paths.
|
417 |
+
|
418 |
Returns:
|
419 |
A NumPy array of image embeddings.
|
420 |
|
|
|
446 |
) from e
|
447 |
|
448 |
async def generate_embeddings(
|
449 |
+
self, model: str, inputs: Union[str, List[str]]
|
|
|
|
|
450 |
) -> np.ndarray:
|
451 |
"""
|
452 |
Asynchronously generate embeddings for text or image inputs based on model type.
|
|
|
460 |
"""
|
461 |
modality = detect_model_kind(model)
|
462 |
if modality == ModelKind.TEXT:
|
463 |
+
text_model_enum = TextModelType(model)
|
464 |
text_list = self._validate_text_list(inputs)
|
465 |
+
model_instance = self.text_models[text_model_enum]
|
466 |
+
# Truncate each text if it exceeds the maximum allowed token length.
|
467 |
+
truncated_texts = [
|
468 |
+
self._truncate_text(text, model_instance) for text in text_list
|
469 |
+
]
|
470 |
return await asyncio.to_thread(
|
471 |
+
self._generate_text_embeddings, text_model_enum, truncated_texts
|
472 |
)
|
473 |
elif modality == ModelKind.IMAGE:
|
474 |
+
image_model_enum = ImageModelType(model)
|
475 |
image_list = self._validate_image_list(inputs)
|
476 |
return await self._async_generate_image_embeddings(
|
477 |
+
image_model_enum, image_list
|
478 |
)
|
479 |
|
480 |
async def rank(
|
|
|
487 |
Asynchronously rank candidate texts/images against the provided queries.
|
488 |
Embeddings for queries and candidates are generated concurrently.
|
489 |
|
490 |
+
Args:
|
491 |
+
model: The model identifier.
|
492 |
+
queries: The query input(s).
|
493 |
+
candidates: The candidate input(s).
|
494 |
+
|
495 |
Returns:
|
496 |
A dictionary containing probabilities, cosine similarities, and usage statistics.
|
497 |
"""
|
|
|
537 |
"""
|
538 |
Estimate the token count for the given text input using the SentenceTransformer tokenizer.
|
539 |
|
540 |
+
Args:
|
541 |
+
input_data: The text input(s).
|
542 |
+
|
543 |
Returns:
|
544 |
The total number of tokens.
|
545 |
"""
|
|
|
553 |
"""
|
554 |
Compute the softmax over the last dimension of the input array.
|
555 |
|
556 |
+
Args:
|
557 |
+
scores: A NumPy array of scores.
|
558 |
+
|
559 |
Returns:
|
560 |
+
A NumPy array of softmax probabilities.
|
561 |
"""
|
562 |
exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
|
563 |
return exps / np.sum(exps, axis=-1, keepdims=True)
|
|
|
567 |
"""
|
568 |
Compute the pairwise cosine similarity between all rows of arrays a and b.
|
569 |
|
570 |
+
Args:
|
571 |
+
a: A NumPy array.
|
572 |
+
b: A NumPy array.
|
573 |
+
|
574 |
Returns:
|
575 |
A (N x M) matrix of cosine similarities.
|
576 |
"""
|