aihuashanying commited on
Commit
04e426f
·
1 Parent(s): 1e51d69

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. .gitattributes +36 -35
  2. README.md +14 -14
  3. app.py +421 -0
  4. requirements.txt +0 -0
  5. test.py +274 -0
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ faiss_index_hnsw_new/index.faiss filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,14 @@
1
- ---
2
- title: Aileeao
3
- emoji: 🏢
4
- colorFrom: purple
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.20.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: ai李敖
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Aileeao
3
+ emoji: 🏢
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.20.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: ai李敖
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from langchain_community.document_loaders import TextLoader, DirectoryLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain.prompts import PromptTemplate
8
+ import requests
9
+ import numpy as np
10
+ import json
11
+ import faiss
12
+ from collections import deque
13
+ from langchain_core.embeddings import Embeddings
14
+ import threading
15
+ import queue
16
+ from langchain_core.messages import HumanMessage, AIMessage
17
+ from sentence_transformers import SentenceTransformer
18
+ import pickle
19
+ import torch
20
+ from langchain_core.documents import Document
21
+
22
+ # 全局停止标志和输出队列
23
+ stop_flag = threading.Event()
24
+ output_queue = queue.Queue()
25
+
26
+ # 自定义 SentenceTransformers 嵌入类
27
+ class SentenceTransformerEmbeddings(Embeddings):
28
+ def __init__(self, model_name="BAAI/bge-m3"):
29
+ self.model = SentenceTransformer(model_name)
30
+ self.batch_size = 64
31
+
32
+ def embed_documents(self, texts):
33
+ embeddings_file = "embeddings_temp.npy"
34
+ total_chunks = len(texts)
35
+ embeddings_shape = (total_chunks, 1024)
36
+
37
+ embeddings_array = np.memmap(embeddings_file, dtype='float32', mode='w+', shape=embeddings_shape)
38
+ with torch.cuda.amp.autocast():
39
+ for i in range(0, total_chunks, 1000):
40
+ batch = texts[i:i+1000]
41
+ batch_emb = self.model.encode(
42
+ batch,
43
+ normalize_embeddings=True,
44
+ batch_size=self.batch_size,
45
+ show_progress_bar=False
46
+ )
47
+ embeddings_array[i:i+len(batch)] = batch_emb
48
+ if (i + len(batch)) % 100 == 0:
49
+ print(f"嵌入进度: {i+len(batch)} / {total_chunks}")
50
+ torch.cuda.empty_cache()
51
+ embeddings_array.flush()
52
+ return np.array(embeddings_array)
53
+
54
+ def embed_query(self, text):
55
+ with torch.cuda.amp.autocast():
56
+ return self.model.encode([text], normalize_embeddings=True, batch_size=1)[0]
57
+
58
+ # SiliconFlow 重排序函数(保持不变)
59
+ def rerank_documents(query, documents, api_key, top_n=10):
60
+ url = "https://api.siliconflow.cn/v1/rerank"
61
+ headers = {
62
+ "Authorization": f"Bearer {api_key}",
63
+ "Content-Type": "application/json"
64
+ }
65
+ doc_texts = [doc.page_content for doc in documents]
66
+ payload = {
67
+ "model": "BAAI/bge-reranker-v2-m3",
68
+ "query": query,
69
+ "documents": doc_texts,
70
+ "top_n": top_n
71
+ }
72
+ response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=30)
73
+ if response.status_code == 200:
74
+ result = response.json()
75
+ reranked_results = result.get("results", [])
76
+ if not reranked_results:
77
+ raise Exception("重排序结果为空")
78
+ reranked_docs_with_scores = [
79
+ (documents[res["index"]], res["relevance_score"])
80
+ for res in reranked_results
81
+ ]
82
+ return reranked_docs_with_scores
83
+ else:
84
+ raise Exception(f"重排序失败: {response.status_code}, {response.text}")
85
+
86
+ # 设置 API Keys
87
+ os.environ["SILICONFLOW_API_KEY"] = os.getenv("SILICONFLOW_API_KEY", "sk-cigytzyzghoziznvniugfihuicjcgmborusgodktydremtvd")
88
+ os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "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")
89
+
90
+ # 初始化嵌入模型
91
+ embeddings = SentenceTransformerEmbeddings(model_name="BAAI/bge-m3")
92
+
93
+ # 构建 HNSW 索引
94
+ def build_hnsw_index(knowledge_base_path, index_path):
95
+ print("开始加载文档...")
96
+ loader = DirectoryLoader(
97
+ knowledge_base_path,
98
+ glob="*.txt",
99
+ loader_cls=lambda path: TextLoader(path, encoding="utf-8"),
100
+ use_multithreading=True
101
+ )
102
+ documents = loader.load()
103
+ print(f"加载完成,共 {len(documents)} 个文档")
104
+
105
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
106
+ if not os.path.exists("chunks.pkl"):
107
+ print("开始分片...")
108
+ docs = text_splitter.split_documents(documents)
109
+ texts = [doc.page_content for doc in docs]
110
+ with open("chunks.pkl", "wb") as f:
111
+ pickle.dump(texts, f)
112
+ print(f"分片完成,共 {len(texts)} 个 chunk")
113
+ else:
114
+ with open("chunks.pkl", "rb") as f:
115
+ texts = pickle.load(f)
116
+ print(f"加载已有分片,共 {len(texts)} 个 chunk")
117
+
118
+ embeddings_file = "embeddings_temp.npy"
119
+ if os.path.exists(embeddings_file):
120
+ os.remove(embeddings_file)
121
+
122
+ if not os.path.exists("embeddings.npy"):
123
+ print("开始生成嵌入...")
124
+ embeddings_array = embeddings.embed_documents(texts)
125
+ np.save("embeddings.npy", embeddings_array)
126
+ if os.path.exists(embeddings_file):
127
+ os.remove(embeddings_file)
128
+ print(f"嵌入生成完成,维度: {embeddings_array.shape}")
129
+ else:
130
+ embeddings_array = np.load("embeddings.npy")
131
+ print(f"加载已有嵌入,维度: {embeddings_array.shape}")
132
+
133
+ dimension = embeddings_array.shape[1]
134
+ index = faiss.IndexHNSWFlat(dimension, 16)
135
+ index.hnsw.efConstruction = 100
136
+ print("开始构建 HNSW 索引...")
137
+
138
+ batch_size = 5000
139
+ total_vectors = embeddings_array.shape[0]
140
+ for i in range(0, total_vectors, batch_size):
141
+ batch = embeddings_array[i:i + batch_size]
142
+ index.add(batch)
143
+ print(f"索引构建进度: {min(i + batch_size, total_vectors)} / {total_vectors}")
144
+
145
+ print("开始构造 FAISS 对象...")
146
+ # 使用 FAISS.from_texts 初始化基础结构,避免重复嵌入
147
+ dummy_texts = [texts[0]] # 用一个样本初始化,避免嵌入所有文本
148
+ vector_store = FAISS.from_texts(dummy_texts, embeddings)
149
+ # 替换索引和文档存储
150
+ vector_store.index = index
151
+ vector_store.docstore._dict.clear() # 清空默认的 docstore
152
+ vector_store.index_to_docstore_id.clear() # 清空默认映射
153
+
154
+ # 手动填充文档存储
155
+ for i, text in enumerate(texts):
156
+ doc_id = str(i)
157
+ vector_store.docstore._dict[doc_id] = Document(page_content=text)
158
+ vector_store.index_to_docstore_id[i] = doc_id
159
+
160
+ print(f"构造后 vector_store 类型: {type(vector_store)}")
161
+
162
+ print("开始保存索引...")
163
+ vector_store.save_local(index_path)
164
+ print(f"HNSW 索引已生成并保存到 '{index_path}'")
165
+
166
+ # 验证保存结果
167
+ loaded_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
168
+ print(f"加载后 vector_store 类型: {type(loaded_store)}")
169
+ return loaded_store
170
+
171
+ # 将已有 faiss_index 转为 HNSW
172
+ def convert_to_hnsw(existing_index_path, new_index_path):
173
+ old_vector_store = FAISS.load_local(existing_index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
174
+ doc_texts = [doc.page_content for doc in old_vector_store.docstore._dict.values()]
175
+ embeddings_array = embeddings.embed_documents(doc_texts)
176
+ dimension = embeddings_array.shape[1]
177
+ index = faiss.IndexHNSWFlat(dimension, 8)
178
+ index.hnsw.efConstruction = 40
179
+
180
+ batch_size = 5000
181
+ total_vectors = embeddings_array.shape[0]
182
+ for i in range(0, total_vectors, batch_size):
183
+ batch = embeddings_array[i:i + batch_size]
184
+ index.add(batch)
185
+ print(f"索引转换进度: {min(i + batch_size, total_vectors)} / {total_vectors}")
186
+
187
+ print("开始构造 FAISS 对象...")
188
+ dummy_texts = [doc_texts[0]]
189
+ new_vector_store = FAISS.from_texts(dummy_texts, embeddings)
190
+ new_vector_store.index = index
191
+ new_vector_store.docstore._dict.clear()
192
+ new_vector_store.index_to_docstore_id.clear()
193
+
194
+ for i, text in enumerate(doc_texts):
195
+ doc_id = str(i)
196
+ new_vector_store.docstore._dict[doc_id] = Document(page_content=text)
197
+ new_vector_store.index_to_docstore_id[i] = doc_id
198
+
199
+ print(f"构造后 vector_store 类型: {type(new_vector_store)}")
200
+
201
+ print("开始保存索引...")
202
+ new_vector_store.save_local(new_index_path)
203
+ print(f"已将 '{existing_index_path}' 转换为 HNSW 并保存到 '{new_index_path}'")
204
+
205
+ loaded_store = FAISS.load_local(new_index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
206
+ print(f"加载后 vector_store 类型: {type(loaded_store)}")
207
+ return loaded_store
208
+
209
+ # 加载或生成索引
210
+ index_path = "faiss_index_hnsw_new"
211
+ knowledge_base_path = "knowledge_base"
212
+
213
+ if not os.path.exists(index_path):
214
+ if os.path.exists("faiss_index"):
215
+ print("检测到已有 faiss_index,正在转换为 HNSW...")
216
+ vector_store = convert_to_hnsw("faiss_index", index_path)
217
+ elif os.path.exists(knowledge_base_path):
218
+ print("检测到 knowledge_base,正在生成 HNSW 索引...")
219
+ vector_store = build_hnsw_index(knowledge_base_path, index_path)
220
+ else:
221
+ raise FileNotFoundError("未找到 'faiss_index' 或 'knowledge_base',请提供知识库数据")
222
+ else:
223
+ vector_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
224
+ vector_store.index.hnsw.efSearch = 300
225
+ print("已加载 HNSW 索引 'faiss_index_hnsw_new',efSearch 设置为 300")
226
+ print(f"加载后 vector_store 类型: {type(vector_store)}")
227
+
228
+ # 初始化 ChatOpenAI
229
+ llm = ChatOpenAI(
230
+ model="deepseek/deepseek-r1:free",
231
+ api_key=os.environ["OPENROUTER_API_KEY"],
232
+ base_url="https://openrouter.ai/api/v1",
233
+ timeout=60,
234
+ temperature=0.3,
235
+ max_tokens=88888,
236
+ streaming=True
237
+ )
238
+
239
+ # 定义提示词模板(保持不变)
240
+ prompt_template = PromptTemplate(
241
+ input_variables=["context", "question", "chat_history"],
242
+ template="""
243
+ 你是一个研究李敖的专家,根据用户提出的问题{question}、最近10轮对话历史{chat_history}以及从李敖相关书籍和评论中检索的内容{context}回答问题。
244
+ 在回答时,请注意以下几点:
245
+ - 结合李敖的写作风格和思想,筛选出与问题和对话历史最相关的检索内容,避免无关信息。
246
+ - 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。
247
+ - 回答应结构化、分段落,确保逻辑清晰,语言生动,类似李敖的犀利风格。
248
+ - 如果检索内容和历史不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但需说明这是推测。
249
+ - 列出引用的书籍或文章名称及章节(如有),如《李敖大全集》第X卷或具体书名。
250
+ - 只能基于提供的知识库内容{context}和对话历史{chat_history}回答,不得引入外部信息。
251
+ - 对于列举类问题,控制在10个要点以内,并优先提供最相关项。
252
+ - 如果回答较长,结构化分段总结,分点作答控制在5个点以内。
253
+ - 根据对话历史调整回答,避免重复或矛盾。
254
+ """
255
+ )
256
+
257
+ # 对话历史管理(保持不变)
258
+ class ConversationHistory:
259
+ def __init__(self, max_length=10):
260
+ self.history = deque(maxlen=max_length)
261
+
262
+ def add_turn(self, question, answer):
263
+ self.history.append((question, answer))
264
+
265
+ def get_history(self):
266
+ return [(turn[0], turn[1]) for turn in self.history]
267
+
268
+ def clear(self):
269
+ self.history.clear()
270
+
271
+ conversation = ConversationHistory()
272
+
273
+ # 计算余弦相似度函数(保持不变)
274
+ def compute_cosine_similarity(query_embedding, doc_embeddings):
275
+ query_embedding = np.array(query_embedding)
276
+ doc_embeddings = np.array(doc_embeddings)
277
+ dot_product = np.dot(doc_embeddings, query_embedding)
278
+ query_norm = np.linalg.norm(query_embedding)
279
+ doc_norms = np.linalg.norm(doc_embeddings, axis=1)
280
+ similarities = dot_product / (query_norm * doc_norms + 1e-8)
281
+ return similarities
282
+
283
+ # 生成回答的线程函数
284
+ def generate_answer_thread(question, output_queue):
285
+ global stop_flag
286
+ stop_flag.clear()
287
+ try:
288
+ print(f"vector_store 类型: {type(vector_store)}") # 调试
289
+ history_list = conversation.get_history()
290
+ history_text = "\n".join([f"问: {q}\n答: {a}" for q, a in history_list]) if history_list else ""
291
+ query_with_context = f"{history_text}\n当前问题: {question}" if history_text else question
292
+ initial_docs_with_scores = vector_store.similarity_search_with_score(query_with_context, k=50)
293
+ print(f"初始检索数量: {len(initial_docs_with_scores)}")
294
+ output_queue.put(f"初始检索数量: {len(initial_docs_with_scores)}\n")
295
+
296
+ if stop_flag.is_set():
297
+ output_queue.put("生成已停止")
298
+ return
299
+
300
+ query_embedding = embeddings.embed_query(query_with_context)
301
+ doc_embeddings = [embeddings.embed_query(doc.page_content) for doc, _ in initial_docs_with_scores]
302
+ similarities = compute_cosine_similarity(query_embedding, doc_embeddings)
303
+ print(f"余弦相似度范围: {min(similarities):.4f} - {max(similarities):.4f}")
304
+ output_queue.put(f"余弦相似度范围: {min(similarities):.4f} - {max(similarities):.4f}\n")
305
+
306
+ if stop_flag.is_set():
307
+ output_queue.put("生成已停止")
308
+ return
309
+
310
+ similarity_threshold = max(similarities) * 0.8
311
+ filtered_docs_with_scores = [
312
+ (doc, sim)
313
+ for (doc, _), sim in zip(initial_docs_with_scores, similarities)
314
+ if sim >= similarity_threshold
315
+ ]
316
+ if len(filtered_docs_with_scores) < 5:
317
+ filtered_docs_with_scores = [(doc, sim) for (doc, _), sim in zip(initial_docs_with_scores[:10], similarities[:10])]
318
+ print(f"过滤后数量不足,保留前 10 个文档")
319
+ output_queue.put("过滤后数量不足,保留前 10 个文档\n")
320
+ else:
321
+ print(f"过滤后数量: {len(filtered_docs_with_scores)}")
322
+ output_queue.put(f"过滤后数量: {len(filtered_docs_with_scores)}\n")
323
+
324
+ if stop_flag.is_set():
325
+ output_queue.put("生成已停止")
326
+ return
327
+
328
+ initial_docs = [doc for doc, _ in filtered_docs_with_scores]
329
+ vector_similarities = [sim for _, sim in filtered_docs_with_scores]
330
+ reranked_docs_with_scores = rerank_documents(query_with_context, initial_docs, os.environ["SILICONFLOW_API_KEY"], top_n=10)
331
+ reranked_docs = [doc for doc, score in reranked_docs_with_scores]
332
+ rerank_scores = [score for _, score in reranked_docs_with_scores]
333
+
334
+ if stop_flag.is_set():
335
+ output_queue.put("生成已停止")
336
+ return
337
+
338
+ combined_scores = [
339
+ 0.2 * vector_similarities[i] + 0.8 * rerank_scores[i]
340
+ for i in range(len(reranked_docs))
341
+ ]
342
+ sorted_docs_with_scores = sorted(
343
+ zip(reranked_docs, combined_scores),
344
+ key=lambda x: x[1],
345
+ reverse=True
346
+ )
347
+ final_docs = [doc for doc, _ in sorted_docs_with_scores][:5]
348
+
349
+ if stop_flag.is_set():
350
+ output_queue.put("生成已停止")
351
+ return
352
+
353
+ context = "\n\n".join([doc.page_content for doc in final_docs])
354
+ chat_history = [HumanMessage(content=q) if i % 2 == 0 else AIMessage(content=a)
355
+ for i, (q, a) in enumerate(history_list)]
356
+ prompt = prompt_template.format(context=context, question=question, chat_history=history_text)
357
+
358
+ answer = ""
359
+ for chunk in llm.stream([HumanMessage(content=prompt)]):
360
+ if stop_flag.is_set():
361
+ output_queue.put(answer + "\n\n(生成已停止)")
362
+ return
363
+ answer += chunk.content
364
+ output_queue.put(answer)
365
+
366
+ conversation.add_turn(question, answer)
367
+ output_queue.put(answer)
368
+
369
+ except Exception as e:
370
+ output_queue.put(f"Error: {str(e)}")
371
+
372
+ # Gradio 接口函数(保持不变)
373
+ def answer_question(question):
374
+ global stop_flag, output_queue
375
+ stop_flag.clear()
376
+ output_queue.queue.clear()
377
+
378
+ thread = threading.Thread(target=generate_answer_thread, args=(question, output_queue))
379
+ thread.start()
380
+
381
+ while thread.is_alive() or not output_queue.empty():
382
+ try:
383
+ output = output_queue.get(timeout=0.1)
384
+ yield output
385
+ except queue.Empty:
386
+ continue
387
+
388
+ while not output_queue.empty():
389
+ yield output_queue.get()
390
+
391
+ def stop_generation():
392
+ global stop_flag
393
+ stop_flag.set()
394
+ return "生成已停止,正在中止..."
395
+
396
+ def clear_conversation():
397
+ conversation.clear()
398
+ return "对话历史已清空,请开始新的对话。"
399
+
400
+ # 创建 Gradio 界面(保持不变)
401
+ with gr.Blocks(title="AI李敖助手") as interface:
402
+ gr.Markdown("### AI李敖助手")
403
+ gr.Markdown("基于李敖163本相关书籍构建的知识库,支持上下文关联,记住最近10轮对话,输入问题以获取李敖风格的回答。")
404
+
405
+ with gr.Row():
406
+ with gr.Column(scale=3):
407
+ question_input = gr.Textbox(label="请输入您的问题", placeholder="输入您的问题...")
408
+ submit_button = gr.Button("提交")
409
+ with gr.Column(scale=1):
410
+ clear_button = gr.Button("新建对话")
411
+ stop_button = gr.Button("停止生成")
412
+
413
+ output_text = gr.Textbox(label="回答", interactive=False)
414
+
415
+ submit_button.click(fn=answer_question, inputs=question_input, outputs=output_text)
416
+ clear_button.click(fn=clear_conversation, inputs=None, outputs=output_text)
417
+ stop_button.click(fn=stop_generation, inputs=None, outputs=output_text)
418
+
419
+ # 启动应用
420
+ if __name__ == "__main__":
421
+ interface.launch(share=True)
requirements.txt ADDED
Binary file (3.89 kB). View file
 
test.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ import torch
3
+ print(torch.__version__) # 如 2.4.0+cu118
4
+ print(torch.cuda.is_available()) # 应返回 True
5
+ print(torch.cuda.get_device_name(0)) # 应返回 GPU 型号
6
+ =======
7
+ import os
8
+ import gradio as gr
9
+ from langchain_community.document_loaders import TextLoader, DirectoryLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_openai import ChatOpenAI
13
+ from langchain.chains import RetrievalQA
14
+ from langchain_core.embeddings import Embeddings
15
+ from langchain.prompts import PromptTemplate
16
+ import requests
17
+ import numpy as np
18
+ import json
19
+ import faiss
20
+ from langchain_community.embeddings import OllamaEmbeddings
21
+
22
+ # 自定义 SiliconFlow 嵌入类
23
+ class SiliconFlowEmbeddings(Embeddings):
24
+ def __init__(self, model="BAAI/bge-m3", api_key=None):
25
+ self.model = model
26
+ self.api_key = api_key
27
+
28
+ def embed_documents(self, texts):
29
+ return self._get_embeddings(texts)
30
+
31
+ def embed_query(self, text):
32
+ return self._get_embeddings([text])[0]
33
+
34
+ def _get_embeddings(self, texts):
35
+ url = "https://api.siliconflow.cn/v1/embeddings"
36
+ headers = {
37
+ "Authorization": f"Bearer {self.api_key}",
38
+ "Content-Type": "application/json"
39
+ }
40
+ payload = {
41
+ "model": self.model,
42
+ "input": texts
43
+ }
44
+ response = requests.post(url, json=payload, headers=headers, timeout=30)
45
+ if response.status_code == 200:
46
+ data = response.json()
47
+ return np.array([item["embedding"] for item in data["data"]])
48
+ else:
49
+ raise Exception(f"API 调用失败: {response.status_code}, {response.text}")
50
+
51
+ # SiliconFlow 重排序函数
52
+ def rerank_documents(query, documents, api_key, top_n=10):
53
+ url = "https://api.siliconflow.cn/v1/rerank"
54
+ headers = {
55
+ "Authorization": f"Bearer {api_key}",
56
+ "Content-Type": "application/json"
57
+ }
58
+ doc_texts = [doc.page_content for doc in documents]
59
+ payload = {
60
+ "model": "BAAI/bge-reranker-v2-m3",
61
+ "query": query,
62
+ "documents": doc_texts,
63
+ "top_n": top_n
64
+ }
65
+ response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=30)
66
+ if response.status_code == 200:
67
+ result = response.json()
68
+ reranked_results = result.get("results", [])
69
+ if not reranked_results:
70
+ raise Exception("重排序结果为空")
71
+ reranked_docs_with_scores = [
72
+ (documents[res["index"]], res["relevance_score"])
73
+ for res in reranked_results
74
+ ]
75
+ return reranked_docs_with_scores
76
+ else:
77
+ raise Exception(f"重排序失败: {response.status_code}, {response.text}")
78
+
79
+ # 设置 API Keys
80
+ os.environ["SILICONFLOW_API_KEY"] = os.getenv("SILICONFLOW_API_KEY", "sk-cigytzyzghoziznvniugfihuicjcgmborusgodktydremtvd")
81
+ os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "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")
82
+
83
+ # 初始化嵌入模型
84
+ embeddings = OllamaEmbeddings(model="bge-m3", base_url="http://localhost:11434")
85
+
86
+ # 从 knowledge_base 生成 HNSW 索引
87
+ def build_hnsw_index(knowledge_base_path, index_path):
88
+ loader = DirectoryLoader(
89
+ knowledge_base_path,
90
+ glob="*.txt",
91
+ loader_cls=lambda path: TextLoader(path, encoding="utf-8")
92
+ )
93
+ documents = loader.load()
94
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
95
+ texts = text_splitter.split_documents(documents)
96
+
97
+ # 使用 FAISS.from_documents 创建向量存储
98
+ vector_store = FAISS.from_documents(texts, embeddings)
99
+
100
+ # 获取嵌入并转换为 HNSW
101
+ embeddings_array = np.array(embeddings.embed_documents([doc.page_content for doc in texts]))
102
+ dimension = embeddings_array.shape[1]
103
+ index = faiss.IndexHNSWFlat(dimension, 16) # M=16
104
+ index.hnsw.efConstruction = 100
105
+ index.hnsw.efSearch = 50
106
+ index.add(embeddings_array)
107
+
108
+ # 更新 FAISS 的索引
109
+ vector_store.index = index
110
+ vector_store.save_local(index_path)
111
+ print(f"HNSW 索引已生成并保存到 '{index_path}'")
112
+ return vector_store
113
+
114
+ # 将已有 faiss_index 转为 HNSW
115
+ def convert_to_hnsw(existing_index_path, new_index_path):
116
+ # 加载现有索引
117
+ old_vector_store = FAISS.load_local(existing_index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
118
+
119
+ # 获取文档内容
120
+ if hasattr(old_vector_store, 'docstore') and hasattr(old_vector_store.docstore, '_dict'):
121
+ docs = list(old_vector_store.docstore._dict.values())
122
+ doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in docs]
123
+ else:
124
+ doc_ids = list(old_vector_store.index_to_docstore_id.keys())
125
+ doc_texts = [old_vector_store.docstore._dict[old_vector_store.index_to_docstore_id[i]].page_content
126
+ if hasattr(old_vector_store.docstore._dict[old_vector_store.index_to_docstore_id[i]], 'page_content')
127
+ else str(old_vector_store.docstore._dict[old_vector_store.index_to_docstore_id[i]])
128
+ for i in doc_ids]
129
+
130
+ # 使用全局 embeddings 对象生成嵌入
131
+ embeddings_array = np.array(embeddings.embed_documents(doc_texts))
132
+
133
+ # 创建 HNSW 索引
134
+ dimension = embeddings_array.shape[1]
135
+ index = faiss.IndexHNSWFlat(dimension, 16) # M=16
136
+ index.hnsw.efConstruction = 100
137
+ index.hnsw.efSearch = 50
138
+ index.add(embeddings_array)
139
+
140
+ # 创建新的 FAISS 向量存储,注意不直接传递 index,而是稍后赋值
141
+ new_vector_store = FAISS.from_texts(doc_texts, embeddings)
142
+ new_vector_store.index = index # 直接替换索引
143
+ new_vector_store.save_local(new_index_path)
144
+ print(f"已将 '{existing_index_path}' 转换为 HNSW 并保存到 '{new_index_path}'")
145
+ return new_vector_store
146
+
147
+ # 加载或生成索引
148
+ index_path = "faiss_index_hnsw"
149
+ knowledge_base_path = "knowledge_base"
150
+
151
+ if not os.path.exists(index_path):
152
+ if os.path.exists("faiss_index"):
153
+ print("检测到已有 faiss_index,正在转换为 HNSW...")
154
+ vector_store = convert_to_hnsw("faiss_index", index_path)
155
+ elif os.path.exists(knowledge_base_path):
156
+ print("检测到 knowledge_base,正在生成 HNSW 索引...")
157
+ vector_store = build_hnsw_index(knowledge_base_path, index_path)
158
+ else:
159
+ raise FileNotFoundError("未找到 'faiss_index' 或 'knowledge_base',请提供知识库数据")
160
+ else:
161
+ vector_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
162
+ print("已加载 HNSW 索引 'faiss_index_hnsw'")
163
+
164
+ # 初始化 ChatOpenAI 使用 OpenRouter
165
+ llm = ChatOpenAI(
166
+ model="deepseek/deepseek-r1:free",
167
+ api_key=os.environ["OPENROUTER_API_KEY"],
168
+ base_url="https://openrouter.ai/api/v1",
169
+ timeout=60,
170
+ temperature=0.3,
171
+ max_tokens=88888,
172
+ )
173
+
174
+ # 定义提示词模板
175
+ prompt_template = PromptTemplate(
176
+ input_variables=["context", "question"],
177
+ template="""
178
+ 你是一个研究李敖的专家,根据用户提出的问题{question}以及从李敖相关书籍和评论中检索的内容{context}回答问题。
179
+
180
+ 在回答时,请注意以下几点:
181
+ - 结合李敖的写作风格和思想,筛选出与问题最相关的检索内容,避免无关信息。
182
+ - 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。
183
+ - 回答应结构化、分段落,确保逻辑清晰,语言生动,类似李敖的犀利风格。
184
+ - 如果检索内容不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但需说明这是推测。
185
+ - 列出引用的书籍或文章名称及章节(如有),如《李敖大全集》第X卷或具体书名。
186
+ - 只能基于提供的知识库内容{context}回答,不得引入外部信息。
187
+ - 并非搜索结果的所有内容都与用户的问题密切相关,你需要结合问题,对搜索结果进行甄别、筛选。
188
+ - 对于列举类的问题(如列举所有航班信息),尽量将答案控制在10个要点以内,并告诉用户可以查看搜索来源、获得完整信息。优先提供信息完整、最相关的列举项;如非必要,不要主动告诉用户搜索结果未提供的内容。
189
+ - 如果回答很长,请尽量结构化、分段落总结。如果需要分点作答,尽量控制在5个点以内,并合并相关的内容。
190
+ - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
191
+ - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
192
+ - 你的回答应该综合多个相关知识库内容来回答,不能重复引用一个知识库内容。
193
+ - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
194
+ """
195
+ )
196
+
197
+ # 创建检索问答链
198
+ qa_chain = RetrievalQA.from_chain_type(
199
+ llm=llm,
200
+ chain_type="stuff",
201
+ retriever=vector_store.as_retriever(search_kwargs={"k": 30}),
202
+ return_source_documents=True,
203
+ chain_type_kwargs={"prompt": prompt_template}
204
+ )
205
+
206
+ # 定义 Gradio 接口函数
207
+ def answer_question(question):
208
+ try:
209
+ # Step 1: FAISS 初始检索
210
+ initial_docs_with_scores = vector_store.similarity_search_with_score(question, k=30)
211
+ print(f"初始检索数量: {len(initial_docs_with_scores)}")
212
+
213
+ # FAISS 返回的是距离,转换为相似度
214
+ similarities = [1 - score for _, score in initial_docs_with_scores]
215
+ print(f"相似度范围: {min(similarities):.4f} - {max(similarities):.4f}")
216
+
217
+ # 打印前 5 个文档内容和相似度
218
+ for i, (doc, score) in enumerate(initial_docs_with_scores[:5]):
219
+ print(f"Top {i+1} - 相似度: {1 - score:.4f}, 内容: {doc.page_content[:100]}")
220
+
221
+ # Step 2: 动态阈值过滤
222
+ similarity_threshold = max(similarities) * 0.8
223
+ filtered_docs_with_scores = [
224
+ (doc, 1 - score)
225
+ for doc, score in initial_docs_with_scores
226
+ if (1 - score) >= similarity_threshold
227
+ ]
228
+ if len(filtered_docs_with_scores) < 5:
229
+ filtered_docs_with_scores = initial_docs_with_scores[:10]
230
+ print(f"过滤后数量不足,保留前 10 个文档")
231
+ else:
232
+ print(f"过滤后数量: {len(filtered_docs_with_scores)}")
233
+
234
+ initial_docs = [doc for doc, _ in filtered_docs_with_scores]
235
+ vector_similarities = [sim for _, sim in filtered_docs_with_scores]
236
+
237
+ # Step 3: 重排序
238
+ reranked_docs_with_scores = rerank_documents(question, initial_docs, os.environ["SILICONFLOW_API_KEY"], top_n=10)
239
+ reranked_docs = [doc for doc, score in reranked_docs_with_scores]
240
+ rerank_scores = [score for _, score in reranked_docs_with_scores]
241
+
242
+ # Step 4: 融合得分并排序
243
+ combined_scores = [
244
+ 0.2 * vector_similarities[i] + 0.8 * rerank_scores[i]
245
+ for i in range(len(reranked_docs))
246
+ ]
247
+ sorted_docs_with_scores = sorted(
248
+ zip(reranked_docs, combined_scores),
249
+ key=lambda x: x[1],
250
+ reverse=True
251
+ )
252
+ final_docs = [doc for doc, _ in sorted_docs_with_scores][:5]
253
+
254
+ # Step 5: 生成回答
255
+ context = "\n\n".join([doc.page_content for doc in final_docs])
256
+ response = qa_chain.invoke({"query": question, "context": context})
257
+
258
+ return response["result"]
259
+ except Exception as e:
260
+ return f"Error: {str(e)}"
261
+
262
+ # 创建 Gradio 界面
263
+ interface = gr.Interface(
264
+ fn=answer_question,
265
+ inputs=gr.Textbox(label="请输入您的问题"),
266
+ outputs=gr.Textbox(label="回答"),
267
+ title="AI李敖助手",
268
+ description="基于李敖163本相关书籍构建的知识库,输入问题以获取李敖风格的回答。"
269
+ )
270
+
271
+ # 启动应用
272
+ if __name__ == "__main__":
273
+ interface.launch(share=True)
274
+ >>>>>>> 921dc7e73a28368974490d7eba946303cf2129ba