lamhieu commited on
Commit
86d6248
·
1 Parent(s): b2c7d24

chore: update something

Browse files
lightweight_embeddings/__init__.py CHANGED
@@ -124,11 +124,23 @@ def call_embeddings_api(user_input: str, selected_model: str) -> str:
124
 
125
  try:
126
  data = response.json()
127
- return json.dumps(data, indent=2)
128
  except ValueError:
129
  return "❌ Failed to parse JSON from API response."
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  def create_main_interface():
133
  """
134
  Creates a Gradio Blocks interface showing project info and an embeddings playground.
@@ -147,10 +159,7 @@ def create_main_interface():
147
  ]
148
 
149
  with gr.Blocks(title="Lightweight Embeddings", theme="default") as demo:
150
- # Project Info
151
- gr.Markdown(APP_DESCRIPTION)
152
-
153
- # Split Layout: Playground and cURL Examples
154
  with gr.Row():
155
  with gr.Column():
156
  gr.Markdown("### 🔬 Try the Embeddings Playground")
@@ -171,7 +180,6 @@ def create_main_interface():
171
  interactive=False,
172
  )
173
 
174
- # Link button to inference function
175
  generate_btn.click(
176
  fn=call_embeddings_api,
177
  inputs=[input_text, model_dropdown],
@@ -214,6 +222,14 @@ def create_main_interface():
214
  """
215
  )
216
 
 
 
 
 
 
 
 
 
217
  return demo
218
 
219
 
 
124
 
125
  try:
126
  data = response.json()
127
+ return json.dumps(data, indent=2, ensure_ascii=False)
128
  except ValueError:
129
  return "❌ Failed to parse JSON from API response."
130
 
131
 
132
+ def call_stats_api() -> str:
133
+ """
134
+ Calls the /v1/stats endpoint to retrieve analytics data.
135
+ Returns the JSON response as a formatted string.
136
+ """
137
+ url = "https://lamhieu-lightweight-embeddings.hf.space/v1/stats"
138
+ response = requests.get(url)
139
+ if response.status_code != 200:
140
+ raise ValueError(f"Failed to fetch stats: {response.text}")
141
+ return json.dumps(response.json(), indent=2, ensure_ascii=False)
142
+
143
+
144
  def create_main_interface():
145
  """
146
  Creates a Gradio Blocks interface showing project info and an embeddings playground.
 
159
  ]
160
 
161
  with gr.Blocks(title="Lightweight Embeddings", theme="default") as demo:
162
+ # ...existing code...
 
 
 
163
  with gr.Row():
164
  with gr.Column():
165
  gr.Markdown("### 🔬 Try the Embeddings Playground")
 
180
  interactive=False,
181
  )
182
 
 
183
  generate_btn.click(
184
  fn=call_embeddings_api,
185
  inputs=[input_text, model_dropdown],
 
222
  """
223
  )
224
 
225
+ # NEW STATS SECTION
226
+ with gr.Accordion("Analytics Stats"):
227
+ stats_btn = gr.Button("Get Stats")
228
+ stats_json = gr.Textbox(
229
+ label="Stats API Response", lines=10, interactive=False
230
+ )
231
+ stats_btn.click(fn=call_stats_api, inputs=[], outputs=stats_json)
232
+
233
  return demo
234
 
235
 
lightweight_embeddings/analytics.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import redis.asyncio as redis
3
+ from datetime import datetime
4
+ from collections import defaultdict
5
+ from typing import Dict
6
+
7
+
8
+ class Analytics:
9
+ def __init__(self, redis_url: str, sync_interval: int = 60):
10
+ """
11
+ Initializes the Analytics class with an async Redis connection and sync interval.
12
+
13
+ Parameters:
14
+ - redis_url: Redis connection URL (e.g., 'redis://localhost:6379/0')
15
+ - sync_interval: Interval in seconds for syncing with Redis.
16
+ """
17
+ self.pool = redis.ConnectionPool.from_url(redis_url, decode_responses=True)
18
+ self.redis_client = redis.Redis(connection_pool=self.pool)
19
+ self.local_buffer = defaultdict(
20
+ lambda: defaultdict(int)
21
+ ) # {period: {model_id: count}}
22
+ self.sync_interval = sync_interval
23
+ self.lock = asyncio.Lock() # Async lock for thread-safe updates
24
+ asyncio.create_task(self._start_sync_task())
25
+
26
+ def _get_period_keys(self) -> tuple:
27
+ """
28
+ Returns keys for day, week, month, and year based on the current date.
29
+ """
30
+ now = datetime.utcnow()
31
+ day_key = now.strftime("%Y-%m-%d")
32
+ week_key = f"{now.year}-W{now.strftime('%U')}"
33
+ month_key = now.strftime("%Y-%m")
34
+ year_key = now.strftime("%Y")
35
+ return day_key, week_key, month_key, year_key
36
+
37
+ async def access(self, model_id: str):
38
+ """
39
+ Records an access for a specific model_id.
40
+ """
41
+ day_key, week_key, month_key, year_key = self._get_period_keys()
42
+
43
+ async with self.lock:
44
+ self.local_buffer[day_key][model_id] += 1
45
+ self.local_buffer[week_key][model_id] += 1
46
+ self.local_buffer[month_key][model_id] += 1
47
+ self.local_buffer[year_key][model_id] += 1
48
+ self.local_buffer["total"][model_id] += 1
49
+
50
+ async def stats(self) -> Dict[str, Dict[str, int]]:
51
+ """
52
+ Returns statistics for all models from the local buffer.
53
+ """
54
+ async with self.lock:
55
+ return {
56
+ period: dict(models) for period, models in self.local_buffer.items()
57
+ }
58
+
59
+ async def _sync_to_redis(self):
60
+ """
61
+ Synchronizes local buffer data with Redis.
62
+ """
63
+ async with self.lock:
64
+ pipeline = self.redis_client.pipeline()
65
+ for period, models in self.local_buffer.items():
66
+ for model_id, count in models.items():
67
+ redis_key = f"analytics:{period}"
68
+ pipeline.hincrby(redis_key, model_id, count)
69
+ await pipeline.execute()
70
+ self.local_buffer.clear() # Clear the buffer after sync
71
+
72
+ async def _start_sync_task(self):
73
+ """
74
+ Starts a background task that periodically syncs data to Redis.
75
+ """
76
+ while True:
77
+ await asyncio.sleep(self.sync_interval)
78
+ await self._sync_to_redis()
lightweight_embeddings/router.py CHANGED
@@ -20,12 +20,15 @@ Supported Image Model IDs:
20
  from __future__ import annotations
21
 
22
  import logging
23
- from typing import List, Union
 
24
  from enum import Enum
 
25
 
26
- from fastapi import APIRouter, HTTPException
27
  from pydantic import BaseModel, Field
28
 
 
29
  from .service import (
30
  ModelConfig,
31
  TextModelType,
@@ -120,12 +123,29 @@ class RankResponse(BaseModel):
120
  probabilities: List[List[float]]
121
  cosine_similarities: List[List[float]]
122
 
 
 
 
 
 
 
 
 
 
 
 
123
  service_config = ModelConfig()
124
  embeddings_service = EmbeddingsService(config=service_config)
125
 
 
 
 
 
126
 
127
  @router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
128
- async def create_embeddings(request: EmbeddingRequest):
 
 
129
  """
130
  Generates embeddings for the given input (text or image).
131
  """
@@ -144,6 +164,8 @@ async def create_embeddings(request: EmbeddingRequest):
144
  input_data=request.input, modality=mkind.value
145
  )
146
 
 
 
147
  # 4) Estimate tokens for text only
148
  total_tokens = 0
149
  if mkind == ModelKind.TEXT:
@@ -158,6 +180,7 @@ async def create_embeddings(request: EmbeddingRequest):
158
  "total_tokens": total_tokens,
159
  },
160
  }
 
161
  for idx, emb in enumerate(embeddings):
162
  resp["data"].append(
163
  {
@@ -179,7 +202,7 @@ async def create_embeddings(request: EmbeddingRequest):
179
 
180
 
181
  @router.post("/rank", response_model=RankResponse, tags=["rank"])
182
- async def rank_candidates(request: RankRequest):
183
  """
184
  Ranks candidate texts against the given queries (which can be text or image).
185
  """
@@ -196,6 +219,9 @@ async def rank_candidates(request: RankRequest):
196
  candidates=request.candidates,
197
  modality=mkind.value,
198
  )
 
 
 
199
  return results
200
 
201
  except Exception as e:
@@ -205,3 +231,25 @@ async def rank_candidates(request: RankRequest):
205
  )
206
  logger.error(msg)
207
  raise HTTPException(status_code=500, detail=msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  from __future__ import annotations
21
 
22
  import logging
23
+ import os
24
+ from typing import Dict, Any, List, Union
25
  from enum import Enum
26
+ from datetime import datetime
27
 
28
+ from fastapi import APIRouter, BackgroundTasks, HTTPException
29
  from pydantic import BaseModel, Field
30
 
31
+ from .analytics import Analytics
32
  from .service import (
33
  ModelConfig,
34
  TextModelType,
 
123
  probabilities: List[List[float]]
124
  cosine_similarities: List[List[float]]
125
 
126
+
127
+ class StatsResponse(BaseModel):
128
+ """Analytics stats response model"""
129
+
130
+ total: Dict[str, int]
131
+ daily: Dict[str, int]
132
+ weekly: Dict[str, int]
133
+ monthly: Dict[str, int]
134
+ yearly: Dict[str, int]
135
+
136
+
137
  service_config = ModelConfig()
138
  embeddings_service = EmbeddingsService(config=service_config)
139
 
140
+ analytics = Analytics(
141
+ redis_url=os.environ.get("REDIS_URL", "redis://localhost:6379/0"), sync_interval=60
142
+ )
143
+
144
 
145
  @router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
146
+ async def create_embeddings(
147
+ request: EmbeddingRequest, background_tasks: BackgroundTasks
148
+ ):
149
  """
150
  Generates embeddings for the given input (text or image).
151
  """
 
164
  input_data=request.input, modality=mkind.value
165
  )
166
 
167
+ background_tasks.add_task(analytics.access, request.model)
168
+
169
  # 4) Estimate tokens for text only
170
  total_tokens = 0
171
  if mkind == ModelKind.TEXT:
 
180
  "total_tokens": total_tokens,
181
  },
182
  }
183
+
184
  for idx, emb in enumerate(embeddings):
185
  resp["data"].append(
186
  {
 
202
 
203
 
204
  @router.post("/rank", response_model=RankResponse, tags=["rank"])
205
+ async def rank_candidates(request: RankRequest, background_tasks: BackgroundTasks):
206
  """
207
  Ranks candidate texts against the given queries (which can be text or image).
208
  """
 
219
  candidates=request.candidates,
220
  modality=mkind.value,
221
  )
222
+
223
+ background_tasks.add_task(analytics.access, request.model)
224
+
225
  return results
226
 
227
  except Exception as e:
 
231
  )
232
  logger.error(msg)
233
  raise HTTPException(status_code=500, detail=msg)
234
+
235
+
236
+ @router.get("/stats", response_model=StatsResponse, tags=["stats"])
237
+ async def get_stats():
238
+ """Get usage statistics for all models"""
239
+ try:
240
+ stats = await analytics.stats()
241
+
242
+ return {
243
+ "total": stats.get("total", {}),
244
+ "daily": stats.get(datetime.utcnow().strftime("%Y-%m-%d"), {}),
245
+ "weekly": stats.get(
246
+ f"{datetime.utcnow().year}-W{datetime.utcnow().strftime('%U')}", {}
247
+ ),
248
+ "monthly": stats.get(datetime.utcnow().strftime("%Y-%m"), {}),
249
+ "yearly": stats.get(datetime.utcnow().strftime("%Y"), {}),
250
+ }
251
+
252
+ except Exception as e:
253
+ msg = f"Failed to fetch analytics stats: {str(e)}"
254
+ logger.error(msg)
255
+ raise HTTPException(status_code=500, detail=msg)
requirements.txt CHANGED
@@ -7,3 +7,4 @@ sentence-transformers[onnx]==3.3.1
7
  sentencepiece==0.2.0
8
  torch==2.4.0
9
  transformers==4.45.0
 
 
7
  sentencepiece==0.2.0
8
  torch==2.4.0
9
  transformers==4.45.0
10
+ redis-py=5.2.1