semantic_github / deal_data.py
Aniun's picture
Upload 6 files
f2acc7f verified
import os
import json
import asyncio
import requests
from tqdm import tqdm
from dotenv import load_dotenv
load_dotenv()
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
# 获取当前目录根路径
current_file_path = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(current_file_path)
data_path = os.path.join(root_path, "data_simple")
db_path = os.path.join(root_path, "database", "init")
# 1. 根据 star 数量区间获取 GitHub 仓库,同时根据 star 数量从多到少排序(闭区间)并保存 GitHub 仓库
def get_top_repo_by_star(per_page=1000, page=1, min_star_num=0, max_star_num=500000):
query = f'stars:{min_star_num}..{max_star_num} pushed:>2021-01-01'
sort = 'stars'
order = 'desc'
search_url = f'{os.getenv('GITHUB_API_URL')}/search/repositories?q={query}&sort={sort}&order={order}&per_page={per_page}&page={page}'
headers = {"Authorization": f"token {os.getenv('GITHUB_TOKEN')}"}
response = requests.get(search_url, headers=headers)
if response.status_code == 200:
total_count = response.json()['total_count']
total_page = total_count // per_page + 1
print(f"Total page: {total_page}, current page: {page}")
if response.json()['incomplete_results']: print("Incomplete results")
return response.json()['items'], response.json()['items'][-1]['stargazers_count'], total_count
else:
print(f"Failed to retrieve repositories: {response.status_code}")
print("")
# 直接退出
exit(1)
def save_repo_by_star(max_star=500000):
# github 限制每次请求最多得到 100 个仓库,因此 page 固定为 1
top_repositories, max_star, count = get_top_repo_by_star(per_page=1000, page=1, min_star_num=1000, max_star_num=max_star)
for i, repo in enumerate(top_repositories):
owner = repo['owner']['login']
name = repo['name']
unique_id = f"{name} -- {owner}"
stars = repo['stargazers_count']
print(f"Repository {i}: {name}, Stars: {stars}")
# 存储为 json 格式
with open(os.path.join(data_path, f'{unique_id}.json'), 'w') as f:
json.dump(repo, f, indent=4)
if count < 100: exit(1)
return max_star
def main_repo():
max_star = 500000 # 最多 star 的仓库有 500k
num = 1
while True:
print("=" * 50)
print(f"Round {num}, Max star: {max_star}")
max_star = save_repo_by_star(max_star)
num += 1
# 2. 将数据转换为向量
async def create_vector_db(docs, embeddings, batch_size=800):
# 初始化第一批数据
vector_db = await FAISS.afrom_documents(docs[0:batch_size], embeddings)
if len(docs) < batch_size: return vector_db
# 创建任务x``
tasks = []
for start_idx in range(batch_size, len(docs), batch_size):
end_idx = min(start_idx + batch_size, len(docs))
tasks.append(FAISS.afrom_documents(docs[start_idx:end_idx], embeddings))
# 执行任务
results = await asyncio.gather(*tasks)
# 合并结果
for temp_db in results:
vector_db.merge_from(temp_db)
return vector_db
async def main_convert_to_vector():
# 读取文件
files = os.listdir(data_path)
# 构建 document
docs = []
for file in tqdm(files):
if not file.endswith(".json"): continue
with open(os.path.join(data_path, file), "r", encoding="utf-8") as f:
data = json.load(f)
content_map = {
"name": data["name"],
"description": data["description"],
}
content = json.dumps(content_map)
doc = Document(page_content=content, metadata={"html_url": data["html_url"],
"topics": data["topics"],
"created_at": data["created_at"],
"updated_at": data["updated_at"],
"star_count": data["stargazers_count"]})
docs.append(doc)
print(f"Total {len(docs)} documents.")
# 初始化 Embedding 实例
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_BASE_URL"),
model="text-embedding-3-small")
print("Embedding model success: text-embedding-3-small")
# 文档嵌入
if os.path.exists(os.path.join(db_path, "init.faiss")):
vector_db = FAISS.load_local(db_path, embeddings=embeddings,
index_name="init",
allow_dangerous_deserialization=True)
else:
vector_db = await create_vector_db(docs, embeddings=embeddings)
vector_db.save_local(db_path, index_name="init")
return vector_db
if __name__ == "__main__":
# 1. 获取仓库信息
# main_repo()
# 2. 构建向量数据库
asyncio.run(main_convert_to_vector())