File size: 1,376 Bytes
e931b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import json
from typing import List

from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema import BaseRetriever, Document
from langchain.tools import Tool

from backend.chat_bot.json_decoder import CustomJSONEncoder


class RetrieverInput(BaseModel):
    query: str = Field(description="query to look up in retriever")


def create_retriever_tool(
        retriever: BaseRetriever,
        tool_name: str,
        description: str
) -> Tool:
    """Create a tool to do retrieval of documents.

    Args:
        retriever: The retriever to use for the retrieval
        tool_name: The name for the tool. This will be passed to the language model,
            so should be unique and somewhat descriptive.
        description: The description for the tool. This will be passed to the language
            model, so should be descriptive.

    Returns:
        Tool class to pass to an agent
    """
    def wrap(func):
        def wrapped_retrieve(*args, **kwargs):
            docs: List[Document] = func(*args, **kwargs)
            return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)

        return wrapped_retrieve

    return Tool(
        name=tool_name,
        description=description,
        func=wrap(retriever.get_relevant_documents),
        coroutine=retriever.aget_relevant_documents,
        args_schema=RetrieverInput,
    )