File size: 3,591 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from typing import Any, Dict, List, Literal, Optional, Union

from exa_py import Exa  # type: ignore
from exa_py.api import HighlightsContentsOptions, TextContentsOptions  # type: ignore
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.retrievers import BaseRetriever

from langchain_exa._utilities import initialize_client


def _get_metadata(result: Any) -> Dict[str, Any]:
    """Get the metadata from a result object."""
    metadata = {
        "title": result.title,
        "url": result.url,
        "id": result.id,
        "score": result.score,
        "published_date": result.published_date,
        "author": result.author,
    }
    if getattr(result, "highlights"):
        metadata["highlights"] = result.highlights
    if getattr(result, "highlight_scores"):
        metadata["highlight_scores"] = result.highlight_scores
    return metadata


class ExaSearchRetriever(BaseRetriever):
    """Exa Search retriever."""

    k: int = 10  # num_results
    """The number of search results to return."""
    include_domains: Optional[List[str]] = None
    """A list of domains to include in the search."""
    exclude_domains: Optional[List[str]] = None
    """A list of domains to exclude from the search."""
    start_crawl_date: Optional[str] = None
    """The start date for the crawl (in YYYY-MM-DD format)."""
    end_crawl_date: Optional[str] = None
    """The end date for the crawl (in YYYY-MM-DD format)."""
    start_published_date: Optional[str] = None
    """The start date for when the document was published (in YYYY-MM-DD format)."""
    end_published_date: Optional[str] = None
    """The end date for when the document was published (in YYYY-MM-DD format)."""
    use_autoprompt: Optional[bool] = None
    """Whether to use autoprompt for the search."""
    type: str = "neural"
    """The type of search, 'keyword' or 'neural'. Default: neural"""
    highlights: Optional[Union[HighlightsContentsOptions, bool]] = None
    """Whether to set the page content to the highlights of the results."""
    text_contents_options: Union[TextContentsOptions, Literal[True]] = True
    """How to set the page content of the results"""

    client: Exa = Field(default=None)
    exa_api_key: SecretStr = Field(default=None)
    exa_base_url: Optional[str] = None

    @root_validator(pre=True)
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate the environment."""
        values = initialize_client(values)
        return values

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        response = self.client.search_and_contents(  # type: ignore[misc]
            query,
            num_results=self.k,
            text=self.text_contents_options,
            highlights=self.highlights,  # type: ignore
            include_domains=self.include_domains,
            exclude_domains=self.exclude_domains,
            start_crawl_date=self.start_crawl_date,
            end_crawl_date=self.end_crawl_date,
            start_published_date=self.start_published_date,
            end_published_date=self.end_published_date,
            use_autoprompt=self.use_autoprompt,
        )

        results = response.results

        return [
            Document(
                page_content=(result.text),
                metadata=_get_metadata(result),
            )
            for result in results
        ]