import streamlit as st
from time import sleep
import json
from pymongo import MongoClient
from bson import ObjectId
from openai import OpenAI
openai_client = OpenAI()
import os
# Get the restaurants based on the search and location
def get_restaurants(search, location, meters):
try:
uri = os.environ.get('MONGODB_ATLAS_URI')
client = MongoClient(uri)
db_name = 'whatscooking'
collection_name = 'restaurants'
restaurants_collection = client[db_name][collection_name]
trips_collection = client[db_name]['smart_trips']
except:
st.error("Error Connecting to the MongoDB Atlas Cluster")
return None, None, None, None
try:
with st.status("Search data..."):
newTrip, pre_agg = pre_aggregate_meters(restaurants_collection, location, meters)
response = openai_client.embeddings.create(
input=search,
model="text-embedding-3-small",
dimensions=256
)
vectorQuery = {
"$vectorSearch": {
"index": "vector_index",
"queryVector": response.data[0].embedding,
"path": "embedding",
"numCandidates": 10,
"limit": 3,
"filter": {"searchTrip": newTrip}
}
}
st.write("Vector query")
restaurant_docs = list(trips_collection.aggregate([vectorQuery, {"$project": {"_id": 0, "embedding": 0}}]))
st.write("RAG...")
stream_response = openai_client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are a helpful restaurant assistant. Answer shortly and quickly. You will get a context if the context is not relevant to the user query please address that and not provide by default the restaurants as is."},
{"role": "user", "content": f"Find me the 2 best restaurant and why based on {search} and {restaurant_docs}. Shortly explain trades offs and why I should go to each one. You can mention the third option as a possible alternative in one sentence."}
],
stream=True
)
chat_response = st.write_stream(stream_response)
trips_collection.delete_many({"searchTrip": newTrip})
if len(restaurant_docs) == 0:
return "No restaurants found", '', str(pre_agg), str(vectorQuery)
first_restaurant = restaurant_docs[0]['restaurant_id']
second_restaurant = restaurant_docs[1]['restaurant_id']
third_restaurant = restaurant_docs[2]['restaurant_id']
restaurant_string = f"'{first_restaurant}', '{second_restaurant}', '{third_restaurant}'"
iframe = ''
client.close()
return chat_response, iframe, str(pre_agg), str(vectorQuery)
except Exception as e:
st.error(f"Your query caused an error: {e}")
return "Your query caused an error, please retry with allowed input only ...", '', str(pre_agg), str(vectorQuery)
def pre_aggregate_meters(restaurants_collection, location, meters):
tripId = ObjectId()
pre_aggregate_pipeline = [{
"$geoNear": {
"near": location,
"distanceField": "distance",
"maxDistance": meters,
"spherical": True,
},
}, {
"$addFields": {
"searchTrip": tripId,
"date": tripId.generation_time
}
}, {
"$merge": {
"into": "smart_trips"
}
}]
result = restaurants_collection.aggregate(pre_aggregate_pipeline)
#sleep(3)
return tripId, pre_aggregate_pipeline
st.markdown(
"""
# MongoDB's Vector Restaurant Planner
Start typing below to see the results. You can search a specific cuisine for you and choose 3 predefined locations.
The radius specifies the distance from the start search location. This space uses the dataset called [whatscooking.restaurants](https://huggingface.co/datasets/AIatMongoDB/whatscooking.restaurants)
"""
)
search = st.text_input("What type of dinner are you looking for?")
location = st.radio("Location", options=[
{"label": "Timesquare Manhattan", "value": {"type": "Point", "coordinates": [-73.98527039999999, 40.7589099]}},
{"label": "Westside Manhattan", "value": {"type": "Point", "coordinates": [-74.013686, 40.701975]}},
{"label": "Downtown Manhattan", "value": {"type": "Point", "coordinates": [-74.000468, 40.720777]}}
], format_func=lambda x: x['label'])
meters = st.slider("Radius in meters", min_value=500, max_value=10000, step=5)
if st.button("Get Restaurants"):
location_value = location['value']
result, iframe, pre_agg, vectorQuery = get_restaurants(search, location_value, meters)
if result:
st.subheader("Map")
st.markdown(iframe, unsafe_allow_html=True)
st.subheader("Geo pre aggregation")
st.code(pre_agg)
st.subheader("Vector query")
st.code(vectorQuery)