Maximofn commited on
Commit
0065184
·
1 Parent(s): 4b4c28d

feat(SECURITY): :lock: Implement rate limiting for API endpoints and update request handling for text generation and summarization.

Browse files
Files changed (2) hide show
  1. app.py +30 -13
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,9 +4,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  from functools import partial
6
  from fastapi.responses import JSONResponse
7
- from fastapi import Security, Depends
8
  from fastapi.security.api_key import APIKeyHeader, APIKey
9
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
10
 
11
  from langchain_core.messages import HumanMessage, AIMessage
12
  from langgraph.checkpoint.memory import MemorySaver
@@ -16,7 +19,10 @@ import os
16
  from dotenv import load_dotenv
17
  load_dotenv()
18
 
19
- # Configuración de API Key
 
 
 
20
  API_KEY_NAME = "X-API-Key"
21
  API_KEY = os.getenv("API_KEY")
22
  api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
@@ -148,6 +154,10 @@ app = FastAPI(
148
  ]
149
  )
150
 
 
 
 
 
151
  # Configure the security scheme in the OpenAPI documentation
152
  app.openapi_tags = [
153
  {"name": "Authentication", "description": "Protected endpoints that require API Key"}
@@ -186,21 +196,25 @@ async def general_exception_handler(request, exc):
186
 
187
  # Welcome endpoint
188
  @app.get("/")
189
- async def api_home():
 
190
  """Welcome endpoint"""
191
  return {"detail": "Welcome to Máximo Fernández Núñez IriusRisk test challenge"}
192
 
193
  # Generate endpoint
194
  @app.post("/generate")
 
195
  async def generate(
196
- request: QueryRequest,
 
197
  api_key: APIKey = Depends(get_api_key)
198
  ):
199
  """
200
  Endpoint to generate text using the language model
201
 
202
  Args:
203
- request: QueryRequest
 
204
  query: str
205
  thread_id: str = "default"
206
  system_prompt: str = DEFAULT_SYSTEM_PROMPT
@@ -211,16 +225,16 @@ async def generate(
211
  """
212
  try:
213
  # Configure the thread ID
214
- config = {"configurable": {"thread_id": request.thread_id}}
215
 
216
  # Create the input message
217
- input_messages = [HumanMessage(content=request.query)]
218
 
219
  # Invoke the graph with custom system prompt
220
  # Combine config parameters into a single dictionary
221
  combined_config = {
222
  **config,
223
- "model": {"system_prompt": request.system_prompt}
224
  }
225
 
226
  # Invoke the graph with proper argument count
@@ -245,15 +259,18 @@ async def generate(
245
  )
246
 
247
  @app.post("/summarize")
 
248
  async def summarize(
249
- request: SummaryRequest,
 
250
  api_key: APIKey = Depends(get_api_key)
251
  ):
252
  """
253
  Endpoint to generate a summary using the language model
254
 
255
  Args:
256
- request: SummaryRequest
 
257
  text: str - The text to summarize
258
  thread_id: str = "default"
259
  max_length: int = 200 - Maximum summary length
@@ -264,13 +281,13 @@ async def summarize(
264
  """
265
  try:
266
  # Configure the thread ID
267
- config = {"configurable": {"thread_id": request.thread_id}}
268
 
269
  # Create a specific system prompt for summarization
270
- summary_system_prompt = f"Make a summary of the following text in no more than {request.max_length} words. Keep the most important information and eliminate unnecessary details."
271
 
272
  # Create the input message
273
- input_messages = [HumanMessage(content=request.text)]
274
 
275
  # Combine config parameters into a single dictionary
276
  combined_config = {
 
4
  import torch
5
  from functools import partial
6
  from fastapi.responses import JSONResponse
7
+ from fastapi import Security, Depends, Request
8
  from fastapi.security.api_key import APIKeyHeader, APIKey
9
  from fastapi.middleware.cors import CORSMiddleware
10
+ from slowapi import Limiter, _rate_limit_exceeded_handler
11
+ from slowapi.util import get_remote_address
12
+ from slowapi.errors import RateLimitExceeded
13
 
14
  from langchain_core.messages import HumanMessage, AIMessage
15
  from langgraph.checkpoint.memory import MemorySaver
 
19
  from dotenv import load_dotenv
20
  load_dotenv()
21
 
22
+ # Rate Limiter configuration
23
+ limiter = Limiter(key_func=get_remote_address)
24
+
25
+ # API Key configuration
26
  API_KEY_NAME = "X-API-Key"
27
  API_KEY = os.getenv("API_KEY")
28
  api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
 
154
  ]
155
  )
156
 
157
+ # Configure the rate limiter in the application
158
+ app.state.limiter = limiter
159
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
160
+
161
  # Configure the security scheme in the OpenAPI documentation
162
  app.openapi_tags = [
163
  {"name": "Authentication", "description": "Protected endpoints that require API Key"}
 
196
 
197
  # Welcome endpoint
198
  @app.get("/")
199
+ @limiter.limit("10/minute")
200
+ async def api_home(request: Request):
201
  """Welcome endpoint"""
202
  return {"detail": "Welcome to Máximo Fernández Núñez IriusRisk test challenge"}
203
 
204
  # Generate endpoint
205
  @app.post("/generate")
206
+ @limiter.limit("5/minute")
207
  async def generate(
208
+ request: Request,
209
+ query_request: QueryRequest,
210
  api_key: APIKey = Depends(get_api_key)
211
  ):
212
  """
213
  Endpoint to generate text using the language model
214
 
215
  Args:
216
+ request: Request - FastAPI request object for rate limiting
217
+ query_request: QueryRequest
218
  query: str
219
  thread_id: str = "default"
220
  system_prompt: str = DEFAULT_SYSTEM_PROMPT
 
225
  """
226
  try:
227
  # Configure the thread ID
228
+ config = {"configurable": {"thread_id": query_request.thread_id}}
229
 
230
  # Create the input message
231
+ input_messages = [HumanMessage(content=query_request.query)]
232
 
233
  # Invoke the graph with custom system prompt
234
  # Combine config parameters into a single dictionary
235
  combined_config = {
236
  **config,
237
+ "model": {"system_prompt": query_request.system_prompt}
238
  }
239
 
240
  # Invoke the graph with proper argument count
 
259
  )
260
 
261
  @app.post("/summarize")
262
+ @limiter.limit("5/minute")
263
  async def summarize(
264
+ request: Request,
265
+ summary_request: SummaryRequest,
266
  api_key: APIKey = Depends(get_api_key)
267
  ):
268
  """
269
  Endpoint to generate a summary using the language model
270
 
271
  Args:
272
+ request: Request - FastAPI request object for rate limiting
273
+ summary_request: SummaryRequest
274
  text: str - The text to summarize
275
  thread_id: str = "default"
276
  max_length: int = 200 - Maximum summary length
 
281
  """
282
  try:
283
  # Configure the thread ID
284
+ config = {"configurable": {"thread_id": summary_request.thread_id}}
285
 
286
  # Create a specific system prompt for summarization
287
+ summary_system_prompt = f"Make a summary of the following text in no more than {summary_request.max_length} words. Keep the most important information and eliminate unnecessary details."
288
 
289
  # Create the input message
290
+ input_messages = [HumanMessage(content=summary_request.text)]
291
 
292
  # Combine config parameters into a single dictionary
293
  combined_config = {
requirements.txt CHANGED
@@ -8,4 +8,5 @@ langgraph>=0.2.27
8
  python-dotenv>=1.0.0
9
  transformers>=4.36.0
10
  torch>=2.0.0
11
- accelerate>=0.26.0
 
 
8
  python-dotenv>=1.0.0
9
  transformers>=4.36.0
10
  torch>=2.0.0
11
+ accelerate>=0.26.0
12
+ slowapi>=0.1.10