Spaces:
Running
Running
import asyncio | |
import random | |
import json | |
from typing import Dict, Optional | |
import logging | |
from ..actions.utils import stream_output | |
from ..actions.query_processing import plan_research_outline, get_search_results | |
from ..document import DocumentLoader, OnlineDocumentLoader, LangChainDocumentLoader | |
from ..utils.enum import ReportSource, ReportType, Tone | |
from ..utils.logging_config import get_json_handler, get_research_logger | |
class ResearchConductor: | |
"""Manages and coordinates the research process.""" | |
def __init__(self, researcher): | |
self.researcher = researcher | |
self.logger = logging.getLogger('research') | |
self.json_handler = get_json_handler() | |
async def plan_research(self, query): | |
self.logger.info(f"Planning research for query: {query}") | |
await stream_output( | |
"logs", | |
"planning_research", | |
f"π Browsing the web to learn more about the task: {query}...", | |
self.researcher.websocket, | |
) | |
search_results = await get_search_results(query, self.researcher.retrievers[0]) | |
self.logger.info(f"Initial search results obtained: {len(search_results)} results") | |
await stream_output( | |
"logs", | |
"planning_research", | |
f"π€ Planning the research strategy and subtasks...", | |
self.researcher.websocket, | |
) | |
outline = await plan_research_outline( | |
query=query, | |
search_results=search_results, | |
agent_role_prompt=self.researcher.role, | |
cfg=self.researcher.cfg, | |
parent_query=self.researcher.parent_query, | |
report_type=self.researcher.report_type, | |
cost_callback=self.researcher.add_costs, | |
) | |
self.logger.info(f"Research outline planned: {outline}") | |
return outline | |
async def conduct_research(self): | |
"""Runs the GPT Researcher to conduct research""" | |
if self.json_handler: | |
self.json_handler.update_content("query", self.researcher.query) | |
self.logger.info(f"Starting research for query: {self.researcher.query}") | |
# Reset visited_urls and source_urls at the start of each research task | |
self.researcher.visited_urls.clear() | |
research_data = [] | |
if self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"starting_research", | |
f"π Starting the research task for '{self.researcher.query}'...", | |
self.researcher.websocket, | |
) | |
if self.researcher.verbose: | |
await stream_output("logs", "agent_generated", self.researcher.agent, self.researcher.websocket) | |
# Research for relevant sources based on source types below | |
if self.researcher.source_urls: | |
self.logger.info("Using provided source URLs") | |
research_data = await self._get_context_by_urls(self.researcher.source_urls) | |
if research_data and len(research_data) == 0 and self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"answering_from_memory", | |
f"π§ I was unable to find relevant context in the provided sources...", | |
self.researcher.websocket, | |
) | |
if self.researcher.complement_source_urls: | |
self.logger.info("Complementing with web search") | |
additional_research = await self._get_context_by_web_search(self.researcher.query) | |
research_data += ' '.join(additional_research) | |
elif self.researcher.report_source == ReportSource.Web.value: | |
self.logger.info("Using web search") | |
research_data = await self._get_context_by_web_search(self.researcher.query) | |
# ... rest of the conditions ... | |
elif self.researcher.report_source == ReportSource.Local.value: | |
self.logger.info("Using local search") | |
document_data = await DocumentLoader(self.researcher.cfg.doc_path).load() | |
self.logger.info(f"Loaded {len(document_data)} documents") | |
if self.researcher.vector_store: | |
self.researcher.vector_store.load(document_data) | |
research_data = await self._get_context_by_web_search(self.researcher.query, document_data) | |
# Hybrid search including both local documents and web sources | |
elif self.researcher.report_source == ReportSource.Hybrid.value: | |
if self.researcher.document_urls: | |
document_data = await OnlineDocumentLoader(self.researcher.document_urls).load() | |
else: | |
document_data = await DocumentLoader(self.researcher.cfg.doc_path).load() | |
if self.researcher.vector_store: | |
self.researcher.vector_store.load(document_data) | |
docs_context = await self._get_context_by_web_search(self.researcher.query, document_data) | |
web_context = await self._get_context_by_web_search(self.researcher.query) | |
research_data = f"Context from local documents: {docs_context}\n\nContext from web sources: {web_context}" | |
elif self.researcher.report_source == ReportSource.LangChainDocuments.value: | |
langchain_documents_data = await LangChainDocumentLoader( | |
self.researcher.documents | |
).load() | |
if self.researcher.vector_store: | |
self.researcher.vector_store.load(langchain_documents_data) | |
research_data = await self._get_context_by_web_search( | |
self.researcher.query, langchain_documents_data | |
) | |
elif self.researcher.report_source == ReportSource.LangChainVectorStore.value: | |
research_data = await self._get_context_by_vectorstore(self.researcher.query, self.researcher.vector_store_filter) | |
# Rank and curate the sources | |
self.researcher.context = research_data | |
if self.researcher.cfg.curate_sources: | |
self.logger.info("Curating sources") | |
self.researcher.context = await self.researcher.source_curator.curate_sources(research_data) | |
if self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"research_step_finalized", | |
f"Finalized research step.\nπΈ Total Research Costs: ${self.researcher.get_costs()}", | |
self.researcher.websocket, | |
) | |
if self.json_handler: | |
self.json_handler.update_content("costs", self.researcher.get_costs()) | |
self.json_handler.update_content("context", self.researcher.context) | |
self.logger.info(f"Research completed. Context size: {len(str(self.researcher.context))}") | |
return self.researcher.context | |
async def _get_context_by_urls(self, urls): | |
"""Scrapes and compresses the context from the given urls""" | |
self.logger.info(f"Getting context from URLs: {urls}") | |
new_search_urls = await self._get_new_urls(urls) | |
self.logger.info(f"New URLs to process: {new_search_urls}") | |
scraped_content = await self.researcher.scraper_manager.browse_urls(new_search_urls) | |
self.logger.info(f"Scraped content from {len(scraped_content)} URLs") | |
if self.researcher.vector_store: | |
self.logger.info("Loading content into vector store") | |
self.researcher.vector_store.load(scraped_content) | |
context = await self.researcher.context_manager.get_similar_content_by_query( | |
self.researcher.query, scraped_content | |
) | |
self.logger.info(f"Generated context length: {len(context)}") | |
return context | |
# Add logging to other methods similarly... | |
async def _get_context_by_vectorstore(self, query, filter: Optional[dict] = None): | |
""" | |
Generates the context for the research task by searching the vectorstore | |
Returns: | |
context: List of context | |
""" | |
context = [] | |
# Generate Sub-Queries including original query | |
sub_queries = await self.plan_research(query) | |
# If this is not part of a sub researcher, add original query to research for better results | |
if self.researcher.report_type != "subtopic_report": | |
sub_queries.append(query) | |
if self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"subqueries", | |
f"ποΈ I will conduct my research based on the following queries: {sub_queries}...", | |
self.researcher.websocket, | |
True, | |
sub_queries, | |
) | |
# Using asyncio.gather to process the sub_queries asynchronously | |
context = await asyncio.gather( | |
*[ | |
self._process_sub_query_with_vectorstore(sub_query, filter) | |
for sub_query in sub_queries | |
] | |
) | |
return context | |
async def _get_context_by_web_search(self, query, scraped_data: list = []): | |
""" | |
Generates the context for the research task by searching the query and scraping the results | |
Returns: | |
context: List of context | |
""" | |
self.logger.info(f"Starting web search for query: {query}") | |
# Generate Sub-Queries including original query | |
sub_queries = await self.plan_research(query) | |
self.logger.info(f"Generated sub-queries: {sub_queries}") | |
# If this is not part of a sub researcher, add original query to research for better results | |
if self.researcher.report_type != "subtopic_report": | |
sub_queries.append(query) | |
if self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"subqueries", | |
f"ποΈ I will conduct my research based on the following queries: {sub_queries}...", | |
self.researcher.websocket, | |
True, | |
sub_queries, | |
) | |
# Using asyncio.gather to process the sub_queries asynchronously | |
try: | |
context = await asyncio.gather( | |
*[ | |
self._process_sub_query(sub_query, scraped_data) | |
for sub_query in sub_queries | |
] | |
) | |
self.logger.info(f"Gathered context from {len(context)} sub-queries") | |
# Filter out empty results and join the context | |
context = [c for c in context if c] | |
if context: | |
combined_context = " ".join(context) | |
self.logger.info(f"Combined context size: {len(combined_context)}") | |
return combined_context | |
return [] | |
except Exception as e: | |
self.logger.error(f"Error during web search: {e}", exc_info=True) | |
return [] | |
async def _process_sub_query(self, sub_query: str, scraped_data: list = []): | |
"""Takes in a sub query and scrapes urls based on it and gathers context.""" | |
if self.json_handler: | |
self.json_handler.log_event("sub_query", { | |
"query": sub_query, | |
"scraped_data_size": len(scraped_data) | |
}) | |
if self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"running_subquery_research", | |
f"\nπ Running research for '{sub_query}'...", | |
self.researcher.websocket, | |
) | |
try: | |
if not scraped_data: | |
scraped_data = await self._scrape_data_by_urls(sub_query) | |
self.logger.info(f"Scraped data size: {len(scraped_data)}") | |
content = await self.researcher.context_manager.get_similar_content_by_query(sub_query, scraped_data) | |
self.logger.info(f"Content found for sub-query: {len(str(content)) if content else 0} chars") | |
if content and self.researcher.verbose: | |
await stream_output( | |
"logs", "subquery_context_window", f"π {content}", self.researcher.websocket | |
) | |
elif self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"subquery_context_not_found", | |
f"π€· No content found for '{sub_query}'...", | |
self.researcher.websocket, | |
) | |
if content: | |
if self.json_handler: | |
self.json_handler.log_event("content_found", { | |
"sub_query": sub_query, | |
"content_size": len(content) | |
}) | |
return content | |
except Exception as e: | |
self.logger.error(f"Error processing sub-query {sub_query}: {e}", exc_info=True) | |
return "" | |
async def _process_sub_query_with_vectorstore(self, sub_query: str, filter: Optional[dict] = None): | |
"""Takes in a sub query and gathers context from the user provided vector store | |
Args: | |
sub_query (str): The sub-query generated from the original query | |
Returns: | |
str: The context gathered from search | |
""" | |
if self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"running_subquery_with_vectorstore_research", | |
f"\nπ Running research for '{sub_query}'...", | |
self.researcher.websocket, | |
) | |
content = await self.researcher.context_manager.get_similar_content_by_query_with_vectorstore(sub_query, filter) | |
if content and self.researcher.verbose: | |
await stream_output( | |
"logs", "subquery_context_window", f"π {content}", self.researcher.websocket | |
) | |
elif self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"subquery_context_not_found", | |
f"π€· No content found for '{sub_query}'...", | |
self.researcher.websocket, | |
) | |
return content | |
async def _get_new_urls(self, url_set_input): | |
"""Gets the new urls from the given url set. | |
Args: url_set_input (set[str]): The url set to get the new urls from | |
Returns: list[str]: The new urls from the given url set | |
""" | |
new_urls = [] | |
for url in url_set_input: | |
if url not in self.researcher.visited_urls: | |
self.researcher.visited_urls.add(url) | |
new_urls.append(url) | |
if self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"added_source_url", | |
f"β Added source url to research: {url}\n", | |
self.researcher.websocket, | |
True, | |
url, | |
) | |
return new_urls | |
async def _search_relevant_source_urls(self, query): | |
new_search_urls = [] | |
# Iterate through all retrievers | |
for retriever_class in self.researcher.retrievers: | |
# Instantiate the retriever with the sub-query | |
retriever = retriever_class(query) | |
# Perform the search using the current retriever | |
search_results = await asyncio.to_thread( | |
retriever.search, max_results=self.researcher.cfg.max_search_results_per_query | |
) | |
# Collect new URLs from search results | |
search_urls = [url.get("href") for url in search_results] | |
new_search_urls.extend(search_urls) | |
# Get unique URLs | |
new_search_urls = await self._get_new_urls(new_search_urls) | |
random.shuffle(new_search_urls) | |
return new_search_urls | |
async def _scrape_data_by_urls(self, sub_query): | |
""" | |
Runs a sub-query across multiple retrievers and scrapes the resulting URLs. | |
Args: | |
sub_query (str): The sub-query to search for. | |
Returns: | |
list: A list of scraped content results. | |
""" | |
new_search_urls = await self._search_relevant_source_urls(sub_query) | |
# Log the research process if verbose mode is on | |
if self.researcher.verbose: | |
await stream_output( | |
"logs", | |
"researching", | |
f"π€ Researching for relevant information across multiple sources...\n", | |
self.researcher.websocket, | |
) | |
# Scrape the new URLs | |
scraped_content = await self.researcher.scraper_manager.browse_urls(new_search_urls) | |
if self.researcher.vector_store: | |
self.researcher.vector_store.load(scraped_content) | |
return scraped_content | |