GPT-Researcher / backend /server /server_utils.py
Shreyas094's picture
Update backend/server/server_utils.py
dd6cde0 verified
raw
history blame
9.94 kB
import json
import os
import re
import time
import shutil
from typing import Dict, List, Any
from fastapi.responses import JSONResponse, FileResponse
from gpt_researcher.document.document import DocumentLoader
from backend.utils import write_md_to_pdf, write_md_to_word, write_text_to_md
from pathlib import Path
from datetime import datetime
from fastapi import HTTPException
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class CustomLogsHandler:
"""Custom handler to capture streaming logs from the research process"""
def __init__(self, websocket, task: str):
self.logs = []
self.websocket = websocket
sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{task}")
self.log_file = os.path.join("/tmp/outputs", f"{sanitized_filename}.json")
self.timestamp = datetime.now().isoformat()
# Initialize log file with metadata
os.makedirs("/tmp/outputs", exist_ok=True)
with open(self.log_file, 'w') as f:
json.dump({
"timestamp": self.timestamp,
"events": [],
"content": {
"query": "",
"sources": [],
"context": [],
"report": "",
"costs": 0.0
}
}, f, indent=2)
async def send_json(self, data: Dict[str, Any]) -> None:
"""Store log data and send to websocket"""
# Send to websocket for real-time display
if self.websocket:
await self.websocket.send_json(data)
# Read current log file
with open(self.log_file, 'r') as f:
log_data = json.load(f)
# Update appropriate section based on data type
if data.get('type') == 'logs':
log_data['events'].append({
"timestamp": datetime.now().isoformat(),
"type": "event",
"data": data
})
else:
# Update content section for other types of data
log_data['content'].update(data)
# Save updated log file
with open(self.log_file, 'w') as f:
json.dump(log_data, f, indent=2)
logger.debug(f"Log entry written to: {self.log_file}")
class Researcher:
def __init__(self, query: str, report_type: str = "research_report"):
self.query = query
self.report_type = report_type
# Generate unique ID for this research task
self.research_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{hash(query)}"
# Initialize logs handler with research ID
self.logs_handler = CustomLogsHandler(self.research_id)
self.researcher = GPTResearcher(
query=query,
report_type=report_type,
websocket=self.logs_handler
)
async def research(self) -> dict:
"""Conduct research and return paths to generated files"""
await self.researcher.conduct_research()
report = await self.researcher.write_report()
# Generate the files
sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{self.query}")
file_paths = await generate_report_files(report, sanitized_filename)
# Get the JSON log path that was created by CustomLogsHandler
json_relative_path = os.path.relpath(self.logs_handler.log_file)
return {
"output": {
**file_paths, # Include PDF, DOCX, and MD paths
"json": json_relative_path
}
}
def sanitize_filename(filename: str) -> str:
# Split into components
prefix, timestamp, *task_parts = filename.split('_')
task = '_'.join(task_parts)
# Calculate max length for task portion
# 255 - len("/tmp/outputs/") - len("task_") - len(timestamp) - len("_.json") - safety_margin
max_task_length = 255 - 8 - 5 - 10 - 6 - 10 # ~216 chars for task
# Truncate task if needed
truncated_task = task[:max_task_length] if len(task) > max_task_length else task
# Reassemble and clean the filename
sanitized = f"{prefix}_{timestamp}_{truncated_task}"
return re.sub(r"[^\w\s-]", "", sanitized).strip()
async def handle_start_command(websocket, data: str, manager):
json_data = json.loads(data[6:])
task, report_type, source_urls, document_urls, tone, headers, report_source = extract_command_data(
json_data)
if not task or not report_type:
print("Error: Missing task or report_type")
return
# Create logs handler with websocket and task
logs_handler = CustomLogsHandler(websocket, task)
# Initialize log content with query
await logs_handler.send_json({
"query": task,
"sources": [],
"context": [],
"report": ""
})
sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{task}")
report = await manager.start_streaming(
task,
report_type,
report_source,
source_urls,
document_urls,
tone,
websocket,
headers
)
report = str(report)
file_paths = await generate_report_files(report, sanitized_filename)
# Add JSON log path to file_paths
file_paths["json"] = os.path.relpath(logs_handler.log_file)
await send_file_paths(websocket, file_paths)
async def handle_human_feedback(data: str):
feedback_data = json.loads(data[14:]) # Remove "human_feedback" prefix
print(f"Received human feedback: {feedback_data}")
# TODO: Add logic to forward the feedback to the appropriate agent or update the research state
async def handle_chat(websocket, data: str, manager):
json_data = json.loads(data[4:])
print(f"Received chat message: {json_data.get('message')}")
await manager.chat(json_data.get("message"), websocket)
async def generate_report_files(report: str, filename: str) -> Dict[str, str]:
pdf_path = await write_md_to_pdf(report, filename)
docx_path = await write_md_to_word(report, filename)
md_path = await write_text_to_md(report, filename)
return {"pdf": pdf_path, "docx": docx_path, "md": md_path}
async def send_file_paths(websocket, file_paths: Dict[str, str]):
await websocket.send_json({"type": "path", "output": file_paths})
def get_config_dict(
langchain_api_key: str, openai_api_key: str, tavily_api_key: str,
google_api_key: str, google_cx_key: str, bing_api_key: str,
searchapi_api_key: str, serpapi_api_key: str, serper_api_key: str, searx_url: str
) -> Dict[str, str]:
return {
"LANGCHAIN_API_KEY": langchain_api_key or os.getenv("LANGCHAIN_API_KEY", ""),
"OPENAI_API_KEY": openai_api_key or os.getenv("OPENAI_API_KEY", ""),
"TAVILY_API_KEY": tavily_api_key or os.getenv("TAVILY_API_KEY", ""),
"GOOGLE_API_KEY": google_api_key or os.getenv("GOOGLE_API_KEY", ""),
"GOOGLE_CX_KEY": google_cx_key or os.getenv("GOOGLE_CX_KEY", ""),
"BING_API_KEY": bing_api_key or os.getenv("BING_API_KEY", ""),
"SEARCHAPI_API_KEY": searchapi_api_key or os.getenv("SEARCHAPI_API_KEY", ""),
"SERPAPI_API_KEY": serpapi_api_key or os.getenv("SERPAPI_API_KEY", ""),
"SERPER_API_KEY": serper_api_key or os.getenv("SERPER_API_KEY", ""),
"SEARX_URL": searx_url or os.getenv("SEARX_URL", ""),
"LANGCHAIN_TRACING_V2": os.getenv("LANGCHAIN_TRACING_V2", "true"),
"DOC_PATH": os.getenv("DOC_PATH", "/tmp/my-docs"),
"RETRIEVER": os.getenv("RETRIEVER", ""),
"EMBEDDING_MODEL": os.getenv("OPENAI_EMBEDDING_MODEL", "")
}
def update_environment_variables(config: Dict[str, str]):
for key, value in config.items():
os.environ[key] = value
async def handle_file_upload(file, DOC_PATH: str) -> Dict[str, str]:
file_path = os.path.join(DOC_PATH, os.path.basename(file.filename))
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
print(f"File uploaded to {file_path}")
document_loader = DocumentLoader(DOC_PATH)
await document_loader.load()
return {"filename": file.filename, "path": file_path}
async def handle_file_deletion(filename: str, DOC_PATH: str) -> JSONResponse:
file_path = os.path.join(DOC_PATH, os.path.basename(filename))
if os.path.exists(file_path):
os.remove(file_path)
print(f"File deleted: {file_path}")
return JSONResponse(content={"message": "File deleted successfully"})
else:
print(f"File not found: {file_path}")
return JSONResponse(status_code=404, content={"message": "File not found"})
async def execute_multi_agents(manager) -> Any:
websocket = manager.active_connections[0] if manager.active_connections else None
if websocket:
report = await run_research_task("Is AI in a hype cycle?", websocket, stream_output)
return {"report": report}
else:
return JSONResponse(status_code=400, content={"message": "No active WebSocket connection"})
async def handle_websocket_communication(websocket, manager):
while True:
data = await websocket.receive_text()
if data.startswith("start"):
await handle_start_command(websocket, data, manager)
elif data.startswith("human_feedback"):
await handle_human_feedback(data)
elif data.startswith("chat"):
await handle_chat(websocket, data, manager)
else:
print("Error: Unknown command or not enough parameters provided.")
def extract_command_data(json_data: Dict) -> tuple:
return (
json_data.get("task"),
json_data.get("report_type"),
json_data.get("source_urls"),
json_data.get("document_urls"),
json_data.get("tone"),
json_data.get("headers", {}),
json_data.get("report_source")
)