Pash1986 commited on
Commit
255eac3
1 Parent(s): 2a9818f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +455 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from time import sleep
3
+ import json
4
+ from pymongo import MongoClient
5
+ from bson import ObjectId
6
+ from openai import OpenAI
7
+ import os
8
+ from PIL import Image
9
+ import time
10
+ import traceback
11
+ import asyncio
12
+ from langchain_community.vectorstores import MongoDBAtlasVectorSearch
13
+ from langchain_openai import OpenAIEmbeddings
14
+ from langchain_openai import ChatOpenAI
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from langchain_core.output_parsers import StrOutputParser
17
+ import base64
18
+ import io
19
+ from reportlab.pdfgen import canvas
20
+ from reportlab.lib.pagesizes import letter
21
+ from reportlab.lib.utils import ImageReader
22
+ import boto3
23
+ import re
24
+
25
+ output_parser = StrOutputParser()
26
+
27
+ import json
28
+ import requests
29
+
30
+ openai_client = OpenAI()
31
+
32
+ def fetch_url_data(url):
33
+ try:
34
+ response = requests.get(url)
35
+ response.raise_for_status() # Raises an HTTPError if the HTTP request returned an unsuccessful status code
36
+ return response.text
37
+ except requests.RequestException as e:
38
+ return f"Error: {e}"
39
+
40
+
41
+ uri = os.environ.get('MONGODB_ATLAS_URI')
42
+ email = "[email protected]"
43
+ email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
44
+
45
+ # AWS Bedrock client setup
46
+ bedrock_runtime = boto3.client('bedrock-runtime',
47
+ aws_access_key_id=os.environ.get('AWS_ACCESS_KEY'),
48
+ aws_secret_access_key=os.environ.get('AWS_SECRET_KEY'),
49
+ region_name="us-east-1")
50
+
51
+
52
+
53
+ chatClient = MongoClient(uri)
54
+ db_name = 'sample_mflix'
55
+ collection_name = 'embedded_movies'
56
+ collection = chatClient[db_name][collection_name]
57
+
58
+
59
+ ## Chat RAG Functions
60
+ try:
61
+ vector_store = MongoDBAtlasVectorSearch(embedding=OpenAIEmbeddings(), collection=collection, index_name='vector_index', text_key='plot', embedding_key='plot_embedding')
62
+ llm = ChatOpenAI(temperature=0)
63
+ prompt = ChatPromptTemplate.from_messages([
64
+ ("system", "You are a movie recommendation engine which post a concise and short summary on relevant movies."),
65
+ ("user", "List of movies: {input}")
66
+ ])
67
+ chain = prompt | llm | output_parser
68
+
69
+ except:
70
+ #If open ai key is wrong
71
+ print ('Open AI key is wrong')
72
+ vector_store = None
73
+ print("An error occurred: \n" + error_message)
74
+
75
+ def get_movies(message, history):
76
+
77
+ try:
78
+ movies = vector_store.similarity_search(query=message, k=3, embedding_key='plot_embedding')
79
+ return_text = ''
80
+ for movie in movies:
81
+ return_text = return_text + 'Title : ' + movie.metadata['title'] + '\n------------\n' + 'Plot: ' + movie.page_content + '\n\n'
82
+
83
+ print_llm_text = chain.invoke({"input": return_text})
84
+
85
+ for i in range(len(print_llm_text)):
86
+ time.sleep(0.05)
87
+ yield "Found: " + "\n\n" + print_llm_text[: i+1]
88
+ except Exception as e:
89
+ error_message = traceback.format_exc()
90
+ print("An error occurred: \n" + error_message)
91
+ yield "Please clone the repo and add your open ai key as well as your MongoDB Atlas URI in the Secret Section of you Space\n OPENAI_API_KEY (your Open AI key) and MONGODB_ATLAS_CLUSTER_URI (0.0.0.0/0 whitelisted instance with Vector index created) \n\n For more information : https://mongodb.com/products/platform/atlas-vector-search"
92
+
93
+
94
+ ## Restaurant Advisor RAG Functions
95
+ def get_restaurants(search, location, meters):
96
+
97
+ try:
98
+
99
+ client = MongoClient(uri)
100
+ db_name = 'whatscooking'
101
+ collection_name = 'restaurants'
102
+ restaurants_collection = client[db_name][collection_name]
103
+ trips_collection = client[db_name]['smart_trips']
104
+
105
+ except:
106
+ print("Error Connecting to the MongoDB Atlas Cluster")
107
+
108
+
109
+ # Pre aggregate restaurants collection based on chosen location and radius, the output is stored into
110
+ # trips_collection
111
+ try:
112
+ newTrip, pre_agg = pre_aggregate_meters(restaurants_collection, location, meters)
113
+
114
+ ## Get openai embeddings
115
+ response = openai_client.embeddings.create(
116
+ input=search,
117
+ model="text-embedding-3-small",
118
+ dimensions=256
119
+ )
120
+
121
+ ## prepare the similarity search on current trip
122
+ vectorQuery = {
123
+ "$vectorSearch": {
124
+ "index" : "vector_index",
125
+ "queryVector": response.data[0].embedding,
126
+ "path" : "embedding",
127
+ "numCandidates": 10,
128
+ "limit": 3,
129
+ "filter": {"searchTrip": newTrip}
130
+ }}
131
+
132
+ ## Run the retrieved documents through a RAG system.
133
+ restaurant_docs = list(trips_collection.aggregate([vectorQuery,
134
+ {"$project": {"_id" : 0, "embedding": 0}}]))
135
+
136
+
137
+ chat_response = openai_client.chat.completions.create(
138
+ model="gpt-3.5-turbo",
139
+ messages=[
140
+ {"role": "system", "content": "You are a helpful restaurant assistant. You will get a context if the context is not relevat to the user query please address that and not provide by default the restaurants as is."},
141
+ { "role": "user", "content": f"Find me the 2 best restaurant and why based on {search} and {restaurant_docs}. explain trades offs and why I should go to each one. You can mention the third option as a possible alternative."}
142
+ ]
143
+ )
144
+
145
+ ## Removed the temporary documents
146
+ trips_collection.delete_many({"searchTrip": newTrip})
147
+
148
+
149
+ if len(restaurant_docs) == 0:
150
+ return "No restaurants found", '<iframe style="background: #FFFFFF;border: none;border-radius: 2px;box-shadow: 0 2px 10px 0 rgba(70, 76, 79, .2);" width="640" height="480" src="https://charts.mongodb.com/charts-paveldev-wiumf/embed/charts?id=65c24b0c-2215-4e6f-829c-f484dfd8a90c&filter={\'restaurant_id\':\'\'}&maxDataAge=3600&theme=light&autoRefresh=true"></iframe>', str(pre_agg), str(vectorQuery)
151
+
152
+ ## Build the map filter
153
+ first_restaurant = restaurant_docs[0]['restaurant_id']
154
+ second_restaurant = restaurant_docs[1]['restaurant_id']
155
+ third_restaurant = restaurant_docs[2]['restaurant_id']
156
+ restaurant_string = f"'{first_restaurant}', '{second_restaurant}', '{third_restaurant}'"
157
+
158
+
159
+ iframe = '<iframe style="background: #FFFFFF;border: none;border-radius: 2px;box-shadow: 0 2px 10px 0 rgba(70, 76, 79, .2);" width="640" height="480" src="https://charts.mongodb.com/charts-paveldev-wiumf/embed/charts?id=65c24b0c-2215-4e6f-829c-f484dfd8a90c&filter={\'restaurant_id\':{$in:[' + restaurant_string + ']}}&maxDataAge=3600&theme=light&autoRefresh=true"></iframe>'
160
+ client.close()
161
+ return chat_response.choices[0].message.content, iframe,str(pre_agg), str(vectorQuery)
162
+ except Exception as e:
163
+ print(e)
164
+ return "Your query caused an error, please retry with allowed input only ...", '<iframe style="background: #FFFFFF;border: none;border-radius: 2px;box-shadow: 0 2px 10px 0 rgba(70, 76, 79, .2);" width="640" height="480" src="https://charts.mongodb.com/charts-paveldev-wiumf/embed/charts?id=65c24b0c-2215-4e6f-829c-f484dfd8a90c&filter={\'restaurant_id\':\'\'}&maxDataAge=3600&theme=light&autoRefresh=true"></iframe>', str(pre_agg), str(vectorQuery)
165
+
166
+
167
+ def pre_aggregate_meters(restaurants_collection, location, meters):
168
+
169
+ ## Do the geo location preaggregate and assign the search trip id.
170
+ tripId = ObjectId()
171
+ pre_aggregate_pipeline = [{
172
+ "$geoNear": {
173
+ "near": location,
174
+ "distanceField": "distance",
175
+ "maxDistance": meters,
176
+ "spherical": True,
177
+ },
178
+ },
179
+ {
180
+ "$addFields": {
181
+ "searchTrip" : tripId,
182
+ "date" : tripId.generation_time
183
+ }
184
+ },
185
+ {
186
+ "$merge": {
187
+ "into": "smart_trips"
188
+ }
189
+ } ]
190
+
191
+ result = restaurants_collection.aggregate(pre_aggregate_pipeline);
192
+
193
+ sleep(3)
194
+
195
+ return tripId, pre_aggregate_pipeline
196
+
197
+ ## Celeb Matcher RAG Functions
198
+ def construct_bedrock_body(base64_string, text):
199
+ if text:
200
+ return json.dumps({
201
+ "inputImage": base64_string,
202
+ "embeddingConfig": {"outputEmbeddingLength": 1024},
203
+ "inputText": text
204
+ })
205
+ return json.dumps({
206
+ "inputImage": base64_string,
207
+ "embeddingConfig": {"outputEmbeddingLength": 1024},
208
+ })
209
+
210
+ # Function to get the embedding from Bedrock model
211
+ def get_embedding_from_titan_multimodal(body):
212
+ response = bedrock_runtime.invoke_model(
213
+ body=body,
214
+ modelId="amazon.titan-embed-image-v1",
215
+ accept="application/json",
216
+ contentType="application/json",
217
+ )
218
+ response_body = json.loads(response.get("body").read())
219
+ return response_body["embedding"]
220
+
221
+ # MongoDB setup
222
+ uri = os.environ.get('MONGODB_ATLAS_URI')
223
+ client = MongoClient(uri)
224
+ db_name = 'celebrity_1000_embeddings'
225
+ collection_name = 'celeb_images'
226
+ celeb_images = client[db_name][collection_name]
227
+
228
+ participants_db = client[db_name]['participants']
229
+
230
+ # Function to record participant details
231
+ def record_participant(email, company, description, images):
232
+ if not email or not company:
233
+ ## regex to validate email
234
+ if not re.match(email_pattern, email):
235
+ raise gr.Error("Please enter a valid email address")
236
+
237
+ raise gr.Error("Please enter your email and company name to record the participant details.")
238
+ if not images:
239
+ raise gr.Error("Please search for an image first before recording the participant.")
240
+
241
+ participant_data = {'email': email, 'company': company}
242
+ participants_db.insert_one(participant_data)
243
+
244
+ # Create PDF after recording participant
245
+ pdf_file = create_pdf(images, description, email, company)
246
+ return pdf_file
247
+
248
+ def create_pdf(images, description, email, company):
249
+ filename = f"image_search_results_{email}.pdf"
250
+ c = canvas.Canvas(filename, pagesize=letter)
251
+ width, height = letter
252
+ y_position = height
253
+
254
+ c.drawString(50, y_position - 30, f"Thanks for participating, {email}! Here are your celeb match results:")
255
+
256
+ c.drawString(50, y_position - 70, "Claude 3 summary of the MongoDB celeb comparison:")
257
+
258
+ # Split the description into words
259
+ words = description.split()
260
+
261
+ # Initialize variables
262
+ lines = []
263
+ current_line = []
264
+
265
+ # Iterate through words and group them into lines
266
+ for word in words:
267
+ current_line.append(word)
268
+ if len(current_line) == 10: # Split every 10 words
269
+ lines.append(" ".join(current_line))
270
+ current_line = []
271
+
272
+ # Add the remaining words to the last line
273
+ if current_line:
274
+ lines.append(" ".join(current_line))
275
+
276
+ # Write each line of the description
277
+ y_position -= 90 # Initial Y position
278
+ for line in lines:
279
+ c.drawString(50, y_position, line)
280
+ y_position -= 15 # Adjust for line spacing
281
+
282
+ for image in images:
283
+ y_position -= 300 # Adjust this based on your image sizes
284
+ if y_position <= 150:
285
+ c.showPage()
286
+ y_position = height - 50
287
+
288
+ buffered = io.BytesIO()
289
+
290
+ pil_image = Image.open(image[1][0].image.path)
291
+ pil_image.save(buffered, format='JPEG')
292
+ c.drawImage(ImageReader(buffered), 50, y_position - 150, width=200, height=200)
293
+
294
+
295
+ c.save()
296
+
297
+
298
+ return filename
299
+
300
+
301
+ # Function to generate image description using Claude 3 Sonnet
302
+ def generate_image_description_with_claude(images_base64_strs, image_base64):
303
+ claude_body = json.dumps({
304
+ "anthropic_version": "bedrock-2023-05-31",
305
+ "max_tokens": 1000,
306
+ "system": "Please act as face comperison analyzer.",
307
+ "messages": [{
308
+ "role": "user",
309
+ "content": [
310
+ {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": image_base64}},
311
+ {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[0]}},
312
+ {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[1]}},
313
+ {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": images_base64_strs[2]}},
314
+ {"type": "text", "text": "Please let the user know how his first image is similar to the other 3 and which one is the most similar?"}
315
+ ]
316
+ }]
317
+ })
318
+
319
+ claude_response = bedrock_runtime.invoke_model(
320
+ body=claude_body,
321
+ modelId="anthropic.claude-3-sonnet-20240229-v1:0",
322
+ accept="application/json",
323
+ contentType="application/json",
324
+ )
325
+ response_body = json.loads(claude_response.get("body").read())
326
+ # Assuming the response contains a field 'content' with the description
327
+ return response_body["content"][0].get("text", "No description available")
328
+
329
+ # Main function to start image search
330
+ def start_image_search(image, text):
331
+ if not image:
332
+ raise gr.Error("Please upload an image first, make sure to press the 'Submit' button after selecting the image.")
333
+ buffered = io.BytesIO()
334
+ image = image.resize((800, 600))
335
+ image.save(buffered, format="JPEG", quality=85)
336
+ img_byte = buffered.getvalue()
337
+ img_base64 = base64.b64encode(img_byte)
338
+ img_base64_str = img_base64.decode('utf-8')
339
+ body = construct_bedrock_body(img_base64_str, text)
340
+ embedding = get_embedding_from_titan_multimodal(body)
341
+
342
+ doc = list(celeb_images.aggregate([
343
+ {
344
+ "$vectorSearch": {
345
+ "index": "vector_index",
346
+ "path": "embeddings",
347
+ "queryVector": embedding,
348
+ "numCandidates": 15,
349
+ "limit": 3
350
+ }
351
+ }, {"$project": {"image": 1}}
352
+ ]))
353
+
354
+ images = []
355
+ images_base64_strs = []
356
+ for image_doc in doc:
357
+ pil_image = Image.open(io.BytesIO(base64.b64decode(image_doc['image'])))
358
+ img_byte = io.BytesIO()
359
+ pil_image.save(img_byte, format='JPEG')
360
+ img_base64 = base64.b64encode(img_byte.getvalue()).decode('utf-8')
361
+ images_base64_strs.append(img_base64)
362
+ images.append(pil_image)
363
+
364
+ description = generate_image_description_with_claude(images_base64_strs, img_base64_str)
365
+ return images, description
366
+
367
+
368
+ with gr.Blocks() as demo:
369
+
370
+ with gr.Tab("Chat RAG Demo"):
371
+ with gr.Tab("Demo"):
372
+ gr.ChatInterface(get_movies, examples=["What movies are scary?", "Find me a comedy", "Movies for kids"], title="Movies Atlas Vector Search",description="This small chat uses a similarity search to find relevant movies, it uses MongoDB Atlas Vector Search read more here: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-tutorial",submit_btn="Search").queue()
373
+ with gr.Tab("Code"):
374
+ gr.Code(label="Code", language="python", value=fetch_url_data('https://huggingface.co/spaces/MongoDB/MongoDB-Movie-Search/raw/main/app.py'))
375
+
376
+ with gr.Tab("Restaruant advisor RAG Demo"):
377
+ with gr.Tab("Demo"):
378
+ gr.Markdown(
379
+ """
380
+ # MongoDB's Vector Restaurant planner
381
+ Start typing below to see the results. You can search a specific cuisine for you and choose 3 predefined locations.
382
+ The radius specify the distance from the start search location. This space uses the dataset called [whatscooking.restaurants](https://huggingface.co/datasets/AIatMongoDB/whatscooking.restaurants)
383
+ """)
384
+
385
+ # Create the interface
386
+ gr.Interface(
387
+ get_restaurants,
388
+ [gr.Textbox(placeholder="What type of dinner are you looking for?"),
389
+ gr.Radio(choices=[
390
+ ("Timesquare Manhattan", {
391
+ "type": "Point",
392
+ "coordinates": [-73.98527039999999, 40.7589099]
393
+ }),
394
+ ("Westside Manhattan", {
395
+ "type": "Point",
396
+ "coordinates": [-74.013686, 40.701975]
397
+ }),
398
+ ("Downtown Manhattan", {
399
+ "type": "Point",
400
+ "coordinates": [-74.000468, 40.720777]
401
+ })
402
+ ], label="Location", info="What location you need?"),
403
+ gr.Slider(minimum=500, maximum=10000, randomize=False, step=5, label="Radius in meters")],
404
+ [gr.Textbox(label="MongoDB Vector Recommendations", placeholder="Results will be displayed here"), "html",
405
+ gr.Code(label="Pre-aggregate pipeline",language="json" ),
406
+ gr.Code(label="Vector Query", language="json")]
407
+ )
408
+ with gr.Tab("Code"):
409
+ gr.Code(label="Code", language="python", value=fetch_url_data('https://huggingface.co/spaces/MongoDB/whatscooking-advisor/raw/main/app.py'))
410
+
411
+ with gr.Tab("Celeb Matcher Demo"):
412
+ with gr.Tab("Demo"):
413
+ gr.Markdown("""
414
+ # MongoDB's Vector Celeb Image Matcher
415
+
416
+ Upload an image and find the most similar celeb image from the database, along with an AI-generated description.
417
+
418
+ 💪 Make a great pose to impact the search! 🤯
419
+ """)
420
+ with gr.Row():
421
+ with gr.Column():
422
+ image_input = gr.Image(type="pil", label="Upload an image")
423
+ text_input = gr.Textbox(label="Enter an adjustment to the image")
424
+ search_button = gr.Button("Search")
425
+
426
+
427
+ with gr.Column():
428
+ output_gallery = gr.Gallery(label="Located images", show_label=False, elem_id="gallery",
429
+ columns=[3], rows=[1], object_fit="contain", height="auto")
430
+ output_description = gr.Textbox(label="AI Based vision description")
431
+ gr.Markdown("""
432
+
433
+ """)
434
+ with gr.Row():
435
+ email_input = gr.Textbox(label="Enter your email")
436
+ company_input = gr.Textbox(label="Enter your company name")
437
+ record_button = gr.Button("Record & Download PDF")
438
+
439
+ search_button.click(
440
+ fn=start_image_search,
441
+ inputs=[image_input, text_input],
442
+ outputs=[output_gallery, output_description]
443
+ )
444
+
445
+ record_button.click(
446
+ fn=record_participant,
447
+ inputs=[email_input, company_input, output_description, output_gallery],
448
+ outputs=gr.File(label="Download Search Results as PDF")
449
+ )
450
+ with gr.Tab("Code"):
451
+ gr.Code(label="Code", language="python", value=fetch_url_data('https://huggingface.co/spaces/MongoDB/aws-bedrock-celeb-matcher/raw/main/app.py'))
452
+
453
+
454
+ if __name__ == "__main__":
455
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pymongo
2
+ huggingface_hub
3
+ langchain-community
4
+ langchain-core
5
+ langchain-openai
6
+ tiktoken
7
+ datasets