File size: 2,816 Bytes
d029425
d102e03
 
27e2360
d102e03
27e2360
 
d102e03
8d0a0d3
 
 
 
 
 
 
d102e03
27e2360
 
d029425
 
 
 
27e2360
 
 
 
d029425
 
 
 
27e2360
 
 
 
 
 
 
 
 
d102e03
 
27e2360
 
 
 
d102e03
 
 
27e2360
d102e03
 
27e2360
 
 
 
 
 
 
 
d102e03
 
27e2360
 
d102e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfa084c
27e2360
dfa084c
 
27e2360
d102e03
 
8d0a0d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Generator, Dict, List

from googleapiclient.discovery import build
from streamlit import secrets

INSTRUCTIONS = "Instructions: " \
               "Using the provided web search results, " \
               "write a comprehensive reply to the given query. " \
               "Make sure to cite results using [[number](URL)] notation after the reference. " \
               "If the provided search results refer to multiple subjects with the same name, " \
               "write separate answers for each subject."


def get_google_api_key():
    """Returns the Google API key from streamlit's secrets"""
    try:
        return secrets["google_search_api_key"]
    except FileNotFoundError:
        return os.environ["google_search_api_key"]


def get_google_cse_id():
    """Returns the Google CSE ID from streamlit's secrets"""
    try:
        return secrets["google_cse_id"]
    except FileNotFoundError:
        return os.environ["google_cse_id"]


def google_search(search_term, **kwargs) -> list:
    service = build("customsearch", "v1", developerKey=get_google_api_key())
    search_engine = service.cse()
    res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute()
    return res['items']


@dataclass
class SearchResult:
    __slots__ = ["title", "body", "url"]
    title: str
    body: str
    url: str


def get_web_search_results(
        query: str,
        num_results: int,
) -> Generator[SearchResult, None, None]:
    """Gets a list of web search results using the Google search API"""
    rew_results: List[Dict[str, str]] = google_search(
        search_term=query,
        num=num_results
    )[:num_results]
    for result in rew_results:
        if result["snippet"].endswith("\xa0..."):
            result["snippet"] = result["snippet"][:-4]
        yield SearchResult(
            title=result["title"],
            body=result["snippet"],
            url=result["link"],
        )


def format_search_result(search_result: Generator[SearchResult, None, None]) -> str:
    """Formats a search result to be added to the prompt."""
    ans = ""
    for i, result in enumerate(search_result):
        ans += f"[{i}] {result.body}\nURL: {result.url}\n\n"
    return ans


def rewrite_prompt(
        prompt: str,
) -> str:
    """Rewrites the prompt by adding web search results to it."""
    raw_results = get_web_search_results(
        query=prompt,
        num_results=5,
    )
    formatted_results = "Web search results:\n" + format_search_result(raw_results)
    formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
    formatted_prompt = f"Query: {prompt}"
    return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])