web-search-api / apis /search_api.py
MAsad789565's picture
Update apis/search_api.py
cde1a76 verified
import argparse
import os
import sys
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel, Field
from typing import Union
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from networks.google_searcher import GoogleSearcher
from networks.webpage_fetcher import BatchWebpageFetcher
from documents.query_results_extractor import QueryResultsExtractor
from documents.webpage_content_extractor import BatchWebpageContentExtractor
from utils.logger import logger
class SearchAPIApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title="Web Search API",
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
version="1.0",
)
self.setup_routes()
self.setup_middleware()
class QueriesToSearchResultsPostItem(BaseModel):
queries: list = Field(
default=[""],
description="(list[str]) Queries to search",
)
result_num: int = Field(
default=10,
description="(int) Number of search results",
)
safe: bool = Field(
default=False,
description="(bool) Enable SafeSearch",
)
types: list = Field(
default=["web"],
description="(list[str]) Types of search results: `web`, `image`, `videos`, `news`",
)
extract_webpage: bool = Field(
default=False,
description="(bool) Enable extracting main text contents from webpage, will add `text` filed in each `query_result` dict",
)
overwrite_query_html: bool = Field(
default=False,
description="(bool) Overwrite HTML file of query results",
)
overwrite_webpage_html: bool = Field(
default=False,
description="(bool) Overwrite HTML files of webpages from query results",
)
def queries_to_search_results(self, item: QueriesToSearchResultsPostItem):
google_searcher = GoogleSearcher()
queries_search_results = []
for query in item.queries:
query_results_extractor = QueryResultsExtractor()
if not query.strip():
continue
query_html_path = google_searcher.search(
query=query,
result_num=item.result_num,
safe=item.safe,
overwrite=item.overwrite_query_html,
)
query_search_results = query_results_extractor.extract(query_html_path)
queries_search_results.append(query_search_results)
logger.note(queries_search_results)
if item.extract_webpage:
queries_search_results = self.extract_webpages(
queries_search_results,
overwrite_webpage_html=item.overwrite_webpage_html,
)
return queries_search_results
def extract_webpages(self, queries_search_results, overwrite_webpage_html=False):
for query_idx, query_search_results in enumerate(queries_search_results):
# Fetch webpages with urls
batch_webpage_fetcher = BatchWebpageFetcher()
urls = [
query_result["url"]
for query_result in query_search_results["query_results"]
]
url_and_html_path_list = batch_webpage_fetcher.fetch(
urls,
overwrite=overwrite_webpage_html,
output_parent=query_search_results["query"],
)
# Extract webpage contents from htmls
html_paths = [
str(url_and_html_path["html_path"])
for url_and_html_path in url_and_html_path_list
]
batch_webpage_content_extractor = BatchWebpageContentExtractor()
html_path_and_extracted_content_list = (
batch_webpage_content_extractor.extract(html_paths)
)
# Build the map of url to extracted_content
html_path_to_url_dict = {
str(url_and_html_path["html_path"]): url_and_html_path["url"]
for url_and_html_path in url_and_html_path_list
}
url_to_extracted_content_dict = {
html_path_to_url_dict[
html_path_and_extracted_content["html_path"]
]: html_path_and_extracted_content["extracted_content"]
for html_path_and_extracted_content in html_path_and_extracted_content_list
}
# Write extracted contents (as 'text' field) to query_search_results
for query_result_idx, query_result in enumerate(
query_search_results["query_results"]
):
url = query_result["url"]
extracted_content = url_to_extracted_content_dict[url]
queries_search_results[query_idx]["query_results"][query_result_idx][
"text"
] = extracted_content
return queries_search_results
def setup_routes(self):
self.app.post(
"/queries_to_search_results",
summary="Search queries, and extract contents from results",
)(self.queries_to_search_results)
def setup_middleware(self):
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class NoCacheMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
return response
self.app.add_middleware(NoCacheMiddleware)
class ArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ArgParser, self).__init__(*args, **kwargs)
self.add_argument(
"-s",
"--server",
type=str,
default="0.0.0.0",
help="Server IP for Web Search API",
)
self.add_argument(
"-p",
"--port",
type=int,
default=21111,
help="Server Port for Web Search API",
)
self.add_argument(
"-d",
"--dev",
default=False,
action="store_true",
help="Run in dev mode",
)
self.args = self.parse_args(sys.argv[1:])
app = SearchAPIApp().app
if __name__ == "__main__":
args = ArgParser().args
if args.dev:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
else:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
# python -m apis.search_api # [Docker] in product mode
# python -m apis.search_api -d # [Dev] in develop mode