import streamlit as st from time import sleep import json from pymongo import MongoClient from bson import ObjectId from openai import OpenAI openai_client = OpenAI() import os # Get the restaurants based on the search and location def get_restaurants(search, location, meters): try: uri = os.environ.get('MONGODB_ATLAS_URI') client = MongoClient(uri) db_name = 'whatscooking' collection_name = 'restaurants' restaurants_collection = client[db_name][collection_name] trips_collection = client[db_name]['smart_trips'] except: st.error("Error Connecting to the MongoDB Atlas Cluster") return None, None, None, None try: with st.status("Search data..."): newTrip, pre_agg = pre_aggregate_meters(restaurants_collection, location, meters) response = openai_client.embeddings.create( input=search, model="text-embedding-3-small", dimensions=256 ) vectorQuery = { "$vectorSearch": { "index": "vector_index", "queryVector": response.data[0].embedding, "path": "embedding", "numCandidates": 10, "limit": 3, "filter": {"searchTrip": newTrip} } } st.write("Vector query") restaurant_docs = list(trips_collection.aggregate([vectorQuery, {"$project": {"_id": 0, "embedding": 0}}])) st.write("RAG...") stream_response = openai_client.chat.completions.create( model="gpt-4o", messages=[ {"role": "system", "content": "You are a helpful restaurant assistant. Answer shortly and quickly. You will get a context if the context is not relevant to the user query please address that and not provide by default the restaurants as is."}, {"role": "user", "content": f"Find me the 2 best restaurant and why based on {search} and {restaurant_docs}. Shortly explain trades offs and why I should go to each one. You can mention the third option as a possible alternative in one sentence."} ], stream=True ) chat_response = st.write_stream(stream_response) trips_collection.delete_many({"searchTrip": newTrip}) if len(restaurant_docs) == 0: return "No restaurants found", '', str(pre_agg), str(vectorQuery) first_restaurant = restaurant_docs[0]['restaurant_id'] second_restaurant = restaurant_docs[1]['restaurant_id'] third_restaurant = restaurant_docs[2]['restaurant_id'] restaurant_string = f"'{first_restaurant}', '{second_restaurant}', '{third_restaurant}'" iframe = '' client.close() return chat_response, iframe, str(pre_agg), str(vectorQuery) except Exception as e: st.error(f"Your query caused an error: {e}") return "Your query caused an error, please retry with allowed input only ...", '', str(pre_agg), str(vectorQuery) def pre_aggregate_meters(restaurants_collection, location, meters): tripId = ObjectId() pre_aggregate_pipeline = [{ "$geoNear": { "near": location, "distanceField": "distance", "maxDistance": meters, "spherical": True, }, }, { "$addFields": { "searchTrip": tripId, "date": tripId.generation_time } }, { "$merge": { "into": "smart_trips" } }] result = restaurants_collection.aggregate(pre_aggregate_pipeline) #sleep(3) return tripId, pre_aggregate_pipeline st.markdown( """ # MongoDB's Vector Restaurant Planner Start typing below to see the results. You can search a specific cuisine for you and choose 3 predefined locations. The radius specifies the distance from the start search location. This space uses the dataset called [whatscooking.restaurants](https://huggingface.co/datasets/AIatMongoDB/whatscooking.restaurants) """ ) search = st.text_input("What type of dinner are you looking for?") location = st.radio("Location", options=[ {"label": "Timesquare Manhattan", "value": {"type": "Point", "coordinates": [-73.98527039999999, 40.7589099]}}, {"label": "Westside Manhattan", "value": {"type": "Point", "coordinates": [-74.013686, 40.701975]}}, {"label": "Downtown Manhattan", "value": {"type": "Point", "coordinates": [-74.000468, 40.720777]}} ], format_func=lambda x: x['label']) meters = st.slider("Radius in meters", min_value=500, max_value=10000, step=5) if st.button("Get Restaurants"): location_value = location['value'] result, iframe, pre_agg, vectorQuery = get_restaurants(search, location_value, meters) if result: st.subheader("Map") st.markdown(iframe, unsafe_allow_html=True) st.subheader("Geo pre aggregation") st.code(pre_agg) st.subheader("Vector query") st.code(vectorQuery)