Richard Fan commited on
Commit
ef10e9f
·
1 Parent(s): f7d455a

initial commit

Browse files
.github/workflows/.daily_pipeline.yaml.swp ADDED
Binary file (12.3 kB). View file
 
.github/workflows/daily_pipeline.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will install Python dependencies, run tests and lint with a single version of Python
2
+ # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3
+
4
+ name: Daily pipeline
5
+
6
+ on:
7
+ workflow_dispatch: {}
8
+ schedule:
9
+ # * is a special character in YAML so you have to quote this string
10
+ # Feel free to change this cron schedule
11
+ # Currently its scheduled for 1:25 pm UTC, Sun-Thurs
12
+ - cron: '25 13 * * 0-4'
13
+
14
+ jobs:
15
+ build_and_test:
16
+ runs-on: ubuntu-latest
17
+ steps:
18
+ - uses: actions/checkout@v2
19
+ - name: Set up Python 3.8
20
+ uses: actions/setup-python@v2
21
+ with:
22
+ python-version: 3.8
23
+ - name: Install dependencies
24
+ run: |
25
+ python -m pip install --upgrade pip
26
+ pip install -r src/requirements.txt
27
+ - name: Generate Digest
28
+ run: |
29
+ python src/action.py
30
+ env:
31
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
32
+ SENDGRID_API_KEY: ${{ secrets.SENDGRID_API_KEY }}
33
+ FROM_EMAIL: ${{ secrets.FROM_EMAIL }}
34
+ TO_EMAIL: ${{ secrets.TO_EMAIL }}
35
+ - name: Upload Artifact
36
+ uses: actions/upload-artifact@v3
37
+ with:
38
+ name: digest.html
39
+ path: digest.html
40
+ - name: check
41
+ id: check
42
+ env:
43
+ SENDGRID_API_KEY: ${{ secrets.SENDGRID_API_KEY }}
44
+ MAIL_USERNAME: ${{ secrets.MAIL_USERNAME }}
45
+ MAIL_PASSWORD: ${{ secrets.MAIL_PASSWORD }}
46
+ MAIL_CONNECTION: ${{ secrets.MAIL_CONNECTION }}
47
+ if: "${{ env.SENDGRID_API_KEY != '' && (env.MAIL_CONNECTION || env.MAIL_USERNAME != '' && env.MAIL_PASSWORD != '') }}"
48
+ run: echo "DEFINED=true" >> $GITHUB_OUTPUT
49
+ - name: Test step
50
+ env:
51
+ DEFINED: ${{ steps.check.outputs.DEFINED }}
52
+ run: echo "$DEFINED"
53
+ - name: Send mail
54
+ uses: dawidd6/action-send-mail@v3
55
+ env:
56
+ DEFINED: ${{ steps.check.outputs.DEFINED }}
57
+ if: ${{ env.DEFINED == 'true' }}
58
+ with:
59
+ # Specify connection via URL (replaces server_address, server_port, secure,
60
+ # username and password)
61
+ #
62
+ # Format:
63
+ #
64
+ # * smtp://user:password@server:port
65
+ # * smtp+starttls://user:password@server:port
66
+ connection_url: ${{secrets.MAIL_CONNECTION}}
67
+ # Required mail server address if not connection_url:
68
+ server_address: smtp.gmail.com
69
+ # Server port, default 25:
70
+ server_port: 465
71
+ username: ${{secrets.MAIL_USERNAME}}
72
+ password: ${{secrets.MAIL_PASSWORD}}
73
+ secure: true
74
+ subject: Personalized arXiv Digest
75
+ to: ${{ secrets.TO_EMAIL }}
76
+ from: "Personalized arxiv digest"
77
+ html_body: file://digest.html
78
+ ignore_cert: true
79
+ convert_markdown: true
80
+ priority: normal
README.md CHANGED
@@ -1,2 +1,105 @@
1
- # Arxiv-Digest
2
- Personalized Arxiv Digest using Large Language Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Personalized-Arxiv-digest
2
+ This repo aims to provide a better daily digest for newly published arxiv papers based on your own research interests and descriptions.
3
+
4
+ ## What this repo does
5
+
6
+ Staying up to date on [arxiv](https://arxiv.org) papers can take a considerable amount of time, with on the order of hundreds of new papers each day to filter through. There is an [official daily digest service](https://info.arxiv.org/help/subscribe.html), however large subtopics like [cs.AI](https://arxiv.org/list/cs.AI/recent) still have 50-100 papers a day. Determining if these papers are relevant and important to you means reading through the title and abstract.
7
+
8
+ This repository provides a way to have this daily digest sorted by relevance via large language models:
9
+
10
+ * You modify the configuration file `config.yaml` with an arxiv topic, some set of subtopics, and a natural language statement about the type of papers you are interested in
11
+ * The code pulls all the abstracts for papers in those subtopics and ranks how relevant they are to your interest on a scale of 1-10 using gpt-3.5-turbo.
12
+ * The code then emits an HTML digest listing all the relevant papers, and optionally emails it to you using [SendGrid](https://sendgrid.com). You will need to have a SendGrid account with an API key for this functionality to work
13
+
14
+
15
+ ### Some examples:
16
+
17
+ - Topic: cs.AI, cs.CL
18
+ - Interest:
19
+ - Large language model pretraining and finetunings
20
+ - Multimodal machine learning
21
+ - Do not care about specific application, for example, information extraction, summarization, etc.
22
+ - Not interested in paper focus on specific languages, e.g., Arabic, Chinese, etc.
23
+
24
+ ![example1](./readme_images/example_1.png)
25
+
26
+
27
+ - Topic: q-fin
28
+ - Interest: "making lots of money"
29
+
30
+ ![example2](./readme_images/example_2.png)
31
+
32
+ ## Usage
33
+
34
+ ### Running as a github action using SendGrid.
35
+
36
+ The recommended way to get started using this repository is to:
37
+
38
+ 1. Fork the repository
39
+ 2. Modify `config.yaml` and merge the changes into your main branch. If you want a different schedule than Sunday through Thursday at 1:25PM UTC, then also modify the file `.github/workflows/daily_pipeline.yaml`
40
+ 3. Create or fetch your api key for [OpenAI](https://platform.openai.com/account/api-keys). Note: you will need an OpenAI account.
41
+ 4. Create or fetch your api key for [SendGrid](https://app.SendGrid.com/settings/api_keys). You will need a SendGrid account. The free tier will generally suffice.
42
+ 5. Set the following secrets:
43
+ - `OPENAI_API_KEY`
44
+ - `SENDGRID_API_KEY`
45
+ - `FROM_EMAIL` (only if you don't have it set in `config.yaml`)
46
+ - `TO_EMAIL` (only if you don't have it set in `config.yaml`)
47
+ 6. Manually trigger the action or wait until the scheduled action takes place.
48
+
49
+ ![artifact](./readme_images/trigger.png)
50
+
51
+
52
+ ### Running as a github action with SMTP credentials.
53
+
54
+ An alternative way to get started using this repository is to:
55
+
56
+ 1. Fork the repository
57
+ 2. Modify `config.yaml` and merge the changes into your main branch. If you want a different schedule than Sunday through Thursday at 1:25PM UTC, then also modify the file `.github/workflows/daily_pipeline.yaml`
58
+ 3. Create or fetch your api key for [OpenAI](https://platform.openai.com/account/api-keys). Note: you will need an OpenAI account.
59
+ 4. Find your email provider's SMTP settings and set the secret `MAIL_CONNECTION` to that. It should be in the form `smtp://user:password@server:port` or `smtp+starttls://user:password@server:port`. Alternatively, if you are using Gmail, you can set `MAIL_USERNAME` and `MAIL_PASSWORD` instead. If you are (understandably) apprehensive about using your email authentication here, you can create something like an [application password](https://support.google.com/accounts/answer/185833) instead
60
+ 5. Set the following secrets:
61
+ - `OPENAI_API_KEY`
62
+ - `MAIL_CONNECTION` (see above)
63
+ - `MAIL_PASSWORD` (only if you don't have `MAIL_CONNECTION` set)
64
+ - `MAIL_USERNAME` (only if you don't have `MAIL_CONNECTION` set)
65
+ - `FROM_EMAIL` (only if you don't have it set in `config.yaml`)
66
+ - `TO_EMAIL` (only if you don't have it set in `config.yaml`)
67
+ 6. Manually trigger the action or wait until the scheduled action takes place.
68
+
69
+ #### Running as a github action without emails
70
+
71
+ If you do not wish to create a SendGrid account or use your email authentication, the action will also emit an artifact containing the HTML output. Simply do not create the SendGrid or SMTP secrets.
72
+
73
+ You can access this digest as part of the github action artifact.
74
+
75
+ ![artifact](./readme_images/artifact.png)
76
+
77
+ ### Running from the command line
78
+
79
+ If you do not wish to fork this repository, and would prefer to clone and run it locally instead:
80
+
81
+ 1. Install the requirements in `src/requirements.txt`
82
+ 2. Modify the configuration file `config.yaml`
83
+ 3. Create or fetch your api key for [OpenAI](https://platform.openai.com/account/api-keys). Note: you will need an OpenAI account.
84
+ 4. Create or fetch your api key for [SendGrid](https://app.SendGrid.com/settings/api_keys) (optional, if you want the script to email you)
85
+ 5. Set the following secrets:
86
+ - `OPENAI_API_KEY`
87
+ - `SENDGRID_API_KEY` (only if using SendGrid)
88
+ - `FROM_EMAIL` (only if using SendGrid and if you don't have them set in `config.yaml`)
89
+ - `TO_EMAIL` (only if using SendGrid and if you don't have them set in `config.yaml`)
90
+ 6. Run `python action.py`.
91
+ 7. If you are not using SendGrid, the html of the digest will be written to `digest.html`. You can then use your favorite webbrowser to view it.
92
+
93
+ You may want to use something like crontab to schedule the digest.
94
+
95
+ ### Running with a user interface
96
+
97
+ Install the requirements in `src/requirements.txt` as well as `gradio`. Set the evironment variables `OPENAI_API_KEY`, `FROM_EMAIL` and `SENDGRID_API_KEY`
98
+
99
+ Run `python src/app.py` and go to the local URL. From there you will be able to preview the papers from today, as well as the generated digests.
100
+
101
+ ## Extending and Contributing
102
+
103
+ You may (and are encourage to) modify the code in this repository to suit your personal needs. If you think your modifications would be in any way useful to others, please submit a pull request.
104
+
105
+ These types of modifications include things like changes to the prompt, different language models, or additional ways for the digest is delivered to you.
config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # For physics topics, use the specific subtopics, e.g. "Astrophysics"
2
+ topic: "Computer Science"
3
+ # An empty list here will include all categories in a topic
4
+ # Including more categories will result in more calls to the large language model
5
+ categories: ["Artificial Intelligence", "Computation and Language"]
6
+
7
+ # The email address that the digest will be sent from. must be the address matching
8
+ # your sendgrid api key.
9
+ # Leaving this empty will cause the script to use the
10
+ # FROM_EMAIL environment variable instead
11
+ from_email: ""
12
+
13
+ # The email address you are going to send the digest to
14
+ # Leaving this empty will cause the script to use the
15
+ # TO_EMAIL environment variable instead
16
+ to_email: ""
17
+
18
+ # Relevance score threshold. abstracts that receive a score less than this from the large language model
19
+ # will have their papers filtered out.
20
+ #
21
+ # Must be within 1-10
22
+ threshold: 7
23
+
24
+ # A natural language statement that the large language model will use to judge which papers are relevant
25
+ #
26
+ # For example:
27
+ # "I am interested in complexity theory papers that establish upper bounds"
28
+ # "gas chromatography, mass spectrometry"
29
+ # "making lots of money"
30
+ #
31
+ # This can be empty, which just return a full list of papers with no judgement or filtering,
32
+ # in whatever order arXiv responds with.
33
+ interest: |
34
+ 1. Large language model pretraining and finetunings
35
+ 2. Multimodal machine learning
36
+ 3. Do not care about specific application, for example, information extraction, summarization, etc.
37
+ 4. Not interested in paper focus on specific languages, e.g., Arabic, Chinese, etc.
readme_images/artifact.png ADDED
readme_images/example_1.png ADDED
readme_images/example_2.png ADDED
readme_images/trigger.png ADDED
src/action.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sendgrid import SendGridAPIClient
2
+ from sendgrid.helpers.mail import Mail, Email, To, Content
3
+
4
+ from datetime import date
5
+
6
+ import argparse
7
+ import yaml
8
+ import os
9
+
10
+ from relevancy import generate_relevance_score, process_subject_fields
11
+ from download_new_papers import get_papers
12
+
13
+
14
+
15
+ # Hackathon quality code. Don't judge too harshly.
16
+ # Feel free to submit pull requests to improve the code.
17
+
18
+ topics = {
19
+ "Physics": "",
20
+ "Mathematics": "math",
21
+ "Computer Science": "cs",
22
+ "Quantitative Biology": "q-bio",
23
+ "Quantitative Finance": "q-fin",
24
+ "Statistics": "stat",
25
+ "Electrical Engineering and Systems Science": "eess",
26
+ "Economics": "econ"
27
+ }
28
+
29
+ physics_topics = {
30
+ "Astrophysics": "astro-ph",
31
+ "Condensed Matter": "cond-mat",
32
+ "General Relativity and Quantum Cosmology": "gr-qc",
33
+ "High Energy Physics - Experiment": "hep-ex",
34
+ "High Energy Physics - Lattice": "hep-lat",
35
+ "High Energy Physics - Phenomenology": "hep-ph",
36
+ "High Energy Physics - Theory": "hep-th",
37
+ "Mathematical Physics": "math-ph",
38
+ "Nonlinear Sciences": "nlin",
39
+ "Nuclear Experiment": "nucl-ex",
40
+ "Nuclear Theory": "nucl-th",
41
+ "Physics": "physics",
42
+ "Quantum Physics": "quant-ph"
43
+ }
44
+
45
+
46
+ # TODO: surely theres a better way
47
+ category_map = {
48
+ "Astrophysics": ["Astrophysics of Galaxies", "Cosmology and Nongalactic Astrophysics", "Earth and Planetary Astrophysics", "High Energy Astrophysical Phenomena", "Instrumentation and Methods for Astrophysics", "Solar and Stellar Astrophysics"],
49
+ "Condensed Matter": ["Disordered Systems and Neural Networks", "Materials Science", "Mesoscale and Nanoscale Physics", "Other Condensed Matter", "Quantum Gases", "Soft Condensed Matter", "Statistical Mechanics", "Strongly Correlated Electrons", "Superconductivity"],
50
+ "General Relativity and Quantum Cosmology": ["None"],
51
+ "High Energy Physics - Experiment": ["None"],
52
+ "High Energy Physics - Lattice": ["None"],
53
+ "High Energy Physics - Phenomenology": ["None"],
54
+ "High Energy Physics - Theory": ["None"],
55
+ "Mathematical Physics": ["None"],
56
+ "Nonlinear Sciences": ["Adaptation and Self-Organizing Systems", "Cellular Automata and Lattice Gases", "Chaotic Dynamics", "Exactly Solvable and Integrable Systems", "Pattern Formation and Solitons"],
57
+ "Nuclear Experiment": ["None"],
58
+ "Nuclear Theory": ["None"],
59
+ "Physics": ["Accelerator Physics", "Applied Physics", "Atmospheric and Oceanic Physics", "Atomic and Molecular Clusters", "Atomic Physics", "Biological Physics", "Chemical Physics", "Classical Physics", "Computational Physics", "Data Analysis, Statistics and Probability", "Fluid Dynamics", "General Physics", "Geophysics", "History and Philosophy of Physics", "Instrumentation and Detectors", "Medical Physics", "Optics", "Physics and Society", "Physics Education", "Plasma Physics", "Popular Physics", "Space Physics"],
60
+ "Quantum Physics": ["None"],
61
+ "Mathematics": ["Algebraic Geometry", "Algebraic Topology", "Analysis of PDEs", "Category Theory", "Classical Analysis and ODEs", "Combinatorics", "Commutative Algebra", "Complex Variables", "Differential Geometry", "Dynamical Systems", "Functional Analysis", "General Mathematics", "General Topology", "Geometric Topology", "Group Theory", "History and Overview", "Information Theory", "K-Theory and Homology", "Logic", "Mathematical Physics", "Metric Geometry", "Number Theory", "Numerical Analysis", "Operator Algebras", "Optimization and Control", "Probability", "Quantum Algebra", "Representation Theory", "Rings and Algebras", "Spectral Theory", "Statistics Theory", "Symplectic Geometry"],
62
+ "Computer Science": ["Artificial Intelligence", "Computation and Language", "Computational Complexity", "Computational Engineering, Finance, and Science", "Computational Geometry", "Computer Science and Game Theory", "Computer Vision and Pattern Recognition", "Computers and Society", "Cryptography and Security", "Data Structures and Algorithms", "Databases", "Digital Libraries", "Discrete Mathematics", "Distributed, Parallel, and Cluster Computing", "Emerging Technologies", "Formal Languages and Automata Theory", "General Literature", "Graphics", "Hardware Architecture", "Human-Computer Interaction", "Information Retrieval", "Information Theory", "Logic in Computer Science", "Machine Learning", "Mathematical Software", "Multiagent Systems", "Multimedia", "Networking and Internet Architecture", "Neural and Evolutionary Computing", "Numerical Analysis", "Operating Systems", "Other Computer Science", "Performance", "Programming Languages", "Robotics", "Social and Information Networks", "Software Engineering", "Sound", "Symbolic Computation", "Systems and Control"],
63
+ "Quantitative Biology": ["Biomolecules", "Cell Behavior", "Genomics", "Molecular Networks", "Neurons and Cognition", "Other Quantitative Biology", "Populations and Evolution", "Quantitative Methods", "Subcellular Processes", "Tissues and Organs"],
64
+ "Quantitative Finance": ["Computational Finance", "Economics", "General Finance", "Mathematical Finance", "Portfolio Management", "Pricing of Securities", "Risk Management", "Statistical Finance", "Trading and Market Microstructure"],
65
+ "Statistics": ["Applications", "Computation", "Machine Learning", "Methodology", "Other Statistics", "Statistics Theory"],
66
+ "Electrical Engineering and Systems Science": ["Audio and Speech Processing", "Image and Video Processing", "Signal Processing", "Systems and Control"],
67
+ "Economics": ["Econometrics", "General Economics", "Theoretical Economics"]
68
+ }
69
+
70
+
71
+ def generate_body(topic, categories, interest, threshold):
72
+ if topic == "Physics":
73
+ raise RuntimeError("You must choose a physics subtopic.")
74
+ elif topic in physics_topics:
75
+ abbr = physics_topics[topic]
76
+ elif topic in topics:
77
+ abbr = topics[topic]
78
+ else:
79
+ raise RuntimeError(f"Invalid topic {topic}")
80
+ if categories:
81
+ for category in categories:
82
+ if category not in category_map[topic]:
83
+ raise RuntimeError(f"{category} is not a category of {topic}")
84
+ papers = get_papers(abbr)
85
+ papers = [
86
+ t for t in papers
87
+ if bool(set(process_subject_fields(t['subjects'])) & set(categories))]
88
+ else:
89
+ papers = get_papers(abbr)
90
+ if interest:
91
+ relevancy, hallucination = generate_relevance_score(
92
+ papers,
93
+ query={"interest": interest},
94
+ threshold_score=threshold,
95
+ num_paper_in_prompt=8)
96
+ body = "<br><br>".join(
97
+ [f'Title: <a href="{paper["main_page"]}">{paper["title"]}</a><br>Authors: {paper["authors"]}<br>Score: {paper["Relevancy score"]}<br>Reason: {paper["Reasons for match"]}'
98
+ for paper in relevancy])
99
+ if hallucination:
100
+ body = "Warning: the model hallucinated some papers. We have tried to remove them, but the scores may not be accurate.<br><br>" + body
101
+ else:
102
+ body = "<br><br>".join(
103
+ [f'Title: <a href="{paper["main_page"]}">{paper["title"]}</a><br>Authors: {paper["authors"]}'
104
+ for paper in papers])
105
+ return body
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument("--config", help="yaml config file to use", default="config.yaml")
111
+ args = parser.parse_args()
112
+ with open(args.config, "r") as f:
113
+ config = yaml.safe_load(f)
114
+ if "OPENAI_API_KEY" not in os.environ:
115
+ raise RuntimeError("No openai api key found")
116
+
117
+ topic = config["topic"]
118
+ categories = config["categories"]
119
+ from_email = config.get("from_email") or os.environ.get("FROM_EMAIL")
120
+ to_email = config.get("to_email") or os.environ.get("TO_EMAIL")
121
+ threshold = config["threshold"]
122
+ interest = config["interest"]
123
+ with open("digest.html", "w") as f:
124
+ body = generate_body(topic, categories, interest, threshold)
125
+ f.write(body)
126
+ if os.environ.get('SENDGRID_API_KEY', None):
127
+ sg = SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
128
+ from_email = Email(from_email) # Change to your verified sender
129
+ to_email = To(to_email)
130
+ subject = date.today().strftime("Personalized arXiv Digest, %d %b %Y")
131
+ content = Content("text/html", body)
132
+ mail = Mail(from_email, to_email, subject, content)
133
+ mail_json = mail.get()
134
+
135
+ # Send an HTTP POST request to /mail/send
136
+ response = sg.client.mail.send.post(request_body=mail_json)
137
+ if response.status_code >= 200 and response.status_code <= 300:
138
+ print("Send test email: Success!")
139
+ else:
140
+ print("Send test email: Failure ({response.status_code}, {response.text})")
141
+ else:
142
+ print("No sendgrid api key found. Skipping email")
src/app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from download_new_papers import get_papers
3
+ from relevancy import generate_relevance_score, process_subject_fields
4
+ from sendgrid.helpers.mail import Mail, Email, To, Content
5
+ import sendgrid
6
+ import os
7
+
8
+ topics = {
9
+ "Physics": "",
10
+ "Mathematics": "math",
11
+ "Computer Science": "cs",
12
+ "Quantitative Biology": "q-bio",
13
+ "Quantitative Finance": "q-fin",
14
+ "Statistics": "stat",
15
+ "Electrical Engineering and Systems Science": "eess",
16
+ "Economics": "econ"
17
+ }
18
+
19
+ physics_topics = {
20
+ "Astrophysics": "astro-ph",
21
+ "Condensed Matter": "cond-mat",
22
+ "General Relativity and Quantum Cosmology": "gr-qc",
23
+ "High Energy Physics - Experiment": "hep-ex",
24
+ "High Energy Physics - Lattice": "hep-lat",
25
+ "High Energy Physics - Phenomenology": "hep-ph",
26
+ "High Energy Physics - Theory": "hep-th",
27
+ "Mathematical Physics": "math-ph",
28
+ "Nonlinear Sciences": "nlin",
29
+ "Nuclear Experiment": "nucl-ex",
30
+ "Nuclear Theory": "nucl-th",
31
+ "Physics": "physics",
32
+ "Quantum Physics": "quant-ph"
33
+ }
34
+
35
+ categories_map = {
36
+ "Astrophysics": ["Astrophysics of Galaxies", "Cosmology and Nongalactic Astrophysics", "Earth and Planetary Astrophysics", "High Energy Astrophysical Phenomena", "Instrumentation and Methods for Astrophysics", "Solar and Stellar Astrophysics"],
37
+ "Condensed Matter": ["Disordered Systems and Neural Networks", "Materials Science", "Mesoscale and Nanoscale Physics", "Other Condensed Matter", "Quantum Gases", "Soft Condensed Matter", "Statistical Mechanics", "Strongly Correlated Electrons", "Superconductivity"],
38
+ "General Relativity and Quantum Cosmology": ["None"],
39
+ "High Energy Physics - Experiment": ["None"],
40
+ "High Energy Physics - Lattice": ["None"],
41
+ "High Energy Physics - Phenomenology": ["None"],
42
+ "High Energy Physics - Theory": ["None"],
43
+ "Mathematical Physics": ["None"],
44
+ "Nonlinear Sciences": ["Adaptation and Self-Organizing Systems", "Cellular Automata and Lattice Gases", "Chaotic Dynamics", "Exactly Solvable and Integrable Systems", "Pattern Formation and Solitons"],
45
+ "Nuclear Experiment": ["None"],
46
+ "Nuclear Theory": ["None"],
47
+ "Physics": ["Accelerator Physics", "Applied Physics", "Atmospheric and Oceanic Physics", "Atomic and Molecular Clusters", "Atomic Physics", "Biological Physics", "Chemical Physics", "Classical Physics", "Computational Physics", "Data Analysis, Statistics and Probability", "Fluid Dynamics", "General Physics", "Geophysics", "History and Philosophy of Physics", "Instrumentation and Detectors", "Medical Physics", "Optics", "Physics and Society", "Physics Education", "Plasma Physics", "Popular Physics", "Space Physics"],
48
+ "Quantum Physics": ["None"],
49
+ "Mathematics": ["Algebraic Geometry", "Algebraic Topology", "Analysis of PDEs", "Category Theory", "Classical Analysis and ODEs", "Combinatorics", "Commutative Algebra", "Complex Variables", "Differential Geometry", "Dynamical Systems", "Functional Analysis", "General Mathematics", "General Topology", "Geometric Topology", "Group Theory", "History and Overview", "Information Theory", "K-Theory and Homology", "Logic", "Mathematical Physics", "Metric Geometry", "Number Theory", "Numerical Analysis", "Operator Algebras", "Optimization and Control", "Probability", "Quantum Algebra", "Representation Theory", "Rings and Algebras", "Spectral Theory", "Statistics Theory", "Symplectic Geometry"],
50
+ "Computer Science": ["Artificial Intelligence", "Computation and Language", "Computational Complexity", "Computational Engineering, Finance, and Science", "Computational Geometry", "Computer Science and Game Theory", "Computer Vision and Pattern Recognition", "Computers and Society", "Cryptography and Security", "Data Structures and Algorithms", "Databases", "Digital Libraries", "Discrete Mathematics", "Distributed, Parallel, and Cluster Computing", "Emerging Technologies", "Formal Languages and Automata Theory", "General Literature", "Graphics", "Hardware Architecture", "Human-Computer Interaction", "Information Retrieval", "Information Theory", "Logic in Computer Science", "Machine Learning", "Mathematical Software", "Multiagent Systems", "Multimedia", "Networking and Internet Architecture", "Neural and Evolutionary Computing", "Numerical Analysis", "Operating Systems", "Other Computer Science", "Performance", "Programming Languages", "Robotics", "Social and Information Networks", "Software Engineering", "Sound", "Symbolic Computation", "Systems and Control"],
51
+ "Quantitative Biology": ["Biomolecules", "Cell Behavior", "Genomics", "Molecular Networks", "Neurons and Cognition", "Other Quantitative Biology", "Populations and Evolution", "Quantitative Methods", "Subcellular Processes", "Tissues and Organs"],
52
+ "Quantitative Finance": ["Computational Finance", "Economics", "General Finance", "Mathematical Finance", "Portfolio Management", "Pricing of Securities", "Risk Management", "Statistical Finance", "Trading and Market Microstructure"],
53
+ "Statistics": ["Applications", "Computation", "Machine Learning", "Methodology", "Other Statistics", "Statistics Theory"],
54
+ "Electrical Engineering and Systems Science": ["Audio and Speech Processing", "Image and Video Processing", "Signal Processing", "Systems and Control"],
55
+ "Economics": ["Econometrics", "General Economics", "Theoretical Economics"]
56
+ }
57
+
58
+
59
+ def sample(email, topic, physics_topic, categories, interest):
60
+ if subject == "Physics":
61
+ if isinstance(physics_topic, list):
62
+ raise gr.Error("You must choose a physics topic.")
63
+ topic = physics_topic
64
+ abbr = physics_topics[topic]
65
+ else:
66
+ abbr = topics[topic]
67
+ if categories:
68
+ papers = get_papers(abbr)
69
+ papers = [
70
+ t for t in papers
71
+ if bool(set(process_subject_fields(t['subjects'])) & set(categories))][:4]
72
+ else:
73
+ papers = get_papers(abbr, limit=4)
74
+ if interest:
75
+ relevancy, _ = generate_relevance_score(
76
+ papers,
77
+ query={"interest": interest},
78
+ threshold_score=0,
79
+ num_paper_in_prompt=4)
80
+ return "\n\n".join([paper["summarized_text"] for paper in relevancy])
81
+ else:
82
+ return "\n\n".join(f"Title: {paper['title']}\nAuthors: {paper['authors']}" for paper in papers)
83
+
84
+
85
+ def change_subsubject(subject, physics_subject):
86
+ if subject != "Physics":
87
+ return gr.Dropdown.update(choices=categories_map[subject], value=[], visible=True)
88
+ else:
89
+ print(physics_subject)
90
+ if physics_subject and not isinstance(physics_subject, list):
91
+ return gr.Dropdown.update(choices=categories_map[physics_subject], value=[], visible=True)
92
+ else:
93
+ return gr.Dropdown.update(choices=[], value=[], visible=False)
94
+
95
+
96
+ def change_physics(subject):
97
+ if subject != "Physics":
98
+ return gr.Dropdown.update(visible=False, value=[])
99
+ else:
100
+ return gr.Dropdown.update(physics_topics, visible=True)
101
+
102
+
103
+ def test(email, topic, physics_topic, categories, interest):
104
+ if topic == "Physics":
105
+ if isinstance(physics_topic, list):
106
+ raise gr.Error("You must choose a physics topic.")
107
+ topic = physics_topic
108
+ abbr = physics_topics[topic]
109
+ else:
110
+ abbr = topics[topic]
111
+ if categories:
112
+ papers = get_papers(abbr)
113
+ papers = [
114
+ t for t in papers
115
+ if bool(set(process_subject_fields(t['subjects'])) & set(categories))][:4]
116
+ else:
117
+ papers = get_papers(abbr, limit=4)
118
+ if interest:
119
+ relevancy, hallucination = generate_relevance_score(
120
+ papers,
121
+ query={"interest": interest},
122
+ threshold_score=7,
123
+ num_paper_in_prompt=8)
124
+ print(relevancy[0].keys())
125
+ body = "<br><br>".join([f'Title: <a href="{paper["main_page"]}">{paper["title"]}</a><br>Authors: {paper["authors"]}<br>Score: {paper["Relevancy score"]}<br>Reason: {paper["Reasons for match"]}' for paper in relevancy])
126
+ if hallucination:
127
+ body = "Warning: the model hallucinated some papers. We have tried to remove them, but the scores may not be accurate.<br><br>" + body
128
+ else:
129
+ body = "<br><br>".join([f'Title: <a href="{paper["main_page"]}">{paper["title"]}</a><br>Authors: {paper["authors"]}' for paper in papers])
130
+ sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
131
+ from_email = Email("[email protected]") # Change to your verified sender
132
+ to_email = To(email)
133
+ subject = "arXiv digest"
134
+ content = Content("text/html", body)
135
+ mail = Mail(from_email, to_email, subject, content)
136
+ mail_json = mail.get()
137
+
138
+ # Send an HTTP POST request to /mail/send
139
+ response = sg.client.mail.send.post(request_body=mail_json)
140
+ if response.status_code >= 200 and response.status_code <= 300:
141
+ return "Send test email: Success!"
142
+ else:
143
+ return f"Send test email: Failure ({response.status_code})"
144
+
145
+
146
+ with gr.Blocks() as demo:
147
+ with gr.Column():
148
+ email = gr.Textbox(label="Email address")
149
+ subject = gr.Radio(
150
+ list(topics.keys()), label="Topic to subscribe to"
151
+ )
152
+ physics_subject = gr.Dropdown(physics_topics, value=[], multiselect=False, label="Physics category", visible=False, info="")
153
+ subsubject = gr.Dropdown(
154
+ [], value=[], multiselect=True, label="Subtopic", info="", visible=False)
155
+ subject.change(fn=change_physics, inputs=[subject], outputs=physics_subject)
156
+ subject.change(fn=change_subsubject, inputs=[subject, physics_subject], outputs=subsubject)
157
+ physics_subject.change(fn=change_subsubject, inputs=[subject, physics_subject], outputs=subsubject)
158
+
159
+
160
+ interest = gr.Textbox(label="A natural language description of what you are interested in. Press enter to update.")
161
+ sample_output = gr.Textbox(label="Examples")
162
+ test_btn = gr.Button("Send email")
163
+ output = gr.Textbox(label="Test email status")
164
+ test_btn.click(fn=test, inputs=[email, subject, physics_subject, subsubject, interest], outputs=output)
165
+ subject.change(fn=sample, inputs=[email, subject, physics_subject, subsubject, interest], outputs=sample_output)
166
+ physics_subject.change(fn=sample, inputs=[email, subject, physics_subject, subsubject, interest], outputs=sample_output)
167
+ subsubject.change(fn=sample, inputs=[email, subject, physics_subject, subsubject, interest], outputs=sample_output)
168
+ interest.submit(fn=sample, inputs=[email, subject, physics_subject, subsubject, interest], outputs=sample_output)
169
+
170
+ demo.launch()
src/download_new_papers.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding: utf-8
2
+ import os
3
+ import tqdm
4
+ from bs4 import BeautifulSoup as bs
5
+ import urllib.request
6
+ import json
7
+ import datetime
8
+ import pytz
9
+
10
+
11
+ def _download_new_papers(field_abbr):
12
+ NEW_SUB_URL = f'https://arxiv.org/list/{field_abbr}/new' # https://arxiv.org/list/cs/new
13
+ page = urllib.request.urlopen(NEW_SUB_URL)
14
+ soup = bs(page)
15
+ content = soup.body.find("div", {'id': 'content'})
16
+
17
+ # find the first h3 element in content
18
+ h3 = content.find("h3").text # e.g: New submissions for Wed, 10 May 23
19
+ date = h3.replace("New submissions for", "").strip()
20
+
21
+ dt_list = content.dl.find_all("dt")
22
+ dd_list = content.dl.find_all("dd")
23
+ arxiv_base = "https://arxiv.org/abs/"
24
+
25
+ assert len(dt_list) == len(dd_list)
26
+ new_paper_list = []
27
+ for i in tqdm.tqdm(range(len(dt_list))):
28
+ paper = {}
29
+ paper_number = dt_list[i].text.strip().split(" ")[2].split(":")[-1]
30
+ paper['main_page'] = arxiv_base + paper_number
31
+ paper['pdf'] = arxiv_base.replace('abs', 'pdf') + paper_number
32
+
33
+ paper['title'] = dd_list[i].find("div", {"class": "list-title mathjax"}).text.replace("Title: ", "").strip()
34
+ paper['authors'] = dd_list[i].find("div", {"class": "list-authors"}).text \
35
+ .replace("Authors:\n", "").replace("\n", "").strip()
36
+ paper['subjects'] = dd_list[i].find("div", {"class": "list-subjects"}).text.replace("Subjects: ", "").strip()
37
+ paper['abstract'] = dd_list[i].find("p", {"class": "mathjax"}).text.replace("\n", " ").strip()
38
+ new_paper_list.append(paper)
39
+
40
+
41
+ # check if ./data exist, if not, create it
42
+ if not os.path.exists("./data"):
43
+ os.makedirs("./data")
44
+
45
+ # save new_paper_list to a jsonl file, with each line as the element of a dictionary
46
+ date = datetime.date.fromtimestamp(datetime.datetime.now(tz=pytz.timezone("America/New_York")).timestamp())
47
+ date = date.strftime("%a, %d %b %y")
48
+ with open(f"./data/{field_abbr}_{date}.jsonl", "w") as f:
49
+ for paper in new_paper_list:
50
+ f.write(json.dumps(paper) + "\n")
51
+
52
+
53
+ def get_papers(field_abbr, limit=None):
54
+ date = datetime.date.fromtimestamp(datetime.datetime.now(tz=pytz.timezone("America/New_York")).timestamp())
55
+ date = date.strftime("%a, %d %b %y")
56
+ if not os.path.exists(f"./data/{field_abbr}_{date}.jsonl"):
57
+ _download_new_papers(field_abbr)
58
+ results = []
59
+ with open(f"./data/{field_abbr}_{date}.jsonl", "r") as f:
60
+ for i, line in enumerate(f.readlines()):
61
+ if limit and i == limit:
62
+ return results
63
+ results.append(json.loads(line))
64
+ return results
src/relevancy.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ run:
3
+ python -m relevancy run_all_day_paper \
4
+ --output_dir ./data \
5
+ --model_name="gpt-3.5-turbo" \
6
+ """
7
+ import time
8
+ import json
9
+ import os
10
+ import random
11
+ import re
12
+ import string
13
+ from datetime import datetime
14
+
15
+ import numpy as np
16
+ import tqdm
17
+ import utils
18
+
19
+
20
+ def encode_prompt(query, prompt_papers):
21
+ """Encode multiple prompt instructions into a single string."""
22
+ prompt = open("src/relevancy_prompt.txt").read() + "\n"
23
+ prompt += query['interest']
24
+
25
+ for idx, task_dict in enumerate(prompt_papers):
26
+ (title, authors, abstract) = task_dict["title"], task_dict["authors"], task_dict["abstract"]
27
+ if not title:
28
+ raise
29
+ prompt += f"###\n"
30
+ prompt += f"{idx + 1}. Title: {title}\n"
31
+ prompt += f"{idx + 1}. Authors: {authors}\n"
32
+ prompt += f"{idx + 1}. Abstract: {abstract}\n"
33
+ prompt += f"\n Generate response:\n1."
34
+ print(prompt)
35
+ return prompt
36
+
37
+
38
+ def post_process_chat_gpt_response(paper_data, response, threshold_score=8):
39
+ selected_data = []
40
+ if response is None:
41
+ return []
42
+ json_items = response['message']['content'].replace("\n\n", "\n").split("\n")
43
+ pattern = r"^\d+\. "
44
+ import pprint
45
+ try:
46
+ score_items = [json.loads(re.sub(pattern, "", line)) for line in json_items if "relevancy score" in line.lower()]
47
+ except Exception:
48
+ pprint.pprint([re.sub(pattern, "", line) for line in json_items if "relevancy score" in line.lower()])
49
+ raise RuntimeError("failed")
50
+ pprint.pprint(score_items)
51
+ scores = []
52
+ for item in score_items:
53
+ temp = item["Relevancy score"]
54
+ if "/" in temp:
55
+ scores.append(int(temp.split("/")[0]))
56
+ else:
57
+ scores.append(int(temp))
58
+ if len(score_items) != len(paper_data):
59
+ score_items = score_items[:len(paper_data)]
60
+ hallucination = True
61
+ else:
62
+ hallucination = False
63
+
64
+ for idx, inst in enumerate(score_items):
65
+ # if the decoding stops due to length, the last example is likely truncated so we discard it
66
+ if scores[idx] < threshold_score:
67
+ continue
68
+ output_str = "Title: " + paper_data[idx]["title"] + "\n"
69
+ output_str += "Authors: " + paper_data[idx]["authors"] + "\n"
70
+ output_str += "Link: " + paper_data[idx]["main_page"] + "\n"
71
+ for key, value in inst.items():
72
+ paper_data[idx][key] = value
73
+ output_str += key + ": " + value + "\n"
74
+ paper_data[idx]['summarized_text'] = output_str
75
+ selected_data.append(paper_data[idx])
76
+ return selected_data, hallucination
77
+
78
+
79
+ def find_word_in_string(w, s):
80
+ return re.compile(r"\b({0})\b".format(w), flags=re.IGNORECASE).search(s)
81
+
82
+
83
+ def process_subject_fields(subjects):
84
+ all_subjects = subjects.split(";")
85
+ all_subjects = [s.split(" (")[0] for s in all_subjects]
86
+ return all_subjects
87
+
88
+ def generate_relevance_score(
89
+ all_papers,
90
+ query,
91
+ model_name="gpt-3.5-turbo",
92
+ threshold_score=8,
93
+ num_paper_in_prompt=4,
94
+ temperature=0.4,
95
+ top_p=1.0,
96
+ sorting=True
97
+ ):
98
+ ans_data = []
99
+ request_idx = 1
100
+ hallucination = False
101
+ for id in tqdm.tqdm(range(0, len(all_papers), num_paper_in_prompt)):
102
+ prompt_papers = all_papers[id:id+num_paper_in_prompt]
103
+ # only sampling from the seed tasks
104
+ prompt = encode_prompt(query, prompt_papers)
105
+
106
+ decoding_args = utils.OpenAIDecodingArguments(
107
+ temperature=temperature,
108
+ n=1,
109
+ max_tokens=1072, # hard-code to maximize the length. the requests will be automatically adjusted
110
+ top_p=top_p,
111
+ )
112
+ request_start = time.time()
113
+ response = utils.openai_completion(
114
+ prompts=prompt,
115
+ model_name=model_name,
116
+ batch_size=1,
117
+ decoding_args=decoding_args,
118
+ logit_bias={"100257": -100}, # prevent the <|endoftext|> from being generated
119
+ # "100265":-100, "100276":-100 for <|im_end|> and <endofprompt> token
120
+ )
121
+ print ("response", response['message']['content'])
122
+ request_duration = time.time() - request_start
123
+
124
+ process_start = time.time()
125
+ batch_data, hallu = post_process_chat_gpt_response(prompt_papers, response, threshold_score=threshold_score)
126
+ hallucination = hallucination or hallu
127
+ ans_data.extend(batch_data)
128
+
129
+ print(f"Request {request_idx+1} took {request_duration:.2f}s")
130
+ print(f"Post-processing took {time.time() - process_start:.2f}s")
131
+
132
+ if sorting:
133
+ ans_data = sorted(ans_data, key=lambda x: x["Relevancy score"], reverse=True)
134
+
135
+ return ans_data, hallucination
136
+
137
+ def run_all_day_paper(
138
+ query={"interest":"", "subjects":["Computation and Language", "Artificial Intelligence"]},
139
+ date=None,
140
+ data_dir="../data",
141
+ model_name="gpt-3.5-turbo",
142
+ threshold_score=8,
143
+ num_paper_in_prompt=8,
144
+ temperature=0.4,
145
+ top_p=1.0
146
+ ):
147
+ if date is None:
148
+ date = datetime.today().strftime('%a, %d %b %y')
149
+ # string format such as Wed, 10 May 23
150
+ print ("the date for the arxiv data is: ", date)
151
+
152
+ all_papers = [json.loads(l) for l in open(f"{data_dir}/{date}.jsonl", "r")]
153
+ print (f"We found {len(all_papers)}.")
154
+
155
+ all_papers_in_subjects = [
156
+ t for t in all_papers
157
+ if bool(set(process_subject_fields(t['subjects'])) & set(query['subjects']))
158
+ ]
159
+ print(f"After filtering subjects, we have {len(all_papers_in_subjects)} papers left.")
160
+ ans_data = generate_relevance_score(all_papers_in_subjects, query, model_name, threshold_score, num_paper_in_prompt, temperature, top_p)
161
+ utils.write_ans_to_file(ans_data, date, output_dir="../outputs")
162
+ return ans_data
163
+
164
+
165
+ if __name__ == "__main__":
166
+ query = {"interest":"""
167
+ 1. Large language model pretraining and finetunings
168
+ 2. Multimodal machine learning
169
+ 3. Do not care about specific application, for example, information extraction, summarization, etc.
170
+ 4. Not interested in paper focus on specific languages, e.g., Arabic, Chinese, etc.\n""",
171
+ "subjects":["Computation and Language"]}
172
+ ans_data = run_all_day_paper(query)
src/relevancy_prompt.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ You have been asked to read a list of a few arxiv papers, each with title, authors and abstract.
2
+ Based on my specific research interests, elevancy score out of 10 for each paper, based on my specific research interest, with a higher score indicating greater relevance. A relevance score more than 7 will need person's attention for details.
3
+ Additionally, please generate 1-2 sentence summary for each paper explaining why it's relevant to my research interests.
4
+ Please keep the paper order the same as in the input list, with one json format per line. Example is:
5
+ 1. {"Relevancy score": "an integer score out of 10", "Reasons for match": "1-2 sentence short reasonings"}
6
+
7
+ My research interests are:
src/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ beautifulsoup4==4.12.2
2
+ tqdm==4.65.0
3
+ pytz==2023.3
4
+ numpy==1.24.2
5
+ openai==0.27.4
6
+ sendgrid==6.10.0
7
+ pyyaml==6.00
src/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import math
4
+ import os
5
+ import io
6
+ import sys
7
+ import time
8
+ import json
9
+ from typing import Optional, Sequence, Union
10
+
11
+ import openai
12
+ import tqdm
13
+ from openai import openai_object
14
+ import copy
15
+
16
+ StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
17
+
18
+ openai_org = os.getenv("OPENAI_ORG")
19
+ if openai_org is not None:
20
+ openai.organization = openai_org
21
+ logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
22
+
23
+
24
+ @dataclasses.dataclass
25
+ class OpenAIDecodingArguments(object):
26
+ max_tokens: int = 1800
27
+ temperature: float = 0.2
28
+ top_p: float = 1.0
29
+ n: int = 1
30
+ stream: bool = False
31
+ stop: Optional[Sequence[str]] = None
32
+ presence_penalty: float = 0.0
33
+ frequency_penalty: float = 0.0
34
+ # logprobs: Optional[int] = None
35
+
36
+
37
+ def openai_completion(
38
+ prompts, #: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
39
+ decoding_args: OpenAIDecodingArguments,
40
+ model_name="text-davinci-003",
41
+ sleep_time=2,
42
+ batch_size=1,
43
+ max_instances=sys.maxsize,
44
+ max_batches=sys.maxsize,
45
+ return_text=False,
46
+ **decoding_kwargs,
47
+ ) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
48
+ """Decode with OpenAI API.
49
+
50
+ Args:
51
+ prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
52
+ as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
53
+ it can also be a dictionary (or list thereof) as explained here:
54
+ https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
55
+ decoding_args: Decoding arguments.
56
+ model_name: Model name. Can be either in the format of "org/model" or just "model".
57
+ sleep_time: Time to sleep once the rate-limit is hit.
58
+ batch_size: Number of prompts to send in a single request. Only for non chat model.
59
+ max_instances: Maximum number of prompts to decode.
60
+ max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
61
+ return_text: If True, return text instead of full completion object (which contains things like logprob).
62
+ decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
63
+
64
+ Returns:
65
+ A completion or a list of completions.
66
+ Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
67
+ - a string (if return_text is True)
68
+ - an openai_object.OpenAIObject object (if return_text is False)
69
+ - a list of objects of the above types (if decoding_args.n > 1)
70
+ """
71
+ is_chat_model = "gpt-3.5" in model_name or "gpt-4" in model_name
72
+ is_single_prompt = isinstance(prompts, (str, dict))
73
+ if is_single_prompt:
74
+ prompts = [prompts]
75
+
76
+ if max_batches < sys.maxsize:
77
+ logging.warning(
78
+ "`max_batches` will be deprecated in the future, please use `max_instances` instead."
79
+ "Setting `max_instances` to `max_batches * batch_size` for now."
80
+ )
81
+ max_instances = max_batches * batch_size
82
+
83
+ prompts = prompts[:max_instances]
84
+ num_prompts = len(prompts)
85
+ prompt_batches = [
86
+ prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
87
+ for batch_id in range(int(math.ceil(num_prompts / batch_size)))
88
+ ]
89
+
90
+ completions = []
91
+ for batch_id, prompt_batch in tqdm.tqdm(
92
+ enumerate(prompt_batches),
93
+ desc="prompt_batches",
94
+ total=len(prompt_batches),
95
+ ):
96
+ batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args
97
+
98
+ while True:
99
+ try:
100
+ shared_kwargs = dict(
101
+ model=model_name,
102
+ **batch_decoding_args.__dict__,
103
+ **decoding_kwargs,
104
+ )
105
+ if is_chat_model:
106
+ completion_batch = openai.ChatCompletion.create(
107
+ messages=[
108
+ {"role": "system", "content": "You are a helpful assistant."},
109
+ {"role": "user", "content": prompt_batch[0]}
110
+ ],
111
+ **shared_kwargs
112
+ )
113
+ else:
114
+ completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
115
+
116
+ choices = completion_batch.choices
117
+
118
+ for choice in choices:
119
+ choice["total_tokens"] = completion_batch.usage.total_tokens
120
+ completions.extend(choices)
121
+ break
122
+ except openai.error.OpenAIError as e:
123
+ logging.warning(f"OpenAIError: {e}.")
124
+ if "Please reduce your prompt" in str(e):
125
+ batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
126
+ logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
127
+ else:
128
+ logging.warning("Hit request rate limit; retrying...")
129
+ time.sleep(sleep_time) # Annoying rate limit on requests.
130
+
131
+ if return_text:
132
+ completions = [completion.text for completion in completions]
133
+ if decoding_args.n > 1:
134
+ # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
135
+ completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
136
+ if is_single_prompt:
137
+ # Return non-tuple if only 1 input and 1 generation.
138
+ (completions,) = completions
139
+ return completions
140
+
141
+
142
+ def write_ans_to_file(ans_data, file_prefix, output_dir="./output"):
143
+ if not os.path.exists(output_dir):
144
+ os.makedirs(output_dir)
145
+ filename = os.path.join(output_dir, file_prefix + ".txt")
146
+ with open(filename, "w") as f:
147
+ for ans in ans_data:
148
+ f.write(ans + "\n")