oldcai commited on
Commit
d7a7846
·
verified ·
1 Parent(s): 2df351d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +22 -0
  2. .gitignore +174 -0
  3. .vscode/launch.json +15 -0
  4. Dockerfile +26 -0
  5. LICENSE +201 -0
  6. README.md +251 -7
  7. assets/demo_fast.gif +3 -0
  8. assets/demo_fast.mp4 +3 -0
  9. assets/medrax_logo.jpg +3 -0
  10. assets/medrax_logo.png +3 -0
  11. benchmark/__init__.py +0 -0
  12. benchmark/create_benchmark.py +352 -0
  13. benchmark/llm.py +42 -0
  14. benchmark/utils.py +78 -0
  15. data/eurorad_metadata.json +0 -0
  16. data/figures.py +74 -0
  17. data/get_cases.py +51 -0
  18. data/stats/age_distribution.png +3 -0
  19. data/stats/area_of_interest_distribution.png +3 -0
  20. data/stats/gender_distribution.png +3 -0
  21. demo/chest/LIDC.dcm +3 -0
  22. demo/chest/Pseudo.dcm +3 -0
  23. demo/chest/RIDER.dcm +3 -0
  24. demo/chest/TCGAA.dcm +3 -0
  25. demo/chest/__init__.py +0 -0
  26. demo/chest/effusion1.png +3 -0
  27. demo/chest/normal1.jpg +3 -0
  28. demo/chest/normal2.jpg +3 -0
  29. demo/chest/normal3.jpg +3 -0
  30. demo/chest/normal4.jpg +3 -0
  31. demo/chest/normal5.jpg +3 -0
  32. demo/chest/normal6.jpg +3 -0
  33. demo/chest/pneumonia1.jpg +0 -0
  34. demo/chest/pneumonia2.jpg +0 -0
  35. demo/chest/pneumonia3.jpg +0 -0
  36. demo/chest/pneumonia4.jpg +3 -0
  37. demo/chest/pneumonia5.jpg +3 -0
  38. experiments/README.md +63 -0
  39. experiments/analyze_axes.py +385 -0
  40. experiments/benchmark_chexagent.py +316 -0
  41. experiments/benchmark_gpt4o.py +331 -0
  42. experiments/benchmark_llama.py +443 -0
  43. experiments/benchmark_llavamed.py +541 -0
  44. experiments/benchmark_medrax.ipynb +374 -0
  45. experiments/chexbench_gpt4.py +405 -0
  46. experiments/compare_runs.py +290 -0
  47. experiments/inspect_logs.py +210 -0
  48. experiments/validate_logs.py +162 -0
  49. handler.py +57 -0
  50. interface.py +259 -0
.gitattributes CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo_fast.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/demo_fast.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ assets/medrax_logo.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/medrax_logo.png filter=lfs diff=lfs merge=lfs -text
40
+ data/stats/age_distribution.png filter=lfs diff=lfs merge=lfs -text
41
+ data/stats/area_of_interest_distribution.png filter=lfs diff=lfs merge=lfs -text
42
+ data/stats/gender_distribution.png filter=lfs diff=lfs merge=lfs -text
43
+ demo/chest/LIDC.dcm filter=lfs diff=lfs merge=lfs -text
44
+ demo/chest/Pseudo.dcm filter=lfs diff=lfs merge=lfs -text
45
+ demo/chest/RIDER.dcm filter=lfs diff=lfs merge=lfs -text
46
+ demo/chest/TCGAA.dcm filter=lfs diff=lfs merge=lfs -text
47
+ demo/chest/effusion1.png filter=lfs diff=lfs merge=lfs -text
48
+ demo/chest/normal1.jpg filter=lfs diff=lfs merge=lfs -text
49
+ demo/chest/normal2.jpg filter=lfs diff=lfs merge=lfs -text
50
+ demo/chest/normal3.jpg filter=lfs diff=lfs merge=lfs -text
51
+ demo/chest/normal4.jpg filter=lfs diff=lfs merge=lfs -text
52
+ demo/chest/normal5.jpg filter=lfs diff=lfs merge=lfs -text
53
+ demo/chest/normal6.jpg filter=lfs diff=lfs merge=lfs -text
54
+ demo/chest/pneumonia4.jpg filter=lfs diff=lfs merge=lfs -text
55
+ demo/chest/pneumonia5.jpg filter=lfs diff=lfs merge=lfs -text
56
+ medrax/llava/serve/examples/bio_patch.png filter=lfs diff=lfs merge=lfs -text
57
+ medrax/llava/serve/examples/med_img_1.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ # ruff
165
+ ruff-cache/
166
+ .ruff_cache/
167
+
168
+ afallah/
169
+
170
+ logs/
171
+
172
+ temp/
173
+
174
+ .gradio/
.vscode/launch.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python Debugger: main.py",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "program": "main.py",
12
+ "console": "integratedTerminal"
13
+ }
14
+ ]
15
+ }
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM runpod/pytorch:2.1.0-py3.10-cuda11.8.0
2
+
3
+ WORKDIR /workspace
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ libgl1-mesa-glx \
8
+ libglib2.0-0 \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements and install Python dependencies
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy application code
16
+ COPY . .
17
+
18
+ # Create directories for models and temporary files
19
+ RUN mkdir -p /model-weights /workspace/temp
20
+
21
+ # Download models (if needed)
22
+ # RUN python download_models.py
23
+
24
+ EXPOSE 8000
25
+
26
+ CMD ["python", "handler.py"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,256 @@
1
  ---
2
- title: Medrax.org
3
- emoji: 📈
4
- colorFrom: indigo
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.16.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: medrax.org
3
+ app_file: interface.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.16.0
 
 
6
  ---
7
+ <h1 align="center">
8
+ 🤖 MedRAX: Medical Reasoning Agent for Chest X-ray
9
+ </h1>
10
+ <p align="center"> <a href="https://arxiv.org/abs/2502.02673" target="_blank"><img src="https://img.shields.io/badge/arXiv-Paper-FF6B6B?style=for-the-badge&logo=arxiv&logoColor=white" alt="arXiv"></a> <a href="https://github.com/bowang-lab/MedRAX"><img src="https://img.shields.io/badge/GitHub-Code-4A90E2?style=for-the-badge&logo=github&logoColor=white" alt="GitHub"></a> <a href="https://huggingface.co/datasets/wanglab/chest-agent-bench"><img src="https://img.shields.io/badge/HuggingFace-Dataset-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Dataset"></a> </p>
11
 
12
+ ![](assets/demo_fast.gif?autoplay=1)
13
+
14
+ <br>
15
+
16
+ ## Abstract
17
+ Chest X-rays (CXRs) play an integral role in driving critical decisions in disease management and patient care. While recent innovations have led to specialized models for various CXR interpretation tasks, these solutions often operate in isolation, limiting their practical utility in clinical practice. We present MedRAX, the first versatile AI agent that seamlessly integrates state-of-the-art CXR analysis tools and multimodal large language models into a unified framework. MedRAX dynamically leverages these models to address complex medical queries without requiring additional training. To rigorously evaluate its capabilities, we introduce ChestAgentBench, a comprehensive benchmark containing 2,500 complex medical queries across 7 diverse categories. Our experiments demonstrate that MedRAX achieves state-of-the-art performance compared to both open-source and proprietary models, representing a significant step toward the practical deployment of automated CXR interpretation systems.
18
+ <br><br>
19
+
20
+
21
+ ## MedRAX
22
+ MedRAX is built on a robust technical foundation:
23
+ - **Core Architecture**: Built on LangChain and LangGraph frameworks
24
+ - **Language Model**: Uses GPT-4o with vision capabilities as the backbone LLM
25
+ - **Deployment**: Supports both local and cloud-based deployments
26
+ - **Interface**: Production-ready interface built with Gradio
27
+ - **Modular Design**: Tool-agnostic architecture allowing easy integration of new capabilities
28
+
29
+ ### Integrated Tools
30
+ - **Visual QA**: Utilizes CheXagent and LLaVA-Med for complex visual understanding and medical reasoning
31
+ - **Segmentation**: Employs MedSAM and PSPNet model trained on ChestX-Det for precise anatomical structure identification
32
+ - **Grounding**: Uses Maira-2 for localizing specific findings in medical images
33
+ - **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
34
+ - **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
35
+ - **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
36
+ - **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
37
+
38
+ Note the current version of MedRAX is experimentally released and does not support vision for GPT-4o and MedSAM. We will be integrating these shortly.
39
+ <br><br>
40
+
41
+
42
+ ## ChestAgentBench
43
+ We introduce ChestAgentBench, a comprehensive evaluation framework with 2,500 complex medical queries across 7 categories, built from 675 expert-curated clinical cases. The benchmark evaluates complex multi-step reasoning in CXR interpretation through:
44
+
45
+ - Detection
46
+ - Classification
47
+ - Localization
48
+ - Comparison
49
+ - Relationship
50
+ - Diagnosis
51
+ - Characterization
52
+
53
+ Download the benchmark: [ChestAgentBench on Hugging Face](https://huggingface.co/datasets/wanglab/chest-agent-bench)
54
+ ```
55
+ huggingface-cli download wanglab/chestagentbench --repo-type dataset --local-dir chestagentbench
56
+ ```
57
+
58
+ Unzip the Eurorad figures to your local `MedMAX` directory.
59
+ ```
60
+ unzip chestagentbench/figures.zip
61
+ ```
62
+
63
+ To evaluate with GPT-4o, set your OpenAI API key and run the quickstart script.
64
+ ```
65
+ export OPENAI_API_KEY="<your-openai-api-key>"
66
+ python quickstart.py \
67
+ --model chatgpt-4o-latest \
68
+ --temperature 0.2 \
69
+ --max-cases 2 \
70
+ --log-prefix chatgpt-4o-latest \
71
+ --use-urls
72
+ ```
73
+
74
+
75
+ <br>
76
+
77
+ ## Installation
78
+ ### Prerequisites
79
+ - Python 3.8+
80
+ - CUDA/GPU for best performance
81
+
82
+ ### Installation Steps
83
+ ```bash
84
+ # Clone the repository
85
+ git clone https://github.com/bowang-lab/MedRAX.git
86
+ cd MedRAX
87
+
88
+ # Install package
89
+ pip install -e .
90
+ ```
91
+
92
+ ### Getting Started
93
+ ```bash
94
+ # Start the Gradio interface
95
+ python main.py
96
+ ```
97
+ or if you run into permission issues
98
+ ```bash
99
+ sudo -E env "PATH=$PATH" python main.py
100
+ ```
101
+ You need to setup the `model_dir` inside `main.py` to the directory where you want to download or already have the weights of above tools from Hugging Face.
102
+ Comment out the tools that you do not have access to.
103
+ Make sure to setup your OpenAI API key in `.env` file!
104
+ <br><br><br>
105
+
106
+
107
+ ## Tool Selection and Initialization
108
+
109
+ MedRAX supports selective tool initialization, allowing you to use only the tools you need. Tools can be specified when initializing the agent (look at `main.py`):
110
+
111
+ ```python
112
+ selected_tools = [
113
+ "ImageVisualizerTool",
114
+ "ChestXRayClassifierTool",
115
+ "ChestXRaySegmentationTool",
116
+ # Add or remove tools as needed
117
+ ]
118
+
119
+ agent, tools_dict = initialize_agent(
120
+ "medrax/docs/system_prompts.txt",
121
+ tools_to_use=selected_tools,
122
+ model_dir="/model-weights"
123
+ )
124
+ ```
125
+
126
+ <br><br>
127
+ ## Automatically Downloaded Models
128
+
129
+ The following tools will automatically download their model weights when initialized:
130
+
131
+ ### Classification Tool
132
+ ```python
133
+ ChestXRayClassifierTool(device=device)
134
+ ```
135
+
136
+ ### Segmentation Tool
137
+ ```python
138
+ ChestXRaySegmentationTool(device=device)
139
+ ```
140
+
141
+ ### Grounding Tool
142
+ ```python
143
+ XRayPhraseGroundingTool(
144
+ cache_dir=model_dir,
145
+ temp_dir=temp_dir,
146
+ load_in_8bit=True,
147
+ device=device
148
+ )
149
+ ```
150
+ - Maira-2 weights download to specified `cache_dir`
151
+ - 8-bit and 4-bit quantization available for reduced memory usage
152
+
153
+ ### LLaVA-Med Tool
154
+ ```python
155
+ LlavaMedTool(
156
+ cache_dir=model_dir,
157
+ device=device,
158
+ load_in_8bit=True
159
+ )
160
+ ```
161
+ - Automatic weight download to `cache_dir`
162
+ - 8-bit and 4-bit quantization available for reduced memory usage
163
+
164
+ ### Report Generation Tool
165
+ ```python
166
+ ChestXRayReportGeneratorTool(
167
+ cache_dir=model_dir,
168
+ device=device
169
+ )
170
+ ```
171
+
172
+ ### Visual QA Tool
173
+ ```python
174
+ XRayVQATool(
175
+ cache_dir=model_dir,
176
+ device=device
177
+ )
178
+ ```
179
+ - CheXagent weights download automatically
180
+
181
+ ### MedSAM Tool
182
+ ```
183
+ Support for MedSAM segmentation will be added in a future update.
184
+ ```
185
+
186
+ ### Utility Tools
187
+ No additional model weights required:
188
+ ```python
189
+ ImageVisualizerTool()
190
+ DicomProcessorTool(temp_dir=temp_dir)
191
+ ```
192
+ <br>
193
+
194
+ ## Manual Setup Required
195
+
196
+ ### Image Generation Tool
197
+ ```python
198
+ ChestXRayGeneratorTool(
199
+ model_path=f"{model_dir}/roentgen",
200
+ temp_dir=temp_dir,
201
+ device=device
202
+ )
203
+ ```
204
+ - RoentGen weights require manual setup:
205
+ 1. Contact authors: https://github.com/StanfordMIMI/RoentGen
206
+ 2. Place weights in `{model_dir}/roentgen`
207
+ 3. Optional tool, can be excluded if not needed
208
+ <br>
209
+
210
+ ## Configuration Notes
211
+
212
+ ### Required Parameters
213
+ - `model_dir` or `cache_dir`: Base directory for model weights that Hugging Face uses
214
+ - `temp_dir`: Directory for temporary files
215
+ - `device`: "cuda" for GPU, "cpu" for CPU-only
216
+
217
+ ### Memory Management
218
+ - Consider selective tool initialization for resource constraints
219
+ - Use 8-bit quantization where available
220
+ - Some tools (LLaVA-Med, Grounding) are more resource-intensive
221
+ <br><br>
222
+
223
+ ## Authors
224
+ - **Adibvafa Fallahpour**¹²³ * ([email protected])
225
+ - **Jun Ma**²³ *
226
+ - **Alif Munim**³⁴ *
227
+ - **Hongwei Lyu**³
228
+ - **Bo Wang**¹²³⁵
229
+
230
+ ¹ Department of Computer Science, University of Toronto, Toronto, Canada
231
+ ² Vector Institute, Toronto, Canada
232
+ ³ University Health Network, Toronto, Canada
233
+ ⁴ Cohere For AI, Toronto, Canada
234
+ ⁵ Department of Laboratory Medicine and Pathobiology, University of Toronto, Toronto, Canada <br>
235
+ \* Equal contribution
236
+ <br><br>
237
+
238
+
239
+ ## Citation
240
+ If you find this work useful, please cite our paper:
241
+ ```bibtex
242
+ @misc{fallahpour2025medraxmedicalreasoningagent,
243
+ title={MedRAX: Medical Reasoning Agent for Chest X-ray},
244
+ author={Adibvafa Fallahpour and Jun Ma and Alif Munim and Hongwei Lyu and Bo Wang},
245
+ year={2025},
246
+ eprint={2502.02673},
247
+ archivePrefix={arXiv},
248
+ primaryClass={cs.LG},
249
+ url={https://arxiv.org/abs/2502.02673},
250
+ }
251
+ ```
252
+
253
+ ---
254
+ <p align="center">
255
+ Made with ❤️ at University of Toronto, Vector Institute, and University Health Network
256
+ </p>
assets/demo_fast.gif ADDED

Git LFS Details

  • SHA256: f9017455401581570b1366228acbcf0a816f1e31d1548a9f363f7eee0002432e
  • Pointer size: 133 Bytes
  • Size of remote file: 26.8 MB
assets/demo_fast.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f3bab0a7f21187f66bf9e7f5809489b6348d8b140b7e90d4fdf750ec989f92f
3
+ size 4109270
assets/medrax_logo.jpg ADDED

Git LFS Details

  • SHA256: 306aa20d47067df102e4ba26d637f22a7d95f449a5969d320ceeca03b71da1d1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
assets/medrax_logo.png ADDED

Git LFS Details

  • SHA256: 5af3f42308022abe028b670e6716152e714c1f25ebbe6375532775a557b66b2c
  • Pointer size: 131 Bytes
  • Size of remote file: 148 kB
benchmark/__init__.py ADDED
File without changes
benchmark/create_benchmark.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Medical X-ray Question Generation Benchmark aka ChestAgentBench
4
+
5
+ This script generates clinical questions from X-ray case data of Eurorad dataset using GPT-4o.
6
+ It structures questions across different analytical categories and saves them as JSON.
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import json
12
+ from typing import *
13
+ from pprint import pprint
14
+
15
+ import openai
16
+ import numpy as np
17
+ from scipy import stats
18
+ import plotly.graph_objects as go
19
+ from tqdm import tqdm
20
+
21
+ from benchmark.utils import load_eurorad_dataset
22
+ from benchmark.llm import get_llm_response
23
+
24
+ # Constants
25
+ DATA_DIR = "set your data directory here, e.g. /home/MedRAX/data"
26
+ DATASET_PATH = os.path.join(DATA_DIR, "eurorad_metadata.json")
27
+
28
+ SYSTEM_PROMPT = """
29
+ You are an expert medical benchmark creation assistant.
30
+ Your goal is to generate questions that evaluate a multimodal medical AI agent's ability to interpret and reason about chest X-rays.
31
+ """.strip()
32
+
33
+ CATEGORIES_META = {
34
+ "detection": "Identify and locate specific findings in the chest X-ray.",
35
+ "classification": "Determine whether specific findings are present or absent in the chest X-ray.",
36
+ "enumeration": "Count the number of target findings in the chest X-ray.",
37
+ "localization": "Locate a given finding in the chest X-ray.",
38
+ "comparison": "Compare the size or position of a specific finding in the chest X-ray.",
39
+ "relationship": "Determine the relationship between two or more findings in the chest X-ray.",
40
+ "diagnosis": "Make a diagnosis or determine a treatment plan by interpreting the chest X-ray.",
41
+ "characterization": "Describe specific attributes (shape, density, margins, etc.) of findings.",
42
+ "reasoning": "Explain the medical rationale and thought process behind findings and conclusions.",
43
+ }
44
+ CATEGORIES = list(CATEGORIES_META.keys())
45
+
46
+ CATEGORY_COMBINATIONS = [
47
+ ["detection", "localization", "characterization", "reasoning"], # Detailed Finding Analysis
48
+ ["detection", "classification", "relationship", "reasoning"], # Pattern Recognition & Relations
49
+ ["localization", "comparison", "relationship", "reasoning"], # Spatial Understanding
50
+ ["classification", "comparison", "diagnosis", "reasoning"], # Clinical Decision Making
51
+ ["classification", "characterization", "diagnosis", "reasoning"], # Diagnostic Characterization
52
+ ]
53
+
54
+ DEFAULT_SECTIONS = [
55
+ "history",
56
+ "image_finding",
57
+ "discussion",
58
+ "differential_diagnosis",
59
+ "diagnosis",
60
+ "figures",
61
+ ]
62
+
63
+
64
+ class Question:
65
+ """A class to generate clinical questions from case data.
66
+
67
+ This class handles creating structured clinical questions by combining case data with
68
+ specified categories and difficulty levels.
69
+
70
+ Attributes:
71
+ type (str): The type of question (e.g. multiple choice)
72
+ difficulty (str): Difficulty level of the question
73
+ case_data (Dict[str, Any]): Dictionary containing the clinical case data
74
+ case_content (str): Formatted case data from selected sections
75
+ case_id (str): Unique identifier for the case
76
+ categories (List[str]): List of analytical categories this question tests
77
+ sections (List[str]): Case sections to include in question
78
+ raw_content (Optional[str]): Raw LLM response to the question prompt
79
+ content (Optional[Dict[str, str]]): Extracted content from the raw LLM response
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ type: str,
85
+ difficulty: str,
86
+ case_data: Dict[str, Any],
87
+ categories: List[str],
88
+ sections: List[str] = [
89
+ "history",
90
+ "image_finding",
91
+ "discussion",
92
+ "differential_diagnosis",
93
+ "diagnosis",
94
+ "figures",
95
+ ],
96
+ system_prompt: str = "You are an expert medical benchmark creation assistant.",
97
+ ) -> None:
98
+ self.type = type
99
+ self.difficulty = difficulty
100
+ self.case_data = case_data
101
+ self.case_id = case_data["case_id"]
102
+ self.categories = categories
103
+ self.sections = sections
104
+ self.system_prompt = system_prompt
105
+ self.case_content = self.select_case_sections()
106
+ self.raw_content: Optional[str] = None
107
+ self.content: Optional[Dict[str, str]] = None
108
+
109
+ def create_question_prompt(self) -> str:
110
+ """Creates a formatted prompt for generating a clinical question.
111
+
112
+ Returns:
113
+ str: A structured prompt containing the question parameters and clinical data
114
+ """
115
+ category_descriptions = "\n".join(
116
+ f"{category}: {desc}"
117
+ for category, desc in CATEGORIES_META.items()
118
+ if category in self.categories
119
+ )
120
+
121
+ return f"""
122
+ You must follow these guidelines:
123
+ 1. Questions must be answerable using only context and chest X-rays.
124
+ - Questions must explicitly mention the referenced figures
125
+ - Questions can only reference the chest X-ray figures
126
+
127
+ 2. Questions must have unambiguous, verifiable answers, and should:
128
+ - Challenge the agent's analytical capabilities
129
+ - Require multi-step reasoning
130
+ - Test ability to make precise observations
131
+ - Evaluate capability to derive insights and findings from the chest X-ray
132
+
133
+ 3. The agent has access to tools like classification, report generation, segmentation, grounding, visual question answering, etc. Your question should be complex to require the use of such tools.
134
+
135
+
136
+ Create a {self.difficulty} {self.type} clinical question that integrates the following:
137
+
138
+ {category_descriptions}
139
+
140
+ based on the following clinical case:
141
+
142
+ {self.case_content}
143
+
144
+ Do not use any infomration derived from the CT and MRI images. Do not provide any information and findings about the chest X-rays.
145
+ Your question should require the agent to derive insights and findings from the chest X-ray by itself.
146
+ Your answer should be verifiable directly in the context of the case.
147
+ You can only use the image findings that come from the chest X-ray figures.
148
+
149
+ Your response must follow this exact format:
150
+ THOUGHTS: [Think about different reasoning steps and tools the agent should use to answer the question]
151
+ QUESTION: [complete question with relevant context. Incorrect choices should be very close to the correct answer.]
152
+ FIGURES: [list of required figures, e.g. ["Figure 1", "Figure 2a"]]
153
+ EXPLANATION: [short explanation of why your answer is verifiable in the case]
154
+ ANSWER: [correct answer e.g. "A"]
155
+ """.strip().replace(
156
+ " ", ""
157
+ ) # remove tabs
158
+
159
+ def select_case_sections(self) -> str:
160
+ """Extract and format selected sections from case data into paragraphs.
161
+
162
+ Returns:
163
+ str: Formatted string with case sections and content
164
+ """
165
+ section_mapping = {
166
+ "history": ("history", "No history provided."),
167
+ "image_finding": ("image_finding", "No findings provided."),
168
+ "discussion": ("discussion", "No discussion provided."),
169
+ "differential_diagnosis": (
170
+ "differential_diagnosis",
171
+ "No differential diagnosis provided.",
172
+ ),
173
+ "diagnosis": ("diagnosis", "No diagnosis provided."),
174
+ "figures": ("figures", "No figures provided."),
175
+ }
176
+
177
+ formatted = []
178
+ for section in self.sections:
179
+ if section in section_mapping:
180
+ key, default = section_mapping[section]
181
+ content = self.case_data.get(key, default)
182
+
183
+ if key == "figures":
184
+ figures_text = []
185
+ for figure in content:
186
+ for subfig in figure["subfigures"]:
187
+ figures_text.append(f"{subfig['number']}: {subfig['caption']}")
188
+ content = "\n".join(figures_text)
189
+
190
+ formatted.append(f"{section}:\n{content}")
191
+
192
+ return "\n\n".join(formatted)
193
+
194
+ def create_question(
195
+ self,
196
+ client: openai.OpenAI,
197
+ temperature: float = 0.7,
198
+ top_p: float = 0.95,
199
+ max_tokens: int = 500,
200
+ model: str = "gpt-4o",
201
+ ) -> str:
202
+ """Create a clinical question using LLM.
203
+
204
+ Args:
205
+ client (openai.OpenAI): OpenAI client instance
206
+ temperature (float): Controls randomness in responses. Defaults to 0.7.
207
+ top_p (float): Controls diversity via nucleus sampling. Defaults to 0.95.
208
+ max_tokens (int): Max tokens in model response. Defaults to 500.
209
+ model (str): OpenAI model to use. Defaults to "gpt-4o".
210
+
211
+ Returns:
212
+ str: LLM response containing formatted question components
213
+ """
214
+ self.raw_content = get_llm_response(
215
+ client=client,
216
+ prompt=self.create_question_prompt(),
217
+ system_prompt=self.system_prompt,
218
+ temperature=temperature,
219
+ top_p=top_p,
220
+ max_tokens=max_tokens,
221
+ model=model,
222
+ )
223
+ self.content = self.extract_content()
224
+
225
+ return self.raw_content
226
+
227
+ def extract_content(self) -> Dict[str, str]:
228
+ """Extract sections from raw LLM response using regex patterns.
229
+
230
+ Returns:
231
+ Dict[str, str]: Extracted sections including thoughts, question, figures, explanation, and answer
232
+ """
233
+ keywords = ["THOUGHTS", "QUESTION", "FIGURES", "EXPLANATION", "ANSWER"]
234
+
235
+ content = {}
236
+ for kw in keywords:
237
+ pattern = rf"{kw}:\s*(.*?)(?=\n[A-Z]+:|$)"
238
+ match = re.search(pattern, self.raw_content, re.DOTALL)
239
+ content[kw.lower()] = match.group(1).strip() if match else None
240
+
241
+ return content
242
+
243
+ def save(self, output_path: str) -> Dict[str, Any]:
244
+ """Save question content and metadata as a JSON file.
245
+
246
+ Args:
247
+ output_path (str): Directory path where the JSON file will be saved
248
+
249
+ Returns:
250
+ Dict[str, Any]: Question data including content (thoughts, question, figures, options,
251
+ explanation, answer) and metadata (type, difficulty, categories, etc.)
252
+ """
253
+ question_metadata = self.content.copy()
254
+
255
+ # Add metadata
256
+ question_metadata["metadata"] = {
257
+ "case_id": self.case_id,
258
+ "type": self.type,
259
+ "difficulty": self.difficulty,
260
+ "categories": self.categories,
261
+ "sections": self.sections,
262
+ }
263
+
264
+ # Create a directory for the case
265
+ case_dir = os.path.join(output_path, str(self.case_id))
266
+ os.makedirs(case_dir, exist_ok=True)
267
+
268
+ # Save the question metadata to a JSON file
269
+ output_file = os.path.join(case_dir, f"{self.case_id}_{self.__hash__()}.json")
270
+ with open(output_file, "w") as f:
271
+ json.dump(question_metadata, f, indent=2)
272
+
273
+ return question_metadata
274
+
275
+
276
+ def generate_questions(
277
+ dataset: Dict[str, Any],
278
+ client: openai.OpenAI,
279
+ output_dir: str,
280
+ skip_first: int = 100,
281
+ temperature: float = 0.7,
282
+ top_p: float = 0.95,
283
+ max_tokens: int = 1200,
284
+ model: str = "gpt-4o",
285
+ ) -> None:
286
+ """Generate questions for each case and category combination.
287
+
288
+ Args:
289
+ dataset: Dictionary of case data
290
+ client: OpenAI client instance
291
+ output_dir: Directory to save generated questions
292
+ skip_first: Number of initial cases to skip
293
+ temperature: LLM temperature parameter
294
+ top_p: LLM top_p parameter
295
+ max_tokens: Maximum tokens for LLM response
296
+ model: LLM model name
297
+ """
298
+ target_cases = sorted(list(dataset.keys()), key=int)[-len(dataset) : -skip_first]
299
+
300
+ for case_id in tqdm(target_cases, desc="Processing cases"):
301
+ case_data = dataset[case_id]
302
+
303
+ for category in tqdm(CATEGORY_COMBINATIONS, desc=f"Categories for case {case_id}"):
304
+ question = Question(
305
+ type="multiple choice (A/B/C/D/E/F)",
306
+ difficulty="complex",
307
+ case_data=case_data,
308
+ categories=category,
309
+ sections=DEFAULT_SECTIONS,
310
+ system_prompt=SYSTEM_PROMPT,
311
+ )
312
+
313
+ response = question.create_question(
314
+ client=client,
315
+ temperature=temperature,
316
+ top_p=top_p,
317
+ max_tokens=max_tokens,
318
+ model=model,
319
+ )
320
+ question.save(output_dir)
321
+
322
+
323
+ def main():
324
+ """Main execution function."""
325
+ client = openai.OpenAI()
326
+
327
+ # Load and verify dataset
328
+ dataset = load_eurorad_dataset(
329
+ DATASET_PATH,
330
+ section="Chest Imaging",
331
+ as_dict=True,
332
+ filter_by_caption=[
333
+ "xray",
334
+ "x-ray",
335
+ "x ray",
336
+ "ray",
337
+ "xr",
338
+ "radiograph",
339
+ ],
340
+ )
341
+ print(f"\n---\nFound {len(dataset)} cases with X-ray mentions\n---\n")
342
+
343
+ # Optional: Print sample case for verification
344
+ case_data = dataset["16798"]
345
+ pprint(case_data, sort_dicts=False)
346
+
347
+ # Generate questions
348
+ generate_questions(dataset=dataset, client=client, output_dir="benchmark/questions")
349
+
350
+
351
+ if __name__ == "__main__":
352
+ main()
benchmark/llm.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from typing import List
3
+
4
+
5
+ def get_llm_response(
6
+ client: openai.OpenAI,
7
+ prompt: str,
8
+ system_prompt: str = "You are a helpful assistant.",
9
+ model: str = "gpt-4o-mini",
10
+ temperature: float = 0.7,
11
+ top_p: float = 0.95,
12
+ max_tokens: int = 500,
13
+ ) -> str:
14
+ """
15
+ Get response from OpenAI language model.
16
+
17
+ Args:
18
+ client (openai.OpenAI): OpenAI client
19
+ prompt (str): The user prompt/question to send to the model
20
+ system_prompt (str, optional): System prompt to set model behavior.
21
+ model (str, optional): OpenAI model to use. Defaults to "gpt-4o-mini".
22
+ temperature (float, optional): Controls randomness in responses. Defaults to 0.7.
23
+ top_p (float, optional): Controls diversity via nucleus sampling. Defaults to 0.95.
24
+ max_tokens (int, optional): Max tokens in model response. Defaults to 200.
25
+
26
+ Returns:
27
+ str: The model's response text
28
+ """
29
+ messages = [
30
+ {"role": "system", "content": system_prompt},
31
+ {"role": "user", "content": prompt},
32
+ ]
33
+
34
+ response = client.chat.completions.create(
35
+ model=model,
36
+ messages=messages,
37
+ temperature=temperature,
38
+ top_p=top_p,
39
+ max_tokens=max_tokens,
40
+ )
41
+
42
+ return response.choices[0].message.content
benchmark/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Dict, List
4
+
5
+
6
+ def load_eurorad_dataset(
7
+ dataset_path: str,
8
+ section: str = "any",
9
+ as_dict: bool = False,
10
+ filter_by_caption: List[str] = [
11
+ "xray",
12
+ "x-ray",
13
+ "x ray",
14
+ "ray",
15
+ "xr",
16
+ "radiograph",
17
+ "radiogram",
18
+ "plain film",
19
+ ],
20
+ ) -> List[Dict] | Dict[str, Dict]:
21
+ """
22
+ Load a dataset from a JSON file.
23
+
24
+ Args:
25
+ dataset_path (str): Path to the JSON dataset file.
26
+ section (str, optional): Section of the dataset to load. Defaults to "any".
27
+ as_dict (bool, optional): Whether to return data as dict. Defaults to False.
28
+ filter_by_caption (List[str], optional): List of strings to filter cases by caption content. Defaults to [].
29
+
30
+ Returns:
31
+ List[Dict] | Dict[str, Dict]: The loaded dataset as a list of dictionaries or dict if as_dict=True.
32
+
33
+ Raises:
34
+ FileNotFoundError: If dataset_path does not exist
35
+ json.JSONDecodeError: If file is not valid JSON
36
+ """
37
+
38
+ with open(dataset_path, "r", encoding="utf-8") as file:
39
+ data = json.load(file)
40
+
41
+ if filter_by_caption:
42
+ filtered_data = {}
43
+ for case_id, case in data.items():
44
+ if any(
45
+ any(x in subfig["caption"].lower() for x in filter_by_caption)
46
+ for figure in case["figures"]
47
+ for subfig in figure["subfigures"]
48
+ ) or any(x in case["image_finding"].lower() for x in filter_by_caption):
49
+ filtered_data[case_id] = case
50
+ data = filtered_data
51
+
52
+ if section != "any":
53
+ section = section.strip().lower()
54
+ if not as_dict:
55
+ data = [
56
+ item for item in data.values() if item.get("section", "").strip().lower() == section
57
+ ]
58
+ else:
59
+ data = {
60
+ k: v for k, v in data.items() if v.get("section", "").strip().lower() == section
61
+ }
62
+
63
+ elif not as_dict:
64
+ data = list(data.values())
65
+
66
+ return data
67
+
68
+
69
+ def save_dataset(dataset: Dict | List[Dict], dataset_path: str):
70
+ """
71
+ Save a dataset to a JSON file.
72
+
73
+ Args:
74
+ dataset (Dict | List[Dict]): The dataset to save as a dictionary or list of dictionaries.
75
+ dataset_path (str): Path where the JSON dataset file will be saved.
76
+ """
77
+ with open(dataset_path, "w", encoding="utf-8") as file:
78
+ json.dump(dataset, file)
data/eurorad_metadata.json ADDED
The diff for this file is too large to render. See raw diff
 
data/figures.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ import requests
5
+ from tqdm import tqdm
6
+
7
+
8
+ def download_eurorad_figures(metadata_path: str, output_dir: str) -> None:
9
+ """
10
+ Download figures from Eurorad dataset and save them organized by case_id.
11
+
12
+ Args:
13
+ metadata_path: Path to the eurorad_metadata.json file
14
+ output_dir: Base directory where figures will be saved
15
+
16
+ The figures will be saved as:
17
+ {output_dir}/{case_id}/{figure_number}.jpg
18
+ Example:
19
+ figures/189/Figure_1a.jpg
20
+ """
21
+ # Create output directory if it doesn't exist
22
+ output_path = Path(output_dir)
23
+ output_path.mkdir(exist_ok=True)
24
+
25
+ # Load metadata
26
+ with open(metadata_path) as f:
27
+ metadata = json.load(f)
28
+
29
+ # Iterate through all cases with progress bar
30
+ for case_id in tqdm(metadata, desc="Downloading cases", unit="case"):
31
+ case = metadata[case_id]
32
+ case_dir = output_path / str(case["case_id"])
33
+ case_dir.mkdir(exist_ok=True)
34
+
35
+ # Process all figures and their subfigures
36
+ for figure in case["figures"]:
37
+ for subfig in figure["subfigures"]:
38
+
39
+ # Remove leading and trailing whitespace and convert to lowercase
40
+ subfig_name = f"{subfig['number'].strip().replace(' ', '_').lower()}.jpg"
41
+ subfig_path = Path(case_dir) / subfig_name
42
+
43
+ save_figure(
44
+ url=subfig["url"],
45
+ output_path=subfig_path,
46
+ )
47
+
48
+
49
+ def save_figure(url: str, output_path: Path) -> None:
50
+ """
51
+ Download and save a single figure.
52
+
53
+ Args:
54
+ url: URL of the figure to download
55
+ output_path: Path where the figure should be saved
56
+ """
57
+ if output_path.exists():
58
+ return
59
+
60
+ try:
61
+ response = requests.get(url, timeout=10)
62
+ response.raise_for_status()
63
+ with open(output_path, "wb") as f:
64
+ f.write(response.content)
65
+ except Exception as e:
66
+ print(f"Error downloading {url}: {e}")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ root = os.path.dirname(os.path.abspath(__file__))
71
+ download_eurorad_figures(
72
+ metadata_path=os.path.join(root, "eurorad_metadata.json"),
73
+ output_dir=os.path.join(root, "figures"),
74
+ )
data/get_cases.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import time
4
+ import json
5
+ from tqdm import tqdm
6
+
7
+
8
+ def get_response(url):
9
+ headers = {
10
+ "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36 Edg/108.0.1462.54"
11
+ }
12
+ return requests.get(url, headers=headers)
13
+
14
+ def get_case_numbers_from_page(page):
15
+ url = f"https://www.eurorad.org/advanced-search?sort_by=published_at&sort_order=ASC&page={page}&filter%5B0%5D=section%3A40"
16
+
17
+ # Remove proxy usage since it's likely triggering the protection
18
+ response = get_response(url)
19
+ print(response.text)
20
+
21
+ soup = BeautifulSoup(response.text, "html.parser")
22
+ spans = soup.find_all("span", class_="case__number small")
23
+
24
+ # Remove '#' from the span text and strip extra whitespace
25
+ numbers = [span.text.strip().replace("#", "").strip() for span in spans]
26
+ return numbers
27
+
28
+
29
+ def main():
30
+ total_pages = 107 # Pages 0 through 106
31
+ all_numbers = []
32
+
33
+ for page in tqdm(range(total_pages)):
34
+ numbers = get_case_numbers_from_page(page)
35
+ all_numbers.extend(numbers)
36
+
37
+ if page != total_pages - 1 and len(numbers) != 9:
38
+ print(f"Warning: Page {page} returned {len(numbers)} cases instead of 9")
39
+
40
+ # Be kind to the server – avoid hitting it too fast
41
+ time.sleep(1)
42
+ break
43
+
44
+ with open('case_numbers.json', 'w') as f:
45
+ json.dump(all_numbers, f)
46
+
47
+ print(f"Saved {len(all_numbers)} case numbers to case_numbers.json")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
data/stats/age_distribution.png ADDED

Git LFS Details

  • SHA256: 0409ec03f305ccd8fdee1c097dede52b7cf0f84f05b99fbd18727fb8e67238ad
  • Pointer size: 132 Bytes
  • Size of remote file: 2.71 MB
data/stats/area_of_interest_distribution.png ADDED

Git LFS Details

  • SHA256: 2a80d9aa1bf9b025b8aaa2b1c0d4807e36afc175747ba71b500ef1ceaf542081
  • Pointer size: 132 Bytes
  • Size of remote file: 2.91 MB
data/stats/gender_distribution.png ADDED

Git LFS Details

  • SHA256: a4cfd37f71fc91a848d990f6e2ff6c9611f555e09e435885680ffbbb85458838
  • Pointer size: 132 Bytes
  • Size of remote file: 1.96 MB
demo/chest/LIDC.dcm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11d25b1d34dff083057de994fef7da3dcef75bd7b334823ec6cb9c16b3ba0338
3
+ size 17071804
demo/chest/Pseudo.dcm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b35ae460fb5f62eb6d6c4c5117f6683100ad92c5fb6ba1a3c36da39703c4652
3
+ size 7535280
demo/chest/RIDER.dcm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc15f7afa5434991e1359f596433870ad611b42227db87d484d31976545de7fd
3
+ size 7534066
demo/chest/TCGAA.dcm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e8137290ac823d3da3c00ce3e18120123eaa62a786934c7afc52a989b0b64cf
3
+ size 7535274
demo/chest/__init__.py ADDED
File without changes
demo/chest/effusion1.png ADDED

Git LFS Details

  • SHA256: ba5af84601f11ab44142e5dfaf578b49d76de45633470e606c7edc4b1c77ba07
  • Pointer size: 131 Bytes
  • Size of remote file: 233 kB
demo/chest/normal1.jpg ADDED

Git LFS Details

  • SHA256: 785419c9ec7d0235fe056c254cd3be785d6052b558ae32c595ad558be57062dd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
demo/chest/normal2.jpg ADDED

Git LFS Details

  • SHA256: cecf56a8b90e9ccb3c54641beb40652e72a2bdcb311efc696a331fe4de7efbf0
  • Pointer size: 131 Bytes
  • Size of remote file: 798 kB
demo/chest/normal3.jpg ADDED

Git LFS Details

  • SHA256: 3f721831529e9604c99e3bd999483321e0e0648c5987351570fe45e48c190948
  • Pointer size: 132 Bytes
  • Size of remote file: 1.43 MB
demo/chest/normal4.jpg ADDED

Git LFS Details

  • SHA256: ed84d75328f1eb80c6554e3c6ba8dcd573e733914b2934bfce399ae6e8f38ec4
  • Pointer size: 131 Bytes
  • Size of remote file: 566 kB
demo/chest/normal5.jpg ADDED

Git LFS Details

  • SHA256: 9e7c4251d9b300f9256c6fe72ef1c3167beeecca747e6b9c8b80ee3260ea9ac8
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
demo/chest/normal6.jpg ADDED

Git LFS Details

  • SHA256: 4b47dd1665b828ab3610d1a60ec08c37083579f834b2dd5891570c8a105825a5
  • Pointer size: 131 Bytes
  • Size of remote file: 387 kB
demo/chest/pneumonia1.jpg ADDED
demo/chest/pneumonia2.jpg ADDED
demo/chest/pneumonia3.jpg ADDED
demo/chest/pneumonia4.jpg ADDED

Git LFS Details

  • SHA256: 8223cf57d33d1528782f83b62d3d62d2f41fe9bf34053553a86e609c2b2ba94b
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
demo/chest/pneumonia5.jpg ADDED

Git LFS Details

  • SHA256: 59bee7e6a36e7629a320e1c74d65dd0683c8310dbbb2489f5d32054419a3a667
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
experiments/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiments
2
+ Below are the instructions for running experiments using our novel ChestAgentBench and the previous SoTA CheXbench. ChestAgentBench is a comprehensive benchmark containing over 2,500 complex medical queries across 8 diverse categories.
3
+
4
+ ### ChestAgentBench
5
+
6
+ To run gpt-4o on ChestAgentBench, enter the `experiments` directory and run the following script:
7
+ ```bash
8
+ python benchmark_gpt4o.py
9
+ ```
10
+
11
+ To run llama 3.2 vision 90B on ChestAgentBench, run the following:
12
+ ```bash
13
+ python benchmark_llama.py
14
+ ```
15
+
16
+ To run chexagent on ChestAgentBench, run the following:
17
+ ```bash
18
+ python benchmark_chexagent.py
19
+ ```
20
+
21
+ To run llava-med on ChestAgentBench, you'll need to clone their repo and copy the following script into it, after you follow their setup instructions.
22
+ ```bash
23
+ mv benchmark_llavamed.py ~/LLaVA-Med/llava/serve
24
+ python -m llava.serve.benchmark_llavamed --model-name llava-med-v1.5-mistral-7b --controller http://localhost:10000
25
+ ```
26
+
27
+ If you want to inspect the logs, you can run the following. It will select the most recent log file by default.
28
+ ```bash
29
+ python inspect_logs.py [optional: log-file] -n [num-logs]
30
+ ```
31
+
32
+ Finally, to analyze results, run:
33
+ ```bash
34
+ python analyze_axes.py results/[logfile].json ../benchmark/questions/ --model [gpt4|llama|chexagent|llava-med] --max-questions [optional:int]
35
+ ```
36
+
37
+ ### CheXbench
38
+
39
+ To run the models on chexbench, you can use `chexbench_gpt4.py` as a reference. You'll need to download the dataset files locally, and upload them for each request. Rad-ReStruct and Open-I use the same set of images, so you can download the `NLMCXR.zip` file just once and copy the images to both directories.
40
+
41
+ You can find the datasets here:
42
+ 1. [SLAKE: A Semantically-Labeled Knowledge-Enhanced Dataset for Medical Visual Question Answering](https://www.med-vqa.com/slake/). Save this to `MedMAX/data/slake`.
43
+ 2. [Rad-ReStruct: A Novel VQA Benchmark and Method for Structured Radiology Reporting](https://github.com/ChantalMP/Rad-ReStruct). Save the images to `MedMAX/data/rad-restruct/images`.
44
+ 3. [Open-I Service of the National Library of Medicine](https://openi.nlm.nih.gov/faq). Save the images to `MedMAX/data/openi/images`.
45
+
46
+ Once you're finished, you'll want to fix the paths in the `chexbench.json` file to your local paths using the `MedMax/data/fix_chexbench.py` script.
47
+
48
+
49
+ ### Compare Runs
50
+ Analyze a single file based on overall accuracy and along different axes
51
+ ```
52
+ python compare_runs.py results/medmax.json
53
+ ```
54
+
55
+ For a direct evaluation comparing **2** models, on the exact same questions
56
+ ```
57
+ python compare_runs.py results/medmax.json results/gpt4o.json
58
+ ```
59
+
60
+ For a direct evaluation comparing **ALL** models, on the exact same questions (add as many model log files as you want).
61
+ ```
62
+ python compare_runs.py results/medmax.json results/gpt4o.json results/llama.json results/chexagent.json results/llavamed.json
63
+ ```
experiments/analyze_axes.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union, Any
2
+ import json
3
+ import os
4
+ import sys
5
+ import argparse
6
+ from collections import defaultdict
7
+ from tqdm import tqdm
8
+
9
+ QUESTION_TYPES = {
10
+ "Detailed Finding Analysis": ["detection", "localization", "characterization"],
11
+ "Pattern Recognition & Relations": ["detection", "classification", "relationship"],
12
+ "Spatial Understanding": ["localization", "comparison", "relationship"],
13
+ "Clinical Decision Making": ["classification", "comparison", "diagnosis"],
14
+ "Diagnostic Classification": ["classification", "characterization", "diagnosis"],
15
+ }
16
+
17
+
18
+ def extract_answer_letter(answer: Optional[Union[str, Any]]) -> Optional[str]:
19
+ """
20
+ Extract just the letter from various answer formats.
21
+
22
+ Args:
23
+ answer: The answer text to extract letter from
24
+
25
+ Returns:
26
+ Optional[str]: The extracted letter in uppercase, or None if no letter found
27
+ """
28
+ if not answer:
29
+ return None
30
+
31
+ # Convert to string and clean
32
+ answer = str(answer).strip()
33
+
34
+ # If it's just a single letter, return it
35
+ if len(answer) == 1 and answer.isalpha():
36
+ return answer.upper()
37
+
38
+ # Try to extract letter from format like "A)" or "A."
39
+ if len(answer) >= 2 and answer[0].isalpha() and answer[1] in ").:- ":
40
+ return answer[0].upper()
41
+
42
+ # Try to extract letter from format like "A) Some text"
43
+ if answer.startswith(("A)", "B)", "C)", "D)", "E)", "F)")):
44
+ return answer[0].upper()
45
+
46
+ return None
47
+
48
+
49
+ def analyze_gpt4_results(
50
+ results_file: str, max_questions: Optional[int] = None
51
+ ) -> Tuple[float, Dict, Dict, List[str], List[str]]:
52
+ """
53
+ Analyze results in GPT-4 format.
54
+
55
+ Args:
56
+ results_file: Path to results file
57
+ max_questions: Maximum number of questions to analyze
58
+
59
+ Returns:
60
+ Tuple containing:
61
+ - overall_accuracy (float)
62
+ - category_accuracies (Dict)
63
+ - question_type_stats (Dict)
64
+ - correct_ids (List[str])
65
+ - incorrect_ids (List[str])
66
+ """
67
+ category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
68
+ all_questions = 0
69
+ all_correct = 0
70
+ correct_ids = []
71
+ incorrect_ids = []
72
+
73
+ with open(results_file, "r") as f:
74
+ lines = f.readlines()
75
+
76
+ processed_questions = 0
77
+
78
+ for line in tqdm(lines, desc="Analyzing Benchmark Results"):
79
+ # Check if we've hit the maximum questions
80
+ if max_questions is not None and processed_questions >= max_questions:
81
+ break
82
+ if line.startswith("HTTP Request:"):
83
+ continue
84
+
85
+ try:
86
+ entry = json.loads(line)
87
+ metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
88
+ question_id = entry.get("question_id")
89
+
90
+ model_letter = extract_answer_letter(entry.get("model_answer"))
91
+ correct_letter = extract_answer_letter(entry.get("correct_answer"))
92
+
93
+ if model_letter and correct_letter:
94
+ all_questions += 1
95
+ processed_questions += 1
96
+ is_correct = model_letter == correct_letter
97
+
98
+ if is_correct:
99
+ all_correct += 1
100
+ correct_ids.append(question_id)
101
+ else:
102
+ incorrect_ids.append(question_id)
103
+
104
+ for category in metadata.get("categories", []):
105
+ category_performance[category]["total"] += 1
106
+ if is_correct:
107
+ category_performance[category]["correct"] += 1
108
+
109
+ except json.JSONDecodeError:
110
+ continue
111
+
112
+ return process_results(
113
+ category_performance, all_questions, all_correct, correct_ids, incorrect_ids
114
+ )
115
+
116
+
117
+ def analyze_llama_results(
118
+ results_file: str, max_questions: Optional[int] = None
119
+ ) -> Tuple[float, Dict, Dict, List[str], List[str]]:
120
+ """
121
+ Analyze results in Llama format.
122
+
123
+ Args:
124
+ results_file: Path to results file
125
+ max_questions: Maximum number of questions to analyze
126
+
127
+ Returns:
128
+ Tuple containing:
129
+ - overall_accuracy (float)
130
+ - category_accuracies (Dict)
131
+ - question_type_stats (Dict)
132
+ - correct_ids (List[str])
133
+ - incorrect_ids (List[str])
134
+ """
135
+ category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
136
+ all_questions = 0
137
+ all_correct = 0
138
+ correct_ids = []
139
+ incorrect_ids = []
140
+
141
+ with open(results_file, "r") as f:
142
+ lines = f.readlines()
143
+
144
+ # If max_questions is set, limit the number of lines processed
145
+ if max_questions is not None:
146
+ lines = lines[:max_questions]
147
+
148
+ for line in tqdm(lines, desc="Analyzing Benchmark Results"):
149
+ if line.startswith("HTTP Request:"):
150
+ continue
151
+
152
+ try:
153
+ entry = json.loads(line)
154
+ metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
155
+ question_id = entry.get("question_id")
156
+
157
+ model_letter = extract_answer_letter(entry.get("model_answer"))
158
+ correct_letter = extract_answer_letter(entry.get("correct_answer"))
159
+
160
+ if model_letter and correct_letter:
161
+ all_questions += 1
162
+ is_correct = model_letter == correct_letter
163
+
164
+ if is_correct:
165
+ all_correct += 1
166
+ correct_ids.append(question_id)
167
+ else:
168
+ incorrect_ids.append(question_id)
169
+
170
+ for category in metadata.get("categories", []):
171
+ category_performance[category]["total"] += 1
172
+ if is_correct:
173
+ category_performance[category]["correct"] += 1
174
+
175
+ except json.JSONDecodeError:
176
+ continue
177
+
178
+ return process_results(
179
+ category_performance, all_questions, all_correct, correct_ids, incorrect_ids
180
+ )
181
+
182
+
183
+ def analyze_chexagent_results(
184
+ results_file: str, max_questions: Optional[int] = None
185
+ ) -> Tuple[float, Dict, Dict, List[str], List[str]]:
186
+ """
187
+ Analyze results in CheXagent format.
188
+
189
+ Args:
190
+ results_file: Path to results file
191
+ max_questions: Maximum number of questions to analyze
192
+
193
+ Returns:
194
+ Tuple containing:
195
+ - overall_accuracy (float)
196
+ - category_accuracies (Dict)
197
+ - question_type_stats (Dict)
198
+ - correct_ids (List[str])
199
+ - incorrect_ids (List[str])
200
+ """
201
+ category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
202
+ all_questions = 0
203
+ all_correct = 0
204
+ correct_ids = []
205
+ incorrect_ids = []
206
+
207
+ with open(results_file, "r") as f:
208
+ lines = f.readlines()
209
+
210
+ # If max_questions is set, limit the number of lines processed
211
+ if max_questions is not None:
212
+ lines = lines[:max_questions]
213
+
214
+ for line in tqdm(lines, desc="Analyzing Benchmark Results"):
215
+ try:
216
+ entry = json.loads(line)
217
+ metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
218
+ question_id = entry.get("question_id")
219
+
220
+ model_letter = extract_answer_letter(entry.get("model_answer"))
221
+ correct_letter = extract_answer_letter(entry.get("correct_answer"))
222
+
223
+ if model_letter and correct_letter:
224
+ all_questions += 1
225
+ is_correct = model_letter == correct_letter
226
+
227
+ if is_correct:
228
+ all_correct += 1
229
+ correct_ids.append(question_id)
230
+ else:
231
+ incorrect_ids.append(question_id)
232
+
233
+ for category in metadata.get("categories", []):
234
+ category_performance[category]["total"] += 1
235
+ if is_correct:
236
+ category_performance[category]["correct"] += 1
237
+
238
+ except json.JSONDecodeError:
239
+ continue
240
+
241
+ return process_results(
242
+ category_performance, all_questions, all_correct, correct_ids, incorrect_ids
243
+ )
244
+
245
+
246
+ def process_results(
247
+ category_performance: Dict,
248
+ all_questions: int,
249
+ all_correct: int,
250
+ correct_ids: Optional[List[str]] = None,
251
+ incorrect_ids: Optional[List[str]] = None,
252
+ ) -> Tuple[float, Dict, Dict, List[str], List[str]]:
253
+ """
254
+ Process raw results into final statistics.
255
+
256
+ Args:
257
+ category_performance: Dict containing performance by category
258
+ all_questions: Total number of questions
259
+ all_correct: Total number of correct answers
260
+ correct_ids: List of IDs for correctly answered questions
261
+ incorrect_ids: List of IDs for incorrectly answered questions
262
+
263
+ Returns:
264
+ Tuple containing:
265
+ - overall_accuracy (float)
266
+ - category_accuracies (Dict)
267
+ - question_type_stats (Dict)
268
+ - correct_ids (List[str])
269
+ - incorrect_ids (List[str])
270
+ """
271
+ category_accuracies = {
272
+ category: {
273
+ "accuracy": stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0,
274
+ "total": stats["total"],
275
+ "correct": stats["correct"],
276
+ }
277
+ for category, stats in category_performance.items()
278
+ }
279
+
280
+ question_type_stats = {}
281
+ for qtype, categories in QUESTION_TYPES.items():
282
+ total = sum(
283
+ category_performance[cat]["total"] for cat in categories if cat in category_performance
284
+ )
285
+ correct = sum(
286
+ category_performance[cat]["correct"]
287
+ for cat in categories
288
+ if cat in category_performance
289
+ )
290
+
291
+ question_type_stats[qtype] = {
292
+ "accuracy": (correct / total * 100) if total > 0 else 0,
293
+ "total": total,
294
+ "correct": correct,
295
+ }
296
+
297
+ overall_accuracy = (all_correct / all_questions * 100) if all_questions > 0 else 0
298
+
299
+ return (
300
+ overall_accuracy,
301
+ category_accuracies,
302
+ question_type_stats,
303
+ correct_ids or [],
304
+ incorrect_ids or [],
305
+ )
306
+
307
+
308
+ def print_analysis(
309
+ overall_accuracy: float,
310
+ category_accuracies: Dict,
311
+ question_type_stats: Dict,
312
+ correct_ids: List[str],
313
+ incorrect_ids: List[str],
314
+ model_name: str,
315
+ ) -> None:
316
+ """
317
+ Print analysis results.
318
+
319
+ Args:
320
+ overall_accuracy: Overall accuracy percentage
321
+ category_accuracies: Dict containing accuracy metrics by category
322
+ question_type_stats: Dict containing stats by question type
323
+ correct_ids: List of IDs for correctly answered questions
324
+ incorrect_ids: List of IDs for incorrectly answered questions
325
+ model_name: Name of the model being analyzed
326
+ """
327
+ total_questions = len(correct_ids) + len(incorrect_ids)
328
+ print(
329
+ f"\nOverall Accuracy: {overall_accuracy:.2f}% ({len(correct_ids)} correct out of {total_questions} questions)"
330
+ )
331
+
332
+ print("\nCategory Performance:")
333
+ sorted_categories = sorted(
334
+ category_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True
335
+ )
336
+ for category, metrics in sorted_categories:
337
+ print(f"{category}:")
338
+ print(f" Accuracy: {metrics['accuracy']:.2f}%")
339
+ print(f" Total Questions: {metrics['total']}")
340
+ print(f" Correct Questions: {metrics['correct']}")
341
+
342
+ print("\nQuestion Type Performance:")
343
+ sorted_types = sorted(question_type_stats.items(), key=lambda x: x[1]["accuracy"], reverse=True)
344
+ for qtype, metrics in sorted_types:
345
+ print(f"\n{qtype}:")
346
+ print(f" Accuracy: {metrics['accuracy']:.2f}%")
347
+ print(f" Total Questions: {metrics['total']}")
348
+ print(f" Correct Questions: {metrics['correct']}")
349
+ print(f" Categories: {', '.join(QUESTION_TYPES[qtype])}")
350
+
351
+ # Save question IDs to JSON
352
+ question_ids = {"correct_ids": correct_ids, "incorrect_ids": incorrect_ids}
353
+
354
+ output_filename = f"{model_name}_question_ids.json"
355
+ with open(output_filename, "w") as f:
356
+ json.dump(question_ids, f, indent=2)
357
+
358
+ print(f"\nQuestion IDs have been saved to {output_filename}")
359
+
360
+
361
+ if __name__ == "__main__":
362
+ parser = argparse.ArgumentParser(description="Analyze benchmark results")
363
+ parser.add_argument("results_file", help="Path to results file")
364
+ parser.add_argument("benchmark_dir", nargs="?", help="Path to benchmark questions directory")
365
+ parser.add_argument(
366
+ "--model",
367
+ choices=["llava-med", "chexagent", "llama", "gpt4", "medrax"],
368
+ default="gpt4",
369
+ help="Specify model format (default: gpt4)",
370
+ )
371
+ parser.add_argument("--max-questions", type=int, help="Maximum number of questions to analyze")
372
+ args = parser.parse_args()
373
+
374
+ if args.model == "gpt4":
375
+ results = analyze_gpt4_results(args.results_file, args.max_questions)
376
+ elif args.model == "llama":
377
+ results = analyze_llama_results(args.results_file, args.max_questions)
378
+ elif args.model == "chexagent":
379
+ results = analyze_chexagent_results(args.results_file, args.max_questions)
380
+ elif args.model == "medrax":
381
+ results = analyze_gpt4_results(args.results_file, args.max_questions)
382
+ else:
383
+ parser.error(f"Unsupported model: {args.model}")
384
+
385
+ print_analysis(*results, args.model)
experiments/benchmark_chexagent.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import os
4
+ import glob
5
+ import time
6
+ import logging
7
+ from datetime import datetime
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from tqdm import tqdm
12
+
13
+ # Configure model settings
14
+ MODEL_NAME = "StanfordAIMI/CheXagent-2-3b"
15
+ DTYPE = torch.bfloat16
16
+ DEVICE = "cuda"
17
+
18
+ # Configure logging
19
+ log_filename = f"model_inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
20
+ logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
21
+
22
+
23
+ def initialize_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
24
+ """Initialize the CheXagent model and tokenizer.
25
+
26
+ Returns:
27
+ tuple containing:
28
+ - AutoModelForCausalLM: The initialized CheXagent model
29
+ - AutoTokenizer: The initialized tokenizer
30
+ """
31
+ print("Loading model and tokenizer...")
32
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ MODEL_NAME, device_map="auto", trust_remote_code=True
35
+ )
36
+ model = model.to(DTYPE)
37
+ model.eval()
38
+ return model, tokenizer
39
+
40
+
41
+ def create_inference_request(
42
+ question_data: dict,
43
+ case_details: dict,
44
+ case_id: str,
45
+ question_id: str,
46
+ model: AutoModelForCausalLM,
47
+ tokenizer: AutoTokenizer,
48
+ ) -> str | None:
49
+ """Create and execute an inference request for the CheXagent model.
50
+
51
+ Args:
52
+ question_data: Dictionary containing question details and metadata
53
+ case_details: Dictionary containing case information and image paths
54
+ case_id: Unique identifier for the medical case
55
+ question_id: Unique identifier for the question
56
+ model: The initialized CheXagent model
57
+ tokenizer: The initialized tokenizer
58
+
59
+ Returns:
60
+ str | None: Single letter answer (A-F) if successful, None if failed
61
+ """
62
+ system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
63
+ Rules:
64
+ 1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
65
+ 2. Do not add periods, explanations, or any other text
66
+ 3. Do not use markdown or formatting
67
+ 4. Do not restate the question
68
+ 5. Do not explain your reasoning
69
+
70
+ Examples of valid responses:
71
+ A
72
+ B
73
+ C
74
+
75
+ Examples of invalid responses:
76
+ "A."
77
+ "Answer: B"
78
+ "C) This shows..."
79
+ "The answer is D"
80
+ """
81
+
82
+ prompt = f"""Given the following medical case:
83
+ Please answer this multiple choice question:
84
+ {question_data['question']}
85
+ Base your answer only on the provided images and case information."""
86
+
87
+ # Parse required figures
88
+ try:
89
+ if isinstance(question_data["figures"], str):
90
+ try:
91
+ required_figures = json.loads(question_data["figures"])
92
+ except json.JSONDecodeError:
93
+ required_figures = [question_data["figures"]]
94
+ elif isinstance(question_data["figures"], list):
95
+ required_figures = question_data["figures"]
96
+ else:
97
+ required_figures = [str(question_data["figures"])]
98
+ except Exception as e:
99
+ print(f"Error parsing figures: {e}")
100
+ required_figures = []
101
+
102
+ required_figures = [
103
+ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
104
+ ]
105
+
106
+ # Get image paths
107
+ image_paths = []
108
+ for figure in required_figures:
109
+ base_figure_num = "".join(filter(str.isdigit, figure))
110
+ figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
111
+
112
+ matching_figures = [
113
+ case_figure
114
+ for case_figure in case_details.get("figures", [])
115
+ if case_figure["number"] == f"Figure {base_figure_num}"
116
+ ]
117
+
118
+ for case_figure in matching_figures:
119
+ subfigures = []
120
+ if figure_letter:
121
+ subfigures = [
122
+ subfig
123
+ for subfig in case_figure.get("subfigures", [])
124
+ if subfig.get("number", "").lower().endswith(figure_letter.lower())
125
+ or subfig.get("label", "").lower() == figure_letter.lower()
126
+ ]
127
+ else:
128
+ subfigures = case_figure.get("subfigures", [])
129
+
130
+ for subfig in subfigures:
131
+ if "local_path" in subfig:
132
+ image_paths.append("medrax/data/" + subfig["local_path"])
133
+
134
+ if not image_paths:
135
+ print(f"No local images found for case {case_id}, question {question_id}")
136
+ return None
137
+
138
+ try:
139
+ start_time = time.time()
140
+
141
+ # Prepare input for the model
142
+ query = tokenizer.from_list_format(
143
+ [*[{"image": path} for path in image_paths], {"text": prompt}]
144
+ )
145
+ conv = [{"from": "system", "value": system_prompt}, {"from": "human", "value": query}]
146
+ input_ids = tokenizer.apply_chat_template(
147
+ conv, add_generation_prompt=True, return_tensors="pt"
148
+ )
149
+
150
+ # Generate response
151
+ with torch.no_grad():
152
+ output = model.generate(
153
+ input_ids.to(DEVICE),
154
+ do_sample=False,
155
+ num_beams=1,
156
+ temperature=1.0,
157
+ top_p=1.0,
158
+ use_cache=True,
159
+ max_new_tokens=512,
160
+ )[0]
161
+
162
+ response = tokenizer.decode(output[input_ids.size(1) : -1])
163
+ duration = time.time() - start_time
164
+
165
+ # Clean response
166
+ clean_answer = validate_answer(response)
167
+
168
+ # Log response
169
+ log_entry = {
170
+ "case_id": case_id,
171
+ "question_id": question_id,
172
+ "timestamp": datetime.now().isoformat(),
173
+ "model": MODEL_NAME,
174
+ "duration": round(duration, 2),
175
+ "model_answer": clean_answer,
176
+ "correct_answer": question_data["answer"],
177
+ "input": {
178
+ "question_data": {
179
+ "question": question_data["question"],
180
+ "explanation": question_data["explanation"],
181
+ "metadata": question_data.get("metadata", {}),
182
+ "figures": question_data["figures"],
183
+ },
184
+ "image_paths": image_paths,
185
+ },
186
+ }
187
+ logging.info(json.dumps(log_entry))
188
+ return clean_answer
189
+
190
+ except Exception as e:
191
+ print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
192
+ log_entry = {
193
+ "case_id": case_id,
194
+ "question_id": question_id,
195
+ "timestamp": datetime.now().isoformat(),
196
+ "model": MODEL_NAME,
197
+ "status": "error",
198
+ "error": str(e),
199
+ "input": {
200
+ "question_data": {
201
+ "question": question_data["question"],
202
+ "explanation": question_data["explanation"],
203
+ "metadata": question_data.get("metadata", {}),
204
+ "figures": question_data["figures"],
205
+ },
206
+ "image_paths": image_paths,
207
+ },
208
+ }
209
+ logging.info(json.dumps(log_entry))
210
+ return None
211
+
212
+
213
+ def validate_answer(response_text: str) -> str | None:
214
+ """Enforce strict single-letter response format.
215
+
216
+ Args:
217
+ response_text: Raw response text from the model
218
+
219
+ Returns:
220
+ str | None: Single uppercase letter (A-F) if valid, None if invalid
221
+ """
222
+ if not response_text:
223
+ return None
224
+
225
+ # Remove all whitespace and convert to uppercase
226
+ cleaned = response_text.strip().upper()
227
+
228
+ # Check if it's exactly one valid letter
229
+ if len(cleaned) == 1 and cleaned in "ABCDEF":
230
+ return cleaned
231
+
232
+ # If not, try to extract just the letter
233
+ match = re.search(r"([A-F])", cleaned)
234
+ return match.group(1) if match else None
235
+
236
+
237
+ def load_benchmark_questions(case_id: str) -> list[str]:
238
+ """Find all question files for a given case ID.
239
+
240
+ Args:
241
+ case_id: Unique identifier for the medical case
242
+
243
+ Returns:
244
+ list[str]: List of paths to question JSON files
245
+ """
246
+ benchmark_dir = "../benchmark/questions"
247
+ return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
248
+
249
+
250
+ def count_total_questions() -> tuple[int, int]:
251
+ """Count total number of cases and questions in benchmark.
252
+
253
+ Returns:
254
+ tuple containing:
255
+ - int: Total number of cases
256
+ - int: Total number of questions
257
+ """
258
+ total_cases = len(glob.glob("../benchmark/questions/*"))
259
+ total_questions = sum(
260
+ len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
261
+ for case_id in os.listdir("../benchmark/questions")
262
+ )
263
+ return total_cases, total_questions
264
+
265
+
266
+ def main():
267
+ # Load the cases with local paths
268
+ with open("medrax/data/updated_cases.json", "r") as file:
269
+ data = json.load(file)
270
+
271
+ # Initialize model and tokenizer
272
+ model, tokenizer = initialize_model()
273
+
274
+ total_cases, total_questions = count_total_questions()
275
+ cases_processed = 0
276
+ questions_processed = 0
277
+ skipped_questions = 0
278
+
279
+ print(f"\nBeginning inference with {MODEL_NAME}")
280
+ print(f"Found {total_cases} cases with {total_questions} total questions")
281
+
282
+ # Process each case with progress bar
283
+ for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
284
+ question_files = load_benchmark_questions(case_id)
285
+ if not question_files:
286
+ continue
287
+
288
+ cases_processed += 1
289
+ for question_file in tqdm(
290
+ question_files, desc=f"Processing questions for case {case_id}", leave=False
291
+ ):
292
+ with open(question_file, "r") as file:
293
+ question_data = json.load(file)
294
+ question_id = os.path.basename(question_file).split(".")[0]
295
+
296
+ questions_processed += 1
297
+ answer = create_inference_request(
298
+ question_data, case_details, case_id, question_id, model, tokenizer
299
+ )
300
+
301
+ if answer is None:
302
+ skipped_questions += 1
303
+ continue
304
+
305
+ print(f"\nCase {case_id}, Question {question_id}")
306
+ print(f"Model Answer: {answer}")
307
+ print(f"Correct Answer: {question_data['answer']}")
308
+
309
+ print(f"\nInference Summary:")
310
+ print(f"Total Cases Processed: {cases_processed}")
311
+ print(f"Total Questions Processed: {questions_processed}")
312
+ print(f"Total Questions Skipped: {skipped_questions}")
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
experiments/benchmark_gpt4o.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import openai
3
+ import os
4
+ import glob
5
+ import time
6
+ import logging
7
+ from datetime import datetime
8
+ from tenacity import retry, wait_exponential, stop_after_attempt
9
+
10
+ model_name = "chatgpt-4o-latest"
11
+ temperature = 0.2
12
+ log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
13
+ logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
14
+
15
+
16
+ def calculate_cost(
17
+ prompt_tokens: int, completion_tokens: int, model: str = "chatgpt-4o-latest"
18
+ ) -> float:
19
+ """Calculate the cost of API usage based on token counts.
20
+
21
+ Args:
22
+ prompt_tokens: Number of tokens in the prompt
23
+ completion_tokens: Number of tokens in the completion
24
+ model: Model name to use for pricing, defaults to chatgpt-4o-latest
25
+
26
+ Returns:
27
+ float: Cost in USD
28
+ """
29
+ pricing = {"chatgpt-4o-latest": {"prompt": 5.0, "completion": 15.0}}
30
+ rates = pricing.get(model, {"prompt": 5.0, "completion": 15.0})
31
+ return (prompt_tokens * rates["prompt"] + completion_tokens * rates["completion"]) / 1000000
32
+
33
+
34
+ @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
35
+ def create_multimodal_request(
36
+ question_data: dict, case_details: dict, case_id: str, question_id: str, client: openai.OpenAI
37
+ ) -> openai.types.chat.ChatCompletion:
38
+ """Create and send a multimodal request to the OpenAI API.
39
+
40
+ Args:
41
+ question_data: Dictionary containing question details and figures
42
+ case_details: Dictionary containing case information and figures
43
+ case_id: Identifier for the medical case
44
+ question_id: Identifier for the specific question
45
+ client: OpenAI client instance
46
+
47
+ Returns:
48
+ openai.types.chat.ChatCompletion: API response object, or None if request fails
49
+ """
50
+ prompt = f"""Given the following medical case:
51
+ Please answer this multiple choice question:
52
+ {question_data['question']}
53
+ Base your answer only on the provided images and case information."""
54
+
55
+ content = [{"type": "text", "text": prompt}]
56
+
57
+ # Parse required figures
58
+ try:
59
+ # Try multiple ways of parsing figures
60
+ if isinstance(question_data["figures"], str):
61
+ try:
62
+ required_figures = json.loads(question_data["figures"])
63
+ except json.JSONDecodeError:
64
+ required_figures = [question_data["figures"]]
65
+ elif isinstance(question_data["figures"], list):
66
+ required_figures = question_data["figures"]
67
+ else:
68
+ required_figures = [str(question_data["figures"])]
69
+ except Exception as e:
70
+ print(f"Error parsing figures: {e}")
71
+ required_figures = []
72
+
73
+ # Ensure each figure starts with "Figure "
74
+ required_figures = [
75
+ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
76
+ ]
77
+
78
+ subfigures = []
79
+ for figure in required_figures:
80
+ # Handle both regular figures and those with letter suffixes
81
+ base_figure_num = "".join(filter(str.isdigit, figure))
82
+ figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
83
+
84
+ # Find matching figures in case details
85
+ matching_figures = [
86
+ case_figure
87
+ for case_figure in case_details.get("figures", [])
88
+ if case_figure["number"] == f"Figure {base_figure_num}"
89
+ ]
90
+
91
+ if not matching_figures:
92
+ print(f"No matching figure found for {figure} in case {case_id}")
93
+ continue
94
+
95
+ for case_figure in matching_figures:
96
+ # If a specific letter is specified, filter subfigures
97
+ if figure_letter:
98
+ matching_subfigures = [
99
+ subfig
100
+ for subfig in case_figure.get("subfigures", [])
101
+ if subfig.get("number", "").lower().endswith(figure_letter.lower())
102
+ or subfig.get("label", "").lower() == figure_letter.lower()
103
+ ]
104
+ subfigures.extend(matching_subfigures)
105
+ else:
106
+ # If no letter specified, add all subfigures
107
+ subfigures.extend(case_figure.get("subfigures", []))
108
+
109
+ # Add images to content
110
+ for subfig in subfigures:
111
+ if "url" in subfig:
112
+ content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
113
+ else:
114
+ print(f"Subfigure missing URL: {subfig}")
115
+
116
+ # If no images found, log and return None
117
+ if len(content) == 1: # Only the text prompt exists
118
+ print(f"No images found for case {case_id}, question {question_id}")
119
+ return None
120
+
121
+ messages = [
122
+ {
123
+ "role": "system",
124
+ "content": "You are a medical imaging expert. Provide only the letter corresponding to your answer choice (A/B/C/D/E/F).",
125
+ },
126
+ {"role": "user", "content": content},
127
+ ]
128
+
129
+ if len(content) == 1: # Only the text prompt exists
130
+ print(f"No images found for case {case_id}, question {question_id}")
131
+ log_entry = {
132
+ "case_id": case_id,
133
+ "question_id": question_id,
134
+ "timestamp": datetime.now().isoformat(),
135
+ "model": model_name,
136
+ "temperature": temperature,
137
+ "status": "skipped",
138
+ "reason": "no_images",
139
+ "cost": 0,
140
+ "input": {
141
+ "messages": messages,
142
+ "question_data": {
143
+ "question": question_data["question"],
144
+ "explanation": question_data["explanation"],
145
+ "metadata": question_data.get("metadata", {}),
146
+ "figures": question_data["figures"],
147
+ },
148
+ "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
149
+ "image_captions": [subfig.get("caption", "") for subfig in subfigures],
150
+ },
151
+ }
152
+ logging.info(json.dumps(log_entry))
153
+ return None
154
+
155
+ try:
156
+ start_time = time.time()
157
+
158
+ response = client.chat.completions.create(
159
+ model=model_name, messages=messages, max_tokens=50, temperature=temperature
160
+ )
161
+ duration = time.time() - start_time
162
+
163
+ log_entry = {
164
+ "case_id": case_id,
165
+ "question_id": question_id,
166
+ "timestamp": datetime.now().isoformat(),
167
+ "model": model_name,
168
+ "temperature": temperature,
169
+ "duration": round(duration, 2),
170
+ "usage": {
171
+ "prompt_tokens": response.usage.prompt_tokens,
172
+ "completion_tokens": response.usage.completion_tokens,
173
+ "total_tokens": response.usage.total_tokens,
174
+ },
175
+ "cost": calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens),
176
+ "model_answer": response.choices[0].message.content,
177
+ "correct_answer": question_data["answer"],
178
+ "input": {
179
+ "messages": messages,
180
+ "question_data": {
181
+ "question": question_data["question"],
182
+ "explanation": question_data["explanation"],
183
+ "metadata": question_data.get("metadata", {}),
184
+ "figures": question_data["figures"],
185
+ },
186
+ "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
187
+ "image_captions": [subfig.get("caption", "") for subfig in subfigures],
188
+ },
189
+ }
190
+ logging.info(json.dumps(log_entry))
191
+ return response
192
+
193
+ except openai.RateLimitError:
194
+ log_entry = {
195
+ "case_id": case_id,
196
+ "question_id": question_id,
197
+ "timestamp": datetime.now().isoformat(),
198
+ "model": model_name,
199
+ "temperature": temperature,
200
+ "status": "error",
201
+ "reason": "rate_limit",
202
+ "cost": 0,
203
+ "input": {
204
+ "messages": messages,
205
+ "question_data": {
206
+ "question": question_data["question"],
207
+ "explanation": question_data["explanation"],
208
+ "metadata": question_data.get("metadata", {}),
209
+ "figures": question_data["figures"],
210
+ },
211
+ "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
212
+ "image_captions": [subfig.get("caption", "") for subfig in subfigures],
213
+ },
214
+ }
215
+ logging.info(json.dumps(log_entry))
216
+ print(
217
+ f"\nRate limit hit for case {case_id}, question {question_id}. Waiting 20s...",
218
+ flush=True,
219
+ )
220
+ time.sleep(20)
221
+ raise
222
+ except Exception as e:
223
+ log_entry = {
224
+ "case_id": case_id,
225
+ "question_id": question_id,
226
+ "timestamp": datetime.now().isoformat(),
227
+ "model": model_name,
228
+ "temperature": temperature,
229
+ "status": "error",
230
+ "error": str(e),
231
+ "cost": 0,
232
+ "input": {
233
+ "messages": messages,
234
+ "question_data": {
235
+ "question": question_data["question"],
236
+ "explanation": question_data["explanation"],
237
+ "metadata": question_data.get("metadata", {}),
238
+ "figures": question_data["figures"],
239
+ },
240
+ "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
241
+ "image_captions": [subfig.get("caption", "") for subfig in subfigures],
242
+ },
243
+ }
244
+ logging.info(json.dumps(log_entry))
245
+ print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
246
+ raise
247
+
248
+
249
+ def load_benchmark_questions(case_id: str) -> list:
250
+ """Load benchmark questions for a given case.
251
+
252
+ Args:
253
+ case_id: Identifier for the medical case
254
+
255
+ Returns:
256
+ list: List of paths to question files
257
+ """
258
+ benchmark_dir = "../benchmark/questions"
259
+ return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
260
+
261
+
262
+ def count_total_questions() -> tuple[int, int]:
263
+ """Count total number of cases and questions in benchmark.
264
+
265
+ Returns:
266
+ tuple: (total_cases, total_questions)
267
+ """
268
+ total_cases = len(glob.glob("../benchmark/questions/*"))
269
+ total_questions = sum(
270
+ len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
271
+ for case_id in os.listdir("../benchmark/questions")
272
+ )
273
+ return total_cases, total_questions
274
+
275
+
276
+ def main() -> None:
277
+ """Main function to run the benchmark evaluation."""
278
+ with open("../data/eurorad_metadata.json", "r") as file:
279
+ data = json.load(file)
280
+
281
+ api_key = os.getenv("OPENAI_API_KEY")
282
+ if not api_key:
283
+ raise ValueError("OPENAI_API_KEY environment variable is not set.")
284
+ global client
285
+ client = openai.OpenAI(api_key=api_key)
286
+
287
+ total_cases, total_questions = count_total_questions()
288
+ cases_processed = 0
289
+ questions_processed = 0
290
+ skipped_questions = 0
291
+
292
+ print(f"Beginning benchmark evaluation for model {model_name} with temperature {temperature}")
293
+
294
+ for case_id, case_details in data.items():
295
+ question_files = load_benchmark_questions(case_id)
296
+ if not question_files:
297
+ continue
298
+
299
+ cases_processed += 1
300
+ for question_file in question_files:
301
+ with open(question_file, "r") as file:
302
+ question_data = json.load(file)
303
+ question_id = os.path.basename(question_file).split(".")[0]
304
+
305
+ questions_processed += 1
306
+ response = create_multimodal_request(
307
+ question_data, case_details, case_id, question_id, client
308
+ )
309
+
310
+ # Handle cases where response is None
311
+ if response is None:
312
+ skipped_questions += 1
313
+ print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
314
+ continue
315
+
316
+ print(
317
+ f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
318
+ )
319
+ print(f"Case ID: {case_id}")
320
+ print(f"Question ID: {question_id}")
321
+ print(f"Model Answer: {response.choices[0].message.content}")
322
+ print(f"Correct Answer: {question_data['answer']}\n")
323
+
324
+ print(f"\nBenchmark Summary:")
325
+ print(f"Total Cases Processed: {cases_processed}")
326
+ print(f"Total Questions Processed: {questions_processed}")
327
+ print(f"Total Questions Skipped: {skipped_questions}")
328
+
329
+
330
+ if __name__ == "__main__":
331
+ main()
experiments/benchmark_llama.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Any, Union
2
+ import re
3
+ import json
4
+ import os
5
+ import glob
6
+ import time
7
+ import logging
8
+ import socket
9
+ import requests
10
+ import httpx
11
+ import backoff
12
+ from datetime import datetime
13
+ from tenacity import retry, wait_exponential, stop_after_attempt
14
+ from openai import OpenAI
15
+
16
+ # Configure model settings
17
+ MODEL_NAME = "meta-llama/llama-3.2-90b-vision-instruct"
18
+ temperature = 0.2
19
+
20
+ # Configure logging
21
+ log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
22
+ logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
23
+
24
+
25
+ def verify_dns() -> bool:
26
+ """Verify DNS resolution and connectivity.
27
+
28
+ Returns:
29
+ bool: True if DNS resolution succeeds, False otherwise
30
+ """
31
+ try:
32
+ # Try to resolve openrouter.ai
33
+ socket.gethostbyname("openrouter.ai")
34
+ return True
35
+ except socket.gaierror:
36
+ print("DNS resolution failed. Trying to use Google DNS (8.8.8.8)...")
37
+ # Modify resolv.conf to use Google DNS
38
+ try:
39
+ with open("/etc/resolv.conf", "w") as f:
40
+ f.write("nameserver 8.8.8.8\n")
41
+ return True
42
+ except Exception as e:
43
+ print(f"Failed to update DNS settings: {e}")
44
+ return False
45
+
46
+
47
+ def verify_connection() -> bool:
48
+ """Verify connection to OpenRouter API.
49
+
50
+ Returns:
51
+ bool: True if connection succeeds, False otherwise
52
+ """
53
+ try:
54
+ response = requests.get("https://openrouter.ai/api/v1/status", timeout=10)
55
+ return response.status_code == 200
56
+ except Exception as e:
57
+ print(f"Connection test failed: {e}")
58
+ return False
59
+
60
+
61
+ def initialize_client() -> OpenAI:
62
+ """Initialize the OpenRouter client with proper timeout settings and connection verification.
63
+
64
+ Returns:
65
+ OpenAI: Configured OpenAI client for OpenRouter
66
+
67
+ Raises:
68
+ ValueError: If OPENROUTER_API_KEY environment variable is not set
69
+ ConnectionError: If DNS verification or connection test fails
70
+ """
71
+ api_key = os.getenv("OPENROUTER_API_KEY")
72
+ if not api_key:
73
+ raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
74
+
75
+ # Configure timeout settings for the client
76
+ timeout_settings = 120 # Increased timeout for large images/responses
77
+
78
+ # Verify DNS and connection
79
+ if not verify_dns():
80
+ raise ConnectionError("DNS verification failed. Please check your network settings.")
81
+
82
+ if not verify_connection():
83
+ raise ConnectionError(
84
+ "Cannot connect to OpenRouter. Please check your internet connection."
85
+ )
86
+
87
+ # Set up client with retry and timeout settings
88
+ return OpenAI(
89
+ base_url="https://openrouter.ai/api/v1",
90
+ api_key=api_key,
91
+ timeout=timeout_settings,
92
+ http_client=httpx.Client(
93
+ timeout=timeout_settings, transport=httpx.HTTPTransport(retries=3)
94
+ ),
95
+ )
96
+
97
+
98
+ @backoff.on_exception(
99
+ backoff.expo,
100
+ (ConnectionError, TimeoutError, socket.gaierror, httpx.ConnectError),
101
+ max_tries=5,
102
+ max_time=300, # Maximum total time to try in seconds
103
+ )
104
+ def create_multimodal_request(
105
+ question_data: Dict[str, Any],
106
+ case_details: Dict[str, Any],
107
+ case_id: str,
108
+ question_id: str,
109
+ client: OpenAI,
110
+ ) -> Optional[Any]:
111
+ """Create and send a multimodal request to the model.
112
+
113
+ Args:
114
+ question_data: Dictionary containing question details
115
+ case_details: Dictionary containing case information
116
+ case_id: ID of the medical case
117
+ question_id: ID of the specific question
118
+ client: OpenAI client instance
119
+
120
+ Returns:
121
+ Optional[Any]: Model response if successful, None if skipped
122
+
123
+ Raises:
124
+ ConnectionError: If connection fails
125
+ TimeoutError: If request times out
126
+ Exception: For other errors
127
+ """
128
+
129
+ system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
130
+ Rules:
131
+ 1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
132
+ 2. Do not add periods, explanations, or any other text
133
+ 3. Do not use markdown or formatting
134
+ 4. Do not restate the question
135
+ 5. Do not explain your reasoning
136
+
137
+ Examples of valid responses:
138
+ A
139
+ B
140
+ C
141
+
142
+ Examples of invalid responses:
143
+ "A."
144
+ "Answer: B"
145
+ "C) This shows..."
146
+ "The answer is D"
147
+ """
148
+
149
+ prompt = f"""Given the following medical case:
150
+ Please answer this multiple choice question:
151
+ {question_data['question']}
152
+ Base your answer only on the provided images and case information."""
153
+
154
+ # Parse required figures
155
+ try:
156
+ if isinstance(question_data["figures"], str):
157
+ try:
158
+ required_figures = json.loads(question_data["figures"])
159
+ except json.JSONDecodeError:
160
+ required_figures = [question_data["figures"]]
161
+ elif isinstance(question_data["figures"], list):
162
+ required_figures = question_data["figures"]
163
+ else:
164
+ required_figures = [str(question_data["figures"])]
165
+ except Exception as e:
166
+ print(f"Error parsing figures: {e}")
167
+ required_figures = []
168
+
169
+ required_figures = [
170
+ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
171
+ ]
172
+
173
+ # Process subfigures and prepare content
174
+ content = [{"type": "text", "text": prompt}]
175
+ image_urls = []
176
+ image_captions = []
177
+
178
+ for figure in required_figures:
179
+ base_figure_num = "".join(filter(str.isdigit, figure))
180
+ figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
181
+
182
+ matching_figures = [
183
+ case_figure
184
+ for case_figure in case_details.get("figures", [])
185
+ if case_figure["number"] == f"Figure {base_figure_num}"
186
+ ]
187
+
188
+ for case_figure in matching_figures:
189
+ subfigures = []
190
+ if figure_letter:
191
+ subfigures = [
192
+ subfig
193
+ for subfig in case_figure.get("subfigures", [])
194
+ if subfig.get("number", "").lower().endswith(figure_letter.lower())
195
+ or subfig.get("label", "").lower() == figure_letter.lower()
196
+ ]
197
+ else:
198
+ subfigures = case_figure.get("subfigures", [])
199
+
200
+ for subfig in subfigures:
201
+ if "url" in subfig:
202
+ content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
203
+ image_urls.append(subfig["url"])
204
+ image_captions.append(subfig.get("caption", ""))
205
+
206
+ if len(content) == 1: # Only the text prompt exists
207
+ print(f"No images found for case {case_id}, question {question_id}")
208
+ # Log the skipped question
209
+ log_entry = {
210
+ "case_id": case_id,
211
+ "question_id": question_id,
212
+ "timestamp": datetime.now().isoformat(),
213
+ "model": MODEL_NAME,
214
+ "status": "skipped",
215
+ "reason": "no_images",
216
+ "input": {
217
+ "question_data": {
218
+ "question": question_data["question"],
219
+ "explanation": question_data["explanation"],
220
+ "metadata": question_data.get("metadata", {}),
221
+ "figures": question_data["figures"],
222
+ },
223
+ "image_urls": image_urls,
224
+ },
225
+ }
226
+ logging.info(json.dumps(log_entry))
227
+ return None
228
+
229
+ try:
230
+ start_time = time.time()
231
+
232
+ response = client.chat.completions.create(
233
+ model=MODEL_NAME,
234
+ temperature=temperature,
235
+ messages=[
236
+ {"role": "system", "content": system_prompt},
237
+ {"role": "user", "content": content},
238
+ ],
239
+ )
240
+ duration = time.time() - start_time
241
+
242
+ # Get raw response
243
+ raw_answer = response.choices[0].message.content
244
+
245
+ # Validate and clean
246
+ clean_answer = validate_answer(raw_answer)
247
+
248
+ if not clean_answer:
249
+ print(f"Warning: Invalid response format for case {case_id}, question {question_id}")
250
+ print(f"Raw response: {raw_answer}")
251
+
252
+ # Update response object with cleaned answer
253
+ response.choices[0].message.content = clean_answer
254
+
255
+ # Log response
256
+ log_entry = {
257
+ "case_id": case_id,
258
+ "question_id": question_id,
259
+ "timestamp": datetime.now().isoformat(),
260
+ "model": MODEL_NAME,
261
+ "temperature": temperature,
262
+ "duration": round(duration, 2),
263
+ "usage": {
264
+ "prompt_tokens": response.usage.prompt_tokens,
265
+ "completion_tokens": response.usage.completion_tokens,
266
+ "total_tokens": response.usage.total_tokens,
267
+ },
268
+ "model_answer": response.choices[0].message.content,
269
+ "correct_answer": question_data["answer"],
270
+ "input": {
271
+ "question_data": {
272
+ "question": question_data["question"],
273
+ "explanation": question_data["explanation"],
274
+ "metadata": question_data.get("metadata", {}),
275
+ "figures": question_data["figures"],
276
+ },
277
+ "image_urls": image_urls,
278
+ },
279
+ }
280
+ logging.info(json.dumps(log_entry))
281
+ return response
282
+
283
+ except ConnectionError as e:
284
+ print(f"Connection error for case {case_id}, question {question_id}: {str(e)}")
285
+ print("Retrying after a longer delay...")
286
+ time.sleep(30) # Add a longer delay before retry
287
+ raise
288
+ except TimeoutError as e:
289
+ print(f"Timeout error for case {case_id}, question {question_id}: {str(e)}")
290
+ print("Retrying with increased timeout...")
291
+ raise
292
+ except Exception as e:
293
+ # Log failed requests too
294
+ log_entry = {
295
+ "case_id": case_id,
296
+ "question_id": question_id,
297
+ "timestamp": datetime.now().isoformat(),
298
+ "model": MODEL_NAME,
299
+ "temperature": temperature,
300
+ "status": "error",
301
+ "error": str(e),
302
+ "input": {
303
+ "question_data": {
304
+ "question": question_data["question"],
305
+ "explanation": question_data["explanation"],
306
+ "metadata": question_data.get("metadata", {}),
307
+ "figures": question_data["figures"],
308
+ },
309
+ "image_urls": image_urls,
310
+ },
311
+ }
312
+ logging.info(json.dumps(log_entry))
313
+ raise
314
+
315
+
316
+ def extract_answer(response_text: str) -> Optional[str]:
317
+ """Extract single letter answer from model response.
318
+
319
+ Args:
320
+ response_text: Raw text response from model
321
+
322
+ Returns:
323
+ Optional[str]: Single letter answer if found, None otherwise
324
+ """
325
+ # Convert to uppercase and remove periods
326
+ text = response_text.upper().replace(".", "")
327
+
328
+ # Look for common patterns
329
+ patterns = [
330
+ r"ANSWER:\s*([A-F])", # Matches "ANSWER: X"
331
+ r"OPTION\s*([A-F])", # Matches "OPTION X"
332
+ r"([A-F])\)", # Matches "X)"
333
+ r"\b([A-F])\b", # Matches single letter
334
+ ]
335
+
336
+ for pattern in patterns:
337
+ matches = re.findall(pattern, text)
338
+ if matches:
339
+ return matches[0]
340
+
341
+ return None
342
+
343
+
344
+ def validate_answer(response_text: str) -> Optional[str]:
345
+ """Enforce strict single-letter response format.
346
+
347
+ Args:
348
+ response_text: Raw text response from model
349
+
350
+ Returns:
351
+ Optional[str]: Valid single letter answer if found, None otherwise
352
+ """
353
+ if not response_text:
354
+ return None
355
+
356
+ # Remove all whitespace and convert to uppercase
357
+ cleaned = response_text.strip().upper()
358
+
359
+ # Check if it's exactly one valid letter
360
+ if len(cleaned) == 1 and cleaned in "ABCDEF":
361
+ return cleaned
362
+
363
+ # If not, try to extract just the letter
364
+ match = re.search(r"([A-F])", cleaned)
365
+ return match.group(1) if match else None
366
+
367
+
368
+ def load_benchmark_questions(case_id: str) -> List[str]:
369
+ """Find all question files for a given case ID.
370
+
371
+ Args:
372
+ case_id: ID of the medical case
373
+
374
+ Returns:
375
+ List[str]: List of paths to question files
376
+ """
377
+ benchmark_dir = "../benchmark/questions"
378
+ return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
379
+
380
+
381
+ def count_total_questions() -> Tuple[int, int]:
382
+ """Count total number of cases and questions.
383
+
384
+ Returns:
385
+ Tuple[int, int]: (total_cases, total_questions)
386
+ """
387
+ total_cases = len(glob.glob("../benchmark/questions/*"))
388
+ total_questions = sum(
389
+ len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
390
+ for case_id in os.listdir("../benchmark/questions")
391
+ )
392
+ return total_cases, total_questions
393
+
394
+
395
+ def main():
396
+ with open("../data/eurorad_metadata.json", "r") as file:
397
+ data = json.load(file)
398
+
399
+ client = initialize_client()
400
+ total_cases, total_questions = count_total_questions()
401
+ cases_processed = 0
402
+ questions_processed = 0
403
+ skipped_questions = 0
404
+
405
+ print(f"Beginning benchmark evaluation for {MODEL_NAME} with temperature {temperature}")
406
+
407
+ for case_id, case_details in data.items():
408
+ question_files = load_benchmark_questions(case_id)
409
+ if not question_files:
410
+ continue
411
+
412
+ cases_processed += 1
413
+ for question_file in question_files:
414
+ with open(question_file, "r") as file:
415
+ question_data = json.load(file)
416
+ question_id = os.path.basename(question_file).split(".")[0]
417
+
418
+ questions_processed += 1
419
+ response = create_multimodal_request(
420
+ question_data, case_details, case_id, question_id, client
421
+ )
422
+
423
+ if response is None:
424
+ skipped_questions += 1
425
+ print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
426
+ continue
427
+
428
+ print(
429
+ f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
430
+ )
431
+ print(f"Case ID: {case_id}")
432
+ print(f"Question ID: {question_id}")
433
+ print(f"Model Answer: {response.choices[0].message.content}")
434
+ print(f"Correct Answer: {question_data['answer']}\n")
435
+
436
+ print(f"\nBenchmark Summary:")
437
+ print(f"Total Cases Processed: {cases_processed}")
438
+ print(f"Total Questions Processed: {questions_processed}")
439
+ print(f"Total Questions Skipped: {skipped_questions}")
440
+
441
+
442
+ if __name__ == "__main__":
443
+ main()
experiments/benchmark_llavamed.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import requests
4
+ import base64
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ from llava.conversation import conv_templates
8
+ import time
9
+ import os
10
+ import glob
11
+ import logging
12
+ from datetime import datetime
13
+ from tqdm import tqdm
14
+ import re
15
+ from typing import Dict, List, Optional, Union, Any, Tuple
16
+
17
+
18
+ def process_image(image_path: str, target_size: int = 640) -> Image.Image:
19
+ """Process and resize an image to match model requirements.
20
+
21
+ Args:
22
+ image_path: Path to the input image file
23
+ target_size: Target size for both width and height in pixels
24
+
25
+ Returns:
26
+ PIL.Image: Processed and padded image with dimensions (target_size, target_size)
27
+ """
28
+ image = Image.open(image_path)
29
+ if image.mode != "RGB":
30
+ image = image.convert("RGB")
31
+
32
+ # Calculate scaling to maintain aspect ratio
33
+ ratio = min(target_size / image.width, target_size / image.height)
34
+ new_size = (int(image.width * ratio), int(image.height * ratio))
35
+
36
+ # Resize image
37
+ image = image.resize(new_size, Image.LANCZOS)
38
+
39
+ # Create new image with padding
40
+ new_image = Image.new("RGB", (target_size, target_size), (0, 0, 0))
41
+ # Paste resized image in center
42
+ offset = ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2)
43
+ new_image.paste(image, offset)
44
+
45
+ return new_image
46
+
47
+
48
+ def validate_answer(response_text: str) -> Optional[str]:
49
+ """Extract and validate a single-letter response from the model's output.
50
+ Handles multiple response formats and edge cases.
51
+
52
+ Args:
53
+ response_text: The full text output from the model
54
+
55
+ Returns:
56
+ A single letter answer (A-F) or None if no valid answer found
57
+ """
58
+ if not response_text:
59
+ return None
60
+
61
+ # Clean the response text
62
+ cleaned = response_text.strip()
63
+
64
+ # Comprehensive set of patterns to extract the answer
65
+ extraction_patterns = [
66
+ # Strict format with explicit letter answer
67
+ r"(?:THE\s*)?(?:SINGLE\s*)?LETTER\s*(?:ANSWER\s*)?(?:IS:?)\s*([A-F])\b",
68
+ # Patterns for extracting from longer descriptions
69
+ r"(?:correct\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
70
+ r"\b(?:answer|option)\s*([A-F])[):]\s*",
71
+ # Patterns for extracting from descriptive sentences
72
+ r"(?:most\s+likely\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
73
+ r"suggest[s]?\s+(?:that\s+)?(?:the\s+)?(?:answer\s+)?(?:is\s*)?([A-F])\b",
74
+ # Patterns with contextual words
75
+ r"characteriz[e]?d?\s+by\s+([A-F])\b",
76
+ r"indicat[e]?s?\s+([A-F])\b",
77
+ # Fallback to Option X or Letterr X formats
78
+ r"Option\s*([A-F])\b",
79
+ r"\b([A-F])\)\s*",
80
+ # Fallback to standalone letter
81
+ r"^\s*([A-F])\s*$",
82
+ ]
83
+
84
+ # Try each pattern
85
+ for pattern in extraction_patterns:
86
+ matches = re.findall(pattern, cleaned, re.IGNORECASE)
87
+ for match in matches:
88
+ # Ensure match is a single valid letter
89
+ if isinstance(match, tuple):
90
+ match = match[0] if match[0] in "ABCDEF" else None
91
+ if match and match.upper() in "ABCDEF":
92
+ return match.upper()
93
+
94
+ # Final fallback: look for standalone letters in context
95
+ context_matches = re.findall(r"\b([A-F])\b", cleaned.upper())
96
+ context_letters = [m for m in context_matches if m in "ABCDEF"]
97
+ if context_letters:
98
+ return context_letters[0]
99
+
100
+ # No valid answer found
101
+ return None
102
+
103
+
104
+ def load_benchmark_questions(case_id: str) -> List[str]:
105
+ """Find all question files for a given case ID.
106
+
107
+ Args:
108
+ case_id: The ID of the medical case
109
+
110
+ Returns:
111
+ List of paths to question JSON files
112
+ """
113
+ benchmark_dir = "MedMAX/benchmark/questions"
114
+ return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
115
+
116
+
117
+ def count_total_questions() -> Tuple[int, int]:
118
+ """Count total number of cases and questions in benchmark.
119
+
120
+ Returns:
121
+ Tuple containing (total_cases, total_questions)
122
+ """
123
+ total_cases = len(glob.glob("MedMAX/benchmark/questions/*"))
124
+ total_questions = sum(
125
+ len(glob.glob(f"MedMAX/benchmark/questions/{case_id}/*.json"))
126
+ for case_id in os.listdir("MedMAX/benchmark/questions")
127
+ )
128
+ return total_cases, total_questions
129
+
130
+
131
+ def create_inference_request(
132
+ question_data: Dict[str, Any],
133
+ case_details: Dict[str, Any],
134
+ case_id: str,
135
+ question_id: str,
136
+ worker_addr: str,
137
+ model_name: str,
138
+ raw_output: bool = False,
139
+ ) -> Union[Tuple[Optional[str], Optional[float]], Dict[str, Any]]:
140
+ """Create and send inference request to worker.
141
+
142
+ Args:
143
+ question_data: Dictionary containing question details and figures
144
+ case_details: Dictionary containing case information and figures
145
+ case_id: Identifier for the medical case
146
+ question_id: Identifier for the specific question
147
+ worker_addr: Address of the worker endpoint
148
+ model_name: Name of the model to use
149
+ raw_output: Whether to return raw model output
150
+
151
+ Returns:
152
+ If raw_output is False: Tuple of (validated_answer, duration)
153
+ If raw_output is True: Dictionary with full inference details
154
+ """
155
+ system_prompt = """You are a medical imaging expert. Your answer MUST be a SINGLE LETTER (A/B/C/D/E/F), provided in this format: 'The SINGLE LETTER answer is: X'.
156
+ """
157
+
158
+ prompt = f"""Given the following medical case:
159
+ Please answer this multiple choice question:
160
+ {question_data['question']}
161
+ Base your answer only on the provided images and case information. Respond with your SINGLE LETTER answer: """
162
+
163
+ try:
164
+ # Parse required figures
165
+ if isinstance(question_data["figures"], str):
166
+ try:
167
+ required_figures = json.loads(question_data["figures"])
168
+ except json.JSONDecodeError:
169
+ required_figures = [question_data["figures"]]
170
+ elif isinstance(question_data["figures"], list):
171
+ required_figures = question_data["figures"]
172
+ else:
173
+ required_figures = [str(question_data["figures"])]
174
+ except Exception as e:
175
+ print(f"Error parsing figures: {e}")
176
+ required_figures = []
177
+
178
+ required_figures = [
179
+ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
180
+ ]
181
+
182
+ # Get image paths
183
+ image_paths = []
184
+ for figure in required_figures:
185
+ base_figure_num = "".join(filter(str.isdigit, figure))
186
+ figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
187
+
188
+ matching_figures = [
189
+ case_figure
190
+ for case_figure in case_details.get("figures", [])
191
+ if case_figure["number"] == f"Figure {base_figure_num}"
192
+ ]
193
+
194
+ for case_figure in matching_figures:
195
+ subfigures = []
196
+ if figure_letter:
197
+ subfigures = [
198
+ subfig
199
+ for subfig in case_figure.get("subfigures", [])
200
+ if subfig.get("number", "").lower().endswith(figure_letter.lower())
201
+ or subfig.get("label", "").lower() == figure_letter.lower()
202
+ ]
203
+ else:
204
+ subfigures = case_figure.get("subfigures", [])
205
+
206
+ for subfig in subfigures:
207
+ if "local_path" in subfig:
208
+ image_paths.append("MedMAX/data/" + subfig["local_path"])
209
+
210
+ if not image_paths:
211
+ print(f"No local images found for case {case_id}, question {question_id}")
212
+ return "skipped", 0.0 # Return a special 'skipped' marker
213
+
214
+ try:
215
+ start_time = time.time()
216
+
217
+ # Process each image
218
+ processed_images = [process_image(path) for path in image_paths]
219
+
220
+ # Create conversation
221
+ conv = conv_templates["mistral_instruct"].copy()
222
+
223
+ # Add image and message
224
+ if "<image>" not in prompt:
225
+ text = prompt + "\n<image>"
226
+ else:
227
+ text = prompt
228
+
229
+ message = (text, processed_images[0], "Default") # Currently handling first image
230
+ conv.append_message(conv.roles[0], message)
231
+ conv.append_message(conv.roles[1], None)
232
+
233
+ prompt = conv.get_prompt()
234
+ headers = {"User-Agent": "LLaVA-Med Client"}
235
+ pload = {
236
+ "model": model_name,
237
+ "prompt": prompt,
238
+ "max_new_tokens": 150, # Reduce this since we only need one letter
239
+ "temperature": 0.5, # Lower temperature for more focused responses
240
+ "stop": conv.sep2,
241
+ "images": conv.get_images(),
242
+ "top_p": 1, # Lower top_p for more focused sampling
243
+ "frequency_penalty": 0.0,
244
+ "presence_penalty": 0.0,
245
+ }
246
+
247
+ max_retries = 3
248
+ retry_delay = 5
249
+ response_text = None
250
+
251
+ for attempt in range(max_retries):
252
+ try:
253
+ response = requests.post(
254
+ worker_addr + "/worker_generate_stream",
255
+ headers=headers,
256
+ json=pload,
257
+ stream=True,
258
+ timeout=30,
259
+ )
260
+
261
+ complete_output = ""
262
+ for chunk in response.iter_lines(
263
+ chunk_size=8192, decode_unicode=False, delimiter=b"\0"
264
+ ):
265
+ if chunk:
266
+ data = json.loads(chunk.decode("utf-8"))
267
+ if data["error_code"] == 0:
268
+ output = data["text"].split("[/INST]")[-1]
269
+ complete_output = output
270
+ else:
271
+ print(f"\nError: {data['text']} (error_code: {data['error_code']})")
272
+ if attempt < max_retries - 1:
273
+ time.sleep(retry_delay)
274
+ break
275
+ return None, None
276
+
277
+ if complete_output:
278
+ response_text = complete_output
279
+ break
280
+
281
+ except (requests.exceptions.RequestException, json.JSONDecodeError) as e:
282
+ if attempt < max_retries - 1:
283
+ print(f"\nNetwork error: {str(e)}. Retrying in {retry_delay} seconds...")
284
+ time.sleep(retry_delay)
285
+ else:
286
+ print(f"\nFailed after {max_retries} attempts: {str(e)}")
287
+ return None, None
288
+
289
+ duration = time.time() - start_time
290
+
291
+ if raw_output:
292
+ inference_details = {
293
+ "raw_output": response_text,
294
+ "validated_answer": validate_answer(response_text),
295
+ "duration": duration,
296
+ "prompt": prompt,
297
+ "system_prompt": system_prompt,
298
+ "image_paths": image_paths,
299
+ "payload": pload,
300
+ }
301
+ return inference_details
302
+
303
+ return validate_answer(response_text), duration
304
+
305
+ except Exception as e:
306
+ print(f"Error in inference request: {str(e)}")
307
+ return None, None
308
+
309
+
310
+ def clean_payload(payload: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
311
+ """Remove image-related and large data from the payload to keep the log lean.
312
+
313
+ Args:
314
+ payload: Original request payload dictionary
315
+
316
+ Returns:
317
+ Cleaned payload dictionary with large data removed
318
+ """
319
+ if not payload:
320
+ return None
321
+
322
+ # Create a copy of the payload to avoid modifying the original
323
+ cleaned_payload = payload.copy()
324
+
325
+ # Remove large or sensitive data
326
+ if "images" in cleaned_payload:
327
+ del cleaned_payload["images"]
328
+
329
+ return cleaned_payload
330
+
331
+
332
+ def main():
333
+ parser = argparse.ArgumentParser()
334
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
335
+ parser.add_argument("--worker-address", type=str)
336
+ parser.add_argument("--model-name", type=str, default="llava-med-v1.5-mistral-7b")
337
+ parser.add_argument("--output-dir", type=str, default="benchmark_results")
338
+ parser.add_argument(
339
+ "--raw-output", action="store_true", help="Return raw model output without validation"
340
+ )
341
+ parser.add_argument(
342
+ "--num-cases",
343
+ type=int,
344
+ help="Number of cases to process if looking at raw outputs",
345
+ default=2,
346
+ )
347
+ args = parser.parse_args()
348
+
349
+ # Setup output directory
350
+ os.makedirs(args.output_dir, exist_ok=True)
351
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
352
+
353
+ # Setup live logging files
354
+ live_log_filename = os.path.join(args.output_dir, f"live_benchmark_log_{timestamp}.json")
355
+ final_results_filename = os.path.join(args.output_dir, f"final_results_{timestamp}.json")
356
+
357
+ # Initialize live log file
358
+ with open(live_log_filename, "w") as live_log_file:
359
+ live_log_file.write("[\n") # Start of JSON array
360
+
361
+ # Setup logging
362
+ logging.basicConfig(
363
+ filename=os.path.join(args.output_dir, f"benchmark_{timestamp}.log"),
364
+ level=logging.INFO,
365
+ format="%(message)s",
366
+ )
367
+
368
+ # Get worker address
369
+ if args.worker_address:
370
+ worker_addr = args.worker_address
371
+ else:
372
+ try:
373
+ requests.post(args.controller_address + "/refresh_all_workers")
374
+ ret = requests.post(args.controller_address + "/list_models")
375
+ models = ret.json()["models"]
376
+ ret = requests.post(
377
+ args.controller_address + "/get_worker_address", json={"model": args.model_name}
378
+ )
379
+ worker_addr = ret.json()["address"]
380
+ print(f"Worker address: {worker_addr}")
381
+ except requests.exceptions.RequestException as e:
382
+ print(f"Failed to connect to controller: {e}")
383
+ return
384
+
385
+ if worker_addr == "":
386
+ print("No available worker")
387
+ return
388
+
389
+ # Load cases with local paths
390
+ with open("MedMAX/data/updated_cases.json", "r") as file:
391
+ data = json.load(file)
392
+
393
+ total_cases, total_questions = count_total_questions()
394
+ print(f"\nStarting benchmark with {args.model_name}")
395
+ print(f"Found {total_cases} cases with {total_questions} total questions")
396
+
397
+ results = {
398
+ "model": args.model_name,
399
+ "timestamp": datetime.now().isoformat(),
400
+ "total_cases": total_cases,
401
+ "total_questions": total_questions,
402
+ "results": [],
403
+ }
404
+
405
+ cases_processed = 0
406
+ questions_processed = 0
407
+ correct_answers = 0
408
+ skipped_questions = 0
409
+ total_processed_entries = 0
410
+
411
+ # Process each case
412
+ for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
413
+ question_files = load_benchmark_questions(case_id)
414
+ if not question_files:
415
+ continue
416
+
417
+ cases_processed += 1
418
+ for question_file in tqdm(
419
+ question_files, desc=f"Processing questions for case {case_id}", leave=False
420
+ ):
421
+ with open(question_file, "r") as file:
422
+ question_data = json.load(file)
423
+ question_id = os.path.basename(question_file).split(".")[0]
424
+
425
+ questions_processed += 1
426
+
427
+ # Get model's answer
428
+ inference_result = create_inference_request(
429
+ question_data,
430
+ case_details,
431
+ case_id,
432
+ question_id,
433
+ worker_addr,
434
+ args.model_name,
435
+ raw_output=True, # Always use raw output for detailed logging
436
+ )
437
+
438
+ # Handle skipped questions
439
+ if inference_result == ("skipped", 0.0):
440
+ skipped_questions += 1
441
+ print(f"\nCase {case_id}, Question {question_id}: Skipped (No images)")
442
+
443
+ # Log skipped question
444
+ skipped_entry = {
445
+ "case_id": case_id,
446
+ "question_id": question_id,
447
+ "status": "skipped",
448
+ "reason": "No images found",
449
+ }
450
+ with open(live_log_filename, "a") as live_log_file:
451
+ json.dump(skipped_entry, live_log_file, indent=2)
452
+ live_log_file.write(",\n") # Add comma for next entry
453
+
454
+ continue
455
+
456
+ # Extract information
457
+ answer = inference_result["validated_answer"]
458
+ duration = inference_result["duration"]
459
+
460
+ # Prepare detailed logging entry
461
+ log_entry = {
462
+ "case_id": case_id,
463
+ "question_id": question_id,
464
+ "question": question_data["question"],
465
+ "correct_answer": question_data["answer"],
466
+ "raw_output": inference_result["raw_output"],
467
+ "validated_answer": answer,
468
+ "model_answer": answer,
469
+ "is_correct": answer == question_data["answer"] if answer else False,
470
+ "duration": duration,
471
+ "system_prompt": inference_result["system_prompt"],
472
+ "input_prompt": inference_result["prompt"],
473
+ "image_paths": inference_result["image_paths"],
474
+ "payload": clean_payload(inference_result["payload"]),
475
+ }
476
+
477
+ # Write to live log file
478
+ with open(live_log_filename, "a") as live_log_file:
479
+ json.dump(log_entry, live_log_file, indent=2)
480
+ live_log_file.write(",\n") # Add comma for next entry
481
+
482
+ # Print to console
483
+ print(f"\nCase {case_id}, Question {question_id}")
484
+ print(f"Model Answer: {answer}")
485
+ print(f"Correct Answer: {question_data['answer']}")
486
+ print(f"Time taken: {duration:.2f}s")
487
+
488
+ # Track correct answers
489
+ if answer == question_data["answer"]:
490
+ correct_answers += 1
491
+
492
+ # Append to results
493
+ results["results"].append(log_entry)
494
+ total_processed_entries += 1
495
+
496
+ # Optional: break if reached specified number of cases
497
+ if args.raw_output and cases_processed == args.num_cases:
498
+ break
499
+
500
+ # Optional: break if reached specified number of cases
501
+ if args.raw_output and cases_processed == args.num_cases:
502
+ break
503
+
504
+ # Close live log file
505
+ with open(live_log_filename, "a") as live_log_file:
506
+ # Remove trailing comma and close JSON array
507
+ live_log_file.seek(live_log_file.tell() - 2, 0) # Go back 2 chars to remove ',\n'
508
+ live_log_file.write("\n]")
509
+
510
+ # Calculate final statistics
511
+ results["summary"] = {
512
+ "cases_processed": cases_processed,
513
+ "questions_processed": questions_processed,
514
+ "total_processed_entries": total_processed_entries,
515
+ "correct_answers": correct_answers,
516
+ "skipped_questions": skipped_questions,
517
+ "accuracy": (
518
+ correct_answers / (questions_processed - skipped_questions)
519
+ if (questions_processed - skipped_questions) > 0
520
+ else 0
521
+ ),
522
+ }
523
+
524
+ # Save final results
525
+ with open(final_results_filename, "w") as f:
526
+ json.dump(results, f, indent=2)
527
+
528
+ print(f"\nBenchmark Summary:")
529
+ print(f"Total Cases Processed: {cases_processed}")
530
+ print(f"Total Questions Processed: {questions_processed}")
531
+ print(f"Total Processed Entries: {total_processed_entries}")
532
+ print(f"Correct Answers: {correct_answers}")
533
+ print(f"Skipped Questions: {skipped_questions}")
534
+ print(f"Accuracy: {(correct_answers / (questions_processed - skipped_questions) * 100):.2f}%")
535
+ print(f"\nResults saved to {args.output_dir}")
536
+ print(f"Live log: {live_log_filename}")
537
+ print(f"Final results: {final_results_filename}")
538
+
539
+
540
+ if __name__ == "__main__":
541
+ main()
experiments/benchmark_medrax.ipynb ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import operator\n",
10
+ "import warnings\n",
11
+ "from typing import *\n",
12
+ "import traceback\n",
13
+ "\n",
14
+ "import os\n",
15
+ "import torch\n",
16
+ "from dotenv import load_dotenv\n",
17
+ "from IPython.display import Image\n",
18
+ "from langgraph.checkpoint.memory import MemorySaver\n",
19
+ "from langgraph.graph import END, StateGraph\n",
20
+ "from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage\n",
21
+ "from langchain_openai import ChatOpenAI\n",
22
+ "from transformers import logging\n",
23
+ "import matplotlib.pyplot as plt\n",
24
+ "import numpy as np\n",
25
+ "import re\n",
26
+ "\n",
27
+ "from medrax.agent import *\n",
28
+ "from medrax.tools import *\n",
29
+ "from medrax.utils import *\n",
30
+ "\n",
31
+ "import json\n",
32
+ "import openai\n",
33
+ "import os\n",
34
+ "import glob\n",
35
+ "import time\n",
36
+ "import logging\n",
37
+ "from datetime import datetime\n",
38
+ "from tenacity import retry, wait_exponential, stop_after_attempt\n",
39
+ "\n",
40
+ "warnings.filterwarnings(\"ignore\")\n",
41
+ "_ = load_dotenv()\n",
42
+ "\n",
43
+ "\n",
44
+ "# Setup directory paths\n",
45
+ "ROOT = \"set this directory to where MedRAX is, .e.g /home/MedRAX\"\n",
46
+ "PROMPT_FILE = f\"{ROOT}/medrax/docs/system_prompts.txt\"\n",
47
+ "BENCHMARK_FILE = f\"{ROOT}/benchmark/questions\"\n",
48
+ "MODEL_DIR = f\"set this to where the tool models are, e.g /home/models\"\n",
49
+ "FIGURES_DIR = f\"{ROOT}/benchmark/figures\"\n",
50
+ "\n",
51
+ "model_name = \"medrax\"\n",
52
+ "temperature = 0.2\n",
53
+ "medrax_logs = f\"{ROOT}/experiments/medrax_logs\"\n",
54
+ "log_filename = f\"{medrax_logs}/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json\"\n",
55
+ "logging.basicConfig(filename=log_filename, level=logging.INFO, format=\"%(message)s\", force=True)\n",
56
+ "device = \"cuda\""
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 2,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "def get_tools():\n",
66
+ " report_tool = ChestXRayReportGeneratorTool(cache_dir=MODEL_DIR, device=device)\n",
67
+ " xray_classification_tool = ChestXRayClassifierTool(device=device)\n",
68
+ " segmentation_tool = ChestXRaySegmentationTool(device=device)\n",
69
+ " grounding_tool = XRayPhraseGroundingTool(\n",
70
+ " cache_dir=MODEL_DIR, temp_dir=\"temp\", device=device, load_in_8bit=True\n",
71
+ " )\n",
72
+ " xray_vqa_tool = XRayVQATool(cache_dir=MODEL_DIR, device=device)\n",
73
+ " llava_med_tool = LlavaMedTool(cache_dir=MODEL_DIR, device=device, load_in_8bit=True)\n",
74
+ "\n",
75
+ " return [\n",
76
+ " report_tool,\n",
77
+ " xray_classification_tool,\n",
78
+ " segmentation_tool,\n",
79
+ " grounding_tool,\n",
80
+ " xray_vqa_tool,\n",
81
+ " llava_med_tool,\n",
82
+ " ]\n",
83
+ "\n",
84
+ "\n",
85
+ "def get_agent(tools):\n",
86
+ " prompts = load_prompts_from_file(PROMPT_FILE)\n",
87
+ " prompt = prompts[\"MEDICAL_ASSISTANT\"]\n",
88
+ "\n",
89
+ " checkpointer = MemorySaver()\n",
90
+ " model = ChatOpenAI(model=\"gpt-4o\", temperature=temperature, top_p=0.95)\n",
91
+ " agent = Agent(\n",
92
+ " model,\n",
93
+ " tools=tools,\n",
94
+ " log_tools=True,\n",
95
+ " log_dir=\"logs\",\n",
96
+ " system_prompt=prompt,\n",
97
+ " checkpointer=checkpointer,\n",
98
+ " )\n",
99
+ " thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
100
+ " return agent, thread\n",
101
+ "\n",
102
+ "\n",
103
+ "def run_medrax(agent, thread, prompt, image_urls=[]):\n",
104
+ " messages = [\n",
105
+ " HumanMessage(\n",
106
+ " content=[\n",
107
+ " {\"type\": \"text\", \"text\": prompt},\n",
108
+ " ]\n",
109
+ " + [{\"type\": \"image_url\", \"image_url\": {\"url\": image_url}} for image_url in image_urls]\n",
110
+ " )\n",
111
+ " ]\n",
112
+ "\n",
113
+ " final_response = None\n",
114
+ " for event in agent.workflow.stream({\"messages\": messages}, thread):\n",
115
+ " for v in event.values():\n",
116
+ " final_response = v\n",
117
+ "\n",
118
+ " final_response = final_response[\"messages\"][-1].content.strip()\n",
119
+ " agent_state = agent.workflow.get_state(thread)\n",
120
+ "\n",
121
+ " return final_response, str(agent_state)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 3,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "def create_multimodal_request(question_data, case_details, case_id, question_id, agent, thread):\n",
131
+ " # Parse required figures\n",
132
+ " try:\n",
133
+ " # Try multiple ways of parsing figures\n",
134
+ " if isinstance(question_data[\"figures\"], str):\n",
135
+ " try:\n",
136
+ " required_figures = json.loads(question_data[\"figures\"])\n",
137
+ " except json.JSONDecodeError:\n",
138
+ " required_figures = [question_data[\"figures\"]]\n",
139
+ " elif isinstance(question_data[\"figures\"], list):\n",
140
+ " required_figures = question_data[\"figures\"]\n",
141
+ " else:\n",
142
+ " required_figures = [str(question_data[\"figures\"])]\n",
143
+ " except Exception as e:\n",
144
+ " print(f\"Error parsing figures: {e}\")\n",
145
+ " required_figures = []\n",
146
+ "\n",
147
+ " # Ensure each figure starts with \"Figure \"\n",
148
+ " required_figures = [\n",
149
+ " fig if fig.startswith(\"Figure \") else f\"Figure {fig}\" for fig in required_figures\n",
150
+ " ]\n",
151
+ "\n",
152
+ " subfigures = []\n",
153
+ " for figure in required_figures:\n",
154
+ " # Handle both regular figures and those with letter suffixes\n",
155
+ " base_figure_num = \"\".join(filter(str.isdigit, figure))\n",
156
+ " figure_letter = \"\".join(filter(str.isalpha, figure.split()[-1])) or None\n",
157
+ "\n",
158
+ " # Find matching figures in case details\n",
159
+ " matching_figures = [\n",
160
+ " case_figure\n",
161
+ " for case_figure in case_details.get(\"figures\", [])\n",
162
+ " if case_figure[\"number\"] == f\"Figure {base_figure_num}\"\n",
163
+ " ]\n",
164
+ "\n",
165
+ " if not matching_figures:\n",
166
+ " print(f\"No matching figure found for {figure} in case {case_id}\")\n",
167
+ " continue\n",
168
+ "\n",
169
+ " for case_figure in matching_figures:\n",
170
+ " # If a specific letter is specified, filter subfigures\n",
171
+ " if figure_letter:\n",
172
+ " matching_subfigures = [\n",
173
+ " subfig\n",
174
+ " for subfig in case_figure.get(\"subfigures\", [])\n",
175
+ " if subfig.get(\"number\", \"\").lower().endswith(figure_letter.lower())\n",
176
+ " or subfig.get(\"label\", \"\").lower() == figure_letter.lower()\n",
177
+ " ]\n",
178
+ " subfigures.extend(matching_subfigures)\n",
179
+ " else:\n",
180
+ " # If no letter specified, add all subfigures\n",
181
+ " subfigures.extend(case_figure.get(\"subfigures\", []))\n",
182
+ "\n",
183
+ " # Add images to content\n",
184
+ " figure_prompt = \"\"\n",
185
+ " image_urls = []\n",
186
+ "\n",
187
+ " for subfig in subfigures:\n",
188
+ " if \"number\" in subfig:\n",
189
+ " subfig_number = subfig[\"number\"].lower().strip().replace(\" \", \"_\") + \".jpg\"\n",
190
+ " subfig_path = os.path.join(FIGURES_DIR, case_id, subfig_number)\n",
191
+ " figure_prompt += f\"{subfig_number} located at {subfig_path}\\n\"\n",
192
+ " if \"url\" in subfig:\n",
193
+ " image_urls.append(subfig[\"url\"])\n",
194
+ " else:\n",
195
+ " print(f\"Subfigure missing URL: {subfig}\")\n",
196
+ "\n",
197
+ " prompt = (\n",
198
+ " f\"Answer this question correctly using chain of thought reasoning and \"\n",
199
+ " \"carefully evaluating choices. Solve using our own vision and reasoning and then\"\n",
200
+ " \"use tools to complement your reasoning. Trust your own judgement over any tools.\\n\"\n",
201
+ " f\"{question_data['question']}\\n{figure_prompt}\"\n",
202
+ " )\n",
203
+ "\n",
204
+ " try:\n",
205
+ " start_time = time.time()\n",
206
+ "\n",
207
+ " final_response, agent_state = run_medrax(\n",
208
+ " agent=agent, thread=thread, prompt=prompt, image_urls=image_urls\n",
209
+ " )\n",
210
+ " model_answer, agent_state = run_medrax(\n",
211
+ " agent=agent,\n",
212
+ " thread=thread,\n",
213
+ " prompt=\"If you had to choose the best option, only respond with the letter of choice (only one of A, B, C, D, E, F)\",\n",
214
+ " )\n",
215
+ " duration = time.time() - start_time\n",
216
+ "\n",
217
+ " log_entry = {\n",
218
+ " \"case_id\": case_id,\n",
219
+ " \"question_id\": question_id,\n",
220
+ " \"timestamp\": datetime.now().isoformat(),\n",
221
+ " \"model\": model_name,\n",
222
+ " \"temperature\": temperature,\n",
223
+ " \"duration\": round(duration, 2),\n",
224
+ " \"usage\": \"\",\n",
225
+ " \"cost\": 0,\n",
226
+ " \"raw_response\": final_response,\n",
227
+ " \"model_answer\": model_answer.strip(),\n",
228
+ " \"correct_answer\": question_data[\"answer\"][0],\n",
229
+ " \"input\": {\n",
230
+ " \"messages\": prompt,\n",
231
+ " \"question_data\": {\n",
232
+ " \"question\": question_data[\"question\"],\n",
233
+ " \"explanation\": question_data[\"explanation\"],\n",
234
+ " \"metadata\": question_data.get(\"metadata\", {}),\n",
235
+ " \"figures\": question_data[\"figures\"],\n",
236
+ " },\n",
237
+ " \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
238
+ " \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
239
+ " },\n",
240
+ " \"agent_state\": agent_state,\n",
241
+ " }\n",
242
+ " logging.info(json.dumps(log_entry))\n",
243
+ " return final_response, model_answer.strip()\n",
244
+ "\n",
245
+ " except Exception as e:\n",
246
+ " log_entry = {\n",
247
+ " \"case_id\": case_id,\n",
248
+ " \"question_id\": question_id,\n",
249
+ " \"timestamp\": datetime.now().isoformat(),\n",
250
+ " \"model\": model_name,\n",
251
+ " \"temperature\": temperature,\n",
252
+ " \"status\": \"error\",\n",
253
+ " \"error\": str(e),\n",
254
+ " \"cost\": 0,\n",
255
+ " \"input\": {\n",
256
+ " \"messages\": prompt,\n",
257
+ " \"question_data\": {\n",
258
+ " \"question\": question_data[\"question\"],\n",
259
+ " \"explanation\": question_data[\"explanation\"],\n",
260
+ " \"metadata\": question_data.get(\"metadata\", {}),\n",
261
+ " \"figures\": question_data[\"figures\"],\n",
262
+ " },\n",
263
+ " \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
264
+ " \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
265
+ " },\n",
266
+ " }\n",
267
+ " logging.info(json.dumps(log_entry))\n",
268
+ " print(f\"Error processing case {case_id}, question {question_id}: {str(e)}\")\n",
269
+ " return \"\", \"\"\n",
270
+ "\n",
271
+ "\n",
272
+ "def load_benchmark_questions(case_id):\n",
273
+ " benchmark_dir = \"../benchmark/questions\"\n",
274
+ " return glob.glob(f\"{benchmark_dir}/{case_id}/{case_id}_*.json\")\n",
275
+ "\n",
276
+ "\n",
277
+ "def count_total_questions():\n",
278
+ " total_cases = len(glob.glob(\"../benchmark/questions/*\"))\n",
279
+ " total_questions = sum(\n",
280
+ " len(glob.glob(f\"../benchmark/questions/{case_id}/*.json\"))\n",
281
+ " for case_id in os.listdir(\"../benchmark/questions\")\n",
282
+ " )\n",
283
+ " return total_cases, total_questions\n",
284
+ "\n",
285
+ "\n",
286
+ "def main(tools):\n",
287
+ " with open(\"../data/eurorad_metadata.json\", \"r\") as file:\n",
288
+ " data = json.load(file)\n",
289
+ "\n",
290
+ " total_cases, total_questions = count_total_questions()\n",
291
+ " cases_processed = 0\n",
292
+ " questions_processed = 0\n",
293
+ " skipped_questions = 0\n",
294
+ "\n",
295
+ " print(f\"Beginning benchmark evaluation for model {model_name} with temperature {temperature}\\n\")\n",
296
+ "\n",
297
+ " for case_id, case_details in data.items():\n",
298
+ " if int(case_details[\"case_id\"]) <= 17158:\n",
299
+ " continue\n",
300
+ "\n",
301
+ " print(f\"----------------------------------------------------------------\")\n",
302
+ " agent, thread = get_agent(tools)\n",
303
+ "\n",
304
+ " question_files = load_benchmark_questions(case_id)\n",
305
+ " if not question_files:\n",
306
+ " continue\n",
307
+ "\n",
308
+ " cases_processed += 1\n",
309
+ " for question_file in question_files:\n",
310
+ " with open(question_file, \"r\") as file:\n",
311
+ " question_data = json.load(file)\n",
312
+ " question_id = os.path.basename(question_file).split(\".\")[0]\n",
313
+ "\n",
314
+ " # agent, thread = get_agent(tools)\n",
315
+ " questions_processed += 1\n",
316
+ " final_response, model_answer = create_multimodal_request(\n",
317
+ " question_data, case_details, case_id, question_id, agent, thread\n",
318
+ " )\n",
319
+ "\n",
320
+ " # Handle cases where response is None\n",
321
+ " if final_response is None:\n",
322
+ " skipped_questions += 1\n",
323
+ " print(f\"Skipped question: Case ID {case_id}, Question ID {question_id}\")\n",
324
+ " continue\n",
325
+ "\n",
326
+ " print(\n",
327
+ " f\"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}\"\n",
328
+ " )\n",
329
+ " print(f\"Case ID: {case_id}\")\n",
330
+ " print(f\"Question ID: {question_id}\")\n",
331
+ " print(f\"Final Response: {final_response}\")\n",
332
+ " print(f\"Model Answer: {model_answer}\")\n",
333
+ " print(f\"Correct Answer: {question_data['answer']}\")\n",
334
+ " print(f\"----------------------------------------------------------------\\n\")\n",
335
+ "\n",
336
+ " print(f\"\\nBenchmark Summary:\")\n",
337
+ " print(f\"Total Cases Processed: {cases_processed}\")\n",
338
+ " print(f\"Total Questions Processed: {questions_processed}\")\n",
339
+ " print(f\"Total Questions Skipped: {skipped_questions}\")"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "tools = get_tools()\n",
349
+ "main(tools)"
350
+ ]
351
+ }
352
+ ],
353
+ "metadata": {
354
+ "kernelspec": {
355
+ "display_name": "medmax",
356
+ "language": "python",
357
+ "name": "python3"
358
+ },
359
+ "language_info": {
360
+ "codemirror_mode": {
361
+ "name": "ipython",
362
+ "version": 3
363
+ },
364
+ "file_extension": ".py",
365
+ "mimetype": "text/x-python",
366
+ "name": "python",
367
+ "nbconvert_exporter": "python",
368
+ "pygments_lexer": "ipython3",
369
+ "version": "3.10.16"
370
+ }
371
+ },
372
+ "nbformat": 4,
373
+ "nbformat_minor": 2
374
+ }
experiments/chexbench_gpt4.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import openai
3
+ import os
4
+ from datetime import datetime
5
+ import base64
6
+ import logging
7
+ from pathlib import Path
8
+ import time
9
+ from tqdm import tqdm
10
+ from typing import Dict, List, Optional, Union, Any
11
+
12
+ # Configuration constants
13
+ DEBUG_MODE = False
14
+ OUTPUT_DIR = "results"
15
+ MODEL_NAME = "gpt-4o-2024-05-13"
16
+ TEMPERATURE = 0.2
17
+ SUBSET = "Visual Question Answering"
18
+
19
+ # Set up logging configuration
20
+ logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO
21
+ logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s")
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def get_mime_type(file_path: str) -> str:
26
+ """
27
+ Determine MIME type based on file extension.
28
+
29
+ Args:
30
+ file_path (str): Path to the file
31
+
32
+ Returns:
33
+ str: MIME type string for the file
34
+ """
35
+ extension = os.path.splitext(file_path)[1].lower()
36
+ mime_types = {
37
+ ".png": "image/png",
38
+ ".jpg": "image/jpeg",
39
+ ".jpeg": "image/jpeg",
40
+ ".gif": "image/gif",
41
+ }
42
+ return mime_types.get(extension, "application/octet-stream")
43
+
44
+
45
+ def encode_image(image_path: str) -> str:
46
+ """
47
+ Encode image to base64 with extensive error checking.
48
+
49
+ Args:
50
+ image_path (str): Path to the image file
51
+
52
+ Returns:
53
+ str: Base64 encoded image string
54
+
55
+ Raises:
56
+ FileNotFoundError: If image file does not exist
57
+ ValueError: If image file is empty or too large
58
+ Exception: For other image processing errors
59
+ """
60
+ logger.debug(f"Attempting to read image from: {image_path}")
61
+ if not os.path.exists(image_path):
62
+ raise FileNotFoundError(f"Image file not found: {image_path}")
63
+
64
+ # Add check for file size
65
+ file_size = os.path.getsize(image_path)
66
+ if file_size > 20 * 1024 * 1024: # 20MB limit
67
+ raise ValueError("Image file size exceeds 20MB limit")
68
+ if file_size == 0:
69
+ raise ValueError("Image file is empty")
70
+ logger.debug(f"Image file size: {file_size / 1024:.2f} KB")
71
+
72
+ try:
73
+ from PIL import Image
74
+
75
+ # Try to open and verify the image
76
+ with Image.open(image_path) as img:
77
+ # Get image details
78
+ width, height = img.size
79
+ format = img.format
80
+ mode = img.mode
81
+ logger.debug(
82
+ f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}"
83
+ )
84
+
85
+ if format not in ["PNG", "JPEG", "GIF"]:
86
+ raise ValueError(f"Unsupported image format: {format}")
87
+
88
+ with open(image_path, "rb") as image_file:
89
+ # Read the first few bytes to verify it's a valid PNG
90
+ header = image_file.read(8)
91
+ # if header != b'\x89PNG\r\n\x1a\n':
92
+ # logger.warning("File does not have a valid PNG signature")
93
+
94
+ # Reset file pointer and read entire file
95
+ image_file.seek(0)
96
+ encoded = base64.b64encode(image_file.read()).decode("utf-8")
97
+ encoded_length = len(encoded)
98
+ logger.debug(f"Base64 encoded length: {encoded_length} characters")
99
+
100
+ # Verify the encoded string is not empty and starts correctly
101
+ if encoded_length == 0:
102
+ raise ValueError("Base64 encoding produced empty string")
103
+ if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"):
104
+ logger.warning("Base64 string doesn't start with expected JPEG or PNG header")
105
+
106
+ return encoded
107
+ except Exception as e:
108
+ logger.error(f"Error reading/encoding image: {str(e)}")
109
+ raise
110
+
111
+
112
+ def create_single_request(
113
+ image_path: str, question: str, options: Dict[str, str]
114
+ ) -> List[Dict[str, Any]]:
115
+ """
116
+ Create a single API request with image and question.
117
+
118
+ Args:
119
+ image_path (str): Path to the image file
120
+ question (str): Question text
121
+ options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1'
122
+
123
+ Returns:
124
+ List[Dict[str, Any]]: List of message dictionaries for the API request
125
+
126
+ Raises:
127
+ Exception: For errors in request creation
128
+ """
129
+ if DEBUG_MODE:
130
+ logger.debug("Creating API request...")
131
+
132
+ prompt = f"""Given the following medical examination question:
133
+ Please answer this multiple choice question:
134
+
135
+ Question: {question}
136
+
137
+ Options:
138
+ A) {options['option_0']}
139
+ B) {options['option_1']}
140
+
141
+ Base your answer only on the provided image and select either A or B."""
142
+
143
+ try:
144
+ encoded_image = encode_image(image_path)
145
+ mime_type = get_mime_type(image_path)
146
+
147
+ if DEBUG_MODE:
148
+ logger.debug(f"Image encoded with MIME type: {mime_type}")
149
+
150
+ messages = [
151
+ {
152
+ "role": "system",
153
+ "content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.",
154
+ },
155
+ {
156
+ "role": "user",
157
+ "content": [
158
+ {"type": "text", "text": prompt},
159
+ {
160
+ "type": "image_url",
161
+ "image_url": {"url": f"data:{mime_type};base64,{encoded_image}"},
162
+ },
163
+ ],
164
+ },
165
+ ]
166
+
167
+ if DEBUG_MODE:
168
+ log_messages = json.loads(json.dumps(messages))
169
+ log_messages[1]["content"][1]["image_url"][
170
+ "url"
171
+ ] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]"
172
+ logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}")
173
+
174
+ return messages
175
+
176
+ except Exception as e:
177
+ logger.error(f"Error creating request: {str(e)}")
178
+ raise
179
+
180
+
181
+ def check_answer(model_answer: str, correct_answer: int) -> bool:
182
+ """
183
+ Check if the model's answer matches the correct answer.
184
+
185
+ Args:
186
+ model_answer (str): The model's answer (A or B)
187
+ correct_answer (int): The correct answer index (0 for A, 1 for B)
188
+
189
+ Returns:
190
+ bool: True if answer is correct, False otherwise
191
+ """
192
+ if not isinstance(model_answer, str):
193
+ return False
194
+
195
+ # Clean the model answer to get just the letter
196
+ model_letter = model_answer.strip().upper()
197
+ if model_letter.startswith("A"):
198
+ model_index = 0
199
+ elif model_letter.startswith("B"):
200
+ model_index = 1
201
+ else:
202
+ return False
203
+
204
+ return model_index == correct_answer
205
+
206
+
207
+ def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str:
208
+ """
209
+ Save results to a JSON file with timestamp.
210
+
211
+ Args:
212
+ results (List[Dict[str, Any]]): List of result dictionaries
213
+ output_dir (str): Directory to save results
214
+
215
+ Returns:
216
+ str: Path to the saved file
217
+ """
218
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
219
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
220
+ output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json")
221
+
222
+ with open(output_file, "w") as f:
223
+ json.dump(results, f, indent=2)
224
+
225
+ logger.info(f"Batch results saved to {output_file}")
226
+ return output_file
227
+
228
+
229
+ def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]:
230
+ """
231
+ Calculate accuracy from results, handling error cases.
232
+
233
+ Args:
234
+ results (List[Dict[str, Any]]): List of result dictionaries
235
+
236
+ Returns:
237
+ tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total)
238
+ """
239
+ if not results:
240
+ return 0.0, 0, 0
241
+
242
+ total = len(results)
243
+ valid_results = [r for r in results if "output" in r]
244
+ correct = sum(
245
+ 1 for result in valid_results if result.get("output", {}).get("is_correct", False)
246
+ )
247
+
248
+ accuracy = (correct / total * 100) if total > 0 else 0
249
+ return accuracy, correct, total
250
+
251
+
252
+ def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float:
253
+ """
254
+ Calculate accuracy for the current batch.
255
+
256
+ Args:
257
+ results (List[Dict[str, Any]]): List of result dictionaries
258
+
259
+ Returns:
260
+ float: Accuracy percentage for the batch
261
+ """
262
+ valid_results = [r for r in results if "output" in r]
263
+ if not valid_results:
264
+ return 0.0
265
+ return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100
266
+
267
+
268
+ def process_batch(
269
+ data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50
270
+ ) -> List[Dict[str, Any]]:
271
+ """
272
+ Process a batch of examples and return results.
273
+
274
+ Args:
275
+ data (List[Dict[str, Any]]): List of data items to process
276
+ client (openai.OpenAI): OpenAI client instance
277
+ start_idx (int, optional): Starting index for batch. Defaults to 0
278
+ batch_size (int, optional): Size of batch to process. Defaults to 50
279
+
280
+ Returns:
281
+ List[Dict[str, Any]]: List of processed results
282
+ """
283
+ batch_results = []
284
+ end_idx = min(start_idx + batch_size, len(data))
285
+
286
+ pbar = tqdm(
287
+ range(start_idx, end_idx),
288
+ desc=f"Processing batch {start_idx//batch_size + 1}",
289
+ unit="example",
290
+ )
291
+
292
+ for index in pbar:
293
+ vqa_item = data[index]
294
+ options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]}
295
+
296
+ try:
297
+ messages = create_single_request(
298
+ image_path=vqa_item["image_path"], question=vqa_item["question"], options=options
299
+ )
300
+
301
+ response = client.chat.completions.create(
302
+ model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE
303
+ )
304
+
305
+ model_answer = response.choices[0].message.content.strip()
306
+ is_correct = check_answer(model_answer, vqa_item["answer"])
307
+
308
+ result = {
309
+ "timestamp": datetime.now().isoformat(),
310
+ "example_index": index,
311
+ "input": {
312
+ "question": vqa_item["question"],
313
+ "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
314
+ "image_path": vqa_item["image_path"],
315
+ },
316
+ "output": {
317
+ "model_answer": model_answer,
318
+ "correct_answer": "A" if vqa_item["answer"] == 0 else "B",
319
+ "is_correct": is_correct,
320
+ "usage": {
321
+ "prompt_tokens": response.usage.prompt_tokens,
322
+ "completion_tokens": response.usage.completion_tokens,
323
+ "total_tokens": response.usage.total_tokens,
324
+ },
325
+ },
326
+ }
327
+ batch_results.append(result)
328
+
329
+ # Update progress bar with current accuracy
330
+ current_accuracy = calculate_batch_accuracy(batch_results)
331
+ pbar.set_description(
332
+ f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% "
333
+ f"({len(batch_results)}/{index-start_idx+1} examples)"
334
+ )
335
+
336
+ except Exception as e:
337
+ error_result = {
338
+ "timestamp": datetime.now().isoformat(),
339
+ "example_index": index,
340
+ "error": str(e),
341
+ "input": {
342
+ "question": vqa_item["question"],
343
+ "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
344
+ "image_path": vqa_item["image_path"],
345
+ },
346
+ }
347
+ batch_results.append(error_result)
348
+ if DEBUG_MODE:
349
+ pbar.write(f"Error processing example {index}: {str(e)}")
350
+
351
+ time.sleep(1) # Rate limiting
352
+
353
+ return batch_results
354
+
355
+
356
+ def main() -> None:
357
+ """
358
+ Main function to process the entire dataset.
359
+
360
+ Raises:
361
+ ValueError: If OPENAI_API_KEY is not set
362
+ Exception: For other processing errors
363
+ """
364
+ logger.info("Starting full dataset processing...")
365
+ json_path = "../data/chexbench_updated.json"
366
+
367
+ try:
368
+ api_key = os.getenv("OPENAI_API_KEY")
369
+ if not api_key:
370
+ raise ValueError("OPENAI_API_KEY environment variable is not set.")
371
+ client = openai.OpenAI(api_key=api_key)
372
+
373
+ with open(json_path, "r") as f:
374
+ data = json.load(f)
375
+
376
+ subset_data = data[SUBSET]
377
+ total_examples = len(subset_data)
378
+ logger.info(f"Found {total_examples} examples in {SUBSET} subset")
379
+
380
+ all_results = []
381
+ batch_size = 50 # Process in batches of 50 examples
382
+
383
+ # Process all examples in batches
384
+ for start_idx in range(0, total_examples, batch_size):
385
+ batch_results = process_batch(subset_data, client, start_idx, batch_size)
386
+ all_results.extend(batch_results)
387
+
388
+ # Save intermediate results after each batch
389
+ output_file = save_results_to_json(all_results, OUTPUT_DIR)
390
+
391
+ # Calculate and log overall progress
392
+ overall_accuracy, correct, total = calculate_accuracy(all_results)
393
+ logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed")
394
+ logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)")
395
+
396
+ logger.info("Processing completed!")
397
+ logger.info(f"Final results saved to: {output_file}")
398
+
399
+ except Exception as e:
400
+ logger.error(f"Fatal error: {str(e)}")
401
+ raise
402
+
403
+
404
+ if __name__ == "__main__":
405
+ main()
experiments/compare_runs.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import random
4
+ from typing import List, Dict, Any, Tuple
5
+ import re
6
+ from collections import defaultdict
7
+
8
+ # Define category order
9
+ CATEGORY_ORDER = [
10
+ "detection",
11
+ "classification",
12
+ "localization",
13
+ "comparison",
14
+ "relationship",
15
+ "diagnosis",
16
+ "characterization",
17
+ ]
18
+
19
+
20
+ def extract_letter_answer(answer: str) -> str:
21
+ """Extract just the letter answer from various answer formats.
22
+
23
+ Args:
24
+ answer: The answer string to extract a letter from
25
+
26
+ Returns:
27
+ str: The extracted letter in uppercase, or empty string if no letter found
28
+ """
29
+ if not answer:
30
+ return ""
31
+
32
+ # Convert to string and clean
33
+ answer = str(answer).strip()
34
+
35
+ # If it's just a single letter A-F, return it
36
+ if len(answer) == 1 and answer.upper() in "ABCDEF":
37
+ return answer.upper()
38
+
39
+ # Try to match patterns like "A)", "A.", "A ", etc.
40
+ match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE)
41
+ if match:
42
+ return match.group(1).upper()
43
+
44
+ # Try to find any standalone A-F letters preceded by space or start of string
45
+ # and followed by space, period, parenthesis or end of string
46
+ matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE)
47
+ if matches:
48
+ return matches[0].upper()
49
+
50
+ # Last resort: just find any A-F letter
51
+ letters = re.findall(r"[A-F]", answer, re.IGNORECASE)
52
+ if letters:
53
+ return letters[0].upper()
54
+
55
+ # If no letter found, return original (cleaned)
56
+ return answer.strip().upper()
57
+
58
+
59
+ def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]:
60
+ """Parse JSON Lines file and extract valid predictions.
61
+
62
+ Args:
63
+ file_path: Path to the JSON Lines file to parse
64
+
65
+ Returns:
66
+ Tuple containing:
67
+ - str: Model name or file path if model name not found
68
+ - List[Dict[str, Any]]: List of valid prediction entries
69
+ """
70
+ valid_predictions = []
71
+ model_name = None
72
+
73
+ # First try to parse as LLaVA format
74
+ try:
75
+ with open(file_path, "r", encoding="utf-8") as f:
76
+ data = json.load(f)
77
+ if data.get("model") == "llava-med-v1.5-mistral-7b":
78
+ model_name = data["model"]
79
+ for result in data.get("results", []):
80
+ if all(k in result for k in ["case_id", "question_id", "correct_answer"]):
81
+ # Extract answer with priority: model_answer > validated_answer > raw_output
82
+ model_answer = (
83
+ result.get("model_answer")
84
+ or result.get("validated_answer")
85
+ or result.get("raw_output", "")
86
+ )
87
+
88
+ # Add default categories for LLaVA results
89
+ prediction = {
90
+ "case_id": result["case_id"],
91
+ "question_id": result["question_id"],
92
+ "model_answer": model_answer,
93
+ "correct_answer": result["correct_answer"],
94
+ "input": {
95
+ "question_data": {
96
+ "metadata": {
97
+ "categories": [
98
+ "detection",
99
+ "classification",
100
+ "localization",
101
+ "comparison",
102
+ "relationship",
103
+ "diagnosis",
104
+ "characterization",
105
+ ]
106
+ }
107
+ }
108
+ },
109
+ }
110
+ valid_predictions.append(prediction)
111
+ return model_name, valid_predictions
112
+ except (json.JSONDecodeError, KeyError):
113
+ pass
114
+
115
+ # If not LLaVA format, process as original format
116
+ with open(file_path, "r", encoding="utf-8") as f:
117
+ for line in f:
118
+ if line.startswith("HTTP Request:"):
119
+ continue
120
+ try:
121
+ data = json.loads(line.strip())
122
+ if "model" in data:
123
+ model_name = data["model"]
124
+ if all(
125
+ k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"]
126
+ ):
127
+ valid_predictions.append(data)
128
+ except json.JSONDecodeError:
129
+ continue
130
+
131
+ return model_name if model_name else file_path, valid_predictions
132
+
133
+
134
+ def filter_common_questions(
135
+ predictions_list: List[List[Dict[str, Any]]]
136
+ ) -> List[List[Dict[str, Any]]]:
137
+ """Ensure only questions that exist across all models are evaluated.
138
+
139
+ Args:
140
+ predictions_list: List of prediction lists from different models
141
+
142
+ Returns:
143
+ List[List[Dict[str, Any]]]: Filtered predictions containing only common questions
144
+ """
145
+ question_sets = [
146
+ set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list
147
+ ]
148
+ common_questions = set.intersection(*question_sets)
149
+
150
+ return [
151
+ [p for p in preds if (p["case_id"], p["question_id"]) in common_questions]
152
+ for preds in predictions_list
153
+ ]
154
+
155
+
156
+ def calculate_accuracy(
157
+ predictions: List[Dict[str, Any]]
158
+ ) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]:
159
+ """Compute overall and category-level accuracy.
160
+
161
+ Args:
162
+ predictions: List of prediction entries to analyze
163
+
164
+ Returns:
165
+ Tuple containing:
166
+ - float: Overall accuracy percentage
167
+ - int: Number of correct predictions
168
+ - int: Total number of predictions
169
+ - Dict[str, Dict[str, float]]: Category-level accuracy statistics
170
+ """
171
+ if not predictions:
172
+ return 0.0, 0, 0, {}
173
+
174
+ category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
175
+ correct = 0
176
+ total = 0
177
+ sample_size = min(5, len(predictions))
178
+ sampled_indices = random.sample(range(len(predictions)), sample_size)
179
+
180
+ print("\nSample extracted answers:")
181
+ for i in sampled_indices:
182
+ pred = predictions[i]
183
+ model_ans = extract_letter_answer(pred["model_answer"])
184
+ correct_ans = extract_letter_answer(pred["correct_answer"])
185
+ print(f"QID: {pred['question_id']}")
186
+ print(f" Raw Model Answer: {pred['model_answer']}")
187
+ print(f" Extracted Model Answer: {model_ans}")
188
+ print(f" Raw Correct Answer: {pred['correct_answer']}")
189
+ print(f" Extracted Correct Answer: {correct_ans}")
190
+ print("-" * 80)
191
+
192
+ for pred in predictions:
193
+ try:
194
+ model_ans = extract_letter_answer(pred["model_answer"])
195
+ correct_ans = extract_letter_answer(pred["correct_answer"])
196
+ categories = (
197
+ pred.get("input", {})
198
+ .get("question_data", {})
199
+ .get("metadata", {})
200
+ .get("categories", [])
201
+ )
202
+
203
+ if model_ans and correct_ans:
204
+ total += 1
205
+ is_correct = model_ans == correct_ans
206
+ if is_correct:
207
+ correct += 1
208
+
209
+ for category in categories:
210
+ category_performance[category]["total"] += 1
211
+ if is_correct:
212
+ category_performance[category]["correct"] += 1
213
+
214
+ except KeyError:
215
+ continue
216
+
217
+ category_accuracies = {
218
+ category: {
219
+ "accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0,
220
+ "total": stats["total"],
221
+ "correct": stats["correct"],
222
+ }
223
+ for category, stats in category_performance.items()
224
+ }
225
+
226
+ return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies)
227
+
228
+
229
+ def compare_models(file_paths: List[str]) -> None:
230
+ """Compare accuracy between multiple model prediction files.
231
+
232
+ Args:
233
+ file_paths: List of paths to model prediction files to compare
234
+ """
235
+ # Parse all files
236
+ parsed_results = [parse_json_lines(file_path) for file_path in file_paths]
237
+ model_names, predictions_list = zip(*parsed_results)
238
+
239
+ # Get initial stats
240
+ print(f"\n📊 **Initial Accuracy**:")
241
+ results = []
242
+ category_results = []
243
+
244
+ for preds, name in zip(predictions_list, model_names):
245
+ acc, correct, total, category_acc = calculate_accuracy(preds)
246
+ results.append((acc, correct, total, name))
247
+ category_results.append(category_acc)
248
+ print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
249
+
250
+ # Get common questions across all models
251
+ filtered_predictions = filter_common_questions(predictions_list)
252
+ print(
253
+ f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}"
254
+ )
255
+
256
+ # Compute accuracy on common questions
257
+ print(f"\n📊 **Accuracy on Common Questions**:")
258
+ filtered_results = []
259
+ filtered_category_results = []
260
+
261
+ for preds, name in zip(filtered_predictions, model_names):
262
+ acc, correct, total, category_acc = calculate_accuracy(preds)
263
+ filtered_results.append((acc, correct, total, name))
264
+ filtered_category_results.append(category_acc)
265
+ print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
266
+
267
+ # Print category-wise accuracy
268
+ print("\nCategory Performance (Common Questions):")
269
+ for category in CATEGORY_ORDER:
270
+ print(f"\n{category.capitalize()}:")
271
+ for model_name, category_acc in zip(model_names, filtered_category_results):
272
+ stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0})
273
+ print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})")
274
+
275
+
276
+ def main():
277
+ parser = argparse.ArgumentParser(
278
+ description="Compare accuracy across multiple model prediction files"
279
+ )
280
+ parser.add_argument("files", nargs="+", help="Paths to model prediction files")
281
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
282
+
283
+ args = parser.parse_args()
284
+ random.seed(args.seed)
285
+
286
+ compare_models(args.files)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
experiments/inspect_logs.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import argparse
3
+ import json
4
+ import glob
5
+ from pathlib import Path
6
+ from datetime import datetime
7
+
8
+
9
+ def get_latest_log() -> str:
10
+ """Find the most recently modified log file in the current directory.
11
+
12
+ Returns:
13
+ str: Path to the most recently modified log file
14
+
15
+ Raises:
16
+ FileNotFoundError: If no log files are found in the current directory
17
+ """
18
+ logs = list(Path(".").glob("api_usage_*.json"))
19
+ if not logs:
20
+ raise FileNotFoundError("No log files found in the current directory.")
21
+ return str(max(logs, key=lambda p: p.stat().st_mtime))
22
+
23
+
24
+ def format_cost(entry: dict) -> str:
25
+ """Format cost if available, otherwise return 'N/A'
26
+
27
+ Args:
28
+ entry: Log entry dictionary containing cost information
29
+
30
+ Returns:
31
+ str: Formatted cost string with $ and 4 decimal places, or 'N/A' if cost not found
32
+ """
33
+ return f"${entry.get('cost', 'N/A'):.4f}" if "cost" in entry else "N/A"
34
+
35
+
36
+ def print_gpt4_entry(entry: dict) -> None:
37
+ """Print entry for GPT-4 format
38
+
39
+ Args:
40
+ entry: Log entry dictionary in GPT-4 format containing model info, inputs and outputs
41
+ """
42
+ print("\n=== Log Entry ===")
43
+ print(f"Model: {entry['model']}")
44
+ print(f"Case ID: {entry['case_id']}")
45
+ print(f"Question ID: {entry['question_id']}")
46
+
47
+ print("\n=== Model Input ===")
48
+ messages = entry["input"]["messages"]
49
+ print("System message:", messages[0]["content"])
50
+ user_content = messages[1]["content"]
51
+ print("\nUser prompt:", user_content[0]["text"])
52
+ print("\nImages provided:")
53
+ for content in user_content[1:]:
54
+ print(f" - {content['image_url']['url']}")
55
+
56
+ print("\n=== Model Output ===")
57
+ print(f"Answer: {entry['model_answer']}")
58
+ print(f"Correct: {entry['correct_answer']}")
59
+
60
+ print("\n=== Usage Stats ===")
61
+ print(f"Duration: {entry['duration']}s")
62
+ print(f"Cost: {format_cost(entry)}")
63
+ print(
64
+ f"Tokens: {entry['usage']['total_tokens']}",
65
+ f"(prompt: {entry['usage']['prompt_tokens']},",
66
+ f"completion: {entry['usage']['completion_tokens']})",
67
+ )
68
+
69
+
70
+ def print_llama_entry(entry: dict) -> None:
71
+ """Print entry for Llama-3.2 format
72
+
73
+ Args:
74
+ entry: Log entry dictionary in Llama format containing model info, inputs and outputs
75
+ """
76
+ print("\n=== Log Entry ===")
77
+ print(f"Model: {entry['model']}")
78
+ print(f"Case ID: {entry['case_id']}")
79
+ print(f"Question ID: {entry['question_id']}")
80
+
81
+ print("\n=== Model Input ===")
82
+ print(f"Question: {entry['input']['question_data']['question']}")
83
+ print("\nImages provided:")
84
+ for url in entry["input"]["image_urls"]:
85
+ print(f" - {url}")
86
+ if entry["input"]["image_captions"]:
87
+ print("\nImage captions:")
88
+ for caption in entry["input"]["image_captions"]:
89
+ if caption:
90
+ print(f" - {caption}")
91
+
92
+ print("\n=== Model Output ===")
93
+ print(f"Answer: {entry['model_answer']}")
94
+ print(f"Correct: {entry['correct_answer']}")
95
+
96
+ print("\n=== Usage Stats ===")
97
+ print(f"Duration: {entry['duration']}s")
98
+ if "usage" in entry:
99
+ print(
100
+ f"Tokens: {entry['usage']['total_tokens']}",
101
+ f"(prompt: {entry['usage']['prompt_tokens']},",
102
+ f"completion: {entry['usage']['completion_tokens']})",
103
+ )
104
+
105
+
106
+ def determine_model_type(entry: dict) -> str:
107
+ """Determine the model type from the entry
108
+
109
+ Args:
110
+ entry: Log entry dictionary containing model information
111
+
112
+ Returns:
113
+ str: Model type - 'gpt4', 'llama', or 'unknown'
114
+ """
115
+ model = entry.get("model", "").lower()
116
+ if "gpt-4" in model:
117
+ return "gpt4"
118
+ elif "llama" in model:
119
+ return "llama"
120
+ elif "chexagent" in model:
121
+ return "chexagent"
122
+ elif "medrax" in model:
123
+ return "medrax"
124
+ else:
125
+ return "unknown"
126
+
127
+
128
+ def print_log_entry(
129
+ log_file: Optional[str] = None,
130
+ num_entries: Optional[int] = None,
131
+ model_filter: Optional[str] = None,
132
+ ) -> None:
133
+ """Print log entries from the specified log file or the latest log file.
134
+
135
+ Args:
136
+ log_file: Path to the log file. If None, uses the latest log file.
137
+ num_entries: Number of entries to print. If None, prints all entries.
138
+ model_filter: Filter entries by model type ('gpt4' or 'llama'). If None, prints all.
139
+ """
140
+ if log_file is None:
141
+ log_file = get_latest_log()
142
+ print(f"Using latest log file: {log_file}")
143
+
144
+ entries_printed = 0
145
+ total_entries = 0
146
+ filtered_entries = 0
147
+
148
+ with open(log_file, "r") as f:
149
+ for line in f:
150
+ if line.startswith("HTTP"):
151
+ continue
152
+ try:
153
+ total_entries += 1
154
+ entry = json.loads(line)
155
+
156
+ # Apply model filter if specified
157
+ model_type = determine_model_type(entry)
158
+ if model_filter and model_type != model_filter:
159
+ filtered_entries += 1
160
+ continue
161
+
162
+ if model_type == "gpt4":
163
+ print_gpt4_entry(entry)
164
+ elif model_type == "llama":
165
+ print_llama_entry(entry)
166
+ else:
167
+ print(f"Unknown model type in entry: {entry['model']}")
168
+ continue
169
+
170
+ print("=" * 50)
171
+ entries_printed += 1
172
+ if num_entries and entries_printed >= num_entries:
173
+ break
174
+
175
+ except (json.JSONDecodeError, KeyError) as e:
176
+ print(f"Error processing entry: {e}")
177
+ continue
178
+
179
+ print(f"\nSummary:")
180
+ print(f"Total entries: {total_entries}")
181
+ print(f"Entries printed: {entries_printed}")
182
+ if model_filter:
183
+ print(f"Entries filtered: {filtered_entries}")
184
+
185
+
186
+ def main() -> None:
187
+ """Main entry point for the script"""
188
+ parser = argparse.ArgumentParser(
189
+ description="Parse and display log entries from API usage logs."
190
+ )
191
+ parser.add_argument("-l", "--log_file", nargs="?", help="Path to the log file (optional)")
192
+ parser.add_argument("-n", "--num_entries", type=int, help="Number of entries to display")
193
+ parser.add_argument(
194
+ "-m",
195
+ "--model",
196
+ choices=["gpt4", "llama"],
197
+ default="gpt4",
198
+ help="Model type to display (default: gpt4)",
199
+ )
200
+ args = parser.parse_args()
201
+
202
+ try:
203
+ print_log_entry(args.log_file, args.num_entries, args.model)
204
+ except FileNotFoundError as e:
205
+ print(f"Error: {e}")
206
+ exit(1)
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
experiments/validate_logs.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Optional
2
+ import json
3
+ import sys
4
+ import glob
5
+ from pathlib import Path
6
+ from collections import defaultdict
7
+
8
+
9
+ def get_latest_log() -> str:
10
+ """Find the most recently modified log file in the current directory.
11
+
12
+ Returns:
13
+ str: Path to the most recently modified log file
14
+
15
+ Raises:
16
+ SystemExit: If no log files are found in current directory
17
+ """
18
+ log_pattern = "api_usage_*.json"
19
+ logs = list(Path(".").glob(log_pattern))
20
+ if not logs:
21
+ print(f"No files matching pattern '{log_pattern}' found in current directory")
22
+ sys.exit(1)
23
+ return str(max(logs, key=lambda p: p.stat().st_mtime))
24
+
25
+
26
+ def analyze_log_file(filename: str) -> Tuple[List[Dict], List[Dict], Dict[str, List[str]]]:
27
+ """Analyze a log file for entries missing images and errors.
28
+
29
+ Args:
30
+ filename: Path to the log file to analyze
31
+
32
+ Returns:
33
+ Tuple containing:
34
+ - List of entries with no images
35
+ - List of skipped/error entries
36
+ - Dict of processing errors by type
37
+
38
+ Raises:
39
+ SystemExit: If file cannot be found or read
40
+ """
41
+ no_images = []
42
+ errors = defaultdict(list)
43
+ skipped = []
44
+
45
+ try:
46
+ with open(filename, "r") as f:
47
+ for line_num, line in enumerate(f, 1):
48
+ # Skip HTTP request logs
49
+ if line.startswith("HTTP Request:") or line.strip() == "":
50
+ continue
51
+ try:
52
+ # Try to parse the JSON line
53
+ if not line.strip().startswith("{"):
54
+ continue
55
+ entry = json.loads(line.strip())
56
+ case_id = entry.get("case_id")
57
+ question_id = entry.get("question_id")
58
+
59
+ # Skip if we can't identify the question
60
+ if not case_id or not question_id:
61
+ continue
62
+
63
+ # Check for explicit skip/error status
64
+ if entry.get("status") in ["skipped", "error"]:
65
+ skipped.append(
66
+ {
67
+ "case_id": case_id,
68
+ "question_id": question_id,
69
+ "reason": entry.get("reason"),
70
+ "status": entry.get("status"),
71
+ }
72
+ )
73
+ continue
74
+
75
+ # Check user content for images
76
+ messages = entry.get("input", {}).get("messages", [])
77
+ has_image = False
78
+ for msg in messages:
79
+ content = msg.get("content", [])
80
+ if isinstance(content, list):
81
+ for item in content:
82
+ if isinstance(item, dict) and item.get("type") == "image_url":
83
+ has_image = True
84
+ break
85
+ if not has_image:
86
+ no_images.append(
87
+ {
88
+ "case_id": case_id,
89
+ "question_id": question_id,
90
+ "question": entry.get("input", {})
91
+ .get("question_data", {})
92
+ .get("question", "")[:100]
93
+ + "...", # First 100 chars of question
94
+ }
95
+ )
96
+ except json.JSONDecodeError:
97
+ errors["json_decode"].append(f"Line {line_num}: Invalid JSON")
98
+ continue
99
+ except Exception as e:
100
+ errors["other"].append(f"Line {line_num}: Error processing entry: {str(e)}")
101
+ except FileNotFoundError:
102
+ print(f"Error: Could not find log file: {filename}")
103
+ sys.exit(1)
104
+ except Exception as e:
105
+ print(f"Error reading file {filename}: {str(e)}")
106
+ sys.exit(1)
107
+
108
+ return no_images, skipped, errors
109
+
110
+
111
+ def print_results(
112
+ filename: str, no_images: List[Dict], skipped: List[Dict], errors: Dict[str, List[str]]
113
+ ) -> None:
114
+ """Print analysis results.
115
+
116
+ Args:
117
+ filename: Name of the analyzed log file
118
+ no_images: List of entries with no images
119
+ skipped: List of skipped/error entries
120
+ errors: Dict of processing errors by type
121
+ """
122
+ print(f"\nAnalyzing log file: {filename}")
123
+ print("\n=== Questions with No Images ===")
124
+ if no_images:
125
+ for entry in no_images:
126
+ print(f"\nCase ID: {entry['case_id']}")
127
+ print(f"Question ID: {entry['question_id']}")
128
+ print(f"Question Preview: {entry['question']}")
129
+ print(f"\nTotal questions without images: {len(no_images)}")
130
+
131
+ print("\n=== Skipped/Error Questions ===")
132
+ if skipped:
133
+ for entry in skipped:
134
+ print(f"\nCase ID: {entry['case_id']}")
135
+ print(f"Question ID: {entry['question_id']}")
136
+ print(f"Status: {entry['status']}")
137
+ print(f"Reason: {entry.get('reason', 'unknown')}")
138
+ print(f"\nTotal skipped/error questions: {len(skipped)}")
139
+
140
+ if errors:
141
+ print("\n=== Processing Errors ===")
142
+ for error_type, messages in errors.items():
143
+ if messages:
144
+ print(f"\n{error_type}:")
145
+ for msg in messages:
146
+ print(f" {msg}")
147
+
148
+
149
+ def main() -> None:
150
+ """Main entry point for log validation script."""
151
+ # If a file is specified as an argument, use it; otherwise find the latest log
152
+ if len(sys.argv) > 1:
153
+ log_file = sys.argv[1]
154
+ else:
155
+ log_file = get_latest_log()
156
+
157
+ no_images, skipped, errors = analyze_log_file(log_file)
158
+ print_results(log_file, no_images, skipped, errors)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
handler.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import runpod
3
+ from main import initialize_agent
4
+
5
+ # 初始化agent(可以根据需要选择要使用的工具)
6
+ selected_tools = [
7
+ "ImageVisualizerTool",
8
+ "DicomProcessorTool",
9
+ "ChestXRayClassifierTool",
10
+ "ChestXRaySegmentationTool",
11
+ "ChestXRayReportGeneratorTool",
12
+ "XRayVQATool",
13
+ ]
14
+
15
+ agent, tools_dict = initialize_agent(
16
+ "medrax/docs/system_prompts.txt",
17
+ tools_to_use=selected_tools,
18
+ model_dir="/model-weights"
19
+ )
20
+
21
+ def handler(event):
22
+ """
23
+ 处理RunPod API请求的主函数
24
+ """
25
+ try:
26
+ # 获取请求参数
27
+ job_input = event["input"]
28
+
29
+ # 验证必需的参数
30
+ if "image" not in job_input:
31
+ return {"error": "Missing required parameter: image"}
32
+
33
+ if "task" not in job_input:
34
+ return {"error": "Missing required parameter: task"}
35
+
36
+ image_data = job_input["image"] # 这里假设是base64编码的图像
37
+ task = job_input["task"] # 任务类型
38
+
39
+ # 根据任务类型调用相应的工具
40
+ if task == "classification":
41
+ result = tools_dict["ChestXRayClassifierTool"].run(image_data)
42
+ elif task == "segmentation":
43
+ result = tools_dict["ChestXRaySegmentationTool"].run(image_data)
44
+ elif task == "report":
45
+ result = tools_dict["ChestXRayReportGeneratorTool"].run(image_data)
46
+ else:
47
+ return {"error": f"Unsupported task type: {task}"}
48
+
49
+ return {
50
+ "status": "success",
51
+ "result": result
52
+ }
53
+
54
+ except Exception as e:
55
+ return {"error": str(e)}
56
+
57
+ runpod.serverless.start({"handler": handler})
interface.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+ from pathlib import Path
4
+ import time
5
+ import shutil
6
+ from typing import AsyncGenerator, List, Optional, Tuple
7
+ from gradio import ChatMessage
8
+
9
+
10
+ class ChatInterface:
11
+ """
12
+ A chat interface for interacting with a medical AI agent through Gradio.
13
+
14
+ Handles file uploads, message processing, and chat history management.
15
+ Supports both regular image files and DICOM medical imaging files.
16
+ """
17
+
18
+ def __init__(self, agent, tools_dict):
19
+ """
20
+ Initialize the chat interface.
21
+
22
+ Args:
23
+ agent: The medical AI agent to handle requests
24
+ tools_dict (dict): Dictionary of available tools for image processing
25
+ """
26
+ self.agent = agent
27
+ self.tools_dict = tools_dict
28
+ self.upload_dir = Path("temp")
29
+ self.upload_dir.mkdir(exist_ok=True)
30
+ self.current_thread_id = None
31
+ # Separate storage for original and display paths
32
+ self.original_file_path = None # For LLM (.dcm or other)
33
+ self.display_file_path = None # For UI (always viewable format)
34
+
35
+ def handle_upload(self, file_path: str) -> str:
36
+ """
37
+ Handle new file upload and set appropriate paths.
38
+
39
+ Args:
40
+ file_path (str): Path to the uploaded file
41
+
42
+ Returns:
43
+ str: Display path for UI, or None if no file uploaded
44
+ """
45
+ if not file_path:
46
+ return None
47
+
48
+ source = Path(file_path)
49
+ timestamp = int(time.time())
50
+
51
+ # Save original file with proper suffix
52
+ suffix = source.suffix.lower()
53
+ saved_path = self.upload_dir / f"upload_{timestamp}{suffix}"
54
+ shutil.copy2(file_path, saved_path) # Use file_path directly instead of source
55
+ self.original_file_path = str(saved_path)
56
+
57
+ # Handle DICOM conversion for display only
58
+ if suffix == ".dcm":
59
+ output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path))
60
+ self.display_file_path = output["image_path"]
61
+ else:
62
+ self.display_file_path = str(saved_path)
63
+
64
+ return self.display_file_path
65
+
66
+ def add_message(
67
+ self, message: str, display_image: str, history: List[dict]
68
+ ) -> Tuple[List[dict], gr.Textbox]:
69
+ """
70
+ Add a new message to the chat history.
71
+
72
+ Args:
73
+ message (str): Text message to add
74
+ display_image (str): Path to image being displayed
75
+ history (List[dict]): Current chat history
76
+
77
+ Returns:
78
+ Tuple[List[dict], gr.Textbox]: Updated history and textbox component
79
+ """
80
+ image_path = self.original_file_path or display_image
81
+ if image_path is not None:
82
+ history.append({"role": "user", "content": {"path": image_path}})
83
+ if message is not None:
84
+ history.append({"role": "user", "content": message})
85
+ return history, gr.Textbox(value=message, interactive=False)
86
+
87
+ async def process_message(
88
+ self, message: str, display_image: Optional[str], chat_history: List[ChatMessage]
89
+ ) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]:
90
+ """
91
+ Process a message and generate responses.
92
+
93
+ Args:
94
+ message (str): User message to process
95
+ display_image (Optional[str]): Path to currently displayed image
96
+ chat_history (List[ChatMessage]): Current chat history
97
+
98
+ Yields:
99
+ Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string
100
+ """
101
+ chat_history = chat_history or []
102
+
103
+ # Initialize thread if needed
104
+ if not self.current_thread_id:
105
+ self.current_thread_id = str(time.time())
106
+
107
+ messages = []
108
+ image_path = self.original_file_path or display_image
109
+ if image_path is not None:
110
+ messages.append({"role": "user", "content": f"path: {image_path}"})
111
+ if message is not None:
112
+ messages.append({"role": "user", "content": message})
113
+
114
+ try:
115
+ for event in self.agent.workflow.stream(
116
+ {"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}}
117
+ ):
118
+ if isinstance(event, dict):
119
+ if "process" in event:
120
+ content = event["process"]["messages"][-1].content
121
+ if content:
122
+ content = re.sub(r"temp/[^\s]*", "", content)
123
+ chat_history.append(ChatMessage(role="assistant", content=content))
124
+ yield chat_history, self.display_file_path, ""
125
+
126
+ elif "execute" in event:
127
+ for message in event["execute"]["messages"]:
128
+ tool_name = message.name
129
+ tool_result = eval(message.content)[0]
130
+
131
+ if tool_result:
132
+ metadata = {"title": f"🖼️ Image from tool: {tool_name}"}
133
+ formatted_result = " ".join(
134
+ line.strip() for line in str(tool_result).splitlines()
135
+ ).strip()
136
+ metadata["description"] = formatted_result
137
+ chat_history.append(
138
+ ChatMessage(
139
+ role="assistant",
140
+ content=formatted_result,
141
+ metadata=metadata,
142
+ )
143
+ )
144
+
145
+ # For image_visualizer, use display path
146
+ if tool_name == "image_visualizer":
147
+ self.display_file_path = tool_result["image_path"]
148
+ chat_history.append(
149
+ ChatMessage(
150
+ role="assistant",
151
+ # content=gr.Image(value=self.display_file_path),
152
+ content={"path": self.display_file_path},
153
+ )
154
+ )
155
+
156
+ yield chat_history, self.display_file_path, ""
157
+
158
+ except Exception as e:
159
+ chat_history.append(
160
+ ChatMessage(
161
+ role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
162
+ )
163
+ )
164
+ yield chat_history, self.display_file_path
165
+
166
+
167
+ def create_demo(agent, tools_dict):
168
+ """
169
+ Create a Gradio demo interface for the medical AI agent.
170
+
171
+ Args:
172
+ agent: The medical AI agent to handle requests
173
+ tools_dict (dict): Dictionary of available tools for image processing
174
+
175
+ Returns:
176
+ gr.Blocks: Gradio Blocks interface
177
+ """
178
+ interface = ChatInterface(agent, tools_dict)
179
+
180
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
181
+ with gr.Column():
182
+ gr.Markdown(
183
+ """
184
+ # 🏥 MedRAX
185
+ Medical Reasoning Agent for Chest X-ray
186
+ """
187
+ )
188
+
189
+ with gr.Row():
190
+ with gr.Column(scale=3):
191
+ chatbot = gr.Chatbot(
192
+ [],
193
+ height=800,
194
+ container=True,
195
+ show_label=True,
196
+ elem_classes="chat-box",
197
+ type="messages",
198
+ label="Agent",
199
+ avatar_images=(
200
+ None,
201
+ "assets/medrax_logo.jpg",
202
+ ),
203
+ )
204
+ with gr.Row():
205
+ with gr.Column(scale=3):
206
+ txt = gr.Textbox(
207
+ show_label=False,
208
+ placeholder="Ask about the X-ray...",
209
+ container=False,
210
+ )
211
+
212
+ with gr.Column(scale=3):
213
+ image_display = gr.Image(
214
+ label="Image", type="filepath", height=700, container=True
215
+ )
216
+ with gr.Row():
217
+ upload_button = gr.UploadButton(
218
+ "📎 Upload X-Ray",
219
+ file_types=["image"],
220
+ )
221
+ dicom_upload = gr.UploadButton(
222
+ "📄 Upload DICOM",
223
+ file_types=["file"],
224
+ )
225
+ with gr.Row():
226
+ clear_btn = gr.Button("Clear Chat")
227
+ new_thread_btn = gr.Button("New Thread")
228
+
229
+ # Event handlers
230
+ def clear_chat():
231
+ interface.original_file_path = None
232
+ interface.display_file_path = None
233
+ return [], None
234
+
235
+ def new_thread():
236
+ interface.current_thread_id = str(time.time())
237
+ return [], interface.display_file_path
238
+
239
+ def handle_file_upload(file):
240
+ return interface.handle_upload(file.name)
241
+
242
+ chat_msg = txt.submit(
243
+ interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
244
+ )
245
+ bot_msg = chat_msg.then(
246
+ interface.process_message,
247
+ inputs=[txt, image_display, chatbot],
248
+ outputs=[chatbot, image_display, txt],
249
+ )
250
+ bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
251
+
252
+ upload_button.upload(handle_file_upload, inputs=upload_button, outputs=image_display)
253
+
254
+ dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
255
+
256
+ clear_btn.click(clear_chat, outputs=[chatbot, image_display])
257
+ new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
258
+
259
+ return demo