Maksimov-Dmitry commited on
Commit
d1a829e
·
1 Parent(s): eb025bc
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.sqlite filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src import streamlit_utils
2
+ from src.prompts import AGENT_SYSTEM_PROMPT, AGENT_USER_PROMPT, RAG_USER_PROMPT, TRAVERSIALAI_USER_PROMPT
3
+ from src.retriever import Retriever
4
+
5
+ import streamlit as st
6
+
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9
+ from langchain.memory import ChatMessageHistory
10
+ import re
11
+ import requests
12
+ import os
13
+ from qdrant_client import QdrantClient
14
+
15
+ collection_name = 'hotels'
16
+
17
+ st.set_page_config(page_title="Hotels search chatbot", page_icon="⭐")
18
+ st.header('Hotels search chatbot')
19
+ st.write('[![view source code and description](https://img.shields.io/badge/view_source_code-gray?logo=github)](https://github.com/Maksimov-Dmitry/traversaal-ai-hackathon)')
20
+ st.write('Developed by [Dmitry Maksimov](https://www.linkedin.com/in/maksimov-dmitry/), [email protected] and [Ilya Dudnik](https://www.linkedin.com/in/ilia-dudnik-5b8018271/), [email protected]')
21
+
22
+ st.sidebar.header('Choose your preferences')
23
+ n_hotels = st.sidebar.number_input('Number of hotels', min_value=1, max_value=10, value=3)
24
+
25
+
26
+ @st.cache_resource
27
+ def get_db_client(path='data/db'):
28
+ client = QdrantClient(path=path)
29
+ return client
30
+
31
+
32
+ def add_new_info(chat_history, queries):
33
+ """After the user has changed any parameters (city, price, rating), we notify the Agent about it.
34
+ The information is added to the chat history.
35
+ Args:
36
+ chat_history: history of the chat
37
+ queries (list): list of queries that the user has changed
38
+ """
39
+ for query in queries:
40
+ chat_history.add_user_message(query)
41
+ chat_history.add_ai_message('Ok, got it!')
42
+
43
+
44
+ def check_params(params):
45
+ """Check if the user has changed the parameters (city, price, rating).
46
+ If the user has changed the parameters, the corresponding queries are created.
47
+
48
+ Args:
49
+ params (dict): dictionary with the parameters
50
+
51
+ Returns:
52
+ list: list of queries that the user has changed
53
+ """
54
+ changed_params = []
55
+
56
+ if 'prev_params' not in st.session_state:
57
+ st.session_state.prev_params = {'city': '<BLANK>', 'price': '<BLANK>', 'rating': '<BLANK>'}
58
+
59
+ if st.session_state.prev_params['city'] != params['city']:
60
+ changed_params.append(f'I want to find hotels in {params["city"]}' if params['city'] else 'I want to find hotels in any city')
61
+
62
+ if st.session_state.prev_params['price'] != params['price']:
63
+ changed_params.append(f'I want to find hotels in price range {params["price"]}' if params['price'] else 'I want to find hotels in any price range')
64
+
65
+ if st.session_state.prev_params['rating'] != params['rating']:
66
+ changed_params.append(f'I want to find hotels with rating greater than {params["rating"]}')
67
+
68
+ st.session_state.prev_params = params
69
+
70
+ return changed_params
71
+
72
+
73
+ def get_parameters(db_client):
74
+ """Get the parameters from the user (city, price, rating),
75
+ The provided metadata (in case it was provided by the user) is used in the MixedRetrieval from Qdrant vector DB
76
+ """
77
+ points, _ = db_client.scroll(
78
+ collection_name=collection_name,
79
+ limit=1e9,
80
+ with_payload=True,
81
+ with_vectors=False,
82
+ )
83
+ cities = ['Doest not matter'] + list(set([point.payload['city'] for point in points]))
84
+ city = st.sidebar.selectbox('City', list(cities), index=0)
85
+ if city == 'Doest not matter':
86
+ city = None
87
+
88
+ prices = ['Doest not matter'] + list(set([point.payload['price'] for point in points]))
89
+ price = st.sidebar.selectbox('Price', list(prices), index=0)
90
+ if price == 'Doest not matter':
91
+ price = None
92
+
93
+ rating = st.sidebar.slider('Min hotel rating', min_value=.0, max_value=5.0, value=4.5, step=.5)
94
+ return dict(city=city, price=price, rating=rating)
95
+
96
+
97
+ class HotelsSearchChatbot:
98
+ """
99
+ This is the Agent class. It is responsible for the decision-making during conversation with the user.
100
+ Based on the user's query, the Agent decides which action to take and how to present result to the user.
101
+ """
102
+ def __init__(self, db_client):
103
+ streamlit_utils.configure_api_keys()
104
+
105
+ self.llm_model = "gpt-4-1106-preview"
106
+ self.temperature = 0.6
107
+
108
+ self.embeedings_model = "text-embedding-3-large"
109
+ self.rerank_model = 'rerank-multilingual-v2.0'
110
+
111
+ self.ares_api_key = os.environ.get("ARES_API_KEY")
112
+ self.db_client = db_client
113
+
114
+ def _traversialai(self, query):
115
+ """Acquiring information from the internet using the Traversaal.ai.
116
+
117
+ Args:
118
+ query (str): search query
119
+
120
+ Returns:
121
+ str: information from the internet based on the query
122
+ """
123
+ url = "https://api-ares.traversaal.ai/live/predict"
124
+
125
+ payload = {"query": [query]}
126
+ headers = {
127
+ "x-api-key": self.ares_api_key,
128
+ "content-type": "application/json"
129
+ }
130
+
131
+ response = requests.post(url, json=payload, headers=headers)
132
+ try:
133
+ return response.json()['data']['response_text']
134
+ except:
135
+ return None
136
+
137
+ def _get_action(self, text):
138
+ """Parse (read) the action and the action input from the response of the Agent
139
+ (after he made a decision what to do).
140
+ 'action' and 'action_input' indicate whether we need to query additional tools
141
+ (vector DB, Traversaal AI) and how.
142
+
143
+ Args:
144
+ text (str): response of the Agent, which contains the action and the action input
145
+
146
+ Returns:
147
+ tuple: action, action input
148
+ """
149
+ action_pattern = r"Action:\s*(.*)\n"
150
+ action_input_pattern = r"Action Input:\s*(.*)"
151
+
152
+ action_match = re.search(action_pattern, text)
153
+ action_input_match = re.search(action_input_pattern, text)
154
+
155
+ action = action_match.group(1) if action_match else None
156
+ action_input = action_input_match.group(1) if action_input_match else None
157
+ return action, action_input
158
+
159
+ def _make_action(self, action, action_input, retriever, chain, chat_history, config, retriever_params):
160
+ """Take the action corresponding to 'action' and 'action input'. The 'action' can be one of the following:
161
+ 'nothing' - Agent is capable of dealing on its own without use of additional tools,
162
+ 'hotels_data_base' - Agent decides to get the information from the hotels vector DB,
163
+ 'ares_api' - Agent requires additional information from the internet using the Traversaal.ai.
164
+
165
+ Args:
166
+ action (str): action to make
167
+ action_input (str): action input (formulated by Agent search query)
168
+ retriever (Retriever): Retriever object
169
+ chain (Chain): Chain object
170
+ chat_history (ChatMessageHistory): history of the chat
171
+ config (dict): handlers for a LangChain invoke method
172
+ retriever_params (dict): parameters for the Retriever
173
+ """
174
+ if action == 'nothing':
175
+ st.markdown(action_input)
176
+ return action_input
177
+
178
+ if action == 'hotels_data_base':
179
+ context = retriever(action_input, top_k=n_hotels, **retriever_params)
180
+ chat_history.add_user_message(RAG_USER_PROMPT.format(context=context, query=action_input))
181
+ response = chain.invoke({"messages": chat_history.messages}, config)
182
+ chat_history.messages.pop()
183
+ return response.content
184
+
185
+ if action == 'ares_api':
186
+ context = self._traversialai(action_input)
187
+ chat_history.add_user_message(TRAVERSIALAI_USER_PROMPT.format(context=context, query=action_input))
188
+ response = chain.invoke({"messages": chat_history.messages}, config)
189
+ chat_history.messages.pop()
190
+ return response.content
191
+
192
+ return None
193
+
194
+ @st.cache_resource
195
+ def setup_chain(_self):
196
+ retriever = Retriever(embedding_model=_self.embeedings_model, llm_model=_self.llm_model,
197
+ rerank_model=_self.rerank_model, db_client=_self.db_client, db_collection=collection_name)
198
+
199
+ chat_history = ChatMessageHistory()
200
+ prompt = ChatPromptTemplate.from_messages(
201
+ [
202
+ (
203
+ "system",
204
+ AGENT_SYSTEM_PROMPT,
205
+ ),
206
+ MessagesPlaceholder(variable_name="messages"),
207
+ ]
208
+ )
209
+ chat = ChatOpenAI(model=_self.llm_model, temperature=_self.temperature, streaming=True)
210
+ chain = prompt | chat
211
+
212
+ return chain, chat_history, retriever
213
+
214
+ @streamlit_utils.enable_chat_history
215
+ def main(self, params):
216
+ chain, chat_history, retriever = self.setup_chain()
217
+ user_query = st.chat_input(placeholder="Ask me anything!")
218
+ if user_query:
219
+ streamlit_utils.display_msg(user_query, 'user')
220
+
221
+ # add new info to the chat history
222
+ queries = check_params(params)
223
+ add_new_info(chat_history, queries)
224
+
225
+ # get the action and the action input based on the user's query
226
+ chat_history.add_user_message(AGENT_USER_PROMPT.format(input=user_query))
227
+ action_response = chain.invoke({"messages": chat_history.messages})
228
+ chat_history.messages.pop()
229
+ action, action_input = self._get_action(action_response.content)
230
+
231
+ with st.chat_message("assistant"):
232
+ st_cb = streamlit_utils.StreamHandler(st.empty())
233
+
234
+ # create response on the user's query
235
+ response = self._make_action(action, action_input,
236
+ retriever, chain, chat_history, {"callbacks": [st_cb]}, params)
237
+ chat_history.add_user_message(user_query)
238
+ if response is None:
239
+ response = 'Sorry, I cannot help you with it. Could you rephrase your question?'
240
+ st.markdown(response)
241
+
242
+ chat_history.add_ai_message(response)
243
+ st.session_state.messages.append({"role": "assistant", "content": response})
244
+
245
+
246
+ if __name__ == "__main__":
247
+ db_client = get_db_client()
248
+ params = get_parameters(db_client)
249
+ obj = HotelsSearchChatbot(db_client)
250
+ obj.main(params)
data/db/.lock ADDED
@@ -0,0 +1 @@
 
 
1
+ tmp lock file
data/db/collection/hotels/storage.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:deb2004afca01078aacc8036779b783e47f7f9c52d440a517b32eb81b892af97
3
+ size 4726784
data/db/meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"collections": {"hotels": {"vectors": {"size": 3072, "distance": "Cosine", "hnsw_config": null, "quantization_config": null, "on_disk": null}, "shard_number": null, "sharding_method": null, "replication_factor": null, "write_consistency_factor": null, "on_disk_payload": null, "hnsw_config": null, "wal_config": null, "optimizers_config": null, "init_from": null, "quantization_config": null, "sparse_vectors": null}}, "aliases": {}}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ langchain-openai
4
+ qdrant-client
5
+ openai
6
+ cohere
src/__pycache__/prompts.cpython-310.pyc ADDED
Binary file (3.57 kB). View file
 
src/__pycache__/prompts.cpython-39.pyc ADDED
Binary file (3.36 kB). View file
 
src/__pycache__/rag.cpython-310.pyc ADDED
Binary file (2.6 kB). View file
 
src/__pycache__/rag.cpython-39.pyc ADDED
Binary file (3.07 kB). View file
 
src/__pycache__/retriever.cpython-310.pyc ADDED
Binary file (3.93 kB). View file
 
src/__pycache__/streamlit_utils.cpython-310.pyc ADDED
Binary file (2.54 kB). View file
 
src/__pycache__/streamlit_utils.cpython-39.pyc ADDED
Binary file (1.69 kB). View file
 
src/create_vector_db.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ from qdrant_client import QdrantClient, models
3
+ from openai import OpenAI
4
+ from tqdm import tqdm
5
+ import json
6
+ import requests
7
+ import os
8
+ from prompts import REVIEWS_SYSTEM_PROMPT, REVIEWS_USER_PROMPT
9
+
10
+ TRIPADVISOR_API_KEY = os.environ.get('TRIPADVISOR_API_KEY')
11
+
12
+
13
+ def save_json(data, path):
14
+ with open(path, "w") as outfile:
15
+ json.dump(data, outfile)
16
+
17
+
18
+ def get_df(dataset_path, is_hf):
19
+ if is_hf:
20
+ from datasets import load_dataset
21
+ dataset = load_dataset(dataset_path)
22
+ return dataset['train'].to_pandas()
23
+ else:
24
+ import pandas as pd
25
+ return pd.read_csv(dataset_path)
26
+
27
+
28
+ def _concat_reviews(df):
29
+ text = ''
30
+ for _, row in df.iterrows():
31
+ text += '\n'
32
+ if row.review_title:
33
+ text += '\nTitle:\n' + row.review_title
34
+ if row.review_text:
35
+ text += '\nReview:\n' + row.review_text
36
+
37
+ return text
38
+
39
+
40
+ def create_reviews_symmary(df, model, hotels, pos_rate=4.0, neg_rate=4.0, n_reviews=6):
41
+ """Create a summary of reviews for each hotel, based on the most positive and most negative reviews.
42
+
43
+ Args:
44
+ df (pd.DataFrame): hotels dataset
45
+ model (str): OpenAI model name
46
+ hotels (list): list of hotels to create summaries for
47
+ pos_rate (float): minimum positive rate, inclusive
48
+ neg_rate (float): maximum negative rate, exclusive
49
+ n_reviews (int): number of reviews to consider for each category
50
+
51
+ Returns:
52
+ dict: hotel name -> reviews summary
53
+ """
54
+ df['review_text_len'] = df.review_text.str.len().fillna(value=0)
55
+ df['review_title_len'] = df.review_title.str.len().fillna(value=0)
56
+
57
+ client = OpenAI()
58
+ hotels_reviews_summary = {}
59
+ for hotel in tqdm(hotels):
60
+ temp = df[df.hotel_name.eq(hotel)]
61
+ temp_pos = temp[temp.rate >= pos_rate].nlargest(n_reviews, 'review_text_len')
62
+ temp_neg = temp[temp.rate < neg_rate].nlargest(n_reviews, 'review_text_len')
63
+ if len(temp_pos) == 0 and len(temp_neg) == 0:
64
+ temp_pos = temp.nlargest(n_reviews, 'review_title_len')
65
+
66
+ text = _concat_reviews(temp_pos) + _concat_reviews(temp_neg)
67
+
68
+ if text:
69
+ response = client.chat.completions.create(
70
+ model=model,
71
+ messages=[
72
+ {"role": "system", "content": REVIEWS_SYSTEM_PROMPT},
73
+ {"role": "user", "content": REVIEWS_USER_PROMPT.format(text=text)},
74
+ ]
75
+ )
76
+ hotels_reviews_summary[hotel] = response.choices[0].message.content
77
+ return hotels_reviews_summary
78
+
79
+
80
+ def _get_loc_id(hotel):
81
+ """ Given a hotel name, receive location id.
82
+ In order to get the hotel info, we need to get the location id first.
83
+
84
+ Args:
85
+ hotel (str): hotel name
86
+
87
+ Returns:
88
+ str: location id
89
+ """
90
+ url = "https://api.content.tripadvisor.com/api/v1/location/search?key={key}&searchQuery={hotel}&category=hotels&language=en"
91
+ headers = {"accept": "application/json"}
92
+
93
+ response = requests.get(url.format(hotel=hotel, key=TRIPADVISOR_API_KEY), headers=headers)
94
+ try:
95
+ return response.json()['data'][0]['location_id']
96
+ except Exception as e:
97
+ print(f'{response.status_code=}')
98
+ print(f'{response.text=}')
99
+ print(f'Error: {e}')
100
+ return None
101
+
102
+
103
+ def get_hotel_info(hotel):
104
+ """Get hotel info from TripAdvisor.
105
+ The following information is retrieved using the TripAdvisor API:
106
+ - rank
107
+ - ratings distributions
108
+ - subratings
109
+ - amenities
110
+
111
+ Args:
112
+ hotel (str): hotel name
113
+
114
+ Returns:
115
+ dict: hotel info
116
+ """
117
+ url = "https://api.content.tripadvisor.com/api/v1/location/{loc_id}/details?key={key}&language=en&currency=USD"
118
+ headers = {"accept": "application/json"}
119
+
120
+ loc_id = _get_loc_id(hotel)
121
+ if loc_id is None:
122
+ return None
123
+ response = requests.get(url.format(loc_id=loc_id, key=TRIPADVISOR_API_KEY), headers=headers)
124
+ try:
125
+ response = response.json()
126
+ except Exception as e:
127
+ print(f'{response.status_code=}')
128
+ print(f'{response.text=}')
129
+ print(f'Error: {e}')
130
+ return None
131
+ rank = response['ranking_data'].get('ranking_string')
132
+ reviews_ratings = response.get('review_rating_count')
133
+ subratings = {}
134
+ for d in response['subratings']:
135
+ subratings[response['subratings'][d]['name']] = response['subratings'][d]['value']
136
+ amenities = response.get('amenities', [])
137
+ return dict(
138
+ rank=rank,
139
+ reviews_ratings=reviews_ratings,
140
+ subratings=subratings,
141
+ amenities=amenities,
142
+ )
143
+
144
+
145
+ def get_desc(hotel, data):
146
+ """Create a text description of the hotel based on the retrieved data from TripAdvisor.
147
+
148
+ Args:
149
+ hotel (str): hotel name
150
+ data (dict): hotel info
151
+
152
+ Returns:
153
+ str: hotel text description
154
+ """
155
+ rating = "Rating: "+str(data[hotel]['rank'])+". "
156
+
157
+ distr_ranks = "Rating distribution "
158
+ for key in data[hotel]['reviews_ratings'].keys():
159
+ distr_ranks += str(key) + ": " + str(data[hotel]['reviews_ratings'][key] + ", ")
160
+ distr_ranks = distr_ranks[:-2]+". "
161
+
162
+ sub_ranks = "Specific ratings: "
163
+ if 'rate_location' in data[hotel]['subratings'].keys():
164
+ sub_ranks += "Location " + data[hotel]['subratings']['rate_location'] + ", "
165
+
166
+ if 'rate_sleep' in data[hotel]['subratings'].keys():
167
+ sub_ranks += "Sleep " + data[hotel]['subratings']['rate_sleep'] + ", "
168
+ if 'rate_room' in data[hotel]['subratings'].keys():
169
+ sub_ranks += "Room " + data[hotel]['subratings']['rate_room'] + ", "
170
+ if 'rate_service' in data[hotel]['subratings'].keys():
171
+ sub_ranks += "Service " + data[hotel]['subratings']['rate_service'] + ", "
172
+ if 'rate_cleanliness' in data[hotel]['subratings'].keys():
173
+ sub_ranks += "Cleanliness " + data[hotel]['subratings']['rate_cleanliness']
174
+ sub_ranks += ". "
175
+
176
+ amenities = "Amenities available: "
177
+ for i in data[hotel]['amenities']:
178
+ amenities += str(i) + ", "
179
+ amenities = amenities[:-2] + "."
180
+
181
+ total_desc = rating + distr_ranks + sub_ranks + amenities
182
+ return total_desc
183
+
184
+
185
+ def get_payload(hotel, df):
186
+ """Create a metadata which will be stored in the database.
187
+
188
+ Args:
189
+ hotel (str): hotel name
190
+ df (pd.DataFrame): hotels dataset
191
+
192
+ Returns:
193
+ dict: metadata
194
+ """
195
+ temp = df[df.hotel_name.eq(hotel)]
196
+ rating = temp.rating_value.value_counts().index[0]
197
+ city = temp.locality.value_counts().index[0]
198
+ country = temp.country.value_counts().index[0]
199
+ price = temp.price_range.str.split(' ').str[0].value_counts().index[0]
200
+ return dict(
201
+ hotel_name=hotel,
202
+ rating=rating,
203
+ city=city,
204
+ country=country,
205
+ price=price
206
+ )
207
+
208
+
209
+ @click.command()
210
+ @click.option('--dataset-path', default='traversaal-ai-hackathon/hotel_datasets', help='Path to the dataset.')
211
+ @click.option('--is-hf', is_flag=True, default=True, help='Whether the dataset is in huggingface format, csv otherwise.')
212
+ @click.option('--db-path', default='data/db', help='Path to the output database.')
213
+ @click.option('--collection-name', default='hotels', help='Name of the collection in the database.')
214
+ @click.option('--embeddings-model', default='text-embedding-3-large', help='Name of the model to use for embeddings.')
215
+ @click.option('--embeddings-size', default=3072, help='Size of the embeddings.')
216
+ @click.option('--reviews-model', default='gpt-3.5-turbo-0125', help='Name of the model to use for reviews summary.')
217
+ def create_vector_db(dataset_path, is_hf, db_path, collection_name, embeddings_model, embeddings_size, reviews_model):
218
+ REVIEW_SUMMARIES_PATH = 'reviews_summary.json'
219
+ HOTELS_INFO_PATH = 'hotels_info.json'
220
+
221
+ df = get_df(dataset_path, is_hf)
222
+
223
+ # Create a collection if it does not exist and filter out hotels that are already in the collection
224
+ qdrant_client = QdrantClient(path=db_path)
225
+ if not qdrant_client.collection_exists(collection_name):
226
+ qdrant_client.create_collection(
227
+ collection_name=collection_name,
228
+ vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
229
+ )
230
+ hotels = df.hotel_name.unique()
231
+ else:
232
+ docs, _ = qdrant_client.scroll(
233
+ collection_name=collection_name,
234
+ limit=1e9,
235
+ with_payload=True,
236
+ with_vectors=False,
237
+ )
238
+ hotels = set(df.hotel_name.unique()) - set([doc.payload['hotel_name'] for doc in docs])
239
+ if len(hotels) == 0:
240
+ return
241
+
242
+ # Create reviews summary using OpenAI
243
+ reviews_summary = create_reviews_symmary(df, reviews_model, hotels)
244
+ save_json(reviews_summary, REVIEW_SUMMARIES_PATH)
245
+
246
+ # Get hotel info from TripAdvisor
247
+ hotels_info = {}
248
+ for hotel in tqdm(hotels):
249
+ hotels_info[hotel] = get_hotel_info(hotel)
250
+ save_json(hotels_info, HOTELS_INFO_PATH)
251
+
252
+ # Create descriptions and payloads for each hotel
253
+ texts = []
254
+ payloads = []
255
+ for hotel in hotels:
256
+ trip_desc_hotel = get_desc(hotel, hotels_info)
257
+ review_hotel = reviews_summary.get(hotel)
258
+ payload = get_payload(hotel, df)
259
+ text = trip_desc_hotel if trip_desc_hotel else '' + '\n' + review_hotel if review_hotel else ''
260
+ payload['description'] = text
261
+ payloads.append(payload)
262
+ texts.append(text)
263
+
264
+ # Create description embeddings and upsert them to the database
265
+ openai_client = OpenAI()
266
+ embeddings = openai_client.embeddings.create(input=texts, model=embeddings_model)
267
+ points = [
268
+ models.PointStruct(
269
+ id=idx,
270
+ vector=data.embedding,
271
+ payload=payload,
272
+ )
273
+ for idx, (data, payload) in enumerate(zip(embeddings.data, payloads))
274
+ ]
275
+ qdrant_client.upsert(collection_name, points)
276
+
277
+
278
+ if __name__ == '__main__':
279
+ create_vector_db()
src/prompts.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RAG_SYSTEM_PROMPT = "You are a helpful assistant, who recommends the hotels based only on my preferences."
2
+
3
+ RAG_CONTEXT_TEMPLATE = """
4
+ {id}: {hotel_name}
5
+ {description}
6
+ """
7
+
8
+ RAG_USER_PROMPT = """
9
+ Here are the information about most relevant hotels to my query
10
+ ---------------------
11
+ {context}
12
+ ---------------------
13
+ Present these results to me and justify the ranking (explain why a hotel matches my preferences). Don't draw ANY conclusion and don't based on own knowledge.
14
+ Query: {query}
15
+ Answer:
16
+ """
17
+
18
+ AGENT_USER_PROMPT = """
19
+ Answer the following question as best you can. You have access to the following tools:
20
+
21
+ hotels_data_base: A tool which present information about most relevant hotels based on the query. The information contains pros and cons of the hotel based on reviews, reviews ratings and ammenities. It is usefull when user want to get hotels recommendations. In this case Action Input should be query which will be complete and usefull to retrive the most relevant hotels.
22
+ ares_api: An API which performs real-time internet searches. It can be usefull than you need specific information about the hotel or the locataion or smth else from the internet. In this case Action Input should be query which will be complete and usefull to retrive the information from the Internet.
23
+ nothing: If you are sure you can answer the user's query without additional tools. In this case Action Input should be just an answer.
24
+
25
+ Use the following format:
26
+
27
+ Question: the input question you must answer
28
+ Thought: you should always think about what to do
29
+ Action: the action to take, should be one of [hotels_data_base, ares_api, nothing]
30
+ Action Input: the input to the action
31
+
32
+ Begin!
33
+
34
+ Question: {input}
35
+ Thought:
36
+ """
37
+
38
+ AGENT_SYSTEM_PROMPT = "You are a helpful assistant for a hotel recommendation system based on my preferences. Answer all questions to the best of your ability."
39
+
40
+ REVIEWS_SYSTEM_PROMPT = "You are a helpful assistant. Your goal is to underpin the strong and the weak points (features, amenities). If you can't find strong or weak points, don't write ANYTHING about them. The information consists of hotel reviews, i.e. Title of the review and the Review itself."
41
+ REVIEWS_USER_PROMPT = """{text} Good Example:
42
+ ### Strong Points:
43
+ - The hotel boasts a favorable location with sea views and proximity to Zeitinburnu train station.
44
+ - Upgraded rooms, fitness facilities, and the outdoor pool area are well-received.
45
+ - The staff, including specific individuals like Mr. Levent, Cihan, and Buse, have been commended for their service.
46
+ - Room cleanliness is frequently mentioned as a positive aspect.
47
+
48
+ ### Weak Points:
49
+ - Inconsistency in customer service, with some guests reporting a lack of assistance with luggage and unfriendly reception.
50
+ - Miscommunication regarding room rates and issues with overcharges.
51
+ - Some guests have found the hotel's amenities, such as the narrow balcony and the pool's restrictive rules, to be lacking.
52
+ - A few guests reported cleanliness issues in the bathroom and concerns with room repairs.
53
+ """
54
+
55
+ TRAVERSIALAI_USER_PROMPT = """
56
+ Based on the information retrived from the internet, answer the following question as best you can.
57
+ ---------------------
58
+ {context}
59
+ ---------------------
60
+ Query: {query}
61
+ Answer:
62
+ """
src/retriever.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import cohere
3
+ from qdrant_client import models
4
+ from src.prompts import RAG_CONTEXT_TEMPLATE
5
+
6
+
7
+ class Retriever:
8
+ """Retriever class for retrieving documents from the database
9
+ For retrieving documents, the following steps are performed:
10
+ 1. Create an embedding for the query
11
+ 2. Get n documents from the database based on the query and filters (Mixed retrieval)
12
+ 3. Rerank the documents based on the query and select top k documents, where k << n (ReRanking)
13
+ 4. Create a context from the selected documents
14
+ """
15
+ def __init__(self, embedding_model, llm_model, rerank_model, db_client, db_collection='hotels'):
16
+ self.db_collection = db_collection
17
+ self.db_client = db_client
18
+ self.rerank_model = rerank_model
19
+ self.openai_client = OpenAI()
20
+ self.co = cohere.Client()
21
+ self.embedding_model = embedding_model
22
+ self.llm_model = llm_model
23
+ self.max_retrieved_docs = 13
24
+
25
+ def _get_documents(self, query, top_k, city, price, rating):
26
+ """Retrieve top n documents from the database based on the query and filters
27
+
28
+ Args:
29
+ query (str): query
30
+ top_k (int): number of documents to retrieve
31
+ city (str): city name
32
+ price (str): price range
33
+ rating (float): rating
34
+
35
+ Returns:
36
+ list: list of documents
37
+ """
38
+ embedding = self.openai_client.embeddings.create(input=query, model=self.embedding_model)
39
+ filtr = []
40
+ if city:
41
+ filtr.append(models.FieldCondition(key="city", match=models.MatchValue(value=city)))
42
+ if price:
43
+ filtr.append(models.FieldCondition(key="price", match=models.MatchValue(value=price)))
44
+ if rating:
45
+ filtr.append(models.FieldCondition(key="rating", range=models.Range(gte=rating)))
46
+ response = self.db_client.search(
47
+ collection_name=self.db_collection,
48
+ query_vector=embedding.data[0].embedding,
49
+ limit=top_k,
50
+ query_filter=models.Filter(
51
+ must=filtr
52
+ ),
53
+ )
54
+ return response
55
+
56
+ def _get_context(self, docs):
57
+ """Create a context from the retrieved documents
58
+
59
+ Args:
60
+ docs (list): list of documents
61
+
62
+ Returns:
63
+ str: context
64
+ """
65
+ context = ''
66
+ for i, doc in enumerate(docs, 1):
67
+ context += RAG_CONTEXT_TEMPLATE.format(id=i, hotel_name=doc.payload['hotel_name'], description=doc.payload['description'])
68
+ return context
69
+
70
+ def _reranker(self, docs, query, top_k):
71
+ """Rerank the retrieved documents using Cohere based on the query and select top k documents
72
+
73
+ Args:
74
+ docs (list): list of documents
75
+ query (str): query
76
+ top_k (int): number of documents to select
77
+
78
+ Returns:
79
+ list: list of reranked documents
80
+ """
81
+ texts = [doc.payload['description'] for doc in docs]
82
+ rerank_hits = self.co.rerank(query=query, documents=texts, top_n=top_k, model=self.rerank_model)
83
+ result = [docs[hit.index] for hit in rerank_hits[:top_k]]
84
+ return result
85
+
86
+ def __call__(self, query, top_k=3, city=None, price=None, rating=None):
87
+ docs = self._get_documents(query, top_k=max(self.max_retrieved_docs, top_k), city=city, price=price, rating=rating)
88
+ if len(docs) == 0:
89
+ return 'There are no such hotels'
90
+ docs = self._reranker(docs, query, top_k)
91
+ context = self._get_context(docs)
92
+ return context
src/streamlit_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain.callbacks.base import BaseCallbackHandler
4
+
5
+
6
+ class StreamHandler(BaseCallbackHandler):
7
+ def __init__(self, container, initial_text=""):
8
+ self.container = container
9
+ self.text = initial_text
10
+
11
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
12
+ self.text += token
13
+ self.container.markdown(self.text)
14
+
15
+
16
+ def enable_chat_history(func):
17
+ if os.environ.get("OPENAI_API_KEY"):
18
+
19
+ # to clear chat history after swtching chatbot
20
+ current_page = func.__qualname__
21
+ if "current_page" not in st.session_state:
22
+ st.session_state["current_page"] = current_page
23
+ if st.session_state["current_page"] != current_page:
24
+ try:
25
+ st.cache_resource.clear()
26
+ del st.session_state["current_page"]
27
+ del st.session_state["messages"]
28
+ except:
29
+ pass
30
+
31
+ # to show chat history on ui
32
+ if "messages" not in st.session_state:
33
+ st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
34
+ for msg in st.session_state["messages"]:
35
+ st.chat_message(msg["role"]).write(msg["content"])
36
+
37
+ def execute(*args, **kwargs):
38
+ func(*args, **kwargs)
39
+ return execute
40
+
41
+
42
+ def display_msg(msg, author):
43
+ """Method to display message on the UI
44
+
45
+ Args:
46
+ msg (str): message to display
47
+ author (str): author of the message -user/assistant
48
+ """
49
+ st.session_state.messages.append({"role": author, "content": msg})
50
+ st.chat_message(author).write(msg)
51
+
52
+
53
+ def configure_api_keys():
54
+ KEYS = ['OPENAI_API_KEY', 'CO_API_KEY', 'ARES_API_KEY']
55
+ st.sidebar.header('Api Keys Configuration')
56
+ st.markdown(
57
+ """
58
+ <style>
59
+ [title="Show password text"] {
60
+ display: none;
61
+ }
62
+ </style>
63
+ """,
64
+ unsafe_allow_html=True,
65
+ )
66
+ for key in KEYS:
67
+ if key in os.environ:
68
+ st.session_state[key] = os.environ[key]
69
+ api_key = st.sidebar.text_input(
70
+ label=key,
71
+ type="password",
72
+ value=st.session_state[key] if key in st.session_state else '',
73
+ placeholder="..."
74
+ )
75
+ if api_key:
76
+ st.session_state[key] = api_key
77
+ os.environ[key] = api_key
78
+ else:
79
+ st.error(f"Please add your {key} to continue.")
80
+ st.stop()