goodmodeler commited on
Commit
c99bc7a
·
1 Parent(s): c1c9e88

UPDATE: RAG process

Browse files
retrieval_augmented_generation/build_embeddings.py CHANGED
@@ -1,11 +1,261 @@
 
 
 
 
 
 
 
 
 
1
  from sentence_transformers import SentenceTransformer
2
- import faiss, json, glob, os, numpy as np
3
-
4
- model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
5
- texts=[]; vecs=[]
6
- for f in glob.glob("nyc_ads_dataset/*.json"):
7
- cap=json.load(open(f))["caption"]
8
- texts.append(cap); vecs.append(model.encode(cap,normalize_embeddings=True))
9
- vecs=np.vstack(vecs).astype("float32")
10
- index=faiss.IndexFlatIP(vecs.shape[1]); index.add(vecs)
11
- faiss.write_index(index,"prompt.index"); json.dump(texts,open("prompt.txt","w"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 使用BERT + FAISS构建产品描述和Slogan的嵌入数据库
4
+ 支持相似性搜索和检索
5
+ """
6
+
7
+ import faiss
8
+ import numpy as np
9
+ import pandas as pd
10
  from sentence_transformers import SentenceTransformer
11
+ from datasets import Dataset
12
+ import pickle
13
+ import json
14
+ from typing import List, Dict, Tuple
15
+ import os
16
+
17
+ class SloganEmbeddingDB:
18
+ def __init__(self, model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
19
+ """
20
+ 初始化BERT+FAISS数据库
21
+
22
+ Args:
23
+ model_name: 多语言BERT模型,支持中英文
24
+ """
25
+ print(f"📥 Loading BERT model: {model_name}")
26
+ self.model = SentenceTransformer(model_name)
27
+ self.dimension = self.model.get_sentence_embedding_dimension()
28
+
29
+ # 初始化FAISS索引
30
+ self.index = faiss.IndexFlatIP(self.dimension) # 内积相似度
31
+ self.data = [] # 存储原始数据
32
+
33
+ print(f"✅ Model loaded. Embedding dimension: {self.dimension}")
34
+
35
+ def create_sample_dataset(self) -> Dataset:
36
+ """创建示例数据集"""
37
+ sample_data = [
38
+ # 中文品牌
39
+ {"business": "肯德基", "category": "快餐", "description": "美式炸鸡快餐连锁", "slogan": "有了肯德基生活好滋味"},
40
+ {"business": "麦当劳", "category": "快餐", "description": "全球知名汉堡快餐", "slogan": "我就喜欢"},
41
+ {"business": "星巴克", "category": "咖啡", "description": "全球连锁咖啡店", "slogan": "启发并滋润人类精神"},
42
+ {"business": "小米", "category": "电子产品", "description": "智能手机和科技产品", "slogan": "让每个人都能享受科技的乐趣"},
43
+ {"business": "华为", "category": "电子产品", "description": "通信设备和智能手机", "slogan": "构建万物互联的智能世界"},
44
+
45
+ # 英文品牌
46
+ {"business": "Nike", "category": "运动用品", "description": "Athletic footwear and apparel", "slogan": "Just Do It"},
47
+ {"business": "Apple", "category": "科技", "description": "Consumer electronics and software", "slogan": "Think Different"},
48
+ {"business": "Coca-Cola", "category": "饮料", "description": "Carbonated soft drinks", "slogan": "Open Happiness"},
49
+ {"business": "BMW", "category": "汽车", "description": "Luxury automobiles", "slogan": "The Ultimate Driving Machine"},
50
+ {"business": "Amazon", "category": "电商", "description": "E-commerce and cloud services", "slogan": "Earth's Most Customer-Centric Company"},
51
+
52
+ # 产品描述
53
+ {"business": "智能手表", "category": "可穿戴设备", "description": "健康监测和通知功能的智能手表", "slogan": "时刻关注您的健康"},
54
+ {"business": "电动汽车", "category": "新能源汽车", "description": "零排放环保电动车", "slogan": "绿色出行,智享未来"},
55
+ {"business": "在线教育平台", "category": "教育科技", "description": "AI驱动的个性化学习平台", "slogan": "让学习更智能"},
56
+ {"business": "健身APP", "category": "健康应用", "description": "AI私教健身指导应用", "slogan": "随时随地,专业健身"},
57
+ {"business": "外卖平台", "category": "生活服务", "description": "快速便捷的餐食配送服务", "slogan": "美食到家,生活更美好"},
58
+ ]
59
+
60
+ return Dataset.from_pandas(pd.DataFrame(sample_data))
61
+
62
+ def build_embeddings(self, dataset: Dataset):
63
+ """构建嵌入向量并建立FAISS索引"""
64
+ print("🔨 Building embeddings and FAISS index...")
65
+
66
+ # 准备数据
67
+ texts = []
68
+ for item in dataset:
69
+ # 组合文本:业务名称 + 类别 + 描述
70
+ combined_text = f"{item['business']} {item['category']} {item['description']}"
71
+ texts.append(combined_text)
72
+
73
+ # 保存原始数据
74
+ self.data.append({
75
+ "business": item["business"],
76
+ "category": item["category"],
77
+ "description": item["description"],
78
+ "slogan": item["slogan"],
79
+ "combined_text": combined_text
80
+ })
81
+
82
+ # 生成嵌入向量
83
+ print(f"📊 Generating embeddings for {len(texts)} items...")
84
+ embeddings = self.model.encode(texts, show_progress_bar=True)
85
+
86
+ # 标准化向量(用于余弦相似度)
87
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
88
+
89
+ # 添加到FAISS索引
90
+ self.index.add(embeddings.astype('float32'))
91
+
92
+ print(f"✅ Built FAISS index with {self.index.ntotal} vectors")
93
+
94
+ def search_similar(self, query: str, top_k: int = 5) -> List[Dict]:
95
+ """搜索相似的业务描述"""
96
+ print(f"🔍 Searching for: '{query}'")
97
+
98
+ # 生成查询向量
99
+ query_embedding = self.model.encode([query])
100
+ query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)
101
+
102
+ # FAISS搜索
103
+ scores, indices = self.index.search(query_embedding.astype('float32'), top_k)
104
+
105
+ # 整理结果
106
+ results = []
107
+ for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
108
+ if idx < len(self.data):
109
+ result = self.data[idx].copy()
110
+ result["similarity_score"] = float(score)
111
+ result["rank"] = i + 1
112
+ results.append(result)
113
+
114
+ return results
115
+
116
+ def save_database(self, save_path: str = "./slogan_db"):
117
+ """保存数据库"""
118
+ os.makedirs(save_path, exist_ok=True)
119
+
120
+ # 保存FAISS索引
121
+ faiss.write_index(self.index, f"{save_path}/faiss.index")
122
+
123
+ # 保存数据
124
+ with open(f"{save_path}/data.pkl", "wb") as f:
125
+ pickle.dump(self.data, f)
126
+
127
+ # 保存配置
128
+ config = {
129
+ "model_name": self.model._modules['0'].auto_model.config.name_or_path,
130
+ "dimension": self.dimension,
131
+ "total_items": len(self.data)
132
+ }
133
+ with open(f"{save_path}/config.json", "w", encoding="utf-8") as f:
134
+ json.dump(config, f, ensure_ascii=False, indent=2)
135
+
136
+ print(f"💾 Database saved to {save_path}")
137
+
138
+ def load_database(self, load_path: str = "./slogan_db"):
139
+ """加载数据库"""
140
+ print(f"📂 Loading database from {load_path}")
141
+
142
+ # 加载FAISS索引
143
+ self.index = faiss.read_index(f"{load_path}/faiss.index")
144
+
145
+ # 加载数据
146
+ with open(f"{load_path}/data.pkl", "rb") as f:
147
+ self.data = pickle.load(f)
148
+
149
+ print(f"✅ Loaded database with {len(self.data)} items")
150
+
151
+ def add_new_item(self, business: str, category: str, description: str, slogan: str):
152
+ """动态添加新项目"""
153
+ combined_text = f"{business} {category} {description}"
154
+
155
+ # 生成嵌入
156
+ embedding = self.model.encode([combined_text])
157
+ embedding = embedding / np.linalg.norm(embedding, axis=1, keepdims=True)
158
+
159
+ # 添加到索引
160
+ self.index.add(embedding.astype('float32'))
161
+
162
+ # 添加到数据
163
+ self.data.append({
164
+ "business": business,
165
+ "category": category,
166
+ "description": description,
167
+ "slogan": slogan,
168
+ "combined_text": combined_text
169
+ })
170
+
171
+ print(f"➕ Added new item: {business}")
172
+
173
+ def generate_slogan_suggestions(self, business_description: str, top_k: int = 3) -> List[str]:
174
+ """根据业务描述生成Slogan建议"""
175
+ similar_items = self.search_similar(business_description, top_k)
176
+
177
+ suggestions = []
178
+ for item in similar_items:
179
+ suggestions.append({
180
+ "slogan": item["slogan"],
181
+ "reference": f"{item['business']} ({item['category']})",
182
+ "similarity": item["similarity_score"]
183
+ })
184
+
185
+ return suggestions
186
+
187
+ def main():
188
+ """主函数演示"""
189
+ # 初始化数据库
190
+ db = SloganEmbeddingDB()
191
+
192
+ # 创建或加载数据
193
+ if os.path.exists("./slogan_db"):
194
+ print("📂 Found existing database, loading...")
195
+ db.load_database()
196
+ else:
197
+ print("🆕 Creating new database...")
198
+ dataset = db.create_sample_dataset()
199
+ db.build_embeddings(dataset)
200
+ db.save_database()
201
+
202
+ # 测试搜索
203
+ test_queries = [
204
+ "智能穿戴设备健康监测",
205
+ "环保新能源汽车",
206
+ "人工智能学习平台",
207
+ "美式快餐炸鸡",
208
+ "luxury sports car",
209
+ "mobile phone technology"
210
+ ]
211
+
212
+ print("\n" + "="*60)
213
+ print("🔍 SEARCH RESULTS")
214
+ print("="*60)
215
+
216
+ for query in test_queries:
217
+ print(f"\n🔍 Query: {query}")
218
+ results = db.search_similar(query, top_k=3)
219
+
220
+ for result in results:
221
+ print(f" {result['rank']}. {result['business']} ({result['category']})")
222
+ print(f" 描述: {result['description']}")
223
+ print(f" Slogan: {result['slogan']}")
224
+ print(f" 相似度: {result['similarity_score']:.3f}")
225
+ print()
226
+
227
+ # 测试Slogan��成建议
228
+ print("\n" + "="*60)
229
+ print("💡 SLOGAN SUGGESTIONS")
230
+ print("="*60)
231
+
232
+ new_business = "AI智能音箱语音助手设备"
233
+ print(f"\n💡 为 '{new_business}' 生成Slogan建议:")
234
+
235
+ suggestions = db.generate_slogan_suggestions(new_business)
236
+ for i, suggestion in enumerate(suggestions, 1):
237
+ print(f" {i}. \"{suggestion['slogan']}\"")
238
+ print(f" 参考: {suggestion['reference']}")
239
+ print(f" 相似度: {suggestion['similarity']:.3f}")
240
+ print()
241
+
242
+ # 演示动态添加
243
+ print("\n" + "="*60)
244
+ print("➕ ADDING NEW ITEM")
245
+ print("="*60)
246
+
247
+ db.add_new_item(
248
+ business="智能眼镜",
249
+ category="AR设备",
250
+ description="增强现实智能眼镜产品",
251
+ slogan="看见未来,触手可及"
252
+ )
253
+
254
+ # 重新搜索测试
255
+ print(f"\n🔍 搜索 'AR增强现实产品':")
256
+ results = db.search_similar("AR增强现实产品", top_k=2)
257
+ for result in results:
258
+ print(f" - {result['business']}: {result['slogan']} (相似度: {result['similarity_score']:.3f})")
259
+
260
+ if __name__ == "__main__":
261
+ main()