|
|
|
from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import gradio as gr |
|
from services.chat_service import ChatService |
|
from services.model_service import ModelService |
|
from services.pdf_service import PDFService |
|
from services.data_service import DataService |
|
from services.faq_service import FAQService |
|
from auth.auth_handler import get_api_key |
|
from models.base_models import UserInput, SearchQuery |
|
import logging |
|
import asyncio |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler('chatbot.log'), |
|
logging.StreamHandler() |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI(title="Bofrost Chat API", version="2.0.0") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
model_service = ModelService() |
|
data_service = DataService(model_service) |
|
pdf_service = PDFService(model_service) |
|
faq_service = FAQService(model_service) |
|
chat_service = ChatService(model_service, data_service, pdf_service, faq_service) |
|
|
|
|
|
@app.post("/api/chat") |
|
async def chat_endpoint( |
|
background_tasks: BackgroundTasks, |
|
user_input: UserInput, |
|
api_key: str = Depends(get_api_key) |
|
): |
|
try: |
|
response, updated_history, search_results = await chat_service.chat( |
|
user_input.user_input, |
|
user_input.chat_history |
|
) |
|
return { |
|
"status": "success", |
|
"response": response, |
|
"chat_history": updated_history, |
|
"search_results": search_results |
|
} |
|
except Exception as e: |
|
logger.error(f"Error in chat endpoint: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/api/search") |
|
async def search_endpoint( |
|
query: SearchQuery, |
|
api_key: str = Depends(get_api_key) |
|
): |
|
try: |
|
results = await data_service.search(query.query, query.top_k) |
|
return {"results": results} |
|
except Exception as e: |
|
logger.error(f"Error in search endpoint: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/api/faq/search") |
|
async def faq_search_endpoint( |
|
query: SearchQuery, |
|
api_key: str = Depends(get_api_key) |
|
): |
|
try: |
|
results = await faq_service.search_faqs(query.query, query.top_k) |
|
return {"results": results} |
|
except Exception as e: |
|
logger.error(f"Error in FAQ search endpoint: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
def create_gradio_interface(): |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# 🦙 * Chat Assistant\nFragen Sie nach Produkten, Rezepten und mehr!") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
chat_display = gr.Chatbot(label="Chat-Verlauf", height=400) |
|
user_input = gr.Textbox( |
|
label="Ihre Nachricht", |
|
placeholder="Stellen Sie Ihre Frage...", |
|
lines=2 |
|
) |
|
|
|
with gr.Column(scale=2): |
|
with gr.Accordion("Zusätzliche Informationen", open=False): |
|
product_info = gr.JSON(label="Produktdetails") |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Senden", variant="primary") |
|
clear_btn = gr.Button("Chat löschen") |
|
|
|
chat_history = gr.State([]) |
|
|
|
async def respond(message, history): |
|
response, updated_history, search_results = await chat_service.chat(message, history) |
|
|
|
if isinstance(updated_history[0], dict): |
|
formatted_history = [(item['user_input'], item['response']) for item in updated_history] |
|
elif isinstance(updated_history[0], tuple): |
|
formatted_history = [(item[0], item[1]) for item in updated_history] |
|
else: |
|
raise TypeError("Unexpected structure for updated_history") |
|
|
|
return formatted_history, updated_history, search_results |
|
|
|
|
|
submit_btn.click( |
|
respond, |
|
inputs=[user_input, chat_history], |
|
outputs=[chat_display, chat_history, product_info] |
|
) |
|
|
|
clear_btn.click( |
|
lambda: ([], [], None), |
|
outputs=[chat_display, chat_history, product_info] |
|
) |
|
|
|
demo.queue() |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
|
|
demo = create_gradio_interface() |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |