import argparse import os import sys import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware from pydantic import BaseModel, Field from typing import Union from sse_starlette.sse import EventSourceResponse, ServerSentEvent from networks.google_searcher import GoogleSearcher from networks.webpage_fetcher import BatchWebpageFetcher from documents.query_results_extractor import QueryResultsExtractor from documents.webpage_content_extractor import BatchWebpageContentExtractor from utils.logger import logger class SearchAPIApp: def __init__(self): self.app = FastAPI( docs_url="/", title="Web Search API", swagger_ui_parameters={"defaultModelsExpandDepth": -1}, version="1.0", ) self.setup_routes() self.setup_middleware() class QueriesToSearchResultsPostItem(BaseModel): queries: list = Field( default=[""], description="(list[str]) Queries to search", ) result_num: int = Field( default=10, description="(int) Number of search results", ) safe: bool = Field( default=False, description="(bool) Enable SafeSearch", ) types: list = Field( default=["web"], description="(list[str]) Types of search results: `web`, `image`, `videos`, `news`", ) extract_webpage: bool = Field( default=False, description="(bool) Enable extracting main text contents from webpage, will add `text` filed in each `query_result` dict", ) overwrite_query_html: bool = Field( default=False, description="(bool) Overwrite HTML file of query results", ) overwrite_webpage_html: bool = Field( default=False, description="(bool) Overwrite HTML files of webpages from query results", ) def queries_to_search_results(self, item: QueriesToSearchResultsPostItem): google_searcher = GoogleSearcher() queries_search_results = [] for query in item.queries: query_results_extractor = QueryResultsExtractor() if not query.strip(): continue query_html_path = google_searcher.search( query=query, result_num=item.result_num, safe=item.safe, overwrite=item.overwrite_query_html, ) query_search_results = query_results_extractor.extract(query_html_path) queries_search_results.append(query_search_results) logger.note(queries_search_results) if item.extract_webpage: queries_search_results = self.extract_webpages( queries_search_results, overwrite_webpage_html=item.overwrite_webpage_html, ) return queries_search_results def extract_webpages(self, queries_search_results, overwrite_webpage_html=False): for query_idx, query_search_results in enumerate(queries_search_results): # Fetch webpages with urls batch_webpage_fetcher = BatchWebpageFetcher() urls = [ query_result["url"] for query_result in query_search_results["query_results"] ] url_and_html_path_list = batch_webpage_fetcher.fetch( urls, overwrite=overwrite_webpage_html, output_parent=query_search_results["query"], ) # Extract webpage contents from htmls html_paths = [ str(url_and_html_path["html_path"]) for url_and_html_path in url_and_html_path_list ] batch_webpage_content_extractor = BatchWebpageContentExtractor() html_path_and_extracted_content_list = ( batch_webpage_content_extractor.extract(html_paths) ) # Build the map of url to extracted_content html_path_to_url_dict = { str(url_and_html_path["html_path"]): url_and_html_path["url"] for url_and_html_path in url_and_html_path_list } url_to_extracted_content_dict = { html_path_to_url_dict[ html_path_and_extracted_content["html_path"] ]: html_path_and_extracted_content["extracted_content"] for html_path_and_extracted_content in html_path_and_extracted_content_list } # Write extracted contents (as 'text' field) to query_search_results for query_result_idx, query_result in enumerate( query_search_results["query_results"] ): url = query_result["url"] extracted_content = url_to_extracted_content_dict[url] queries_search_results[query_idx]["query_results"][query_result_idx][ "text" ] = extracted_content return queries_search_results def setup_routes(self): self.app.post( "/queries_to_search_results", summary="Search queries, and extract contents from results", )(self.queries_to_search_results) def setup_middleware(self): self.app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class NoCacheMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): response = await call_next(request) response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' response.headers['Pragma'] = 'no-cache' response.headers['Expires'] = '0' return response self.app.add_middleware(NoCacheMiddleware) class ArgParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): super(ArgParser, self).__init__(*args, **kwargs) self.add_argument( "-s", "--server", type=str, default="0.0.0.0", help="Server IP for Web Search API", ) self.add_argument( "-p", "--port", type=int, default=21111, help="Server Port for Web Search API", ) self.add_argument( "-d", "--dev", default=False, action="store_true", help="Run in dev mode", ) self.args = self.parse_args(sys.argv[1:]) app = SearchAPIApp().app if __name__ == "__main__": args = ArgParser().args if args.dev: uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) else: uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) # python -m apis.search_api # [Docker] in product mode # python -m apis.search_api -d # [Dev] in develop mode