Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import sys | |
import uvicorn | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from starlette.middleware.base import BaseHTTPMiddleware | |
from pydantic import BaseModel, Field | |
from typing import Union | |
from sse_starlette.sse import EventSourceResponse, ServerSentEvent | |
from networks.google_searcher import GoogleSearcher | |
from networks.webpage_fetcher import BatchWebpageFetcher | |
from documents.query_results_extractor import QueryResultsExtractor | |
from documents.webpage_content_extractor import BatchWebpageContentExtractor | |
from utils.logger import logger | |
class SearchAPIApp: | |
def __init__(self): | |
self.app = FastAPI( | |
docs_url="/", | |
title="Web Search API", | |
swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
version="1.0", | |
) | |
self.setup_routes() | |
self.setup_middleware() | |
class QueriesToSearchResultsPostItem(BaseModel): | |
queries: list = Field( | |
default=[""], | |
description="(list[str]) Queries to search", | |
) | |
result_num: int = Field( | |
default=10, | |
description="(int) Number of search results", | |
) | |
safe: bool = Field( | |
default=False, | |
description="(bool) Enable SafeSearch", | |
) | |
types: list = Field( | |
default=["web"], | |
description="(list[str]) Types of search results: `web`, `image`, `videos`, `news`", | |
) | |
extract_webpage: bool = Field( | |
default=False, | |
description="(bool) Enable extracting main text contents from webpage, will add `text` filed in each `query_result` dict", | |
) | |
overwrite_query_html: bool = Field( | |
default=False, | |
description="(bool) Overwrite HTML file of query results", | |
) | |
overwrite_webpage_html: bool = Field( | |
default=False, | |
description="(bool) Overwrite HTML files of webpages from query results", | |
) | |
def queries_to_search_results(self, item: QueriesToSearchResultsPostItem): | |
google_searcher = GoogleSearcher() | |
queries_search_results = [] | |
for query in item.queries: | |
query_results_extractor = QueryResultsExtractor() | |
if not query.strip(): | |
continue | |
query_html_path = google_searcher.search( | |
query=query, | |
result_num=item.result_num, | |
safe=item.safe, | |
overwrite=item.overwrite_query_html, | |
) | |
query_search_results = query_results_extractor.extract(query_html_path) | |
queries_search_results.append(query_search_results) | |
logger.note(queries_search_results) | |
if item.extract_webpage: | |
queries_search_results = self.extract_webpages( | |
queries_search_results, | |
overwrite_webpage_html=item.overwrite_webpage_html, | |
) | |
return queries_search_results | |
def extract_webpages(self, queries_search_results, overwrite_webpage_html=False): | |
for query_idx, query_search_results in enumerate(queries_search_results): | |
# Fetch webpages with urls | |
batch_webpage_fetcher = BatchWebpageFetcher() | |
urls = [ | |
query_result["url"] | |
for query_result in query_search_results["query_results"] | |
] | |
url_and_html_path_list = batch_webpage_fetcher.fetch( | |
urls, | |
overwrite=overwrite_webpage_html, | |
output_parent=query_search_results["query"], | |
) | |
# Extract webpage contents from htmls | |
html_paths = [ | |
str(url_and_html_path["html_path"]) | |
for url_and_html_path in url_and_html_path_list | |
] | |
batch_webpage_content_extractor = BatchWebpageContentExtractor() | |
html_path_and_extracted_content_list = ( | |
batch_webpage_content_extractor.extract(html_paths) | |
) | |
# Build the map of url to extracted_content | |
html_path_to_url_dict = { | |
str(url_and_html_path["html_path"]): url_and_html_path["url"] | |
for url_and_html_path in url_and_html_path_list | |
} | |
url_to_extracted_content_dict = { | |
html_path_to_url_dict[ | |
html_path_and_extracted_content["html_path"] | |
]: html_path_and_extracted_content["extracted_content"] | |
for html_path_and_extracted_content in html_path_and_extracted_content_list | |
} | |
# Write extracted contents (as 'text' field) to query_search_results | |
for query_result_idx, query_result in enumerate( | |
query_search_results["query_results"] | |
): | |
url = query_result["url"] | |
extracted_content = url_to_extracted_content_dict[url] | |
queries_search_results[query_idx]["query_results"][query_result_idx][ | |
"text" | |
] = extracted_content | |
return queries_search_results | |
def setup_routes(self): | |
self.app.post( | |
"/queries_to_search_results", | |
summary="Search queries, and extract contents from results", | |
)(self.queries_to_search_results) | |
def setup_middleware(self): | |
self.app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class NoCacheMiddleware(BaseHTTPMiddleware): | |
async def dispatch(self, request, call_next): | |
response = await call_next(request) | |
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' | |
response.headers['Pragma'] = 'no-cache' | |
response.headers['Expires'] = '0' | |
return response | |
self.app.add_middleware(NoCacheMiddleware) | |
class ArgParser(argparse.ArgumentParser): | |
def __init__(self, *args, **kwargs): | |
super(ArgParser, self).__init__(*args, **kwargs) | |
self.add_argument( | |
"-s", | |
"--server", | |
type=str, | |
default="0.0.0.0", | |
help="Server IP for Web Search API", | |
) | |
self.add_argument( | |
"-p", | |
"--port", | |
type=int, | |
default=21111, | |
help="Server Port for Web Search API", | |
) | |
self.add_argument( | |
"-d", | |
"--dev", | |
default=False, | |
action="store_true", | |
help="Run in dev mode", | |
) | |
self.args = self.parse_args(sys.argv[1:]) | |
app = SearchAPIApp().app | |
if __name__ == "__main__": | |
args = ArgParser().args | |
if args.dev: | |
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) | |
else: | |
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) | |
# python -m apis.search_api # [Docker] in product mode | |
# python -m apis.search_api -d # [Dev] in develop mode | |