lamhieu commited on
Commit
65c747d
Β·
1 Parent(s): ea8754a

chore: support other models

Browse files
lightweight_embeddings/__init__.py CHANGED
@@ -14,13 +14,27 @@ Supported image model ID:
14
  - "google/siglip-base-patch16-256-multilingual"
15
  """
16
 
17
- from fastapi import FastAPI
18
- from fastapi.middleware.cors import CORSMiddleware
19
  import gradio as gr
20
  import requests
21
  import json
 
 
 
22
  from gradio.routes import mount_gradio_app
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Application metadata
25
  __version__ = "1.0.0"
26
  __author__ = "lamhieu"
@@ -41,17 +55,18 @@ __metadata__ = {
41
  EMBEDDINGS_API_URL = "http://localhost:7860/v1/embeddings"
42
 
43
  # Markdown description for the main interface
44
- APP_DESCRIPTION = f"""\
 
45
  ## πŸš€ **Lightweight Embeddings API**
46
 
47
  The **Lightweight Embeddings API** is a fast, free, and multilingual service designed for generating embeddings and reranking with support for both **text** and **image** inputs. Get started below by exploring our interactive playground or using the cURL examples provided.
48
 
49
- ---
50
 
51
- ### πŸ“¦ Features
52
- - **Multilingual Support**: Process inputs in multiple languages.
53
- - **Versatile API**: Generate embeddings, perform ranking, and more.
54
- - **Developer-Friendly**: Quick to integrate with documentation and examples.
55
 
56
  ### πŸ”— Links
57
  - [Documentation]({__metadata__["docs"]}) | [GitHub]({__metadata__["github"]}) | [Playground]({__metadata__["spaces"]})
@@ -117,7 +132,11 @@ def create_main_interface():
117
  # Available model options for the dropdown
118
  model_options = [
119
  "multilingual-e5-small",
 
 
 
120
  "paraphrase-multilingual-MiniLM-L12-v2",
 
121
  "bge-m3",
122
  "google/siglip-base-patch16-256-multilingual",
123
  ]
@@ -167,7 +186,7 @@ def create_main_interface():
167
  -H 'Content-Type: application/json' \\
168
  -d '{
169
  "model": "multilingual-e5-small",
170
- "input": "Translate this text into Spanish."
171
  }'
172
  ```
173
 
@@ -179,11 +198,11 @@ def create_main_interface():
179
  -H 'Content-Type: application/json' \\
180
  -d '{
181
  "model": "multilingual-e5-small",
182
- "queries": "Find the best match for this query.",
183
  "candidates": [
184
- "Candidate A",
185
- "Candidate B",
186
- "Candidate C"
187
  ]
188
  }'
189
  ```
 
14
  - "google/siglip-base-patch16-256-multilingual"
15
  """
16
 
 
 
17
  import gradio as gr
18
  import requests
19
  import json
20
+ import logging
21
+ from fastapi import FastAPI
22
+ from fastapi.middleware.cors import CORSMiddleware
23
  from gradio.routes import mount_gradio_app
24
 
25
+
26
+ # Filter out /v1 requests from the access log
27
+ class LogFilter(logging.Filter):
28
+ def filter(self, record):
29
+ if record.args and len(record.args) >= 3:
30
+ if "/v1" in str(record.args[2]):
31
+ return False
32
+ return True
33
+
34
+
35
+ logger = logging.getLogger("uvicorn.access")
36
+ logger.addFilter(LogFilter())
37
+
38
  # Application metadata
39
  __version__ = "1.0.0"
40
  __author__ = "lamhieu"
 
55
  EMBEDDINGS_API_URL = "http://localhost:7860/v1/embeddings"
56
 
57
  # Markdown description for the main interface
58
+ APP_DESCRIPTION = f"""
59
+ <br />
60
  ## πŸš€ **Lightweight Embeddings API**
61
 
62
  The **Lightweight Embeddings API** is a fast, free, and multilingual service designed for generating embeddings and reranking with support for both **text** and **image** inputs. Get started below by exploring our interactive playground or using the cURL examples provided.
63
 
64
+ ### ✨ Key Features
65
 
66
+ - **Free, Unlimited, and Multilingual**: A fully free API service with no usage limits, capable of processing text in over 100+ languages to support global applications seamlessly.
67
+ - **Advanced Embedding and Reranking**: Generate high-quality text and image-text embeddings using state-of-the-art models, alongside robust reranking capabilities for enhanced results.
68
+ - **Optimized and Flexible**: Built for speed with lightweight transformer models, efficient backends for rapid inference on low-resource systems, and support for diverse use cases with models.
69
+ - **Production-Ready with Ease of Use**: Deploy effortlessly using Docker for a hassle-free setup, and experiment interactively through a **Gradio-powered playground** with comprehensive REST API documentation.
70
 
71
  ### πŸ”— Links
72
  - [Documentation]({__metadata__["docs"]}) | [GitHub]({__metadata__["github"]}) | [Playground]({__metadata__["spaces"]})
 
132
  # Available model options for the dropdown
133
  model_options = [
134
  "multilingual-e5-small",
135
+ "multilingual-e5-base",
136
+ "multilingual-e5-large",
137
+ "snowflake-arctic-embed-l-v2.0",
138
  "paraphrase-multilingual-MiniLM-L12-v2",
139
+ "paraphrase-multilingual-mpnet-base-v2",
140
  "bge-m3",
141
  "google/siglip-base-patch16-256-multilingual",
142
  ]
 
186
  -H 'Content-Type: application/json' \\
187
  -d '{
188
  "model": "multilingual-e5-small",
189
+ "input": "That is a happy person"
190
  }'
191
  ```
192
 
 
198
  -H 'Content-Type: application/json' \\
199
  -d '{
200
  "model": "multilingual-e5-small",
201
+ "queries": "That is a happy person",
202
  "candidates": [
203
+ "That is a happy dog",
204
+ "That is a very happy person",
205
+ "Today is a sunny day"
206
  ]
207
  }'
208
  ```
lightweight_embeddings/router.py CHANGED
@@ -1,18 +1,16 @@
1
- # filename: router.py
2
-
3
  """
4
- FastAPI Router for Embeddings Service
5
 
6
- This file exposes the EmbeddingsService functionality via a RESTful API
7
- to generate embeddings and rank candidates.
8
 
9
  Supported Text Model IDs:
10
  - "multilingual-e5-small"
11
  - "paraphrase-multilingual-MiniLM-L12-v2"
12
  - "bge-m3"
13
 
14
- Supported Image Model ID:
15
- - "google/siglip-base-patch16-256-multilingual"
 
16
  """
17
 
18
  from __future__ import annotations
@@ -24,143 +22,87 @@ from enum import Enum
24
  from fastapi import APIRouter, HTTPException
25
  from pydantic import BaseModel, Field
26
 
27
- from .service import ModelConfig, TextModelType, EmbeddingsService
 
 
 
 
 
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
- # Initialize FastAPI router
32
  router = APIRouter(
33
  tags=["v1"],
34
  responses={404: {"description": "Not found"}},
35
  )
36
 
37
 
38
- class ModelType(str, Enum):
39
- """
40
- High-level distinction for text vs. image models.
41
- """
42
-
43
  TEXT = "text"
44
  IMAGE = "image"
45
 
46
 
47
- def detect_model_type(model_id: str) -> ModelType:
48
  """
49
- Detect whether the provided model ID is for text or image.
50
-
51
- Supported text model IDs:
52
- - "multilingual-e5-small"
53
- - "paraphrase-multilingual-MiniLM-L12-v2"
54
- - "bge-m3"
55
-
56
- Supported image model ID:
57
- - "google/siglip-base-patch16-256-multilingual"
58
- (or any model containing "siglip" in its identifier).
59
-
60
- Args:
61
- model_id: String identifier of the model.
62
-
63
- Returns:
64
- ModelType.TEXT if it matches one of the recognized text model IDs,
65
- ModelType.IMAGE if it matches (or contains "siglip").
66
-
67
- Raises:
68
- ValueError: If the model_id is not recognized as either text or image.
69
  """
70
- # Gather all known text model IDs (from TextModelType enum)
71
- text_model_ids = {m.value for m in TextModelType}
72
-
73
- # Simple check: if it's in text_model_ids, it's text;
74
- # if 'siglip' is in the model ID, it's recognized as an image model.
75
- if model_id in text_model_ids:
76
- return ModelType.TEXT
77
- elif "siglip" in model_id.lower():
78
- return ModelType.IMAGE
79
-
80
- error_msg = (
81
- f"Unsupported model ID: '{model_id}'.\n"
82
- "Valid text model IDs are: "
83
- "'multilingual-e5-small', 'paraphrase-multilingual-MiniLM-L12-v2', 'bge-m3'.\n"
84
- "Valid image model ID contains 'siglip', for example: 'google/siglip-base-patch16-256-multilingual'."
85
- )
86
- raise ValueError(error_msg)
87
 
88
 
89
- # Pydantic Models for request/response
90
  class EmbeddingRequest(BaseModel):
91
  """
92
- Request body for embedding creation.
93
-
94
- Model IDs (text):
95
- - "multilingual-e5-small"
96
- - "paraphrase-multilingual-MiniLM-L12-v2"
97
- - "bge-m3"
98
-
99
- Model ID (image):
100
- - "google/siglip-base-patch16-256-multilingual"
101
  """
102
 
103
  model: str = Field(
104
  default=TextModelType.MULTILINGUAL_E5_SMALL.value,
105
  description=(
106
- "Model ID to use. Possible text models include: 'multilingual-e5-small', "
107
- "'paraphrase-multilingual-MiniLM-L12-v2', 'bge-m3'. "
108
- "For images, you can use: 'google/siglip-base-patch16-256-multilingual' "
109
- "or any ID containing 'siglip'."
110
  ),
111
  )
112
  input: Union[str, List[str]] = Field(
113
- ...,
114
- description=(
115
- "Input text(s) or image path(s)/URL(s). "
116
- "Accepts a single string or a list of strings."
117
- ),
118
  )
119
 
120
 
121
  class RankRequest(BaseModel):
122
  """
123
- Request body for ranking candidates against queries.
124
-
125
- Model IDs (text):
126
- - "multilingual-e5-small"
127
- - "paraphrase-multilingual-MiniLM-L12-v2"
128
- - "bge-m3"
129
-
130
- Model ID (image):
131
- - "google/siglip-base-patch16-256-multilingual"
132
  """
133
 
134
  model: str = Field(
135
  default=TextModelType.MULTILINGUAL_E5_SMALL.value,
136
  description=(
137
- "Model ID to use for the queries. Supported text models: "
138
- "'multilingual-e5-small', 'paraphrase-multilingual-MiniLM-L12-v2', 'bge-m3'. "
139
- "For image queries, use an ID containing 'siglip' such as 'google/siglip-base-patch16-256-multilingual'."
140
  ),
141
  )
142
  queries: Union[str, List[str]] = Field(
143
- ...,
144
- description=(
145
- "Query input(s): can be text(s) or image path(s)/URL(s). "
146
- "If using an image model, ensure your inputs reference valid image paths or URLs."
147
- ),
148
  )
149
  candidates: List[str] = Field(
150
- ...,
151
- description=(
152
- "List of candidate texts to rank against the given queries. "
153
- "Currently, all candidates must be text."
154
- ),
155
  )
156
 
157
 
158
  class EmbeddingResponse(BaseModel):
159
  """
160
- Response structure for embedding creation.
161
  """
162
 
163
- object: str = "list"
164
  data: List[dict]
165
  model: str
166
  usage: dict
@@ -168,14 +110,12 @@ class EmbeddingResponse(BaseModel):
168
 
169
  class RankResponse(BaseModel):
170
  """
171
- Response structure for ranking results.
172
  """
173
 
174
  probabilities: List[List[float]]
175
  cosine_similarities: List[List[float]]
176
 
177
-
178
- # Initialize the service with default configuration
179
  service_config = ModelConfig()
180
  embeddings_service = EmbeddingsService(config=service_config)
181
 
@@ -183,114 +123,81 @@ embeddings_service = EmbeddingsService(config=service_config)
183
  @router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
184
  async def create_embeddings(request: EmbeddingRequest):
185
  """
186
- Generate embeddings for the provided input text(s) or image(s).
187
-
188
- Supported Model IDs for text:
189
- - "multilingual-e5-small"
190
- - "paraphrase-multilingual-MiniLM-L12-v2"
191
- - "bge-m3"
192
-
193
- Supported Model ID for image:
194
- - "google/siglip-base-patch16-256-multilingual"
195
-
196
- Steps:
197
- 1. Detects model type (text or image) based on the model ID.
198
- 2. Adjusts the service configuration accordingly.
199
- 3. Produces embeddings via the EmbeddingsService.
200
- 4. Returns embedding vectors along with usage information.
201
-
202
- Raises:
203
- HTTPException: For any errors during model detection or embedding generation.
204
  """
205
  try:
206
- modality = detect_model_type(request.model)
 
207
 
208
- # Adjust global config based on the detected modality
209
- if modality == ModelType.TEXT:
210
  service_config.text_model_type = TextModelType(request.model)
211
  else:
212
- service_config.image_model_id = request.model
213
 
214
- # Generate embeddings asynchronously
215
  embeddings = await embeddings_service.generate_embeddings(
216
- input_data=request.input, modality=modality.value
217
  )
218
 
219
- # Estimate tokens only if it's text
220
  total_tokens = 0
221
- if modality == ModelType.TEXT:
222
  total_tokens = embeddings_service.estimate_tokens(request.input)
223
 
224
- return {
225
  "object": "list",
226
- "data": [
227
- {
228
- "object": "embedding",
229
- "index": idx,
230
- "embedding": emb.tolist(),
231
- }
232
- for idx, emb in enumerate(embeddings)
233
- ],
234
  "model": request.model,
235
  "usage": {
236
  "prompt_tokens": total_tokens,
237
  "total_tokens": total_tokens,
238
  },
239
  }
 
 
 
 
 
 
 
 
 
 
240
 
241
  except Exception as e:
242
- error_msg = (
243
- "Failed to generate embeddings. Please verify your model ID, input data, and server logs.\n"
244
- f"Error Details: {str(e)}"
245
  )
246
- logger.error(error_msg)
247
- raise HTTPException(status_code=500, detail=error_msg)
248
 
249
 
250
  @router.post("/rank", response_model=RankResponse, tags=["rank"])
251
  async def rank_candidates(request: RankRequest):
252
  """
253
- Rank the given candidate texts against the provided queries.
254
-
255
- Supported Model IDs for text queries:
256
- - "multilingual-e5-small"
257
- - "paraphrase-multilingual-MiniLM-L12-v2"
258
- - "bge-m3"
259
-
260
- Supported Model ID for image queries:
261
- - "google/siglip-base-patch16-256-multilingual"
262
-
263
- Steps:
264
- 1. Detects model type (text or image) based on the query model ID.
265
- 2. Adjusts the service configuration accordingly.
266
- 3. Generates embeddings for the queries (text or image).
267
- 4. Generates embeddings for the candidates (always text).
268
- 5. Computes cosine similarities and returns softmax-normalized probabilities.
269
-
270
- Raises:
271
- HTTPException: For any errors during model detection or ranking.
272
  """
273
  try:
274
- modality = detect_model_type(request.model)
275
 
276
- # Adjust global config based on the detected modality
277
- if modality == ModelType.TEXT:
278
  service_config.text_model_type = TextModelType(request.model)
279
  else:
280
- service_config.image_model_id = request.model
281
 
282
- # Perform the ranking
283
  results = await embeddings_service.rank(
284
  queries=request.queries,
285
  candidates=request.candidates,
286
- modality=modality.value,
287
  )
288
  return results
289
 
290
  except Exception as e:
291
- error_msg = (
292
- "Failed to rank candidates. Please verify your model ID, input data, and server logs.\n"
293
- f"Error Details: {str(e)}"
294
  )
295
- logger.error(error_msg)
296
- raise HTTPException(status_code=500, detail=error_msg)
 
 
 
1
  """
2
+ FastAPI Router for Embeddings Service (Revised & Simplified)
3
 
4
+ Exposes the EmbeddingsService methods via a RESTful API.
 
5
 
6
  Supported Text Model IDs:
7
  - "multilingual-e5-small"
8
  - "paraphrase-multilingual-MiniLM-L12-v2"
9
  - "bge-m3"
10
 
11
+ Supported Image Model IDs:
12
+ - "siglip-base-patch16-256-multilingual"
13
+ (Extend as needed)
14
  """
15
 
16
  from __future__ import annotations
 
22
  from fastapi import APIRouter, HTTPException
23
  from pydantic import BaseModel, Field
24
 
25
+ from .service import (
26
+ ModelConfig,
27
+ TextModelType,
28
+ ImageModelType,
29
+ EmbeddingsService,
30
+ )
31
 
32
  logger = logging.getLogger(__name__)
33
 
 
34
  router = APIRouter(
35
  tags=["v1"],
36
  responses={404: {"description": "Not found"}},
37
  )
38
 
39
 
40
+ class ModelKind(str, Enum):
 
 
 
 
41
  TEXT = "text"
42
  IMAGE = "image"
43
 
44
 
45
+ def detect_model_kind(model_id: str) -> ModelKind:
46
  """
47
+ Detect whether model_id is for a text or an image model.
48
+ Raises ValueError if unrecognized.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """
50
+ if model_id in [m.value for m in TextModelType]:
51
+ return ModelKind.TEXT
52
+ elif model_id in [m.value for m in ImageModelType]:
53
+ return ModelKind.IMAGE
54
+ else:
55
+ raise ValueError(
56
+ f"Unrecognized model ID: {model_id}.\n"
57
+ f"Valid text: {[m.value for m in TextModelType]}\n"
58
+ f"Valid image: {[m.value for m in ImageModelType]}"
59
+ )
 
 
 
 
 
 
 
60
 
61
 
 
62
  class EmbeddingRequest(BaseModel):
63
  """
64
+ Input to /v1/embeddings
 
 
 
 
 
 
 
 
65
  """
66
 
67
  model: str = Field(
68
  default=TextModelType.MULTILINGUAL_E5_SMALL.value,
69
  description=(
70
+ "Which model ID to use? "
71
+ "Text: ['multilingual-e5-small', 'multilingual-e5-base', 'multilingual-e5-large', 'snowflake-arctic-embed-l-v2.0', 'paraphrase-multilingual-MiniLM-L12-v2', 'paraphrase-multilingual-mpnet-base-v2', 'bge-m3']. "
72
+ "Image: ['siglip-base-patch16-256-multilingual']."
 
73
  ),
74
  )
75
  input: Union[str, List[str]] = Field(
76
+ ..., description="Text(s) or Image URL(s)/path(s)."
 
 
 
 
77
  )
78
 
79
 
80
  class RankRequest(BaseModel):
81
  """
82
+ Input to /v1/rank
 
 
 
 
 
 
 
 
83
  """
84
 
85
  model: str = Field(
86
  default=TextModelType.MULTILINGUAL_E5_SMALL.value,
87
  description=(
88
+ "Model ID for the queries. "
89
+ "Text or Image model, e.g. 'siglip-base-patch16-256-multilingual' for images."
 
90
  ),
91
  )
92
  queries: Union[str, List[str]] = Field(
93
+ ..., description="Query text or image(s) depending on the model type."
 
 
 
 
94
  )
95
  candidates: List[str] = Field(
96
+ ..., description="Candidate texts to rank. Must be text."
 
 
 
 
97
  )
98
 
99
 
100
  class EmbeddingResponse(BaseModel):
101
  """
102
+ Response of /v1/embeddings
103
  """
104
 
105
+ object: str
106
  data: List[dict]
107
  model: str
108
  usage: dict
 
110
 
111
  class RankResponse(BaseModel):
112
  """
113
+ Response of /v1/rank
114
  """
115
 
116
  probabilities: List[List[float]]
117
  cosine_similarities: List[List[float]]
118
 
 
 
119
  service_config = ModelConfig()
120
  embeddings_service = EmbeddingsService(config=service_config)
121
 
 
123
  @router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
124
  async def create_embeddings(request: EmbeddingRequest):
125
  """
126
+ Generates embeddings for the given input (text or image).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  """
128
  try:
129
+ # 1) Determine if it's text or image
130
+ mkind = detect_model_kind(request.model)
131
 
132
+ # 2) Update global service config so it uses the correct model
133
+ if mkind == ModelKind.TEXT:
134
  service_config.text_model_type = TextModelType(request.model)
135
  else:
136
+ service_config.image_model_type = ImageModelType(request.model)
137
 
138
+ # 3) Generate
139
  embeddings = await embeddings_service.generate_embeddings(
140
+ input_data=request.input, modality=mkind.value
141
  )
142
 
143
+ # 4) Estimate tokens for text only
144
  total_tokens = 0
145
+ if mkind == ModelKind.TEXT:
146
  total_tokens = embeddings_service.estimate_tokens(request.input)
147
 
148
+ resp = {
149
  "object": "list",
150
+ "data": [],
 
 
 
 
 
 
 
151
  "model": request.model,
152
  "usage": {
153
  "prompt_tokens": total_tokens,
154
  "total_tokens": total_tokens,
155
  },
156
  }
157
+ for idx, emb in enumerate(embeddings):
158
+ resp["data"].append(
159
+ {
160
+ "object": "embedding",
161
+ "index": idx,
162
+ "embedding": emb.tolist(),
163
+ }
164
+ )
165
+
166
+ return resp
167
 
168
  except Exception as e:
169
+ msg = (
170
+ "Failed to generate embeddings. Check model ID, inputs, etc.\n"
171
+ f"Details: {str(e)}"
172
  )
173
+ logger.error(msg)
174
+ raise HTTPException(status_code=500, detail=msg)
175
 
176
 
177
  @router.post("/rank", response_model=RankResponse, tags=["rank"])
178
  async def rank_candidates(request: RankRequest):
179
  """
180
+ Ranks candidate texts against the given queries (which can be text or image).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  """
182
  try:
183
+ mkind = detect_model_kind(request.model)
184
 
185
+ if mkind == ModelKind.TEXT:
 
186
  service_config.text_model_type = TextModelType(request.model)
187
  else:
188
+ service_config.image_model_type = ImageModelType(request.model)
189
 
 
190
  results = await embeddings_service.rank(
191
  queries=request.queries,
192
  candidates=request.candidates,
193
+ modality=mkind.value,
194
  )
195
  return results
196
 
197
  except Exception as e:
198
+ msg = (
199
+ "Failed to rank candidates. Check model ID, inputs, etc.\n"
200
+ f"Details: {str(e)}"
201
  )
202
+ logger.error(msg)
203
+ raise HTTPException(status_code=500, detail=msg)
lightweight_embeddings/service.py CHANGED
@@ -1,12 +1,10 @@
1
- # filename: service.py
2
-
3
  """
4
- Lightweight Embeddings Service Module
5
 
6
  This module provides a service for generating and comparing embeddings from text and images
7
  using state-of-the-art transformer models. It supports both CPU and GPU inference.
8
 
9
- Key Features:
10
  - Text and image embedding generation
11
  - Cross-modal similarity ranking
12
  - Batch processing support
@@ -17,8 +15,8 @@ Supported Text Model IDs:
17
  - "paraphrase-multilingual-MiniLM-L12-v2"
18
  - "bge-m3"
19
 
20
- Supported Image Model ID (default):
21
- - "google/siglip-base-patch16-256-multilingual"
22
  """
23
 
24
  from __future__ import annotations
@@ -37,441 +35,351 @@ from PIL import Image
37
  from sentence_transformers import SentenceTransformer
38
  from transformers import AutoProcessor, AutoModel
39
 
40
- # Configure logging
41
  logger = logging.getLogger(__name__)
42
  logging.basicConfig(level=logging.INFO)
43
 
44
- # Default Model IDs
45
- TEXT_MODEL_ID = "Xenova/multilingual-e5-small"
46
- IMAGE_MODEL_ID = "google/siglip-base-patch16-256-multilingual"
47
-
48
 
49
  class TextModelType(str, Enum):
50
  """
51
  Enumeration of supported text models.
52
- Please ensure the ONNX files and Hugging Face model IDs are consistent
53
- with your local or remote environment.
54
  """
55
 
56
  MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
 
 
 
57
  PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = "paraphrase-multilingual-MiniLM-L12-v2"
 
58
  BGE_M3 = "bge-m3"
59
 
60
 
 
 
 
 
 
 
 
 
61
  class ModelInfo(NamedTuple):
62
  """
63
- Simple container for mapping a given text model type
64
- to its Hugging Face model repository and the local ONNX file path.
 
65
  """
66
 
67
  model_id: str
68
- onnx_file: str
69
 
70
 
71
  @dataclass
72
  class ModelConfig:
73
  """
74
- Configuration settings for model providers, backends, and defaults.
75
  """
76
 
77
- provider: str = "CPUExecutionProvider"
78
- backend: str = "onnx"
79
- logit_scale: float = 4.60517
80
  text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL
81
- image_model_id: str = IMAGE_MODEL_ID
 
 
 
 
 
82
 
83
  @property
84
  def text_model_info(self) -> ModelInfo:
85
  """
86
- Retrieves the ModelInfo for the currently selected text_model_type.
87
  """
88
- model_configs = {
89
  TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
90
- "Xenova/multilingual-e5-small",
91
- "onnx/model_quantized.onnx",
 
 
 
 
 
 
 
 
 
 
 
 
92
  ),
93
  TextModelType.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2: ModelInfo(
94
- "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
95
- "onnx/model_quint8_avx2.onnx",
 
 
 
 
96
  ),
97
  TextModelType.BGE_M3: ModelInfo(
98
- "BAAI/bge-m3",
99
- "model.onnx",
100
  ),
101
  }
102
- return model_configs[self.text_model_type]
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  class EmbeddingsService:
106
  """
107
- Service for generating and comparing text/image embeddings.
108
-
109
- This service supports multiple text models and a single image model.
110
- It provides methods for:
111
- - Generating text embeddings
112
- - Generating image embeddings
113
- - Ranking candidates by similarity
114
  """
115
 
116
- def __init__(self, config: Optional[ModelConfig] = None) -> None:
117
- """
118
- Initialize the EmbeddingsService.
119
-
120
- Args:
121
- config: Optional ModelConfig object to override default settings.
122
- """
123
- # Determine whether GPU (CUDA) is available
124
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
125
-
126
- # Use the provided config or fall back to defaults
127
  self.config = config or ModelConfig()
128
 
129
- # Dictionary to hold multiple text models
130
  self.text_models: Dict[TextModelType, SentenceTransformer] = {}
 
 
131
 
132
- # Load all models (text + image) into memory
133
- self._load_models()
134
 
135
- def _load_models(self) -> None:
136
  """
137
- Load text and image models into memory.
138
-
139
- This pre-loads all text models defined in the TextModelType enum
140
- and a single image model, enabling quick switching at runtime.
141
  """
142
  try:
143
- # Load all text models
144
- for model_type in TextModelType:
145
- model_info = ModelConfig(text_model_type=model_type).text_model_info
146
- logger.info(f"Loading text model: {model_info.model_id}")
147
-
148
- self.text_models[model_type] = SentenceTransformer(
149
- model_info.model_id,
150
- device=self.device,
151
- backend=self.config.backend,
152
- model_kwargs={
153
- "provider": self.config.provider,
154
- "file_name": model_info.onnx_file,
155
- },
156
- )
157
-
158
- logger.info(f"Loading image model: {self.config.image_model_id}")
159
- self.image_model = AutoModel.from_pretrained(self.config.image_model_id).to(
160
- self.device
161
- )
162
- self.image_processor = AutoProcessor.from_pretrained(
163
- self.config.image_model_id
164
- )
165
-
166
- logger.info(f"All models loaded successfully on {self.device}.")
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  except Exception as e:
169
- logger.error(
170
- "Model loading failed. Please ensure you have valid model IDs and local files.\n"
171
- f"Error details: {str(e)}"
172
- )
173
- raise RuntimeError(f"Failed to load models: {str(e)}") from e
174
 
175
  @staticmethod
176
  def _validate_text_input(input_text: Union[str, List[str]]) -> List[str]:
177
  """
178
- Validate and standardize the input for text embeddings.
179
-
180
- Args:
181
- input_text: Either a single string or a list of strings.
182
-
183
- Returns:
184
- A list of strings to process.
185
-
186
- Raises:
187
- ValueError: If input_text is empty or not string-based.
188
  """
189
  if isinstance(input_text, str):
 
 
190
  return [input_text]
 
191
  if not isinstance(input_text, list) or not all(
192
  isinstance(x, str) for x in input_text
193
  ):
194
- raise ValueError(
195
- "Text input must be a single string or a list of strings. "
196
- "Found a different data type instead."
197
- )
198
- if not input_text:
199
  raise ValueError("Text input list cannot be empty.")
 
200
  return input_text
201
 
202
  @staticmethod
203
  def _validate_modality(modality: str) -> None:
204
- """
205
- Validate the input modality.
206
-
207
- Args:
208
- modality: Must be either 'text' or 'image'.
209
 
210
- Raises:
211
- ValueError: If modality is neither 'text' nor 'image'.
212
  """
213
- if modality not in ["text", "image"]:
214
- raise ValueError(
215
- "Invalid modality. Please specify 'text' or 'image' for embeddings."
216
- )
217
-
218
- def _process_image(self, image_path: Union[str, Path]) -> torch.Tensor:
219
- """
220
- Load and preprocess an image from either a local path or a URL.
221
-
222
- Args:
223
- image_path: Path to the local image file or a URL.
224
-
225
- Returns:
226
- Torch Tensor suitable for model input.
227
-
228
- Raises:
229
- ValueError: If the image file or URL cannot be loaded.
230
  """
231
  try:
232
- if str(image_path).startswith("http"):
233
- response = requests.get(image_path, timeout=10)
234
- response.raise_for_status()
235
- image_content = BytesIO(response.content)
236
  else:
237
- image_content = image_path
238
-
239
- image = Image.open(image_content).convert("RGB")
240
- processed = self.image_processor(images=image, return_tensors="pt").to(
241
- self.device
242
- )
243
- return processed
244
-
245
  except Exception as e:
246
- raise ValueError(
247
- f"Failed to process image at '{image_path}'. Check the path/URL and file format.\n"
248
- f"Details: {str(e)}"
249
- ) from e
250
 
251
  def _generate_text_embeddings(self, texts: List[str]) -> np.ndarray:
252
  """
253
- Helper method to generate text embeddings for a list of texts
254
- using the currently configured text model.
255
-
256
- Args:
257
- texts: A list of text strings.
258
-
259
- Returns:
260
- Numpy array of shape (num_texts, embedding_dim).
261
-
262
- Raises:
263
- RuntimeError: If the text model fails to generate embeddings.
264
  """
265
  try:
266
- logger.info(
267
- f"Generating embeddings for {len(texts)} text items using model: "
268
- f"{self.config.text_model_type}"
269
- )
270
- # Select the preloaded text model based on the current config
271
  model = self.text_models[self.config.text_model_type]
272
- embeddings = model.encode(texts)
273
  return embeddings
274
  except Exception as e:
275
- error_msg = (
276
- f"Error generating text embeddings with model: {self.config.text_model_type}. "
277
- f"Details: {str(e)}"
278
- )
279
- logger.error(error_msg)
280
- raise RuntimeError(error_msg) from e
281
 
282
  def _generate_image_embeddings(
283
- self, input_data: Union[str, List[str]], batch_size: Optional[int]
 
 
284
  ) -> np.ndarray:
285
  """
286
- Helper method to generate image embeddings.
287
-
288
- Args:
289
- input_data: Either a single image path/URL or a list of them.
290
- batch_size: Batch size for processing images in chunks.
291
- If None, process all at once.
292
-
293
- Returns:
294
- Numpy array of shape (num_images, embedding_dim).
295
-
296
- Raises:
297
- RuntimeError: If the image model fails to generate embeddings.
298
  """
299
  try:
300
- if isinstance(input_data, str):
301
- # Single image scenario
302
- processed = self._process_image(input_data)
 
 
303
  with torch.no_grad():
304
- embedding = self.image_model.get_image_features(**processed)
305
- return embedding.cpu().numpy()
306
 
307
- # Multiple images scenario
308
- logger.info(f"Generating embeddings for {len(input_data)} images.")
309
  if batch_size is None:
310
- # Process all images at once
311
- processed_batches = [
312
- self._process_image(img_path) for img_path in input_data
313
- ]
 
 
 
314
  with torch.no_grad():
315
- # Concatenate all images along the batch dimension
316
- batch_keys = processed_batches[0].keys()
317
- concatenated = {
318
- k: torch.cat([pb[k] for pb in processed_batches], dim=0)
319
- for k in batch_keys
320
- }
321
- embedding = self.image_model.get_image_features(**concatenated)
322
- return embedding.cpu().numpy()
323
-
324
- # Process images in smaller batches
325
- embeddings_list = []
326
- for i, img_path in enumerate(input_data):
327
- if i % batch_size == 0:
328
- logger.debug(
329
- f"Processing image batch {i // batch_size + 1} with size up to {batch_size}."
330
- )
331
- processed = self._process_image(img_path)
332
  with torch.no_grad():
333
- embedding = self.image_model.get_image_features(**processed)
334
- embeddings_list.append(embedding.cpu().numpy())
335
 
336
- return np.vstack(embeddings_list)
337
 
338
  except Exception as e:
339
- error_msg = (
340
- f"Error generating image embeddings with model: {self.config.image_model_id}. "
341
- f"Details: {str(e)}"
342
- )
343
- logger.error(error_msg)
344
- raise RuntimeError(error_msg) from e
345
 
346
  async def generate_embeddings(
347
  self,
348
  input_data: Union[str, List[str]],
349
- modality: Literal["text", "image"] = "text",
350
  batch_size: Optional[int] = None,
351
  ) -> np.ndarray:
352
  """
353
- Asynchronously generate embeddings for text or image inputs.
354
-
355
- Args:
356
- input_data: A string or list of strings (text/image paths/URLs).
357
- modality: "text" for text data or "image" for image data.
358
- batch_size: Optional batch size for processing images in chunks.
359
-
360
- Returns:
361
- Numpy array of embeddings.
362
-
363
- Raises:
364
- ValueError: If the modality is invalid.
365
  """
366
  self._validate_modality(modality)
367
-
368
  if modality == "text":
369
- texts = self._validate_text_input(input_data)
370
- return self._generate_text_embeddings(texts)
371
  else:
372
- return self._generate_image_embeddings(input_data, batch_size)
373
 
374
  async def rank(
375
  self,
376
  queries: Union[str, List[str]],
377
  candidates: List[str],
378
- modality: Literal["text", "image"] = "text",
379
  batch_size: Optional[int] = None,
380
  ) -> Dict[str, List[List[float]]]:
381
  """
382
- Rank a set of candidate texts against one or more queries using cosine similarity
383
- and a softmax to produce probability-like scores.
384
-
385
- Args:
386
- queries: Query text(s) or image path(s)/URL(s).
387
- candidates: Candidate texts to be ranked.
388
- (Note: This implementation always treats candidates as text.)
389
- modality: "text" for text queries or "image" for image queries.
390
- batch_size: Batch size if images are processed in chunks.
391
-
392
- Returns:
393
- Dictionary containing:
394
- - "probabilities": 2D list of softmax-normalized scores.
395
- - "cosine_similarities": 2D list of raw cosine similarity values.
396
-
397
- Raises:
398
- RuntimeError: If the query or candidate embeddings fail to generate.
399
  """
400
- logger.info(
401
- f"Ranking {len(candidates)} candidates against "
402
- f"{len(queries) if isinstance(queries, list) else 1} query item(s)."
403
- )
404
-
405
- # Generate embeddings for queries
406
- query_embeds = await self.generate_embeddings(
407
- queries, modality=modality, batch_size=batch_size
408
- )
409
-
410
- # Generate embeddings for candidates (always text)
411
- candidate_embeds = await self.generate_embeddings(
412
- candidates, modality="text", batch_size=batch_size
413
- )
414
-
415
- # Compute cosine similarity and scaled probabilities
416
- cosine_sims = self.cosine_similarity(query_embeds, candidate_embeds)
417
- logit_scale = np.exp(self.config.logit_scale)
418
- probabilities = self.softmax(logit_scale * cosine_sims)
419
 
420
  return {
421
- "probabilities": probabilities.tolist(),
422
- "cosine_similarities": cosine_sims.tolist(),
423
  }
424
 
425
  def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
426
  """
427
- Roughly estimate the total number of tokens in the given text(s).
428
-
429
- Args:
430
- input_data: A string or list of strings representing text input.
431
-
432
- Returns:
433
- Estimated token count (int).
434
-
435
- Raises:
436
- ValueError: If the input is not valid text data.
437
  """
438
  texts = self._validate_text_input(input_data)
439
- # Very rough approximation: assume ~4 characters per token
440
  total_chars = sum(len(t) for t in texts)
441
  return max(1, round(total_chars / 4))
442
 
443
  @staticmethod
444
  def softmax(scores: np.ndarray) -> np.ndarray:
445
  """
446
- Apply softmax along the last dimension of the scores array.
447
-
448
- Args:
449
- scores: Numpy array of shape (..., num_candidates).
450
-
451
- Returns:
452
- Numpy array of softmax-normalized values, same shape as scores.
453
  """
454
- exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
455
- return exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
456
 
457
  @staticmethod
458
- def cosine_similarity(
459
- query_embeds: np.ndarray, candidate_embeds: np.ndarray
460
- ) -> np.ndarray:
461
  """
462
- Compute the cosine similarity between two sets of vectors.
463
-
464
- Args:
465
- query_embeds: Numpy array of shape (num_queries, embed_dim).
466
- candidate_embeds: Numpy array of shape (num_candidates, embed_dim).
467
-
468
- Returns:
469
- 2D Numpy array of shape (num_queries, num_candidates)
470
- containing cosine similarity scores.
471
  """
472
- # Normalize embeddings
473
- query_norm = query_embeds / np.linalg.norm(query_embeds, axis=1, keepdims=True)
474
- candidate_norm = candidate_embeds / np.linalg.norm(
475
- candidate_embeds, axis=1, keepdims=True
476
- )
477
- return np.dot(query_norm, candidate_norm.T)
 
 
 
1
  """
2
+ Lightweight Embeddings Service Module (Revised & Simplified)
3
 
4
  This module provides a service for generating and comparing embeddings from text and images
5
  using state-of-the-art transformer models. It supports both CPU and GPU inference.
6
 
7
+ Features:
8
  - Text and image embedding generation
9
  - Cross-modal similarity ranking
10
  - Batch processing support
 
15
  - "paraphrase-multilingual-MiniLM-L12-v2"
16
  - "bge-m3"
17
 
18
+ Supported Image Model IDs:
19
+ - "google/siglip-base-patch16-256-multilingual" (default, but extensible)
20
  """
21
 
22
  from __future__ import annotations
 
35
  from sentence_transformers import SentenceTransformer
36
  from transformers import AutoProcessor, AutoModel
37
 
 
38
  logger = logging.getLogger(__name__)
39
  logging.basicConfig(level=logging.INFO)
40
 
 
 
 
 
41
 
42
  class TextModelType(str, Enum):
43
  """
44
  Enumeration of supported text models.
45
+ Adjust as needed for your environment.
 
46
  """
47
 
48
  MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
49
+ MULTILINGUAL_E5_BASE = "multilingual-e5-base"
50
+ MULTILINGUAL_E5_LARGE = "multilingual-e5-large"
51
+ SNOWFLAKE_ARCTIC_EMBED_L_V2 = "snowflake-arctic-embed-l-v2.0"
52
  PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = "paraphrase-multilingual-MiniLM-L12-v2"
53
+ PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = "paraphrase-multilingual-mpnet-base-v2"
54
  BGE_M3 = "bge-m3"
55
 
56
 
57
+ class ImageModelType(str, Enum):
58
+ """
59
+ Enumeration of supported image models.
60
+ """
61
+
62
+ SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
63
+
64
+
65
  class ModelInfo(NamedTuple):
66
  """
67
+ Simple container that maps an enum to:
68
+ - model_id: Hugging Face model ID (or local path)
69
+ - onnx_file: Path to ONNX file (if available)
70
  """
71
 
72
  model_id: str
73
+ onnx_file: Optional[str] = None
74
 
75
 
76
  @dataclass
77
  class ModelConfig:
78
  """
79
+ Configuration for text and image models.
80
  """
81
 
 
 
 
82
  text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL
83
+ image_model_type: ImageModelType = (
84
+ ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL
85
+ )
86
+
87
+ # If you need extra parameters like `logit_scale`, etc., keep them here
88
+ logit_scale: float = 4.60517
89
 
90
  @property
91
  def text_model_info(self) -> ModelInfo:
92
  """
93
+ Return ModelInfo for the configured text_model_type.
94
  """
95
+ text_configs = {
96
  TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
97
+ model_id="Xenova/multilingual-e5-small",
98
+ onnx_file="onnx/model_quantized.onnx",
99
+ ),
100
+ TextModelType.MULTILINGUAL_E5_BASE: ModelInfo(
101
+ model_id="Xenova/multilingual-e5-base",
102
+ onnx_file="onnx/model_quantized.onnx",
103
+ ),
104
+ TextModelType.MULTILINGUAL_E5_LARGE: ModelInfo(
105
+ model_id="Xenova/multilingual-e5-large",
106
+ onnx_file="onnx/model_quantized.onnx",
107
+ ),
108
+ TextModelType.SNOWFLAKE_ARCTIC_EMBED_L_V2: ModelInfo(
109
+ model_id="Snowflake/snowflake-arctic-embed-l-v2.0",
110
+ onnx_file="onnx/model_quantized.onnx",
111
  ),
112
  TextModelType.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2: ModelInfo(
113
+ model_id="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
114
+ onnx_file="onnx/model_quint8_avx2.onnx",
115
+ ),
116
+ TextModelType.PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2: ModelInfo(
117
+ model_id="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
118
+ onnx_file="onnx/model_quint8_avx2.onnx",
119
  ),
120
  TextModelType.BGE_M3: ModelInfo(
121
+ model_id="BAAI/bge-m3",
122
+ onnx_file="onnx/model.onnx",
123
  ),
124
  }
125
+ return text_configs[self.text_model_type]
126
+
127
+ @property
128
+ def image_model_info(self) -> ModelInfo:
129
+ """
130
+ Return ModelInfo for the configured image_model_type.
131
+ """
132
+ image_configs = {
133
+ ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo(
134
+ model_id="google/siglip-base-patch16-256-multilingual"
135
+ ),
136
+ }
137
+ return image_configs[self.image_model_type]
138
 
139
 
140
  class EmbeddingsService:
141
  """
142
+ Service for generating text/image embeddings and performing ranking.
 
 
 
 
 
 
143
  """
144
 
145
+ def __init__(self, config: Optional[ModelConfig] = None):
 
 
 
 
 
 
 
146
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
147
  self.config = config or ModelConfig()
148
 
149
+ # Preloaded text & image models
150
  self.text_models: Dict[TextModelType, SentenceTransformer] = {}
151
+ self.image_models: Dict[ImageModelType, AutoModel] = {}
152
+ self.image_processors: Dict[ImageModelType, AutoProcessor] = {}
153
 
154
+ # Load all models
155
+ self._load_all_models()
156
 
157
+ def _load_all_models(self) -> None:
158
  """
159
+ Pre-load all known text and image models for quick switching.
 
 
 
160
  """
161
  try:
162
+ for t_model_type in TextModelType:
163
+ info = ModelConfig(text_model_type=t_model_type).text_model_info
164
+ logger.info("Loading text model: %s", info.model_id)
165
+
166
+ # If you have an ONNX file AND your SentenceTransformer supports ONNX
167
+ if info.onnx_file:
168
+ logger.info("Using ONNX file: %s", info.onnx_file)
169
+ # The following 'backend' & 'model_kwargs' parameters
170
+ # are recognized only in special/certain distributions of SentenceTransformer
171
+ self.text_models[t_model_type] = SentenceTransformer(
172
+ info.model_id,
173
+ device=self.device,
174
+ backend="onnx", # or "ort" in some custom forks
175
+ model_kwargs={
176
+ "provider": "CPUExecutionProvider", # or "CUDAExecutionProvider"
177
+ "file_name": info.onnx_file,
178
+ },
179
+ )
180
+ else:
181
+ # Fallback: standard HF loading
182
+ self.text_models[t_model_type] = SentenceTransformer(
183
+ info.model_id, device=self.device
184
+ )
 
185
 
186
+ for i_model_type in ImageModelType:
187
+ model_id = ModelConfig(
188
+ image_model_type=i_model_type
189
+ ).image_model_info.model_id
190
+ logger.info("Loading image model: %s", model_id)
191
+
192
+ # Typically, for CLIP-like models:
193
+ model = AutoModel.from_pretrained(model_id).to(self.device)
194
+ processor = AutoProcessor.from_pretrained(model_id)
195
+
196
+ self.image_models[i_model_type] = model
197
+ self.image_processors[i_model_type] = processor
198
+
199
+ logger.info("All models loaded successfully.")
200
  except Exception as e:
201
+ msg = f"Error loading models: {str(e)}"
202
+ logger.error(msg)
203
+ raise RuntimeError(msg) from e
 
 
204
 
205
  @staticmethod
206
  def _validate_text_input(input_text: Union[str, List[str]]) -> List[str]:
207
  """
208
+ Ensure input_text is a non-empty string or list of strings.
 
 
 
 
 
 
 
 
 
209
  """
210
  if isinstance(input_text, str):
211
+ if not input_text.strip():
212
+ raise ValueError("Text input cannot be empty.")
213
  return [input_text]
214
+
215
  if not isinstance(input_text, list) or not all(
216
  isinstance(x, str) for x in input_text
217
  ):
218
+ raise ValueError("Text input must be a string or a list of strings.")
219
+
220
+ if len(input_text) == 0:
 
 
221
  raise ValueError("Text input list cannot be empty.")
222
+
223
  return input_text
224
 
225
  @staticmethod
226
  def _validate_modality(modality: str) -> None:
227
+ if modality not in ("text", "image"):
228
+ raise ValueError("Unsupported modality. Must be 'text' or 'image'.")
 
 
 
229
 
230
+ def _process_image(self, path_or_url: Union[str, Path]) -> torch.Tensor:
 
231
  """
232
+ Download/Load image from path/URL and apply transformations.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  """
234
  try:
235
+ if isinstance(path_or_url, Path) or not path_or_url.startswith("http"):
236
+ # Local file path
237
+ img = Image.open(path_or_url).convert("RGB")
 
238
  else:
239
+ # URL
240
+ resp = requests.get(path_or_url, timeout=10)
241
+ resp.raise_for_status()
242
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
243
+
244
+ proc = self.image_processors[self.config.image_model_type]
245
+ data = proc(images=img, return_tensors="pt").to(self.device)
246
+ return data
247
  except Exception as e:
248
+ raise ValueError(f"Error processing image '{path_or_url}': {str(e)}") from e
 
 
 
249
 
250
  def _generate_text_embeddings(self, texts: List[str]) -> np.ndarray:
251
  """
252
+ Generate text embeddings using the currently configured text model.
 
 
 
 
 
 
 
 
 
 
253
  """
254
  try:
 
 
 
 
 
255
  model = self.text_models[self.config.text_model_type]
256
+ embeddings = model.encode(texts) # shape: (num_items, emb_dim)
257
  return embeddings
258
  except Exception as e:
259
+ raise RuntimeError(
260
+ f"Error generating text embeddings for model '{self.config.text_model_type}': {e}"
261
+ ) from e
 
 
 
262
 
263
  def _generate_image_embeddings(
264
+ self,
265
+ images: Union[str, List[str]],
266
+ batch_size: Optional[int] = None,
267
  ) -> np.ndarray:
268
  """
269
+ Generate image embeddings using the currently configured image model.
270
+ If `batch_size` is None, all images are processed at once.
 
 
 
 
 
 
 
 
 
 
271
  """
272
  try:
273
+ model = self.image_models[self.config.image_model_type]
274
+
275
+ # Single image
276
+ if isinstance(images, str):
277
+ processed = self._process_image(images)
278
  with torch.no_grad():
279
+ emb = model.get_image_features(**processed)
280
+ return emb.cpu().numpy()
281
 
282
+ # Multiple images
 
283
  if batch_size is None:
284
+ # Process them all in one batch
285
+ tensors = []
286
+ for img_path in images:
287
+ tensors.append(self._process_image(img_path))
288
+ # Concatenate
289
+ keys = tensors[0].keys()
290
+ combined = {k: torch.cat([t[k] for t in tensors], dim=0) for k in keys}
291
  with torch.no_grad():
292
+ emb = model.get_image_features(**combined)
293
+ return emb.cpu().numpy()
294
+
295
+ # Process in smaller batches
296
+ all_embeddings = []
297
+ for i in range(0, len(images), batch_size):
298
+ batch_images = images[i : i + batch_size]
299
+ # Process each sub-batch
300
+ tensors = []
301
+ for img_path in batch_images:
302
+ tensors.append(self._process_image(img_path))
303
+ keys = tensors[0].keys()
304
+ combined = {k: torch.cat([t[k] for t in tensors], dim=0) for k in keys}
305
+
 
 
 
306
  with torch.no_grad():
307
+ emb = model.get_image_features(**combined)
308
+ all_embeddings.append(emb.cpu().numpy())
309
 
310
+ return np.vstack(all_embeddings)
311
 
312
  except Exception as e:
313
+ raise RuntimeError(
314
+ f"Error generating image embeddings for model '{self.config.image_model_type}': {e}"
315
+ ) from e
 
 
 
316
 
317
  async def generate_embeddings(
318
  self,
319
  input_data: Union[str, List[str]],
320
+ modality: Literal["text", "image"],
321
  batch_size: Optional[int] = None,
322
  ) -> np.ndarray:
323
  """
324
+ Asynchronously generate embeddings for text or image.
 
 
 
 
 
 
 
 
 
 
 
325
  """
326
  self._validate_modality(modality)
 
327
  if modality == "text":
328
+ text_list = self._validate_text_input(input_data)
329
+ return self._generate_text_embeddings(text_list)
330
  else:
331
+ return self._generate_image_embeddings(input_data, batch_size=batch_size)
332
 
333
  async def rank(
334
  self,
335
  queries: Union[str, List[str]],
336
  candidates: List[str],
337
+ modality: Literal["text", "image"],
338
  batch_size: Optional[int] = None,
339
  ) -> Dict[str, List[List[float]]]:
340
  """
341
+ Rank candidates (always text) against the queries, which may be text or image.
342
+ Returns dict of { probabilities, cosine_similarities }.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  """
344
+ # 1) Generate embeddings for queries
345
+ query_embeds = await self.generate_embeddings(queries, modality, batch_size)
346
+ # 2) Generate embeddings for text candidates
347
+ candidate_embeds = await self.generate_embeddings(candidates, "text")
348
+
349
+ # 3) Compute cosine sim
350
+ sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
351
+ # 4) Apply logit scale + softmax
352
+ scaled = np.exp(self.config.logit_scale) * sim_matrix
353
+ probs = self.softmax(scaled)
 
 
 
 
 
 
 
 
 
354
 
355
  return {
356
+ "probabilities": probs.tolist(),
357
+ "cosine_similarities": sim_matrix.tolist(),
358
  }
359
 
360
  def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
361
  """
362
+ Very rough heuristic: ~4 chars per token.
 
 
 
 
 
 
 
 
 
363
  """
364
  texts = self._validate_text_input(input_data)
 
365
  total_chars = sum(len(t) for t in texts)
366
  return max(1, round(total_chars / 4))
367
 
368
  @staticmethod
369
  def softmax(scores: np.ndarray) -> np.ndarray:
370
  """
371
+ Standard softmax along the last dimension.
 
 
 
 
 
 
372
  """
373
+ exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
374
+ return exps / np.sum(exps, axis=-1, keepdims=True)
375
 
376
  @staticmethod
377
+ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
 
 
378
  """
379
+ a: (N, D)
380
+ b: (M, D)
381
+ Return: (N, M) of cos sim
 
 
 
 
 
 
382
  """
383
+ a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9)
384
+ b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9)
385
+ return np.dot(a_norm, b_norm.T)