i135e1fi414i41tqe / console.py
serhan's picture
Upload 16 files
14e11d6
raw
history blame
5.56 kB
import xxhash
from ai import AI
from config import Config
from storage import Storage
from contents import *
def console(cfg: Config):
try:
while True:
if not _console(cfg):
return
except KeyboardInterrupt:
print("exit")
def _console(cfg: Config) -> bool:
"""Run the console."""
contents, lang, identify = _get_contents()
print("The article has been retrieved, and the number of text fragments is:", len(contents))
for content in contents:
print('\t', content)
ai = AI(cfg)
storage = Storage.create_storage(cfg)
print("=====================================")
if storage.been_indexed(identify):
print("The article has already been indexed, so there is no need to index it again.")
print("=====================================")
else:
# 1. 对文章的每个段落生成embedding
# 1. Generate an embedding for each paragraph of the article.
embeddings, tokens = ai.create_embeddings(contents)
print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, "
f"costing ${tokens / 1000 * 0.0004}")
storage.add_all(embeddings, identify)
print("The embeddings have been saved.")
print("=====================================")
while True:
query = input("Please enter your query (/help to view commands):").strip()
if query.startswith("/"):
if query == "/quit":
return False
elif query == "/reset":
print("=====================================")
return True
elif query == "/summary":
# 生成embedding式摘要,根据不同的语言使用有基于SIF的加权平均或一般的直接求平均
# Generate an embedding-based summary, using weighted average based on SIF or direct average based on the language.
ai.generate_summary(storage.get_all_embeddings(identify), num_candidates=100,
use_sif=lang not in ['zh', 'ja', 'ko', 'hi', 'ar', 'fa'])
elif query == "/reindex":
# 重新索引,会清空数据库
# Re-index, which will clear the database.
storage.clear(identify)
embeddings, tokens = ai.create_embeddings(contents)
print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, "
f"costing ${tokens / 1000 * 0.0004}")
storage.add_all(embeddings, identify)
print("The embeddings have been saved.")
elif query == "/help":
print("Enter /summary to generate an embedding-based summary.")
print("Enter /reindex to re-index the article.")
print("Enter /reset to start over.")
print("Enter /quit to exit.")
print("Enter any other content for a query.")
else:
print("Invalid command.")
print("Enter /summary to generate an embedding-based summary.")
print("Enter /reindex to re-index the article.")
print("Enter /reset to start over.")
print("Enter /quit to exit.")
print("Enter any other content for a query.")
print("=====================================")
continue
else:
# 1. 生成关键词
# 1. Generate keywords.
print("Generate keywords.")
keywords = ai.get_keywords(query)
# 2. 对问题生成embedding
# 2. Generate an embedding for the question.
_, embedding = ai.create_embedding(keywords)
# 3. 从数据库中找到最相似的片段
# 3. Find the most similar fragments from the database.
texts = storage.get_texts(embedding, identify)
print("Related fragments found (first 5):")
for text in texts[:5]:
print('\t', text)
# 4. 把相关片段推给AI,AI会根据这些片段回答问题
# 4. Push the relevant fragments to the AI, which will answer the question based on these fragments.
ai.completion(query, texts)
print("=====================================")
def _get_contents() -> tuple[list[str], str, str]:
"""Get the contents."""
while True:
try:
url = input("Please enter the link to the article or the file path of the PDF/TXT/DOCX document: ").strip()
if os.path.exists(url):
if url.endswith('.pdf'):
contents, data = extract_text_from_pdf(url)
elif url.endswith('.txt'):
contents, data = extract_text_from_txt(url)
elif url.endswith('.docx'):
contents, data = extract_text_from_docx(url)
else:
print("Unsupported file format.")
continue
else:
contents, data = web_crawler_newspaper(url)
if not contents:
print("Unable to retrieve the content of the article. Please enter the link to the article or "
"the file path of the PDF/TXT/DOCX document again.")
continue
return contents, data, xxhash.xxh3_128_hexdigest('\n'.join(contents))
except Exception as e:
print("Error:", e)