Spaces:
Runtime error
Runtime error
import xxhash | |
from ai import AI | |
from config import Config | |
from storage import Storage | |
from contents import * | |
def console(cfg: Config): | |
try: | |
while True: | |
if not _console(cfg): | |
return | |
except KeyboardInterrupt: | |
print("exit") | |
def _console(cfg: Config) -> bool: | |
"""Run the console.""" | |
contents, lang, identify = _get_contents() | |
print("The article has been retrieved, and the number of text fragments is:", len(contents)) | |
for content in contents: | |
print('\t', content) | |
ai = AI(cfg) | |
storage = Storage.create_storage(cfg) | |
print("=====================================") | |
if storage.been_indexed(identify): | |
print("The article has already been indexed, so there is no need to index it again.") | |
print("=====================================") | |
else: | |
# 1. 对文章的每个段落生成embedding | |
# 1. Generate an embedding for each paragraph of the article. | |
embeddings, tokens = ai.create_embeddings(contents) | |
print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, " | |
f"costing ${tokens / 1000 * 0.0004}") | |
storage.add_all(embeddings, identify) | |
print("The embeddings have been saved.") | |
print("=====================================") | |
while True: | |
query = input("Please enter your query (/help to view commands):").strip() | |
if query.startswith("/"): | |
if query == "/quit": | |
return False | |
elif query == "/reset": | |
print("=====================================") | |
return True | |
elif query == "/summary": | |
# 生成embedding式摘要,根据不同的语言使用有基于SIF的加权平均或一般的直接求平均 | |
# Generate an embedding-based summary, using weighted average based on SIF or direct average based on the language. | |
ai.generate_summary(storage.get_all_embeddings(identify), num_candidates=100, | |
use_sif=lang not in ['zh', 'ja', 'ko', 'hi', 'ar', 'fa']) | |
elif query == "/reindex": | |
# 重新索引,会清空数据库 | |
# Re-index, which will clear the database. | |
storage.clear(identify) | |
embeddings, tokens = ai.create_embeddings(contents) | |
print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, " | |
f"costing ${tokens / 1000 * 0.0004}") | |
storage.add_all(embeddings, identify) | |
print("The embeddings have been saved.") | |
elif query == "/help": | |
print("Enter /summary to generate an embedding-based summary.") | |
print("Enter /reindex to re-index the article.") | |
print("Enter /reset to start over.") | |
print("Enter /quit to exit.") | |
print("Enter any other content for a query.") | |
else: | |
print("Invalid command.") | |
print("Enter /summary to generate an embedding-based summary.") | |
print("Enter /reindex to re-index the article.") | |
print("Enter /reset to start over.") | |
print("Enter /quit to exit.") | |
print("Enter any other content for a query.") | |
print("=====================================") | |
continue | |
else: | |
# 1. 生成关键词 | |
# 1. Generate keywords. | |
print("Generate keywords.") | |
keywords = ai.get_keywords(query) | |
# 2. 对问题生成embedding | |
# 2. Generate an embedding for the question. | |
_, embedding = ai.create_embedding(keywords) | |
# 3. 从数据库中找到最相似的片段 | |
# 3. Find the most similar fragments from the database. | |
texts = storage.get_texts(embedding, identify) | |
print("Related fragments found (first 5):") | |
for text in texts[:5]: | |
print('\t', text) | |
# 4. 把相关片段推给AI,AI会根据这些片段回答问题 | |
# 4. Push the relevant fragments to the AI, which will answer the question based on these fragments. | |
ai.completion(query, texts) | |
print("=====================================") | |
def _get_contents() -> tuple[list[str], str, str]: | |
"""Get the contents.""" | |
while True: | |
try: | |
url = input("Please enter the link to the article or the file path of the PDF/TXT/DOCX document: ").strip() | |
if os.path.exists(url): | |
if url.endswith('.pdf'): | |
contents, data = extract_text_from_pdf(url) | |
elif url.endswith('.txt'): | |
contents, data = extract_text_from_txt(url) | |
elif url.endswith('.docx'): | |
contents, data = extract_text_from_docx(url) | |
else: | |
print("Unsupported file format.") | |
continue | |
else: | |
contents, data = web_crawler_newspaper(url) | |
if not contents: | |
print("Unable to retrieve the content of the article. Please enter the link to the article or " | |
"the file path of the PDF/TXT/DOCX document again.") | |
continue | |
return contents, data, xxhash.xxh3_128_hexdigest('\n'.join(contents)) | |
except Exception as e: | |
print("Error:", e) | |