Spaces:
Configuration error
Configuration error
import os | |
import json | |
import asyncio | |
import requests | |
from tqdm import tqdm | |
from dotenv import load_dotenv | |
load_dotenv() | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.documents import Document | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores import FAISS | |
# 获取当前目录根路径 | |
current_file_path = os.path.dirname(os.path.abspath(__file__)) | |
root_path = os.path.abspath(current_file_path) | |
data_path = os.path.join(root_path, "data_simple") | |
db_path = os.path.join(root_path, "database", "init") | |
# 1. 根据 star 数量区间获取 GitHub 仓库,同时根据 star 数量从多到少排序(闭区间)并保存 GitHub 仓库 | |
def get_top_repo_by_star(per_page=1000, page=1, min_star_num=0, max_star_num=500000): | |
query = f'stars:{min_star_num}..{max_star_num} pushed:>2021-01-01' | |
sort = 'stars' | |
order = 'desc' | |
search_url = f'{os.getenv('GITHUB_API_URL')}/search/repositories?q={query}&sort={sort}&order={order}&per_page={per_page}&page={page}' | |
headers = {"Authorization": f"token {os.getenv('GITHUB_TOKEN')}"} | |
response = requests.get(search_url, headers=headers) | |
if response.status_code == 200: | |
total_count = response.json()['total_count'] | |
total_page = total_count // per_page + 1 | |
print(f"Total page: {total_page}, current page: {page}") | |
if response.json()['incomplete_results']: print("Incomplete results") | |
return response.json()['items'], response.json()['items'][-1]['stargazers_count'], total_count | |
else: | |
print(f"Failed to retrieve repositories: {response.status_code}") | |
print("") | |
# 直接退出 | |
exit(1) | |
def save_repo_by_star(max_star=500000): | |
# github 限制每次请求最多得到 100 个仓库,因此 page 固定为 1 | |
top_repositories, max_star, count = get_top_repo_by_star(per_page=1000, page=1, min_star_num=1000, max_star_num=max_star) | |
for i, repo in enumerate(top_repositories): | |
owner = repo['owner']['login'] | |
name = repo['name'] | |
unique_id = f"{name} -- {owner}" | |
stars = repo['stargazers_count'] | |
print(f"Repository {i}: {name}, Stars: {stars}") | |
# 存储为 json 格式 | |
with open(os.path.join(data_path, f'{unique_id}.json'), 'w') as f: | |
json.dump(repo, f, indent=4) | |
if count < 100: exit(1) | |
return max_star | |
def main_repo(): | |
max_star = 500000 # 最多 star 的仓库有 500k | |
num = 1 | |
while True: | |
print("=" * 50) | |
print(f"Round {num}, Max star: {max_star}") | |
max_star = save_repo_by_star(max_star) | |
num += 1 | |
# 2. 将数据转换为向量 | |
async def create_vector_db(docs, embeddings, batch_size=800): | |
# 初始化第一批数据 | |
vector_db = await FAISS.afrom_documents(docs[0:batch_size], embeddings) | |
if len(docs) < batch_size: return vector_db | |
# 创建任务x`` | |
tasks = [] | |
for start_idx in range(batch_size, len(docs), batch_size): | |
end_idx = min(start_idx + batch_size, len(docs)) | |
tasks.append(FAISS.afrom_documents(docs[start_idx:end_idx], embeddings)) | |
# 执行任务 | |
results = await asyncio.gather(*tasks) | |
# 合并结果 | |
for temp_db in results: | |
vector_db.merge_from(temp_db) | |
return vector_db | |
async def main_convert_to_vector(): | |
# 读取文件 | |
files = os.listdir(data_path) | |
# 构建 document | |
docs = [] | |
for file in tqdm(files): | |
if not file.endswith(".json"): continue | |
with open(os.path.join(data_path, file), "r", encoding="utf-8") as f: | |
data = json.load(f) | |
content_map = { | |
"name": data["name"], | |
"description": data["description"], | |
} | |
content = json.dumps(content_map) | |
doc = Document(page_content=content, metadata={"html_url": data["html_url"], | |
"topics": data["topics"], | |
"created_at": data["created_at"], | |
"updated_at": data["updated_at"], | |
"star_count": data["stargazers_count"]}) | |
docs.append(doc) | |
print(f"Total {len(docs)} documents.") | |
# 初始化 Embedding 实例 | |
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"), | |
base_url=os.getenv("OPENAI_BASE_URL"), | |
model="text-embedding-3-small") | |
print("Embedding model success: text-embedding-3-small") | |
# 文档嵌入 | |
if os.path.exists(os.path.join(db_path, "init.faiss")): | |
vector_db = FAISS.load_local(db_path, embeddings=embeddings, | |
index_name="init", | |
allow_dangerous_deserialization=True) | |
else: | |
vector_db = await create_vector_db(docs, embeddings=embeddings) | |
vector_db.save_local(db_path, index_name="init") | |
return vector_db | |
if __name__ == "__main__": | |
# 1. 获取仓库信息 | |
# main_repo() | |
# 2. 构建向量数据库 | |
asyncio.run(main_convert_to_vector()) | |