Maximofn commited on
Commit
4b4c28d
·
1 Parent(s): c61e41b

feat: :lock: Implement API Key authentication and CORS configuration for enhanced security. Don`t returns thread ID

Browse files
Files changed (1) hide show
  1. app.py +71 -13
app.py CHANGED
@@ -4,6 +4,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  from functools import partial
6
  from fastapi.responses import JSONResponse
 
 
 
7
 
8
  from langchain_core.messages import HumanMessage, AIMessage
9
  from langgraph.checkpoint.memory import MemorySaver
@@ -13,6 +16,19 @@ import os
13
  from dotenv import load_dotenv
14
  load_dotenv()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Initialize the model and tokenizer
17
  print("Loading model and tokenizer...")
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -120,7 +136,45 @@ class SummaryRequest(BaseModel):
120
  max_length: int = 200
121
 
122
  # Create the FastAPI application
123
- app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph - Máximo Fernández Núñez IriusRisk test challenge")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  # Add general exception handler
126
  @app.exception_handler(Exception)
@@ -138,7 +192,10 @@ async def api_home():
138
 
139
  # Generate endpoint
140
  @app.post("/generate")
141
- async def generate(request: QueryRequest):
 
 
 
142
  """
143
  Endpoint to generate text using the language model
144
 
@@ -147,9 +204,10 @@ async def generate(request: QueryRequest):
147
  query: str
148
  thread_id: str = "default"
149
  system_prompt: str = DEFAULT_SYSTEM_PROMPT
 
150
 
151
  Returns:
152
- dict: A dictionary containing the generated text and the thread ID
153
  """
154
  try:
155
  # Configure the thread ID
@@ -175,21 +233,22 @@ async def generate(request: QueryRequest):
175
  response = output["messages"][-1].content
176
 
177
  return {
178
- "generated_text": response,
179
- "thread_id": request.thread_id
180
  }
181
  except Exception as e:
182
  return JSONResponse(
183
  status_code=500,
184
  content={
185
  "error": f"Error generando texto: {str(e)}",
186
- "type": type(e).__name__,
187
- "thread_id": request.thread_id
188
  }
189
  )
190
 
191
  @app.post("/summarize")
192
- async def summarize(request: SummaryRequest):
 
 
 
193
  """
194
  Endpoint to generate a summary using the language model
195
 
@@ -198,9 +257,10 @@ async def summarize(request: SummaryRequest):
198
  text: str - The text to summarize
199
  thread_id: str = "default"
200
  max_length: int = 200 - Maximum summary length
 
201
 
202
  Returns:
203
- dict: A dictionary containing the summary and the thread ID
204
  """
205
  try:
206
  # Configure the thread ID
@@ -228,16 +288,14 @@ async def summarize(request: SummaryRequest):
228
  response = output["messages"][-1].content
229
 
230
  return {
231
- "summary": response,
232
- "thread_id": request.thread_id
233
  }
234
  except Exception as e:
235
  return JSONResponse(
236
  status_code=500,
237
  content={
238
  "error": f"Error generando resumen: {str(e)}",
239
- "type": type(e).__name__,
240
- "thread_id": request.thread_id
241
  }
242
  )
243
 
 
4
  import torch
5
  from functools import partial
6
  from fastapi.responses import JSONResponse
7
+ from fastapi import Security, Depends
8
+ from fastapi.security.api_key import APIKeyHeader, APIKey
9
+ from fastapi.middleware.cors import CORSMiddleware
10
 
11
  from langchain_core.messages import HumanMessage, AIMessage
12
  from langgraph.checkpoint.memory import MemorySaver
 
16
  from dotenv import load_dotenv
17
  load_dotenv()
18
 
19
+ # Configuración de API Key
20
+ API_KEY_NAME = "X-API-Key"
21
+ API_KEY = os.getenv("API_KEY")
22
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
23
+
24
+ async def get_api_key(api_key_header: str = Security(api_key_header)):
25
+ if api_key_header == API_KEY:
26
+ return api_key_header
27
+ raise HTTPException(
28
+ status_code=403,
29
+ detail="Could not validate API KEY"
30
+ )
31
+
32
  # Initialize the model and tokenizer
33
  print("Loading model and tokenizer...")
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
136
  max_length: int = 200
137
 
138
  # Create the FastAPI application
139
+ app = FastAPI(
140
+ title="LangChain FastAPI",
141
+ description="API to generate text using LangChain and LangGraph - Máximo Fernández Núñez IriusRisk test challenge",
142
+ version="1.0.0",
143
+ openapi_tags=[
144
+ {
145
+ "name": "Authentication",
146
+ "description": "Endpoints require API Key authentication via X-API-Key header"
147
+ }
148
+ ]
149
+ )
150
+
151
+ # Configure the security scheme in the OpenAPI documentation
152
+ app.openapi_tags = [
153
+ {"name": "Authentication", "description": "Protected endpoints that require API Key"}
154
+ ]
155
+
156
+ # Import and configure CORS
157
+ app.add_middleware(
158
+ CORSMiddleware,
159
+ allow_origins=["*"],
160
+ allow_credentials=True,
161
+ allow_methods=["*"],
162
+ allow_headers=["*"],
163
+ )
164
+
165
+ # Configure the security scheme
166
+ app.openapi_components = {
167
+ "securitySchemes": {
168
+ "api_key": {
169
+ "type": "apiKey",
170
+ "name": API_KEY_NAME,
171
+ "in": "header",
172
+ "description": "Enter your API key"
173
+ }
174
+ }
175
+ }
176
+
177
+ app.openapi_security = [{"api_key": []}]
178
 
179
  # Add general exception handler
180
  @app.exception_handler(Exception)
 
192
 
193
  # Generate endpoint
194
  @app.post("/generate")
195
+ async def generate(
196
+ request: QueryRequest,
197
+ api_key: APIKey = Depends(get_api_key)
198
+ ):
199
  """
200
  Endpoint to generate text using the language model
201
 
 
204
  query: str
205
  thread_id: str = "default"
206
  system_prompt: str = DEFAULT_SYSTEM_PROMPT
207
+ api_key: APIKey - API key for authentication
208
 
209
  Returns:
210
+ dict: A dictionary containing the generated text
211
  """
212
  try:
213
  # Configure the thread ID
 
233
  response = output["messages"][-1].content
234
 
235
  return {
236
+ "generated_text": response
 
237
  }
238
  except Exception as e:
239
  return JSONResponse(
240
  status_code=500,
241
  content={
242
  "error": f"Error generando texto: {str(e)}",
243
+ "type": type(e).__name__
 
244
  }
245
  )
246
 
247
  @app.post("/summarize")
248
+ async def summarize(
249
+ request: SummaryRequest,
250
+ api_key: APIKey = Depends(get_api_key)
251
+ ):
252
  """
253
  Endpoint to generate a summary using the language model
254
 
 
257
  text: str - The text to summarize
258
  thread_id: str = "default"
259
  max_length: int = 200 - Maximum summary length
260
+ api_key: APIKey - API key for authentication
261
 
262
  Returns:
263
+ dict: A dictionary containing the summary
264
  """
265
  try:
266
  # Configure the thread ID
 
288
  response = output["messages"][-1].content
289
 
290
  return {
291
+ "summary": response
 
292
  }
293
  except Exception as e:
294
  return JSONResponse(
295
  status_code=500,
296
  content={
297
  "error": f"Error generando resumen: {str(e)}",
298
+ "type": type(e).__name__
 
299
  }
300
  )
301