File size: 6,104 Bytes
7a6fa31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
from typing import List
from langchain_core.documents import Document
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_community.vectorstores import (
    MyScale,
    MyScaleSettings,
)
from langchain_community.vectorstores.qdrant import Qdrant
from langchain_core.callbacks.manager import (
    CallbackManagerForRetrieverRun,
)
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_qdrant.vectorstores import Qdrant

from .metadata import CUISINES, DIETS, EQUIPMENT, KEY_INGREDIENTS, OCCASIONS

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
#QDRANT_CLOUD_KEY = os.environ.get("QDRANT_CLOUD_KEY")
#QDRANT_CLOUD_URL = "https://30591e3d-7092-41c4-95e1-4d3c7ef6e894.us-east4-0.gcp.cloud.qdrant.io"


# Define embedding model
base_embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=OPENAI_API_KEY)


def get_ensemble_retriever():
    # Use a Qdrant VectorStore to embed and store our data
    qdrant_descriptions = Qdrant.from_existing_collection(
        embedding=base_embeddings_model,
        # 3 vector indices - recipe_descriptions, recipe_nutrition, recipe_ingredients
        collection_name="recipe_descriptions",
        url=QDRANT_CLOUD_URL,
        api_key=QDRANT_CLOUD_KEY,
    )

    qdrant_nutrition = Qdrant.from_existing_collection(
        embedding=base_embeddings_model,
        collection_name="recipe_nutrition",
        url=QDRANT_CLOUD_URL,
        api_key=QDRANT_CLOUD_KEY,
    )

    qdrant_ingredients = Qdrant.from_existing_collection(
        embedding=base_embeddings_model,
        collection_name="recipe_ingredients",
        url=QDRANT_CLOUD_URL,
        api_key=QDRANT_CLOUD_KEY,
    )

    # Convert retrieved documents to JSON-serializable format
    descriptions_retriever = qdrant_descriptions.as_retriever(search_kwargs={"k": 20})
    nutrition_retriever = qdrant_nutrition.as_retriever(search_kwargs={"k": 20})
    ingredients_retriever = qdrant_ingredients.as_retriever(search_kwargs={"k": 20})

    ensemble_retriever = EnsembleRetriever(
        retrievers=[
            descriptions_retriever,
            nutrition_retriever,
            ingredients_retriever,
        ],
        weights=[
            0.5,
            0.25,
            0.25,
        ],
    )

    return ensemble_retriever


def _list_to_string(l: list) -> str:
    return ", ".join([f"`{item}`" for item in l])


class ModifiedSelfQueryRetriever(SelfQueryRetriever):
    def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
        """Get documents relevant for a query.

        Args:
            query: string to find relevant documents for

        Returns:
            List of relevant documents
        """
        structured_query = self.query_constructor.invoke(
            {"query": query}, config={"callbacks": run_manager.get_child()}
        )
        # if self.verbose:
        #     logger.info(f"Generated Query: {structured_query}")

        new_query, search_kwargs = self._prepare_query(query, structured_query)

        print("search_kwargs", search_kwargs)
        self.search_kwargs = search_kwargs

        docs = self._get_docs_with_query(new_query, search_kwargs)
        return docs


def get_self_retriever(llm_model):
    metadata_field_info = [
        AttributeInfo(
            name="cuisine",
            description="The national / ethnic cuisine categories of the recipe."
            f"It should be one of {_list_to_string(CUISINES)}. "
            "It only supports contain comparisons. "
            f"Here are some examples: contain (cuisine, '{CUISINES[0]}')",
            type="list[string]",
        ),
        AttributeInfo(
            name="diet",
            description="The diets / dietary restrictions satisfied by this recipe."
            f"It should be one of {_list_to_string(DIETS)}. "
            "It only supports contain comparisons. "
            f"Here are some examples: contain (diet, '{DIETS[0]}')",
            type="list[string]",
        ),
        AttributeInfo(
            name="equipment",
            description="The equipment required by this recipe."
            f"It should be one of {_list_to_string(EQUIPMENT)}. "
            "It only supports contain comparisons. "
            f"Here are some examples: contain (equipment, '{EQUIPMENT[0]}')",
            type="list[string]",
        ),
        AttributeInfo(
            name="occasion",
            description="The occasions, holidays, celebrations that are well suited for this recipe."
            f"It should be one of {_list_to_string(OCCASIONS)}. "
            "It only supports contain comparisons. "
            f"Here are some examples: contain (occasion, '{OCCASIONS[0]}')",
            type="list[string]",
        ),
        # AttributeInfo(
        #     name="ingredients",
        #     description="The ingredients used to make this recipe."
        #     f"It should be one of {_list_to_string(KEY_INGREDIENTS)}"
        #     "It only supports contain comparisons. "
        #     f"Here are some examples: contain (ingredients, '{KEY_INGREDIENTS[0]}')",
        #     type="list[string]",
        # ),
        AttributeInfo(
            name="time",
            description="The estimated time in minutes required to cook and prepare the recipe",
            type="integer",
        ),
    ]

    config = MyScaleSettings(
        host=os.environ["MYSCALE_HOST"],
        port=443,
        username=os.environ["MYSCALE_USERNAME"],
        password=os.environ["MYSCALE_PASSWORD"],
    )
    vectorstore = MyScale(base_embeddings_model, config)

    retriever = ModifiedSelfQueryRetriever.from_llm(
        llm_model,
        vectorstore,
        "Brief summary and key attributes of a recipe, including ingredients, cooking time, occasion, cuisine and diet",
        metadata_field_info,
        verbose=True,
        search_kwargs={"k": 10},
    )
    return retriever