Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from functools import partial | |
| from fastapi.responses import JSONResponse | |
| from fastapi import Security, Depends, Request | |
| from fastapi.security.api_key import APIKeyHeader, APIKey | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import START, MessagesState, StateGraph | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Rate Limiter configuration | |
| limiter = Limiter(key_func=get_remote_address) | |
| # API Key configuration | |
| API_KEY_NAME = "X-API-Key" | |
| API_KEY = os.getenv("API_KEY") | |
| api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) | |
| async def get_api_key(api_key_header: str = Security(api_key_header)): | |
| if api_key_header == API_KEY: | |
| return api_key_header | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Could not validate API KEY" | |
| ) | |
| # Initialize the model and tokenizer | |
| print("Loading model and tokenizer...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | |
| try: | |
| # Load the model in BF16 format for better performance and lower memory usage | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if device == "cuda": | |
| print("Using GPU for the model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| low_cpu_mem_usage=True | |
| ) | |
| else: | |
| print("Using CPU for the model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map={"": device}, | |
| torch_dtype=torch.float32 | |
| ) | |
| print(f"Model loaded successfully on: {device}") | |
| except Exception as e: | |
| print(f"Error loading the model: {str(e)}") | |
| raise | |
| # Define the function that calls the model | |
| def call_model(state: MessagesState, system_prompt: str): | |
| """ | |
| Call the model with the given messages | |
| Args: | |
| state: MessagesState | |
| Returns: | |
| dict: A dictionary containing the generated text and the thread ID | |
| """ | |
| # Convert LangChain messages to chat format | |
| messages = [ | |
| {"role": "system", "content": system_prompt} | |
| ] | |
| for msg in state["messages"]: | |
| if isinstance(msg, HumanMessage): | |
| messages.append({"role": "user", "content": msg.content}) | |
| elif isinstance(msg, AIMessage): | |
| messages.append({"role": "assistant", "content": msg.content}) | |
| # Prepare the input using the chat template | |
| input_text = tokenizer.apply_chat_template(messages, tokenize=False) | |
| inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
| # Generate response | |
| outputs = model.generate( | |
| inputs, | |
| max_new_tokens=512, # Increase the number of tokens for longer responses | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Get just the new tokens (excluding the input prompt tokens) | |
| input_length = inputs.shape[1] | |
| generated_tokens = outputs[0][input_length:] | |
| # Decode only the new tokens to get just the assistant's response | |
| assistant_response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
| # Convert the response to LangChain format | |
| ai_message = AIMessage(content=assistant_response) | |
| return {"messages": state["messages"] + [ai_message]} | |
| # Define the graph | |
| workflow = StateGraph(state_schema=MessagesState) | |
| # Define the node in the graph | |
| workflow.add_edge(START, "model") | |
| # Add memory | |
| memory = MemorySaver() | |
| # Define the default system prompt | |
| DEFAULT_SYSTEM_PROMPT = "You are a friendly Chatbot. Always reply in the language in which the user is writing to you." | |
| # Use partial to create a version of the function with the default system prompt | |
| workflow.add_node("model", partial(call_model, system_prompt=DEFAULT_SYSTEM_PROMPT)) | |
| graph_app = workflow.compile(checkpointer=memory) | |
| # Define the data model for the request | |
| class QueryRequest(BaseModel): | |
| query: str | |
| thread_id: str = "default" | |
| system_prompt: str = DEFAULT_SYSTEM_PROMPT | |
| # Define the model for summary requests | |
| class SummaryRequest(BaseModel): | |
| text: str | |
| thread_id: str = "default" | |
| max_length: int = 200 | |
| # Create the FastAPI application | |
| app = FastAPI( | |
| title="LangChain FastAPI", | |
| description="API to generate text using LangChain and LangGraph - Máximo Fernández Núñez IriusRisk test challenge", | |
| version="1.0.0", | |
| openapi_tags=[ | |
| { | |
| "name": "Authentication", | |
| "description": "Endpoints require API Key authentication via X-API-Key header" | |
| } | |
| ] | |
| ) | |
| # Configure the rate limiter in the application | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| # Configure the security scheme in the OpenAPI documentation | |
| app.openapi_tags = [ | |
| {"name": "Authentication", "description": "Protected endpoints that require API Key"} | |
| ] | |
| # Import and configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Configure the security scheme | |
| app.openapi_components = { | |
| "securitySchemes": { | |
| "api_key": { | |
| "type": "apiKey", | |
| "name": API_KEY_NAME, | |
| "in": "header", | |
| "description": "Enter your API key" | |
| } | |
| } | |
| } | |
| app.openapi_security = [{"api_key": []}] | |
| # Add general exception handler | |
| async def general_exception_handler(request, exc): | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Error interno: {str(exc)}", "type": type(exc).__name__} | |
| ) | |
| # Welcome endpoint | |
| async def api_home(request: Request): | |
| """Welcome endpoint""" | |
| return {"detail": "Welcome to Máximo Fernández Núñez IriusRisk test challenge"} | |
| # Generate endpoint | |
| async def generate( | |
| request: Request, | |
| query_request: QueryRequest, | |
| api_key: APIKey = Depends(get_api_key) | |
| ): | |
| """ | |
| Endpoint to generate text using the language model | |
| Args: | |
| request: Request - FastAPI request object for rate limiting | |
| query_request: QueryRequest | |
| query: str | |
| thread_id: str = "default" | |
| system_prompt: str = DEFAULT_SYSTEM_PROMPT | |
| api_key: APIKey - API key for authentication | |
| Returns: | |
| dict: A dictionary containing the generated text | |
| """ | |
| try: | |
| # Configure the thread ID | |
| config = {"configurable": {"thread_id": query_request.thread_id}} | |
| # Create the input message | |
| input_messages = [HumanMessage(content=query_request.query)] | |
| # Invoke the graph with custom system prompt | |
| # Combine config parameters into a single dictionary | |
| combined_config = { | |
| **config, | |
| "model": {"system_prompt": query_request.system_prompt} | |
| } | |
| # Invoke the graph with proper argument count | |
| output = graph_app.invoke( | |
| {"messages": input_messages}, | |
| combined_config | |
| ) | |
| # Get the model response | |
| response = output["messages"][-1].content | |
| return { | |
| "generated_text": response | |
| } | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": f"Error generando texto: {str(e)}", | |
| "type": type(e).__name__ | |
| } | |
| ) | |
| async def summarize( | |
| request: Request, | |
| summary_request: SummaryRequest, | |
| api_key: APIKey = Depends(get_api_key) | |
| ): | |
| """ | |
| Endpoint to generate a summary using the language model | |
| Args: | |
| request: Request - FastAPI request object for rate limiting | |
| summary_request: SummaryRequest | |
| text: str - The text to summarize | |
| thread_id: str = "default" | |
| max_length: int = 200 - Maximum summary length | |
| api_key: APIKey - API key for authentication | |
| Returns: | |
| dict: A dictionary containing the summary | |
| """ | |
| try: | |
| # Configure the thread ID | |
| config = {"configurable": {"thread_id": summary_request.thread_id}} | |
| # Create a specific system prompt for summarization | |
| summary_system_prompt = f"Make a summary of the following text in no more than {summary_request.max_length} words. Keep the most important information and eliminate unnecessary details." | |
| # Create the input message | |
| input_messages = [HumanMessage(content=summary_request.text)] | |
| # Combine config parameters into a single dictionary | |
| combined_config = { | |
| **config, | |
| "model": {"system_prompt": summary_system_prompt} | |
| } | |
| # Invoke the graph with proper argument count | |
| output = graph_app.invoke( | |
| {"messages": input_messages}, | |
| combined_config | |
| ) | |
| # Get the model response | |
| response = output["messages"][-1].content | |
| return { | |
| "summary": response | |
| } | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": f"Error generando resumen: {str(e)}", | |
| "type": type(e).__name__ | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |