Spaces:
Sleeping
Sleeping
feat(SECURITY): :lock: Implement rate limiting for API endpoints and update request handling for text generation and summarization.
Browse files- app.py +30 -13
- requirements.txt +2 -1
app.py
CHANGED
@@ -4,9 +4,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
4 |
import torch
|
5 |
from functools import partial
|
6 |
from fastapi.responses import JSONResponse
|
7 |
-
from fastapi import Security, Depends
|
8 |
from fastapi.security.api_key import APIKeyHeader, APIKey
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
10 |
|
11 |
from langchain_core.messages import HumanMessage, AIMessage
|
12 |
from langgraph.checkpoint.memory import MemorySaver
|
@@ -16,7 +19,10 @@ import os
|
|
16 |
from dotenv import load_dotenv
|
17 |
load_dotenv()
|
18 |
|
19 |
-
#
|
|
|
|
|
|
|
20 |
API_KEY_NAME = "X-API-Key"
|
21 |
API_KEY = os.getenv("API_KEY")
|
22 |
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
@@ -148,6 +154,10 @@ app = FastAPI(
|
|
148 |
]
|
149 |
)
|
150 |
|
|
|
|
|
|
|
|
|
151 |
# Configure the security scheme in the OpenAPI documentation
|
152 |
app.openapi_tags = [
|
153 |
{"name": "Authentication", "description": "Protected endpoints that require API Key"}
|
@@ -186,21 +196,25 @@ async def general_exception_handler(request, exc):
|
|
186 |
|
187 |
# Welcome endpoint
|
188 |
@app.get("/")
|
189 |
-
|
|
|
190 |
"""Welcome endpoint"""
|
191 |
return {"detail": "Welcome to Máximo Fernández Núñez IriusRisk test challenge"}
|
192 |
|
193 |
# Generate endpoint
|
194 |
@app.post("/generate")
|
|
|
195 |
async def generate(
|
196 |
-
request:
|
|
|
197 |
api_key: APIKey = Depends(get_api_key)
|
198 |
):
|
199 |
"""
|
200 |
Endpoint to generate text using the language model
|
201 |
|
202 |
Args:
|
203 |
-
request:
|
|
|
204 |
query: str
|
205 |
thread_id: str = "default"
|
206 |
system_prompt: str = DEFAULT_SYSTEM_PROMPT
|
@@ -211,16 +225,16 @@ async def generate(
|
|
211 |
"""
|
212 |
try:
|
213 |
# Configure the thread ID
|
214 |
-
config = {"configurable": {"thread_id":
|
215 |
|
216 |
# Create the input message
|
217 |
-
input_messages = [HumanMessage(content=
|
218 |
|
219 |
# Invoke the graph with custom system prompt
|
220 |
# Combine config parameters into a single dictionary
|
221 |
combined_config = {
|
222 |
**config,
|
223 |
-
"model": {"system_prompt":
|
224 |
}
|
225 |
|
226 |
# Invoke the graph with proper argument count
|
@@ -245,15 +259,18 @@ async def generate(
|
|
245 |
)
|
246 |
|
247 |
@app.post("/summarize")
|
|
|
248 |
async def summarize(
|
249 |
-
request:
|
|
|
250 |
api_key: APIKey = Depends(get_api_key)
|
251 |
):
|
252 |
"""
|
253 |
Endpoint to generate a summary using the language model
|
254 |
|
255 |
Args:
|
256 |
-
request:
|
|
|
257 |
text: str - The text to summarize
|
258 |
thread_id: str = "default"
|
259 |
max_length: int = 200 - Maximum summary length
|
@@ -264,13 +281,13 @@ async def summarize(
|
|
264 |
"""
|
265 |
try:
|
266 |
# Configure the thread ID
|
267 |
-
config = {"configurable": {"thread_id":
|
268 |
|
269 |
# Create a specific system prompt for summarization
|
270 |
-
summary_system_prompt = f"Make a summary of the following text in no more than {
|
271 |
|
272 |
# Create the input message
|
273 |
-
input_messages = [HumanMessage(content=
|
274 |
|
275 |
# Combine config parameters into a single dictionary
|
276 |
combined_config = {
|
|
|
4 |
import torch
|
5 |
from functools import partial
|
6 |
from fastapi.responses import JSONResponse
|
7 |
+
from fastapi import Security, Depends, Request
|
8 |
from fastapi.security.api_key import APIKeyHeader, APIKey
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
11 |
+
from slowapi.util import get_remote_address
|
12 |
+
from slowapi.errors import RateLimitExceeded
|
13 |
|
14 |
from langchain_core.messages import HumanMessage, AIMessage
|
15 |
from langgraph.checkpoint.memory import MemorySaver
|
|
|
19 |
from dotenv import load_dotenv
|
20 |
load_dotenv()
|
21 |
|
22 |
+
# Rate Limiter configuration
|
23 |
+
limiter = Limiter(key_func=get_remote_address)
|
24 |
+
|
25 |
+
# API Key configuration
|
26 |
API_KEY_NAME = "X-API-Key"
|
27 |
API_KEY = os.getenv("API_KEY")
|
28 |
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
|
|
154 |
]
|
155 |
)
|
156 |
|
157 |
+
# Configure the rate limiter in the application
|
158 |
+
app.state.limiter = limiter
|
159 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
160 |
+
|
161 |
# Configure the security scheme in the OpenAPI documentation
|
162 |
app.openapi_tags = [
|
163 |
{"name": "Authentication", "description": "Protected endpoints that require API Key"}
|
|
|
196 |
|
197 |
# Welcome endpoint
|
198 |
@app.get("/")
|
199 |
+
@limiter.limit("10/minute")
|
200 |
+
async def api_home(request: Request):
|
201 |
"""Welcome endpoint"""
|
202 |
return {"detail": "Welcome to Máximo Fernández Núñez IriusRisk test challenge"}
|
203 |
|
204 |
# Generate endpoint
|
205 |
@app.post("/generate")
|
206 |
+
@limiter.limit("5/minute")
|
207 |
async def generate(
|
208 |
+
request: Request,
|
209 |
+
query_request: QueryRequest,
|
210 |
api_key: APIKey = Depends(get_api_key)
|
211 |
):
|
212 |
"""
|
213 |
Endpoint to generate text using the language model
|
214 |
|
215 |
Args:
|
216 |
+
request: Request - FastAPI request object for rate limiting
|
217 |
+
query_request: QueryRequest
|
218 |
query: str
|
219 |
thread_id: str = "default"
|
220 |
system_prompt: str = DEFAULT_SYSTEM_PROMPT
|
|
|
225 |
"""
|
226 |
try:
|
227 |
# Configure the thread ID
|
228 |
+
config = {"configurable": {"thread_id": query_request.thread_id}}
|
229 |
|
230 |
# Create the input message
|
231 |
+
input_messages = [HumanMessage(content=query_request.query)]
|
232 |
|
233 |
# Invoke the graph with custom system prompt
|
234 |
# Combine config parameters into a single dictionary
|
235 |
combined_config = {
|
236 |
**config,
|
237 |
+
"model": {"system_prompt": query_request.system_prompt}
|
238 |
}
|
239 |
|
240 |
# Invoke the graph with proper argument count
|
|
|
259 |
)
|
260 |
|
261 |
@app.post("/summarize")
|
262 |
+
@limiter.limit("5/minute")
|
263 |
async def summarize(
|
264 |
+
request: Request,
|
265 |
+
summary_request: SummaryRequest,
|
266 |
api_key: APIKey = Depends(get_api_key)
|
267 |
):
|
268 |
"""
|
269 |
Endpoint to generate a summary using the language model
|
270 |
|
271 |
Args:
|
272 |
+
request: Request - FastAPI request object for rate limiting
|
273 |
+
summary_request: SummaryRequest
|
274 |
text: str - The text to summarize
|
275 |
thread_id: str = "default"
|
276 |
max_length: int = 200 - Maximum summary length
|
|
|
281 |
"""
|
282 |
try:
|
283 |
# Configure the thread ID
|
284 |
+
config = {"configurable": {"thread_id": summary_request.thread_id}}
|
285 |
|
286 |
# Create a specific system prompt for summarization
|
287 |
+
summary_system_prompt = f"Make a summary of the following text in no more than {summary_request.max_length} words. Keep the most important information and eliminate unnecessary details."
|
288 |
|
289 |
# Create the input message
|
290 |
+
input_messages = [HumanMessage(content=summary_request.text)]
|
291 |
|
292 |
# Combine config parameters into a single dictionary
|
293 |
combined_config = {
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ langgraph>=0.2.27
|
|
8 |
python-dotenv>=1.0.0
|
9 |
transformers>=4.36.0
|
10 |
torch>=2.0.0
|
11 |
-
accelerate>=0.26.0
|
|
|
|
8 |
python-dotenv>=1.0.0
|
9 |
transformers>=4.36.0
|
10 |
torch>=2.0.0
|
11 |
+
accelerate>=0.26.0
|
12 |
+
slowapi>=0.1.10
|