CoI_Agent / searcher /sementic_search.py
jianghuyihei's picture
fix
a81bf47
raw
history blame
20.2 kB
import requests
import json
import yaml
import scipdf
import os
import time
import aiohttp
import asyncio
import numpy as np
def get_content_between_a_b(start_tag, end_tag, text):
extracted_text = ""
start_index = text.find(start_tag)
while start_index != -1:
end_index = text.find(end_tag, start_index + len(start_tag))
if end_index != -1:
extracted_text += text[start_index + len(start_tag) : end_index] + " "
start_index = text.find(start_tag, end_index + len(end_tag))
else:
break
return extracted_text.strip()
def extract(text, type):
if text:
target_str = get_content_between_a_b(f"<{type}>", f"</{type}>", text)
if target_str:
return target_str
else:
return text
else:
return ""
async def fetch(url):
await asyncio.sleep(1) # 异步的 sleep 而不是 time.sleep
try:
timeout = aiohttp.ClientTimeout(total=120)
connector = aiohttp.TCPConnector(limit_per_host=10) # 使用连接池
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
async with session.get(url) as response:
if response.status == 200:
content = await response.read() # Read the response content as bytes
return content
else:
print(f"Failed to fetch the URL: {url} with status code: {response.status}")
return None
except aiohttp.ClientError as e: # 更具体的异常捕获
print(f"An error occurred while fetching the URL: {url}")
print(e)
return None
except Exception as e:
print(f"An unexpected error occurred while fetching the URL: {url}")
print(e)
return None
def download(url):
try:
response = requests.get(url)
if response.status_code == 200:
return response.content
else:
print(f"Failed to download the file from the URL: {url}")
return None
except requests.RequestException as e:
print(f"An error occurred while downloading the file from the URL: {url}")
print(e)
return None
except Exception as e:
print(f"An unexpected error occurred while downloading the file from the URL: {url}")
print(e)
return None
class Result:
def __init__(self,title="",abstract="",article = "",citations_conut = 0,year = None) -> None:
self.title = title
self.abstract = abstract
self.article = article
self.citations_conut = citations_conut
self.year = year
# Define the API endpoint URL
semantic_fields = ["title", "abstract", "year", "authors.name", "authors.paperCount", "authors.citationCount","authors.hIndex","url","referenceCount","citationCount","influentialCitationCount","isOpenAccess","openAccessPdf","fieldsOfStudy","s2FieldsOfStudy","embedding.specter_v1","embedding.specter_v2","publicationDate","citations"]
fieldsOfStudy = ["Computer Science","Medicine","Chemistry","Biology","Materials Science","Physics","Geology","Art","History","Geography","Sociology","Business","Political Science","Philosophy","Art","Literature","Music","Economics","Philosophy","Mathematics","Engineering","Environmental Science","Agricultural and Food Sciences","Education","Law","Linguistics"]
# citations.paperId, citations.title, citations.year, citations.authors.name, citations.authors.paperCount, citations.authors.citationCount, citations.authors.hIndex, citations.url, citations.referenceCount, citations.citationCount, citations.influentialCitationCount, citations.isOpenAccess, citations.openAccessPdf, citations.fieldsOfStudy, citations.s2FieldsOfStudy, citations.publicationDate
# publicationDateOrYear: 2019-03-05 ; 2019-03 ; 2019 ; 2016-03-05:2020-06-06 ; 1981-08-25: ; :2020-06-06 ; 1981:2020
# publicationTypes: Review ; JournalArticle CaseReport ; ClinicalTrial ; Dataset ; Editorial ; LettersAndComments ; MetaAnalysis ; News ; Study ; Book ; BookSection
def process_fields(fields):
return ",".join(fields)
class SementicSearcher:
def __init__(self, ban_paper = []) -> None:
self.ban_paper = ban_paper
async def search_papers_async(self, query, limit=5, offset=0, fields=["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citations.title","citations.abstract","citations.isOpenAccess","citations.openAccessPdf","citations.citationCount","citationCount","citations.year"],
publicationDate=None, minCitationCount=0, year=None,
publicationTypes=None, fieldsOfStudy=None):
url = 'https://api.semanticscholar.org/graph/v1/paper/search'
fields = process_fields(fields) if isinstance(fields, list) else fields
# More specific query parameter
query_params = {
'query': query,
"limit": limit,
"offset": offset,
'fields': fields,
'publicationDateOrYear': publicationDate,
'minCitationCount': minCitationCount,
'year': year,
'publicationTypes': publicationTypes,
'fieldsOfStudy': fieldsOfStudy
}
# Load the API key from the configuration file
api_key = os.environ.get('SEMENTIC_SEARCH_API_KEY',None)
headers = {'x-api-key': api_key} if api_key else None
await asyncio.sleep(0.5)
try:
filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
response = requests.get(url, params=filtered_query_params, headers=headers)
if response.status_code == 200:
response_data = response.json()
return response_data
elif response.status_code == 429:
time.sleep(1)
print(f"Request failed with status code {response.status_code}: begin to retry")
return await self.search_papers_async(query, limit, offset, fields, publicationDate, minCitationCount, year, publicationTypes, fieldsOfStudy)
else:
print(f"Request failed with status code {response.status_code}: {response.text}")
return None
except requests.RequestException as e:
print(f"An error occurred: {e}")
return None
def cal_cosine_similarity(self, vec1, vec2):
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
def read_arxiv_from_path(self, pdf_path):
def is_pdf(binary_data):
pdf_header = b'%PDF-'
return binary_data.startswith(pdf_header)
try:
flag = is_pdf(pdf_path)
if not flag:
return None
except Exception as e:
pass
try:
article_dict = scipdf.parse_pdf_to_dict(pdf_path)
except Exception as e:
print(f"Failed to parse the PDF file: {pdf_path}")
return None
return article_dict
async def get_paper_embbeding_and_score_async(self,query_embedding, paper,llm):
paper_content = f"""
Title: {paper['title']}
Abstract: {paper['abstract']}
"""
paper_embbeding = await llm.get_embbeding_async(paper_content)
paper_embbeding = np.array(paper_embbeding)
score = self.cal_cosine_similarity(query_embedding,paper_embbeding)
return [paper,score]
async def rerank_papers_async(self, query_embedding, paper_list,llm):
if len(paper_list) >= 50:
paper_list = paper_list[:50]
results = await asyncio.gather(*[self.get_paper_embbeding_and_score_async(query_embedding, paper,llm) for paper in paper_list if paper])
reranked_papers = sorted(results,key = lambda x: x[1],reverse = True)
return reranked_papers
async def get_embbeding_and_score_async(self,query_embedding, text,llm):
text_embbeding = await llm.get_embbeding_async(text)
text_embbeding = np.array(text_embbeding)
score = self.cal_cosine_similarity(query_embedding,text_embbeding)
return score
async def get_embbeding_and_score_from_texts_async(self,query_embedding, texts,llm):
results = await asyncio.gather(*[self.get_embbeding_and_score_async(query_embedding, text,llm) for text in texts])
return results
async def get_paper_details_async(self, paper_id, fields = ["title", "abstract", "year","citationCount","isOpenAccess","openAccessPdf"]):
url = f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}'
fields = process_fields(fields)
paper_data_query_params = {'fields': fields}
try:
async with aiohttp.ClientSession() as session:
filtered_query_params = {key: value for key, value in paper_data_query_params.items() if value is not None}
headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
async with session.get(url, params=filtered_query_params, headers=headers) as response:
if response.status == 200:
response_data = await response.json()
return response_data
else:
await asyncio.sleep(0.01)
print(f"Request failed with status code {response.status}: {await response.text()}")
return None
except Exception as e:
print(f"Failed to get paper details for paper ID: {paper_id}")
return None
async def batch_retrieve_papers_async(self, paper_ids, fields = semantic_fields):
url = 'https://api.semanticscholar.org/graph/v1/paper/batch'
paper_data_query_params = {'fields': process_fields(fields)}
paper_ids_json = {"ids": paper_ids}
try:
async with aiohttp.ClientSession() as session:
filtered_query_params = {key: value for key, value in paper_data_query_params.items() if value is not None}
headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
async with session.post(url, json=paper_ids_json, params=filtered_query_params, headers=headers) as response:
if response.status == 200:
response_data = await response.json()
return response_data
else:
await asyncio.sleep(0.01)
print(f"Request failed with status code {response.status}: {await response.text()}")
return None
except Exception as e:
print(f"Failed to batch retrieve papers for paper IDs: {paper_ids}")
return None
async def search_paper_from_title_async(self, query,fields = ["title","paperId"]):
url = 'https://api.semanticscholar.org/graph/v1/paper/search/match'
fields = process_fields(fields)
query_params = {'query': query, 'fields': fields}
try:
async with aiohttp.ClientSession() as session:
filtered_query_params = {key: value for key, value in query_params.items() if value is not None}
headers = {'x-api-key': os.environ.get('SEMENTIC_SEARCH_API_KEY',None)}
async with session.get(url, params=filtered_query_params, headers=headers) as response:
if response.status == 200:
response_data = await response.json()
return response_data
else:
await asyncio.sleep(0.01)
print(f"Request failed with status code {response.status}: {await response.text()}")
return None
except Exception as e:
await asyncio.sleep(0.01)
print(f"Failed to search paper from title: {query}")
return None
async def search_async(self,query,max_results = 5 ,paper_list = None ,rerank_query = None,llm = None,year = None,publicationDate = None,need_download = True,fields = ["title", "paperId", "abstract", "isOpenAccess", 'openAccessPdf', "year","publicationDate","citationCount"]):
if rerank_query:
rerank_query_embbeding = llm.get_embbeding(rerank_query)
rerank_query_embbeding = np.array(rerank_query_embbeding)
readed_papers = []
if paper_list:
if isinstance(paper_list,set):
paper_list = list(paper_list)
if len(paper_list) == 0 :
pass
elif isinstance(paper_list[0], str):
readed_papers = paper_list
elif isinstance(paper_list[0], Result):
readed_papers = [paper.title for paper in paper_list]
print(f"Searching for papers related to the query: <{query}>")
results = await self.search_papers_async(query,limit = 10 * max_results,year=year,publicationDate = publicationDate,fields = fields)
if not results or "data" not in results:
return []
new_results = []
for result in results['data']:
if result['title'] in self.ban_paper:
continue
new_results.append(result)
results = new_results
final_results = []
if need_download:
paper_candidates = []
for result in results:
if not result['isOpenAccess'] or not result['openAccessPdf'] or result['title'] in readed_papers:
continue
else:
paper_candidates.append(result)
else:
paper_candidates = results
if llm and rerank_query:
paper_candidates = await self.rerank_papers_async(rerank_query_embbeding, paper_candidates,llm)
paper_candidates = [paper[0] for paper in paper_candidates if paper]
if need_download:
for result in paper_candidates:
pdf_link = result['openAccessPdf']["url"]
try:
content = self.download_pdf(pdf_link)
if not content:
continue
except Exception as e:
continue
title = result['title']
abstract = result['abstract']
citationCount = result['citationCount']
year = result['year']
article = self.read_arxiv_from_path(content)
if not article:
continue
final_results.append(Result(title,abstract,article,citationCount,year))
if len(final_results) >= max_results:
break
else:
for result in paper_candidates:
title = result['title']
abstract = result['abstract']
citationCount = result['citationCount']
year = result['year']
final_results.append(Result(title,abstract,None,citationCount,year))
if len(final_results) >= max_results:
break
return final_results
async def search_related_paper_async(self,title,need_citation = True,need_reference = True,rerank_query = None,llm = None,paper_list = []):
print(f"Searching for the related papers of <{title}>")
fileds = ["title","abstract","citations.title","citations.abstract","citations.citationCount","references.title","references.abstract","references.citationCount","citations.isOpenAccess","citations.openAccessPdf","references.isOpenAccess","references.openAccessPdf","citations.year","references.year"]
results = await self.search_papers_async(title,limit = 3,fields=fileds)
related_papers = []
related_papers_title = []
if not results or "data" not in results:
return None
for result in results["data"]:
if not result:
continue
if need_citation:
for citation in result["citations"]:
if "openAccessPdf" not in citation or not citation["openAccessPdf"]:
continue
elif citation["title"] in related_papers_title or citation["title"] in self.ban_paper or citation["title"] in paper_list:
continue
elif citation["isOpenAccess"] == False or citation["openAccessPdf"] == None:
continue
else:
related_papers.append(citation)
related_papers_title.append(citation["title"])
if need_reference:
for reference in result["references"]:
if "openAccessPdf" not in reference or not reference["openAccessPdf"]:
continue
elif reference["title"] in related_papers_title or reference["title"] in self.ban_paper or reference["title"] in paper_list:
continue
elif reference["isOpenAccess"] == False or reference["openAccessPdf"] == None:
continue
else:
related_papers.append(reference)
related_papers_title.append(reference["title"])
if result:
break
if len(related_papers) >= 200:
related_papers = related_papers[:200]
if rerank_query and llm:
rerank_query_embbeding = llm.get_embbeding(rerank_query)
rerank_query_embbeding = np.array(rerank_query_embbeding)
related_papers = await self.rerank_papers_async(rerank_query_embbeding, related_papers,llm)
related_papers = [paper[0] for paper in related_papers]
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
else:
related_papers = [[paper["title"],paper["abstract"],paper["openAccessPdf"]["url"],paper["citationCount"],paper['year']] for paper in related_papers]
related_papers = sorted(related_papers,key = lambda x: x[3],reverse = True)
print(f"Found {len(related_papers)} related papers")
for paper in related_papers:
url = paper[2]
content = self.download_pdf(url)
if content:
article = self.read_arxiv_from_path(content)
if not article:
continue
result = Result(paper[0],paper[1],article,paper[3],paper[4])
return result
return None
async def download_pdf_async(self, pdf_link):
content = await fetch(pdf_link)
if not content:
return None
else:
return content
def download_pdf(self, pdf_link):
content = download(pdf_link)
return content
def read_paper_title_abstract(self,article):
title = article["title"]
abstract = article["abstract"]
paper_content = f"""
Title: {title}
Abstract: {abstract}
"""
return paper_content
def read_paper_content(self,article):
paper_content = self.read_paper_title_abstract(article)
for section in article["sections"]:
paper_content += f"section: {section['heading']}\n content: {section['text']}\n ref_ids: {section['publication_ref']}\n"
return paper_content
def read_paper_content_with_ref(self,article):
paper_content = self.read_paper_content(article)
paper_content += "<References>\n"
i = 1
for refer in article["references"]:
ref_id = refer["ref_id"]
title = refer["title"]
year = refer["year"]
paper_content += f"Ref_id:{ref_id} Title: {title} Year: ({year})\n"
i += 1
paper_content += "</References>\n"
return paper_content