Spaces:
Sleeping
Sleeping
initial
Browse files- .chainlit/config.toml +84 -0
- .env +1 -0
- Dockerfile +11 -0
- __pycache__/app.cpython-311.pyc +0 -0
- aims.png +0 -0
- app.py +151 -0
- chainlit.md +87 -0
- data/airbnb_midterm.pdf +0 -0
- public/aims.png +0 -0
- public/airbnb.svg +1 -0
- public/barfin.svg +12 -0
- public/favicon.ico +0 -0
- public/fund.svg +3 -0
- public/light.svg +3 -0
- public/logo_dark.png +0 -0
- public/logo_light.png +0 -0
- public/soccer.svg +86 -0
- public/stylesheet.css +35 -0
- requirements.txt +12 -0
- static/aims.png +0 -0
- utils/__pycache__/custom_retriver.cpython-311.pyc +0 -0
- utils/custom_retriver.py +116 -0
.chainlit/config.toml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
# Whether to enable telemetry (default: true). No personal data is collected.
|
3 |
+
enable_telemetry = true
|
4 |
+
|
5 |
+
# List of environment variables to be provided by each user to use the app.
|
6 |
+
user_env = []
|
7 |
+
|
8 |
+
# Duration (in seconds) during which the session is saved when the connection is lost
|
9 |
+
session_timeout = 3600
|
10 |
+
|
11 |
+
# Enable third parties caching (e.g LangChain cache)
|
12 |
+
cache = false
|
13 |
+
|
14 |
+
# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
|
15 |
+
# follow_symlink = false
|
16 |
+
|
17 |
+
[features]
|
18 |
+
# Show the prompt playground
|
19 |
+
prompt_playground = true
|
20 |
+
|
21 |
+
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
22 |
+
unsafe_allow_html = false
|
23 |
+
|
24 |
+
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
25 |
+
latex = false
|
26 |
+
|
27 |
+
# Authorize users to upload files with messages
|
28 |
+
multi_modal = true
|
29 |
+
|
30 |
+
# Allows user to use speech to text
|
31 |
+
[features.speech_to_text]
|
32 |
+
enabled = false
|
33 |
+
# See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string
|
34 |
+
# language = "en-US"
|
35 |
+
|
36 |
+
[UI]
|
37 |
+
# Name of the app and chatbot.
|
38 |
+
name = "Chatbot"
|
39 |
+
|
40 |
+
# Show the readme while the conversation is empty.
|
41 |
+
show_readme_as_default = true
|
42 |
+
|
43 |
+
# Description of the app and chatbot. This is used for HTML tags.
|
44 |
+
# description = ""
|
45 |
+
|
46 |
+
# Large size content are by default collapsed for a cleaner ui
|
47 |
+
default_collapse_content = true
|
48 |
+
|
49 |
+
# The default value for the expand messages settings.
|
50 |
+
default_expand_messages = false
|
51 |
+
|
52 |
+
# Hide the chain of thought details from the user in the UI.
|
53 |
+
hide_cot = false
|
54 |
+
|
55 |
+
# Link to your github repo. This will add a github button in the UI's header.
|
56 |
+
# github = ""
|
57 |
+
|
58 |
+
# Specify a CSS file that can be used to customize the user interface.
|
59 |
+
# The CSS file can be served from the public directory or via an external link.
|
60 |
+
# custom_css = "/public/test.css"
|
61 |
+
|
62 |
+
# Override default MUI light theme. (Check theme.ts)
|
63 |
+
[UI.theme.light]
|
64 |
+
#background = "#FAFAFA"
|
65 |
+
#paper = "#FFFFFF"
|
66 |
+
|
67 |
+
[UI.theme.light.primary]
|
68 |
+
#main = "#F80061"
|
69 |
+
#dark = "#980039"
|
70 |
+
#light = "#FFE7EB"
|
71 |
+
|
72 |
+
# Override default MUI dark theme. (Check theme.ts)
|
73 |
+
[UI.theme.dark]
|
74 |
+
#background = "#FAFAFA"
|
75 |
+
#paper = "#FFFFFF"
|
76 |
+
|
77 |
+
[UI.theme.dark.primary]
|
78 |
+
#main = "#F80061"
|
79 |
+
#dark = "#980039"
|
80 |
+
#light = "#FFE7EB"
|
81 |
+
|
82 |
+
|
83 |
+
[meta]
|
84 |
+
generated_by = "0.7.700"
|
.env
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
OPENAI_API_KEY=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
|
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11.9
|
2 |
+
RUN useradd -m -u 1000 user
|
3 |
+
USER user
|
4 |
+
ENV HOME=/home/user \
|
5 |
+
PATH=/home/user/.local/bin:$PATH
|
6 |
+
WORKDIR $HOME/app
|
7 |
+
COPY --chown=user . $HOME/app
|
8 |
+
COPY ./requirements.txt ~/app/requirements.txt
|
9 |
+
RUN pip install -r requirements.txt
|
10 |
+
COPY . .
|
11 |
+
CMD ["chainlit", "run", "app.py", "--port", "7860"]
|
__pycache__/app.cpython-311.pyc
ADDED
Binary file (5.29 kB). View file
|
|
aims.png
ADDED
![]() |
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python)
|
2 |
+
|
3 |
+
import chainlit as cl # importing chainlit for our app
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
|
6 |
+
from langchain_community.document_loaders import PyPDFLoader
|
7 |
+
from langchain_openai import ChatOpenAI
|
8 |
+
from langchain_core.prompts import ChatPromptTemplate
|
9 |
+
import tiktoken
|
10 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
11 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
12 |
+
from utils.custom_retriver import CustomQDrant
|
13 |
+
#from starters import set_starters
|
14 |
+
|
15 |
+
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
RAG_PROMPT = """
|
22 |
+
CONTEXT:
|
23 |
+
{context}
|
24 |
+
|
25 |
+
QUERY:
|
26 |
+
{question}
|
27 |
+
|
28 |
+
Answer questions first based on provided context and if you can't find answer in provided context, use your previous knowledge.
|
29 |
+
In your answer never mention phrases like Based on provided context, From the context etc.
|
30 |
+
|
31 |
+
At the end of each answer add CONTEXT CONFIDENCE tag -> answer vs. context similarity score -> faithfulness - answer in percent e.g. 85%.
|
32 |
+
Also add CONTEXT vs PRIOR tag: break answer to what you find in provided context and what you build from your prior knowledge.
|
33 |
+
"""
|
34 |
+
|
35 |
+
data_path = "data/airbnb_midterm.pdf"
|
36 |
+
docs = PyPDFLoader(data_path).load()
|
37 |
+
openai_chat_model = ChatOpenAI(model="gpt-4o", streaming=True) #gpt-4o
|
38 |
+
|
39 |
+
def tiktoken_len(text):
|
40 |
+
tokens = tiktoken.encoding_for_model("gpt-4o").encode(
|
41 |
+
text,
|
42 |
+
)
|
43 |
+
return len(tokens)
|
44 |
+
|
45 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
46 |
+
chunk_size = 200,
|
47 |
+
chunk_overlap = 10,
|
48 |
+
length_function = tiktoken_len,
|
49 |
+
)
|
50 |
+
|
51 |
+
split_chunks = text_splitter.split_documents(docs)
|
52 |
+
|
53 |
+
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
|
54 |
+
|
55 |
+
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
|
56 |
+
|
57 |
+
|
58 |
+
qdrant_vectorstore = CustomQDrant.from_documents(
|
59 |
+
split_chunks,
|
60 |
+
embedding_model,
|
61 |
+
location=":memory:",
|
62 |
+
collection_name="air bnb data",
|
63 |
+
score_threshold=0.3
|
64 |
+
|
65 |
+
)
|
66 |
+
|
67 |
+
qdrant_retriever = qdrant_vectorstore.as_retriever()
|
68 |
+
|
69 |
+
from operator import itemgetter
|
70 |
+
from langchain.schema.output_parser import StrOutputParser
|
71 |
+
from langchain.schema.runnable import RunnablePassthrough
|
72 |
+
|
73 |
+
retrieval_augmented_qa_chain = (
|
74 |
+
# INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
|
75 |
+
# "question" : populated by getting the value of the "question" key
|
76 |
+
# "context" : populated by getting the value of the "question" key and chaining it into the base_retriever
|
77 |
+
{"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
|
78 |
+
# "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
|
79 |
+
# by getting the value of the "context" key from the previous step
|
80 |
+
| RunnablePassthrough.assign(context=itemgetter("context"))
|
81 |
+
# "response" : the "context" and "question" values are used to format our prompt object and then piped
|
82 |
+
# into the LLM and stored in a key called "response"
|
83 |
+
# "context" : populated by getting the value of the "context" key from the previous step
|
84 |
+
| {"response": rag_prompt | openai_chat_model, "context": itemgetter("context")}
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
"""@cl.author_rename
|
89 |
+
async def rename(orig_author: str):
|
90 |
+
rename_dict = {"User": "You", "Chatbot": "Airbnb"}
|
91 |
+
return rename_dict.get(orig_author, orig_author)"""
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
@cl.on_chat_start # marks a function that will be executed at the start of a user session
|
96 |
+
async def start_chat():
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
cl.user_session.set("chain", retrieval_augmented_qa_chain, )
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
|
107 |
+
async def main(message: cl.Message):
|
108 |
+
chain = cl.user_session.get("chain")
|
109 |
+
|
110 |
+
resp = await chain.ainvoke({"question" : message.content})
|
111 |
+
source_documents = resp["context"]
|
112 |
+
|
113 |
+
text_elements = [] # type: List[cl.Text]
|
114 |
+
|
115 |
+
resp_msg = resp["response"].content
|
116 |
+
|
117 |
+
#print(source_documents)
|
118 |
+
|
119 |
+
if source_documents:
|
120 |
+
for source_idx, source_doc in enumerate(source_documents):
|
121 |
+
source_name = f"source_{source_idx}"
|
122 |
+
|
123 |
+
# Create the text element referenced in the message
|
124 |
+
#text_elements.append(
|
125 |
+
# cl.Text(content=source_doc.page_content, name="{}".format(source_name), display="side")
|
126 |
+
#)
|
127 |
+
text_elements.append(
|
128 |
+
cl.Text(content=source_doc[0].page_content, name="{} (scr: {})".format(source_name, round(source_doc[1],2)), display="side")
|
129 |
+
)
|
130 |
+
source_names = [text_el.name for text_el in text_elements]
|
131 |
+
|
132 |
+
if source_names:
|
133 |
+
resp_msg += f"\n\nSources: {', '.join(source_names)}"
|
134 |
+
else:
|
135 |
+
resp_msg += "\nNo sources found"
|
136 |
+
|
137 |
+
msg = cl.Message(content=resp_msg, elements=text_elements)
|
138 |
+
|
139 |
+
#print(msg.content)
|
140 |
+
await msg.send()
|
141 |
+
|
142 |
+
|
143 |
+
"""async for chunk in msg.content:
|
144 |
+
|
145 |
+
if token := chunk.choices[0].delta.content or "":
|
146 |
+
await msg.stream_token(token)
|
147 |
+
|
148 |
+
await msg.update()"""
|
149 |
+
|
150 |
+
#async for chunk in chain:
|
151 |
+
# if token:=
|
chainlit.md
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Welcome to <img src="public/aims.png" width="40" height="40"/> Midterm App! 🚀🤖
|
2 |
+
|
3 |
+
This app is developed for **AI Engineering bootcamp!**
|
4 |
+
|
5 |
+
### Prompt approach:
|
6 |
+
|
7 |
+
**Classic:**
|
8 |
+
|
9 |
+
```
|
10 |
+
CONTEXT:
|
11 |
+
{{context}}
|
12 |
+
|
13 |
+
QUERY:
|
14 |
+
{{question}}
|
15 |
+
|
16 |
+
Answer questions only based on provided context and not your previous knowledge.
|
17 |
+
In your answer never mention phrases like Based on provided context, From the context etc.
|
18 |
+
If you don't know the answer say I don't know!
|
19 |
+
```
|
20 |
+
### My idea was to move away from this 101 RAG prompt apporach and have LLM give information regarding answer construction in a way - context vs. prior knowledge!
|
21 |
+
|
22 |
+
This is still not the best approach, and I would prefer to use more advanced evaluation tehnics, but it helps us understand hallucinations better.
|
23 |
+
Additionally, 4o model answers can understand this type of prompt, while GPT 3.5 gets confused between prior knowlege and provided context.
|
24 |
+
|
25 |
+
**My test approach:**
|
26 |
+
|
27 |
+
```
|
28 |
+
CONTEXT:
|
29 |
+
{{context}}
|
30 |
+
|
31 |
+
QUERY:
|
32 |
+
{{question}}
|
33 |
+
|
34 |
+
Answer questions only based on provided context and not your previous knowledge.
|
35 |
+
In your answer never mention phrases like Based on provided context, From the context etc.
|
36 |
+
If you don't know the answer say I don't know!
|
37 |
+
|
38 |
+
```
|
39 |
+
|
40 |
+
## Question answers 🔗
|
41 |
+
|
42 |
+
1. **"What is Airbnb's 'Description of Business'?"**
|
43 |
+
|
44 |
+
Airbnb, Inc. was incorporated in Delaware in June 2008 and is headquartered in San Francisco, California. The company operates a global platform for unique stays and experiences. Its marketplace model connects hosts and guests (collectively referred to as “customers”) online or through mobile devices to book spaces and experiences around the world.
|
45 |
+
|
46 |
+
CONTEXT CONFIDENCE: 100%
|
47 |
+
CONTEXT vs PRIOR: The answer is entirely based on the provided context.
|
48 |
+
|
49 |
+
Sources: source_0 (scr: 0.58), source_1 (scr: 0.47), source_2 (scr: 0.45), source_3 (scr: 0.45)
|
50 |
+
|
51 |
+
2. **"What was the total value of 'Cash and cash equivalents' as of December 31, 2023?"**
|
52 |
+
|
53 |
+
The total value of 'Cash and cash equivalents' as of December 31, 2023, was $12,667 million.
|
54 |
+
|
55 |
+
CONTEXT CONFIDENCE: 100%
|
56 |
+
CONTEXT vs PRIOR:
|
57 |
+
- From context: The specific figure of $12,667 million for 'Cash and cash equivalents' as of December 31, 2023.
|
58 |
+
- Prior knowledge: None used.
|
59 |
+
|
60 |
+
Sources: source_0 (scr: 0.57), source_1 (scr: 0.54), source_2 (scr: 0.54), source_3 (scr: 0.53)
|
61 |
+
|
62 |
+
|
63 |
+
3. **"What is the 'maximum number of shares to be sold under the 10b5-1 Trading plan' by Brian Chesky?"**
|
64 |
+
|
65 |
+
The maximum number of shares to be sold under the 10b5-1 Trading Plan by Brian Chesky is 1,146,000.
|
66 |
+
|
67 |
+
CONTEXT CONFIDENCE: 100%
|
68 |
+
CONTEXT vs PRIOR:
|
69 |
+
- Context: The maximum number of shares to be sold under the 10b5-1 Trading Plan by Brian Chesky.
|
70 |
+
- Prior: None.
|
71 |
+
|
72 |
+
Sources: source_0 (scr: 0.64), source_1 (scr: 0.47), source_2 (scr: 0.45), source_3 (scr: 0.44)
|
73 |
+
|
74 |
+
|
75 |
+
## Example: multiple context question
|
76 |
+
|
77 |
+
4. **In what club Luka Modrić plays and who is Brian Chesky?**
|
78 |
+
|
79 |
+
Luka Modrić plays for Real Madrid, a professional football club based in Madrid, Spain. Brian Chesky is the Chief Executive Officer (CEO) of Airbnb, Inc.
|
80 |
+
|
81 |
+
CONTEXT CONFIDENCE: 100% for Brian Chesky, 0% for Luka Modrić.
|
82 |
+
CONTEXT vs PRIOR:
|
83 |
+
- **Context**: Brian Chesky is identified as the CEO of Airbnb, Inc.
|
84 |
+
- **Prior Knowledge**: Luka Modrić plays for Real Madrid.
|
85 |
+
|
86 |
+
Sources: source_0 (scr: 0.36), source_1 (scr: 0.32), source_2 (scr: 0.32), source_3 (scr: 0.32)
|
87 |
+
|
data/airbnb_midterm.pdf
ADDED
Binary file (596 kB). View file
|
|
public/aims.png
ADDED
![]() |
public/airbnb.svg
ADDED
|
public/barfin.svg
ADDED
|
public/favicon.ico
ADDED
|
public/fund.svg
ADDED
|
public/light.svg
ADDED
|
public/logo_dark.png
ADDED
![]() |
public/logo_light.png
ADDED
![]() |
public/soccer.svg
ADDED
|
public/stylesheet.css
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
img {
|
2 |
+
max-height: 70px !important;
|
3 |
+
}
|
4 |
+
|
5 |
+
.css-8mm1u0 {
|
6 |
+
background-color: #288b8f !important;
|
7 |
+
color: white;
|
8 |
+
}
|
9 |
+
|
10 |
+
/* Hide the original text */
|
11 |
+
.css-pcmo6i {
|
12 |
+
visibility: hidden;
|
13 |
+
}
|
14 |
+
|
15 |
+
/* Insert new text using ::before */
|
16 |
+
.css-pcmo6i::before {
|
17 |
+
content: 'Built by Marko P.';
|
18 |
+
font-weight: bold;
|
19 |
+
color: #8ac7c7;
|
20 |
+
visibility: visible;
|
21 |
+
display: block;
|
22 |
+
}
|
23 |
+
|
24 |
+
svg[viewBox="0 0 1143 266"] {
|
25 |
+
visibility: hidden;
|
26 |
+
}
|
27 |
+
|
28 |
+
img[alt="watermark"] {
|
29 |
+
visibility: hidden;
|
30 |
+
}
|
31 |
+
|
32 |
+
.css-2y0hwe {
|
33 |
+
color: #49a3a3 !important;
|
34 |
+
}
|
35 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai==1.34.0
|
2 |
+
chainlit==0.7.700
|
3 |
+
langchain
|
4 |
+
langchain_openai
|
5 |
+
langchain_community
|
6 |
+
langchain_core
|
7 |
+
langchain_huggingface
|
8 |
+
langchain_text_splitters
|
9 |
+
python-dotenv==1.0.1
|
10 |
+
qdrant-client
|
11 |
+
pypdf
|
12 |
+
tiktoken
|
static/aims.png
ADDED
![]() |
utils/__pycache__/custom_retriver.cpython-311.pyc
ADDED
Binary file (6.22 kB). View file
|
|
utils/custom_retriver.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.vectorstores import VectorStoreRetriever, VectorStore
|
2 |
+
from langchain_community.vectorstores import Qdrant
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
|
5 |
+
from typing import List, Any
|
6 |
+
|
7 |
+
|
8 |
+
class CustomVectorStoreRetriever(VectorStoreRetriever):
|
9 |
+
"""Custom Retriever class that overrides the _get_relevant_documents method."""
|
10 |
+
|
11 |
+
def _get_relevant_documents(
|
12 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
13 |
+
) -> List[Document]:
|
14 |
+
if self.search_type == "similarity":
|
15 |
+
docs = self.vectorstore.similarity_search_with_score(query, **self.search_kwargs)
|
16 |
+
elif self.search_type == "similarity_score_threshold":
|
17 |
+
docs_and_similarities = (
|
18 |
+
self.vectorstore.similarity_search_with_relevance_scores(
|
19 |
+
query, **self.search_kwargs
|
20 |
+
)
|
21 |
+
)
|
22 |
+
docs = [doc for doc, _ in docs_and_similarities]
|
23 |
+
elif self.search_type == "mmr":
|
24 |
+
docs = self.vectorstore.max_marginal_relevance_search(
|
25 |
+
query, **self.search_kwargs
|
26 |
+
)
|
27 |
+
else:
|
28 |
+
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
29 |
+
|
30 |
+
# Custom logic for changing the output of the relevant documents
|
31 |
+
|
32 |
+
return docs
|
33 |
+
|
34 |
+
async def _aget_relevant_documents(
|
35 |
+
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
36 |
+
) -> List[Document]:
|
37 |
+
if self.search_type == "similarity":
|
38 |
+
docs = await self.vectorstore.asimilarity_search_with_score(
|
39 |
+
query, **self.search_kwargs
|
40 |
+
)
|
41 |
+
elif self.search_type == "similarity_score_threshold":
|
42 |
+
docs_and_similarities = (
|
43 |
+
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
44 |
+
query, **self.search_kwargs
|
45 |
+
)
|
46 |
+
)
|
47 |
+
docs = [doc for doc, _ in docs_and_similarities]
|
48 |
+
elif self.search_type == "mmr":
|
49 |
+
docs = await self.vectorstore.amax_marginal_relevance_search(
|
50 |
+
query, **self.search_kwargs
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
54 |
+
return docs
|
55 |
+
|
56 |
+
def as_retriever(self, **kwargs: Any) -> CustomVectorStoreRetriever:
|
57 |
+
"""Return VectorStoreRetriever initialized from this VectorStore.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
search_type (Optional[str]): Defines the type of search that
|
61 |
+
the Retriever should perform.
|
62 |
+
Can be "similarity" (default), "mmr", or
|
63 |
+
"similarity_score_threshold".
|
64 |
+
search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
65 |
+
search function. Can include things like:
|
66 |
+
k: Amount of documents to return (Default: 4)
|
67 |
+
score_threshold: Minimum relevance threshold
|
68 |
+
for similarity_score_threshold
|
69 |
+
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
|
70 |
+
lambda_mult: Diversity of results returned by MMR;
|
71 |
+
1 for minimum diversity and 0 for maximum. (Default: 0.5)
|
72 |
+
filter: Filter by document metadata
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
VectorStoreRetriever: Retriever class for VectorStore.
|
76 |
+
|
77 |
+
Examples:
|
78 |
+
|
79 |
+
.. code-block:: python
|
80 |
+
|
81 |
+
# Retrieve more documents with higher diversity
|
82 |
+
# Useful if your dataset has many similar documents
|
83 |
+
docsearch.as_retriever(
|
84 |
+
search_type="mmr",
|
85 |
+
search_kwargs={'k': 6, 'lambda_mult': 0.25}
|
86 |
+
)
|
87 |
+
|
88 |
+
# Fetch more documents for the MMR algorithm to consider
|
89 |
+
# But only return the top 5
|
90 |
+
docsearch.as_retriever(
|
91 |
+
search_type="mmr",
|
92 |
+
search_kwargs={'k': 5, 'fetch_k': 50}
|
93 |
+
)
|
94 |
+
|
95 |
+
# Only retrieve documents that have a relevance score
|
96 |
+
# Above a certain threshold
|
97 |
+
docsearch.as_retriever(
|
98 |
+
search_type="similarity_score_threshold",
|
99 |
+
search_kwargs={'score_threshold': 0.8}
|
100 |
+
)
|
101 |
+
|
102 |
+
# Only get the single most similar document from the dataset
|
103 |
+
docsearch.as_retriever(search_kwargs={'k': 1})
|
104 |
+
|
105 |
+
# Use a filter to only retrieve documents from a specific paper
|
106 |
+
docsearch.as_retriever(
|
107 |
+
search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
|
108 |
+
)
|
109 |
+
"""
|
110 |
+
tags = kwargs.pop("tags", None) or [] + self._get_retriever_tags()
|
111 |
+
return CustomVectorStoreRetriever(vectorstore=self, tags=tags, **kwargs)
|
112 |
+
|
113 |
+
class CustomQDrant(Qdrant):
|
114 |
+
pass
|
115 |
+
|
116 |
+
CustomQDrant.as_retriever=as_retriever
|