llm / app.py
Chris4K's picture
Update app.py
5a15d4a verified
raw
history blame
5.06 kB
# main.py
from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
import gradio as gr
from services.chat_service import ChatService
from services.model_service import ModelService
from services.pdf_service import PDFService
from services.data_service import DataService
from services.faq_service import FAQService
from auth.auth_handler import get_api_key
from models.base_models import UserInput, SearchQuery
import logging
import asyncio
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('chatbot.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="Bofrost Chat API", version="2.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize services
model_service = ModelService()
data_service = DataService(model_service)
pdf_service = PDFService(model_service)
faq_service = FAQService(model_service)
chat_service = ChatService(model_service, data_service, pdf_service, faq_service)
# API endpoints
@app.post("/api/chat")
async def chat_endpoint(
background_tasks: BackgroundTasks,
user_input: UserInput,
api_key: str = Depends(get_api_key)
):
try:
response, updated_history, search_results = await chat_service.chat(
user_input.user_input,
user_input.chat_history
)
return {
"status": "success",
"response": response,
"chat_history": updated_history,
"search_results": search_results
}
except Exception as e:
logger.error(f"Error in chat endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/search")
async def search_endpoint(
query: SearchQuery,
api_key: str = Depends(get_api_key)
):
try:
results = await data_service.search(query.query, query.top_k)
return {"results": results}
except Exception as e:
logger.error(f"Error in search endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/faq/search")
async def faq_search_endpoint(
query: SearchQuery,
api_key: str = Depends(get_api_key)
):
try:
results = await faq_service.search_faqs(query.query, query.top_k)
return {"results": results}
except Exception as e:
logger.error(f"Error in FAQ search endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Gradio interface
def create_gradio_interface():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🦙 * Chat Assistant\nFragen Sie nach Produkten, Rezepten und mehr!")
with gr.Row():
with gr.Column(scale=4):
chat_display = gr.Chatbot(label="Chat-Verlauf", height=400)
user_input = gr.Textbox(
label="Ihre Nachricht",
placeholder="Stellen Sie Ihre Frage...",
lines=2
)
with gr.Column(scale=2):
with gr.Accordion("Zusätzliche Informationen", open=False):
product_info = gr.JSON(label="Produktdetails")
with gr.Row():
submit_btn = gr.Button("Senden", variant="primary")
clear_btn = gr.Button("Chat löschen")
chat_history = gr.State([])
async def respond(message, history):
response, updated_history, search_results = await chat_service.chat(message, history)
# Convert updated_history to the required format
if isinstance(updated_history[0], dict):
formatted_history = [(item['user_input'], item['response']) for item in updated_history]
elif isinstance(updated_history[0], tuple):
formatted_history = [(item[0], item[1]) for item in updated_history]
else:
raise TypeError("Unexpected structure for updated_history")
#formatted_history = [(item['user_input'], item['response']) for item in updated_history]
return formatted_history, updated_history, search_results
submit_btn.click(
respond,
inputs=[user_input, chat_history],
outputs=[chat_display, chat_history, product_info]
)
clear_btn.click(
lambda: ([], [], None),
outputs=[chat_display, chat_history, product_info]
)
demo.queue()
return demo
if __name__ == "__main__":
import uvicorn
# Create and launch Gradio interface
demo = create_gradio_interface()
demo.launch(server_name="0.0.0.0", server_port=7860)
# Start FastAPI server
uvicorn.run(app, host="0.0.0.0", port=8000)