Zasha1 commited on
Commit
c17360f
·
verified ·
1 Parent(s): 1654127

Update product_recommender.py

Browse files
Files changed (1) hide show
  1. product_recommender.py +72 -19
product_recommender.py CHANGED
@@ -1,19 +1,72 @@
1
- import pandas as pd
2
- from sentence_transformers import SentenceTransformer
3
- import faiss
4
-
5
- class ProductRecommender:
6
- def __init__(self, product_data_path):
7
- self.data = pd.read_csv(product_data_path)
8
- self.model = SentenceTransformer('all-MiniLM-L6-v2')
9
- self.embeddings = self.model.encode(self.data['product_description'].tolist())
10
- self.index = faiss.IndexFlatL2(self.embeddings.shape[1])
11
- self.index.add(self.embeddings)
12
-
13
- def get_recommendations(self, query, top_n=5):
14
- query_embedding = self.model.encode([query])
15
- distances, indices = self.index.search(query_embedding, top_n)
16
- recommendations = []
17
- for i in indices[0]:
18
- recommendations.append(self.data.iloc[i]['product_title'] + ": " + self.data.iloc[i]['product_description'])
19
- return recommendations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sentence_transformers import SentenceTransformer
3
+ import faiss
4
+
5
+ class ProductRecommender:
6
+ def __init__(self, product_data_path):
7
+ try:
8
+ # Attempt to load the product data CSV
9
+ self.data = pd.read_csv(product_data_path)
10
+ print("Product data loaded successfully.")
11
+ except Exception as e:
12
+ print(f"Error loading product data: {e}")
13
+ self.data = pd.DataFrame() # Create an empty DataFrame if loading fails
14
+ return
15
+
16
+ try:
17
+ # Initialize the sentence transformer model
18
+ self.model = SentenceTransformer('all-MiniLM-L6-v2')
19
+ print("Model loaded successfully.")
20
+ except Exception as e:
21
+ print(f"Error loading SentenceTransformer model: {e}")
22
+ self.model = None # Set model to None if loading fails
23
+ return
24
+
25
+ try:
26
+ # Check if 'product_description' column exists
27
+ if 'product_description' not in self.data.columns:
28
+ print("Error: 'product_description' column is missing in the data.")
29
+ return
30
+
31
+ # Generate embeddings for the product descriptions
32
+ self.embeddings = self.model.encode(self.data['product_description'].tolist())
33
+ print(f"Embeddings generated successfully. Shape: {self.embeddings.shape}")
34
+ except Exception as e:
35
+ print(f"Error generating embeddings: {e}")
36
+ self.embeddings = None # Set embeddings to None if generation fails
37
+ return
38
+
39
+ try:
40
+ # Initialize FAISS index and add the embeddings
41
+ self.index = faiss.IndexFlatL2(self.embeddings.shape[1])
42
+ self.index.add(self.embeddings)
43
+ print("FAISS index created and embeddings added.")
44
+ except Exception as e:
45
+ print(f"Error creating FAISS index or adding embeddings: {e}")
46
+ self.index = None # Set index to None if creation fails
47
+ return
48
+
49
+ def get_recommendations(self, query, top_n=5):
50
+ if self.model is None or self.index is None:
51
+ print("Error: Model or FAISS index not initialized. Cannot make recommendations.")
52
+ return []
53
+
54
+ try:
55
+ # Generate the embedding for the query
56
+ query_embedding = self.model.encode([query])
57
+ print(f"Query embedding generated. Shape: {query_embedding.shape}")
58
+ except Exception as e:
59
+ print(f"Error generating query embedding: {e}")
60
+ return []
61
+
62
+ try:
63
+ # Search for top_n recommendations
64
+ distances, indices = self.index.search(query_embedding, top_n)
65
+ recommendations = []
66
+ for i in indices[0]:
67
+ recommendations.append(self.data.iloc[i]['product_title'] + ": " + self.data.iloc[i]['product_description'])
68
+ print(f"Recommendations generated successfully: {recommendations}")
69
+ return recommendations
70
+ except Exception as e:
71
+ print(f"Error during recommendation search: {e}")
72
+ return []