Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from functools import partial | |
from fastapi.responses import JSONResponse | |
from fastapi import Security, Depends, Request | |
from fastapi.security.api_key import APIKeyHeader, APIKey | |
from fastapi.middleware.cors import CORSMiddleware | |
from slowapi import Limiter, _rate_limit_exceeded_handler | |
from slowapi.util import get_remote_address | |
from slowapi.errors import RateLimitExceeded | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.graph import START, MessagesState, StateGraph | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Rate Limiter configuration | |
limiter = Limiter(key_func=get_remote_address) | |
# API Key configuration | |
API_KEY_NAME = "X-API-Key" | |
API_KEY = os.getenv("API_KEY") | |
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) | |
async def get_api_key(api_key_header: str = Security(api_key_header)): | |
if api_key_header == API_KEY: | |
return api_key_header | |
raise HTTPException( | |
status_code=403, | |
detail="Could not validate API KEY" | |
) | |
# Initialize the model and tokenizer | |
print("Loading model and tokenizer...") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | |
try: | |
# Load the model in BF16 format for better performance and lower memory usage | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if device == "cuda": | |
print("Using GPU for the model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
low_cpu_mem_usage=True | |
) | |
else: | |
print("Using CPU for the model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map={"": device}, | |
torch_dtype=torch.float32 | |
) | |
print(f"Model loaded successfully on: {device}") | |
except Exception as e: | |
print(f"Error loading the model: {str(e)}") | |
raise | |
# Define the function that calls the model | |
def call_model(state: MessagesState, system_prompt: str): | |
""" | |
Call the model with the given messages | |
Args: | |
state: MessagesState | |
Returns: | |
dict: A dictionary containing the generated text and the thread ID | |
""" | |
# Convert LangChain messages to chat format | |
messages = [ | |
{"role": "system", "content": system_prompt} | |
] | |
for msg in state["messages"]: | |
if isinstance(msg, HumanMessage): | |
messages.append({"role": "user", "content": msg.content}) | |
elif isinstance(msg, AIMessage): | |
messages.append({"role": "assistant", "content": msg.content}) | |
# Prepare the input using the chat template | |
input_text = tokenizer.apply_chat_template(messages, tokenize=False) | |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
# Generate response | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=512, # Increase the number of tokens for longer responses | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Get just the new tokens (excluding the input prompt tokens) | |
input_length = inputs.shape[1] | |
generated_tokens = outputs[0][input_length:] | |
# Decode only the new tokens to get just the assistant's response | |
assistant_response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
# Convert the response to LangChain format | |
ai_message = AIMessage(content=assistant_response) | |
return {"messages": state["messages"] + [ai_message]} | |
# Define the graph | |
workflow = StateGraph(state_schema=MessagesState) | |
# Define the node in the graph | |
workflow.add_edge(START, "model") | |
# Add memory | |
memory = MemorySaver() | |
# Define the default system prompt | |
DEFAULT_SYSTEM_PROMPT = "You are a friendly Chatbot. Always reply in the language in which the user is writing to you." | |
# Use partial to create a version of the function with the default system prompt | |
workflow.add_node("model", partial(call_model, system_prompt=DEFAULT_SYSTEM_PROMPT)) | |
graph_app = workflow.compile(checkpointer=memory) | |
# Define the data model for the request | |
class QueryRequest(BaseModel): | |
query: str | |
thread_id: str = "default" | |
system_prompt: str = DEFAULT_SYSTEM_PROMPT | |
# Define the model for summary requests | |
class SummaryRequest(BaseModel): | |
text: str | |
thread_id: str = "default" | |
max_length: int = 200 | |
# Create the FastAPI application | |
app = FastAPI( | |
title="LangChain FastAPI", | |
description="API to generate text using LangChain and LangGraph - Máximo Fernández Núñez IriusRisk test challenge", | |
version="1.0.0", | |
openapi_tags=[ | |
{ | |
"name": "Authentication", | |
"description": "Endpoints require API Key authentication via X-API-Key header" | |
} | |
] | |
) | |
# Configure the rate limiter in the application | |
app.state.limiter = limiter | |
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
# Configure the security scheme in the OpenAPI documentation | |
app.openapi_tags = [ | |
{"name": "Authentication", "description": "Protected endpoints that require API Key"} | |
] | |
# Import and configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Configure the security scheme | |
app.openapi_components = { | |
"securitySchemes": { | |
"api_key": { | |
"type": "apiKey", | |
"name": API_KEY_NAME, | |
"in": "header", | |
"description": "Enter your API key" | |
} | |
} | |
} | |
app.openapi_security = [{"api_key": []}] | |
# Add general exception handler | |
async def general_exception_handler(request, exc): | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"Error interno: {str(exc)}", "type": type(exc).__name__} | |
) | |
# Welcome endpoint | |
async def api_home(request: Request): | |
"""Welcome endpoint""" | |
return {"detail": "Welcome to Máximo Fernández Núñez IriusRisk test challenge"} | |
# Generate endpoint | |
async def generate( | |
request: Request, | |
query_request: QueryRequest, | |
api_key: APIKey = Depends(get_api_key) | |
): | |
""" | |
Endpoint to generate text using the language model | |
Args: | |
request: Request - FastAPI request object for rate limiting | |
query_request: QueryRequest | |
query: str | |
thread_id: str = "default" | |
system_prompt: str = DEFAULT_SYSTEM_PROMPT | |
api_key: APIKey - API key for authentication | |
Returns: | |
dict: A dictionary containing the generated text | |
""" | |
try: | |
# Configure the thread ID | |
config = {"configurable": {"thread_id": query_request.thread_id}} | |
# Create the input message | |
input_messages = [HumanMessage(content=query_request.query)] | |
# Invoke the graph with custom system prompt | |
# Combine config parameters into a single dictionary | |
combined_config = { | |
**config, | |
"model": {"system_prompt": query_request.system_prompt} | |
} | |
# Invoke the graph with proper argument count | |
output = graph_app.invoke( | |
{"messages": input_messages}, | |
combined_config | |
) | |
# Get the model response | |
response = output["messages"][-1].content | |
return { | |
"generated_text": response | |
} | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, | |
content={ | |
"error": f"Error generando texto: {str(e)}", | |
"type": type(e).__name__ | |
} | |
) | |
async def summarize( | |
request: Request, | |
summary_request: SummaryRequest, | |
api_key: APIKey = Depends(get_api_key) | |
): | |
""" | |
Endpoint to generate a summary using the language model | |
Args: | |
request: Request - FastAPI request object for rate limiting | |
summary_request: SummaryRequest | |
text: str - The text to summarize | |
thread_id: str = "default" | |
max_length: int = 200 - Maximum summary length | |
api_key: APIKey - API key for authentication | |
Returns: | |
dict: A dictionary containing the summary | |
""" | |
try: | |
# Configure the thread ID | |
config = {"configurable": {"thread_id": summary_request.thread_id}} | |
# Create a specific system prompt for summarization | |
summary_system_prompt = f"Make a summary of the following text in no more than {summary_request.max_length} words. Keep the most important information and eliminate unnecessary details." | |
# Create the input message | |
input_messages = [HumanMessage(content=summary_request.text)] | |
# Combine config parameters into a single dictionary | |
combined_config = { | |
**config, | |
"model": {"system_prompt": summary_system_prompt} | |
} | |
# Invoke the graph with proper argument count | |
output = graph_app.invoke( | |
{"messages": input_messages}, | |
combined_config | |
) | |
# Get the model response | |
response = output["messages"][-1].content | |
return { | |
"summary": response | |
} | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, | |
content={ | |
"error": f"Error generando resumen: {str(e)}", | |
"type": type(e).__name__ | |
} | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |