Spaces:
Running
Running
from typing import Any, Dict, List, Literal, Optional, Union | |
from exa_py import Exa # type: ignore | |
from exa_py.api import HighlightsContentsOptions, TextContentsOptions # type: ignore | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.documents import Document | |
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_exa._utilities import initialize_client | |
def _get_metadata(result: Any) -> Dict[str, Any]: | |
"""Get the metadata from a result object.""" | |
metadata = { | |
"title": result.title, | |
"url": result.url, | |
"id": result.id, | |
"score": result.score, | |
"published_date": result.published_date, | |
"author": result.author, | |
} | |
if getattr(result, "highlights"): | |
metadata["highlights"] = result.highlights | |
if getattr(result, "highlight_scores"): | |
metadata["highlight_scores"] = result.highlight_scores | |
return metadata | |
class ExaSearchRetriever(BaseRetriever): | |
"""Exa Search retriever.""" | |
k: int = 10 # num_results | |
"""The number of search results to return.""" | |
include_domains: Optional[List[str]] = None | |
"""A list of domains to include in the search.""" | |
exclude_domains: Optional[List[str]] = None | |
"""A list of domains to exclude from the search.""" | |
start_crawl_date: Optional[str] = None | |
"""The start date for the crawl (in YYYY-MM-DD format).""" | |
end_crawl_date: Optional[str] = None | |
"""The end date for the crawl (in YYYY-MM-DD format).""" | |
start_published_date: Optional[str] = None | |
"""The start date for when the document was published (in YYYY-MM-DD format).""" | |
end_published_date: Optional[str] = None | |
"""The end date for when the document was published (in YYYY-MM-DD format).""" | |
use_autoprompt: Optional[bool] = None | |
"""Whether to use autoprompt for the search.""" | |
type: str = "neural" | |
"""The type of search, 'keyword' or 'neural'. Default: neural""" | |
highlights: Optional[Union[HighlightsContentsOptions, bool]] = None | |
"""Whether to set the page content to the highlights of the results.""" | |
text_contents_options: Union[TextContentsOptions, Literal[True]] = True | |
"""How to set the page content of the results""" | |
client: Exa = Field(default=None) | |
exa_api_key: SecretStr = Field(default=None) | |
exa_base_url: Optional[str] = None | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate the environment.""" | |
values = initialize_client(values) | |
return values | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
response = self.client.search_and_contents( # type: ignore[misc] | |
query, | |
num_results=self.k, | |
text=self.text_contents_options, | |
highlights=self.highlights, # type: ignore | |
include_domains=self.include_domains, | |
exclude_domains=self.exclude_domains, | |
start_crawl_date=self.start_crawl_date, | |
end_crawl_date=self.end_crawl_date, | |
start_published_date=self.start_published_date, | |
end_published_date=self.end_published_date, | |
use_autoprompt=self.use_autoprompt, | |
) | |
results = response.results | |
return [ | |
Document( | |
page_content=(result.text), | |
metadata=_get_metadata(result), | |
) | |
for result in results | |
] | |