File size: 2,816 Bytes
d029425 d102e03 27e2360 d102e03 27e2360 d102e03 8d0a0d3 d102e03 27e2360 d029425 27e2360 d029425 27e2360 d102e03 27e2360 d102e03 27e2360 d102e03 27e2360 d102e03 27e2360 d102e03 dfa084c 27e2360 dfa084c 27e2360 d102e03 8d0a0d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Generator, Dict, List
from googleapiclient.discovery import build
from streamlit import secrets
INSTRUCTIONS = "Instructions: " \
"Using the provided web search results, " \
"write a comprehensive reply to the given query. " \
"Make sure to cite results using [[number](URL)] notation after the reference. " \
"If the provided search results refer to multiple subjects with the same name, " \
"write separate answers for each subject."
def get_google_api_key():
"""Returns the Google API key from streamlit's secrets"""
try:
return secrets["google_search_api_key"]
except FileNotFoundError:
return os.environ["google_search_api_key"]
def get_google_cse_id():
"""Returns the Google CSE ID from streamlit's secrets"""
try:
return secrets["google_cse_id"]
except FileNotFoundError:
return os.environ["google_cse_id"]
def google_search(search_term, **kwargs) -> list:
service = build("customsearch", "v1", developerKey=get_google_api_key())
search_engine = service.cse()
res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute()
return res['items']
@dataclass
class SearchResult:
__slots__ = ["title", "body", "url"]
title: str
body: str
url: str
def get_web_search_results(
query: str,
num_results: int,
) -> Generator[SearchResult, None, None]:
"""Gets a list of web search results using the Google search API"""
rew_results: List[Dict[str, str]] = google_search(
search_term=query,
num=num_results
)[:num_results]
for result in rew_results:
if result["snippet"].endswith("\xa0..."):
result["snippet"] = result["snippet"][:-4]
yield SearchResult(
title=result["title"],
body=result["snippet"],
url=result["link"],
)
def format_search_result(search_result: Generator[SearchResult, None, None]) -> str:
"""Formats a search result to be added to the prompt."""
ans = ""
for i, result in enumerate(search_result):
ans += f"[{i}] {result.body}\nURL: {result.url}\n\n"
return ans
def rewrite_prompt(
prompt: str,
) -> str:
"""Rewrites the prompt by adding web search results to it."""
raw_results = get_web_search_results(
query=prompt,
num_results=5,
)
formatted_results = "Web search results:\n" + format_search_result(raw_results)
formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
formatted_prompt = f"Query: {prompt}"
return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])
|