# main.py from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware import gradio as gr from services.chat_service import ChatService from services.model_service import ModelService from services.pdf_service import PDFService from services.data_service import DataService from services.faq_service import FAQService from auth.auth_handler import get_api_key from models.base_models import UserInput, SearchQuery import logging import asyncio # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('chatbot.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI(title="Bofrost Chat API", version="2.0.0") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize services model_service = ModelService() data_service = DataService(model_service) pdf_service = PDFService(model_service) faq_service = FAQService(model_service) chat_service = ChatService(model_service, data_service, pdf_service, faq_service) # API endpoints @app.post("/api/chat") async def chat_endpoint( background_tasks: BackgroundTasks, user_input: UserInput, api_key: str = Depends(get_api_key) ): try: response, updated_history, search_results = await chat_service.chat( user_input.user_input, user_input.chat_history ) return { "status": "success", "response": response, "chat_history": updated_history, "search_results": search_results } except Exception as e: logger.error(f"Error in chat endpoint: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/search") async def search_endpoint( query: SearchQuery, api_key: str = Depends(get_api_key) ): try: results = await data_service.search(query.query, query.top_k) return {"results": results} except Exception as e: logger.error(f"Error in search endpoint: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/faq/search") async def faq_search_endpoint( query: SearchQuery, api_key: str = Depends(get_api_key) ): try: results = await faq_service.search_faqs(query.query, query.top_k) return {"results": results} except Exception as e: logger.error(f"Error in FAQ search endpoint: {e}") raise HTTPException(status_code=500, detail=str(e)) # Gradio interface def create_gradio_interface(): with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🦙 * Chat Assistant\nFragen Sie nach Produkten, Rezepten und mehr!") with gr.Row(): with gr.Column(scale=4): chat_display = gr.Chatbot(label="Chat-Verlauf", height=400) user_input = gr.Textbox( label="Ihre Nachricht", placeholder="Stellen Sie Ihre Frage...", lines=2 ) with gr.Column(scale=2): with gr.Accordion("Zusätzliche Informationen", open=False): product_info = gr.JSON(label="Produktdetails") with gr.Row(): submit_btn = gr.Button("Senden", variant="primary") clear_btn = gr.Button("Chat löschen") chat_history = gr.State([]) async def respond(message, history): response, updated_history, search_results = await chat_service.chat(message, history) # Convert updated_history to the required format if isinstance(updated_history[0], dict): formatted_history = [(item['user_input'], item['response']) for item in updated_history] elif isinstance(updated_history[0], tuple): formatted_history = [(item[0], item[1]) for item in updated_history] else: raise TypeError("Unexpected structure for updated_history") #formatted_history = [(item['user_input'], item['response']) for item in updated_history] return formatted_history, updated_history, search_results submit_btn.click( respond, inputs=[user_input, chat_history], outputs=[chat_display, chat_history, product_info] ) clear_btn.click( lambda: ([], [], None), outputs=[chat_display, chat_history, product_info] ) demo.queue() return demo if __name__ == "__main__": import uvicorn # Create and launch Gradio interface demo = create_gradio_interface() demo.launch(server_name="0.0.0.0", server_port=7860) # Start FastAPI server uvicorn.run(app, host="0.0.0.0", port=8000)