Spaces:
Build error
Build error
import numpy as np | |
import os | |
import re | |
import datetime | |
import time | |
import openai, tenacity | |
import argparse | |
import configparser | |
import json | |
import tiktoken | |
from get_paper_from_pdf import Paper | |
class Assistant: | |
def __init__(self, args=None): | |
if args.language == 'en': | |
self.language = 'English' | |
elif args.language == 'zh': | |
self.language = 'Chinese' | |
else: | |
self.language = 'Chinese' | |
self.config = configparser.ConfigParser() | |
self.config.read('apikey.ini') | |
self.chat_api_list = self.config.get('OpenAI', 'OPENAI_API_KEYS')[1:-1].replace('\'', '').split(',') | |
self.chat_api_list = [api.strip() for api in self.chat_api_list if len(api) > 5] | |
self.cur_api = 0 | |
self.file_format = args.file_format | |
self.max_token_num = 4096 | |
self.encoding = tiktoken.get_encoding("gpt2") | |
self.result_backup = '' | |
def validateTitle(self, title): | |
rstr = r"[\/\\\:\*\?\"\<\>\|]" | |
new_title = re.sub(rstr, "_", title) | |
return new_title | |
def assist_reading_by_chatgpt(self, paper_list): | |
htmls = [] | |
for paper_index, paper in enumerate(paper_list): | |
sections_of_interest = self.extract_paper(paper) | |
# extract the essential parts of the paper | |
text = '' | |
text += 'Title:' + paper.title + '. ' | |
text += 'Abstract: ' + paper.section_texts['Abstract'] | |
intro_title = next((item for item in paper.section_names if 'ntroduction' in item.lower()), None) | |
if intro_title is not None: | |
text += 'Introduction: ' + paper.section_texts[intro_title] | |
# Similar for conclusion section | |
conclusion_title = next((item for item in paper.section_names if 'onclusion' in item), None) | |
if conclusion_title is not None: | |
text += 'Conclusion: ' + paper.section_texts[conclusion_title] | |
for heading in sections_of_interest: | |
if heading in paper.section_names: | |
text += heading + ': ' + paper.section_texts[heading] | |
chat_review_text = self.chat_assist(text=text) | |
htmls.append('## Paper:' + str(paper_index+1)) | |
htmls.append('\n\n\n') | |
htmls.append(chat_review_text) | |
# 将问题与回答保存起来 | |
date_str = str(datetime.datetime.now())[:19].replace(' ', '-').replace(':', '-') | |
try: | |
export_path = os.path.join('./', 'output_file') | |
os.makedirs(export_path) | |
except: | |
pass | |
mode = 'w' if paper_index == 0 else 'a' | |
file_name = os.path.join(export_path, date_str+'-'+self.validateTitle(paper.title)+"."+self.file_format) | |
self.export_to_markdown("\n".join(htmls), file_name=file_name, mode=mode) | |
htmls = [] | |
def extract_paper(self, paper): | |
htmls = [] | |
text = '' | |
text += 'Title: ' + paper.title + '. ' | |
text += 'Abstract: ' + paper.section_texts['Abstract'] | |
text_token = len(self.encoding.encode(text)) | |
if text_token > self.max_token_num/2 - 800: | |
input_text_index = int(len(text)*((self.max_token_num/2)-800)/text_token) | |
text = text[:input_text_index] | |
openai.api_key = self.chat_api_list[self.cur_api] | |
self.cur_api += 1 | |
self.cur_api = 0 if self.cur_api >= len(self.chat_api_list)-1 else self.cur_api | |
print("\n\n"+"********"*10) | |
print("Extracting content from PDF.") | |
print("********"*10) | |
messages = [ | |
{"role": "system", | |
"content": f"You are a professional researcher in the field of {args.research_fields}. You are the mentor of a student who is new to this field. " | |
f"I will give you a paper. You need to help your student to read this paper by instructing him to read the important sections in this paper and answer his questions towards these sections." | |
f"Due to the length limitations, I am only allowed to provide you the abstract, introduction, conclusion and at most two sections of this paper." | |
f"Now I will give you the title and abstract and the headings of potential sections. " | |
f"You need to reply at most two headings. Then I will further provide you the full information, includes aforementioned sections and at most two sections you called for.\n\n" | |
f"Title: {paper.title}\n\n" | |
f"Abstract: {paper.section_texts['Abstract']}\n\n" | |
f"Potential Sections: {paper.section_names[2:-1]}\n\n" | |
f"Follow the following format to output your choice of sections:" | |
f"{{chosen section 1}}, {{chosen section 2}}\n\n"}, | |
{"role": "user", "content": text}, | |
] | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
) | |
result = '' | |
for choice in response.choices: | |
result += choice.message.content | |
print("\n\n"+"********"*10) | |
print("Important sections of this paper:") | |
print(result) | |
print("********"*10) | |
print("prompt_token_used:", response.usage.prompt_tokens) | |
print("completion_token_used:", response.usage.completion_tokens) | |
print("total_token_used:", response.usage.total_tokens) | |
print("response_time:", response.response_ms/1000.0, 's') | |
return result.split(',') | |
def chat_assist(self, text): | |
openai.api_key = self.chat_api_list[self.cur_api] | |
self.cur_api += 1 | |
self.cur_api = 0 if self.cur_api >= len(self.chat_api_list)-1 else self.cur_api | |
review_prompt_token = 1000 | |
text_token = len(self.encoding.encode(text)) | |
input_text_index = int(len(text)*(self.max_token_num-review_prompt_token)/text_token) | |
input_text = "This is the paper for your review:" + text[:input_text_index] + "\n\n" | |
input_text_backup = input_text | |
while True: | |
print("\n\n"+"********"*10) | |
print("Ask ChatGPT questions of the important sections. Type \"quit\" to exit the program. To receive better responses, please describe why you ask the question.\nFor example, ask \"Why does the author use residual connections? I want to know how does the residual connections work in the model structure.\" instead of \"Why does the author use residual connections?\"") | |
print("********"*10) | |
student_question = input() | |
if student_question == "quit": | |
break | |
input_text = input_text_backup | |
input_text = input_text + "The question from your student is: " + student_question | |
messages=[ | |
{"role": "system", "content": "You are a professional researcher in the field of "+args.research_fields+". You are the mentor of a student who is new to this field. Now I will give you a paper. You need to help your student to read this paper by instructing him to read the important sections in this paper and answer his questions towards these sections. Please answer in {}.".format(self.language)}, | |
{"role": "user", "content": input_text}, | |
] | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
) | |
result = '' | |
for choice in response.choices: | |
result += choice.message.content | |
self.result_backup = self.result_backup + "\n\n" + student_question + "\n" | |
self.result_backup += result | |
print("\n\n"+"********"*10) | |
print(result) | |
print("********"*10) | |
print("prompt_token_used:", response.usage.prompt_tokens) | |
print("completion_token_used:", response.usage.completion_tokens) | |
print("total_token_used:", response.usage.total_tokens) | |
print("response_time:", response.response_ms/1000.0, 's') | |
return self.result_backup | |
def export_to_markdown(self, text, file_name, mode='w'): | |
# 使用markdown模块的convert方法,将文本转换为html格式 | |
# html = markdown.markdown(text) | |
# 打开一个文件,以写入模式 | |
with open(file_name, mode, encoding="utf-8") as f: | |
# 将html格式的内容写入文件 | |
f.write(text) | |
def main(args): | |
# Paper reading assistant instructions | |
print("********"*10) | |
print("Extracting content from PDF.") | |
print("********"*10) | |
assistant1 = Assistant(args=args) | |
# 开始判断是路径还是文件: | |
paper_list = [] | |
if args.paper_path.endswith(".pdf"): | |
paper_list.append(Paper(path=args.paper_path)) | |
else: | |
for root, dirs, files in os.walk(args.paper_path): | |
print("root:", root, "dirs:", dirs, 'files:', files) #当前目录路径 | |
for filename in files: | |
# 如果找到PDF文件,则将其复制到目标文件夹中 | |
if filename.endswith(".pdf"): | |
paper_list.append(Paper(path=os.path.join(root, filename))) | |
print("------------------paper_num: {}------------------".format(len(paper_list))) | |
[print(paper_index, paper_name.path.split('\\')[-1]) for paper_index, paper_name in enumerate(paper_list)] | |
assistant1.assist_reading_by_chatgpt(paper_list=paper_list) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--paper_path", type=str, default='', help="path of papers") | |
parser.add_argument("--file_format", type=str, default='txt', help="output file format") | |
parser.add_argument("--research_fields", type=str, default='computer science, artificial intelligence and transfer learning', help="the research fields of paper") | |
parser.add_argument("--language", type=str, default='en', help="output lauguage, en or zh") | |
args = parser.parse_args() | |
start_time = time.time() | |
main(args=args) | |
print("total time:", time.time() - start_time) | |