samarthagarwal23 commited on
Commit
1410c49
·
1 Parent(s): bad9a8f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import pickle
4
+ from sentence_transformers import SentenceTransformer, util
5
+
6
+ mdl_name = 'sentence-transformers/all-distilroberta-v1'
7
+ model = SentenceTransformer(mdl_name)
8
+
9
+ embedding_cache_path = ""
10
+ with open(embedding_cache_path, "rb") as fIn:
11
+ cache_data = pickle.load(fIn)
12
+
13
+
14
+ def user_query_recommend(query, min_p, max_p, embedding_table = cache_data["embeddings"], reviews = cache_data["data"]):
15
+ # Embed user query
16
+ embedding = model.encode(query)
17
+
18
+ # Calculate similarity with all reviews
19
+ sim_scores = util.cos_sim(embedding, embedding_table)
20
+ #print(sim_scores.shape)
21
+
22
+ # Recommend
23
+ recommendations = reviews.copy()
24
+ recommendations['price'] =recommendations.price.apply(lambda x: re.findall("\d+", x.replace(",","").replace(".00","").replace("$",""))[0]).astype('int')
25
+ recommendations['sim'] = sim_scores.T
26
+ recommendations = recommendations.sort_values('sim', ascending=False)
27
+ recommendations = recommendations.loc[(recommendations.price >= min_p) &
28
+ (recommendations.price <= max_p),
29
+ ['name', 'category', 'price', 'description', 'sim']]
30
+
31
+ return recommendations
32
+
33
+ interface = gr.Interface(
34
+ user_query_recommend,
35
+ inputs=[gr.inputs.Textbox(),
36
+ gr.inputs.Slider(minimum=1, maximum=100, default=30, label='Min Price'),
37
+ gr.inputs.Slider(minimum=1, maximum=1000, default=70, label='Max Price')],
38
+ outputs=[
39
+ gr.outputs.Textbox(label="Recommendations"),
40
+ ],
41
+ title = "Scotch Recommendation",
42
+ examples=["very sweet with lemons and oranges and marmalades", "smoky peaty earthy and spicy"],
43
+ theme="huggingface",
44
+ )
45
+
46
+ interface.launch(
47
+ enable_queue=True,
48
+ cache_examples=True,
49
+ )