Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +22 -0
- .gitignore +174 -0
- .vscode/launch.json +15 -0
- Dockerfile +26 -0
- LICENSE +201 -0
- README.md +251 -7
- assets/demo_fast.gif +3 -0
- assets/demo_fast.mp4 +3 -0
- assets/medrax_logo.jpg +3 -0
- assets/medrax_logo.png +3 -0
- benchmark/__init__.py +0 -0
- benchmark/create_benchmark.py +352 -0
- benchmark/llm.py +42 -0
- benchmark/utils.py +78 -0
- data/eurorad_metadata.json +0 -0
- data/figures.py +74 -0
- data/get_cases.py +51 -0
- data/stats/age_distribution.png +3 -0
- data/stats/area_of_interest_distribution.png +3 -0
- data/stats/gender_distribution.png +3 -0
- demo/chest/LIDC.dcm +3 -0
- demo/chest/Pseudo.dcm +3 -0
- demo/chest/RIDER.dcm +3 -0
- demo/chest/TCGAA.dcm +3 -0
- demo/chest/__init__.py +0 -0
- demo/chest/effusion1.png +3 -0
- demo/chest/normal1.jpg +3 -0
- demo/chest/normal2.jpg +3 -0
- demo/chest/normal3.jpg +3 -0
- demo/chest/normal4.jpg +3 -0
- demo/chest/normal5.jpg +3 -0
- demo/chest/normal6.jpg +3 -0
- demo/chest/pneumonia1.jpg +0 -0
- demo/chest/pneumonia2.jpg +0 -0
- demo/chest/pneumonia3.jpg +0 -0
- demo/chest/pneumonia4.jpg +3 -0
- demo/chest/pneumonia5.jpg +3 -0
- experiments/README.md +63 -0
- experiments/analyze_axes.py +385 -0
- experiments/benchmark_chexagent.py +316 -0
- experiments/benchmark_gpt4o.py +331 -0
- experiments/benchmark_llama.py +443 -0
- experiments/benchmark_llavamed.py +541 -0
- experiments/benchmark_medrax.ipynb +374 -0
- experiments/chexbench_gpt4.py +405 -0
- experiments/compare_runs.py +290 -0
- experiments/inspect_logs.py +210 -0
- experiments/validate_logs.py +162 -0
- handler.py +57 -0
- interface.py +259 -0
.gitattributes
CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demo_fast.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/demo_fast.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/medrax_logo.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/medrax_logo.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
data/stats/age_distribution.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
data/stats/area_of_interest_distribution.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
data/stats/gender_distribution.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
demo/chest/LIDC.dcm filter=lfs diff=lfs merge=lfs -text
|
44 |
+
demo/chest/Pseudo.dcm filter=lfs diff=lfs merge=lfs -text
|
45 |
+
demo/chest/RIDER.dcm filter=lfs diff=lfs merge=lfs -text
|
46 |
+
demo/chest/TCGAA.dcm filter=lfs diff=lfs merge=lfs -text
|
47 |
+
demo/chest/effusion1.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
demo/chest/normal1.jpg filter=lfs diff=lfs merge=lfs -text
|
49 |
+
demo/chest/normal2.jpg filter=lfs diff=lfs merge=lfs -text
|
50 |
+
demo/chest/normal3.jpg filter=lfs diff=lfs merge=lfs -text
|
51 |
+
demo/chest/normal4.jpg filter=lfs diff=lfs merge=lfs -text
|
52 |
+
demo/chest/normal5.jpg filter=lfs diff=lfs merge=lfs -text
|
53 |
+
demo/chest/normal6.jpg filter=lfs diff=lfs merge=lfs -text
|
54 |
+
demo/chest/pneumonia4.jpg filter=lfs diff=lfs merge=lfs -text
|
55 |
+
demo/chest/pneumonia5.jpg filter=lfs diff=lfs merge=lfs -text
|
56 |
+
medrax/llava/serve/examples/bio_patch.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
medrax/llava/serve/examples/med_img_1.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
+
# ruff
|
165 |
+
ruff-cache/
|
166 |
+
.ruff_cache/
|
167 |
+
|
168 |
+
afallah/
|
169 |
+
|
170 |
+
logs/
|
171 |
+
|
172 |
+
temp/
|
173 |
+
|
174 |
+
.gradio/
|
.vscode/launch.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "Python Debugger: main.py",
|
9 |
+
"type": "debugpy",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "main.py",
|
12 |
+
"console": "integratedTerminal"
|
13 |
+
}
|
14 |
+
]
|
15 |
+
}
|
Dockerfile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM runpod/pytorch:2.1.0-py3.10-cuda11.8.0
|
2 |
+
|
3 |
+
WORKDIR /workspace
|
4 |
+
|
5 |
+
# Install system dependencies
|
6 |
+
RUN apt-get update && apt-get install -y \
|
7 |
+
libgl1-mesa-glx \
|
8 |
+
libglib2.0-0 \
|
9 |
+
&& rm -rf /var/lib/apt/lists/*
|
10 |
+
|
11 |
+
# Copy requirements and install Python dependencies
|
12 |
+
COPY requirements.txt .
|
13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
14 |
+
|
15 |
+
# Copy application code
|
16 |
+
COPY . .
|
17 |
+
|
18 |
+
# Create directories for models and temporary files
|
19 |
+
RUN mkdir -p /model-weights /workspace/temp
|
20 |
+
|
21 |
+
# Download models (if needed)
|
22 |
+
# RUN python download_models.py
|
23 |
+
|
24 |
+
EXPOSE 8000
|
25 |
+
|
26 |
+
CMD ["python", "handler.py"]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,256 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.16.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: medrax.org
|
3 |
+
app_file: interface.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.16.0
|
|
|
|
|
6 |
---
|
7 |
+
<h1 align="center">
|
8 |
+
🤖 MedRAX: Medical Reasoning Agent for Chest X-ray
|
9 |
+
</h1>
|
10 |
+
<p align="center"> <a href="https://arxiv.org/abs/2502.02673" target="_blank"><img src="https://img.shields.io/badge/arXiv-Paper-FF6B6B?style=for-the-badge&logo=arxiv&logoColor=white" alt="arXiv"></a> <a href="https://github.com/bowang-lab/MedRAX"><img src="https://img.shields.io/badge/GitHub-Code-4A90E2?style=for-the-badge&logo=github&logoColor=white" alt="GitHub"></a> <a href="https://huggingface.co/datasets/wanglab/chest-agent-bench"><img src="https://img.shields.io/badge/HuggingFace-Dataset-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Dataset"></a> </p>
|
11 |
|
12 |
+

|
13 |
+
|
14 |
+
<br>
|
15 |
+
|
16 |
+
## Abstract
|
17 |
+
Chest X-rays (CXRs) play an integral role in driving critical decisions in disease management and patient care. While recent innovations have led to specialized models for various CXR interpretation tasks, these solutions often operate in isolation, limiting their practical utility in clinical practice. We present MedRAX, the first versatile AI agent that seamlessly integrates state-of-the-art CXR analysis tools and multimodal large language models into a unified framework. MedRAX dynamically leverages these models to address complex medical queries without requiring additional training. To rigorously evaluate its capabilities, we introduce ChestAgentBench, a comprehensive benchmark containing 2,500 complex medical queries across 7 diverse categories. Our experiments demonstrate that MedRAX achieves state-of-the-art performance compared to both open-source and proprietary models, representing a significant step toward the practical deployment of automated CXR interpretation systems.
|
18 |
+
<br><br>
|
19 |
+
|
20 |
+
|
21 |
+
## MedRAX
|
22 |
+
MedRAX is built on a robust technical foundation:
|
23 |
+
- **Core Architecture**: Built on LangChain and LangGraph frameworks
|
24 |
+
- **Language Model**: Uses GPT-4o with vision capabilities as the backbone LLM
|
25 |
+
- **Deployment**: Supports both local and cloud-based deployments
|
26 |
+
- **Interface**: Production-ready interface built with Gradio
|
27 |
+
- **Modular Design**: Tool-agnostic architecture allowing easy integration of new capabilities
|
28 |
+
|
29 |
+
### Integrated Tools
|
30 |
+
- **Visual QA**: Utilizes CheXagent and LLaVA-Med for complex visual understanding and medical reasoning
|
31 |
+
- **Segmentation**: Employs MedSAM and PSPNet model trained on ChestX-Det for precise anatomical structure identification
|
32 |
+
- **Grounding**: Uses Maira-2 for localizing specific findings in medical images
|
33 |
+
- **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
|
34 |
+
- **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
|
35 |
+
- **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
|
36 |
+
- **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
|
37 |
+
|
38 |
+
Note the current version of MedRAX is experimentally released and does not support vision for GPT-4o and MedSAM. We will be integrating these shortly.
|
39 |
+
<br><br>
|
40 |
+
|
41 |
+
|
42 |
+
## ChestAgentBench
|
43 |
+
We introduce ChestAgentBench, a comprehensive evaluation framework with 2,500 complex medical queries across 7 categories, built from 675 expert-curated clinical cases. The benchmark evaluates complex multi-step reasoning in CXR interpretation through:
|
44 |
+
|
45 |
+
- Detection
|
46 |
+
- Classification
|
47 |
+
- Localization
|
48 |
+
- Comparison
|
49 |
+
- Relationship
|
50 |
+
- Diagnosis
|
51 |
+
- Characterization
|
52 |
+
|
53 |
+
Download the benchmark: [ChestAgentBench on Hugging Face](https://huggingface.co/datasets/wanglab/chest-agent-bench)
|
54 |
+
```
|
55 |
+
huggingface-cli download wanglab/chestagentbench --repo-type dataset --local-dir chestagentbench
|
56 |
+
```
|
57 |
+
|
58 |
+
Unzip the Eurorad figures to your local `MedMAX` directory.
|
59 |
+
```
|
60 |
+
unzip chestagentbench/figures.zip
|
61 |
+
```
|
62 |
+
|
63 |
+
To evaluate with GPT-4o, set your OpenAI API key and run the quickstart script.
|
64 |
+
```
|
65 |
+
export OPENAI_API_KEY="<your-openai-api-key>"
|
66 |
+
python quickstart.py \
|
67 |
+
--model chatgpt-4o-latest \
|
68 |
+
--temperature 0.2 \
|
69 |
+
--max-cases 2 \
|
70 |
+
--log-prefix chatgpt-4o-latest \
|
71 |
+
--use-urls
|
72 |
+
```
|
73 |
+
|
74 |
+
|
75 |
+
<br>
|
76 |
+
|
77 |
+
## Installation
|
78 |
+
### Prerequisites
|
79 |
+
- Python 3.8+
|
80 |
+
- CUDA/GPU for best performance
|
81 |
+
|
82 |
+
### Installation Steps
|
83 |
+
```bash
|
84 |
+
# Clone the repository
|
85 |
+
git clone https://github.com/bowang-lab/MedRAX.git
|
86 |
+
cd MedRAX
|
87 |
+
|
88 |
+
# Install package
|
89 |
+
pip install -e .
|
90 |
+
```
|
91 |
+
|
92 |
+
### Getting Started
|
93 |
+
```bash
|
94 |
+
# Start the Gradio interface
|
95 |
+
python main.py
|
96 |
+
```
|
97 |
+
or if you run into permission issues
|
98 |
+
```bash
|
99 |
+
sudo -E env "PATH=$PATH" python main.py
|
100 |
+
```
|
101 |
+
You need to setup the `model_dir` inside `main.py` to the directory where you want to download or already have the weights of above tools from Hugging Face.
|
102 |
+
Comment out the tools that you do not have access to.
|
103 |
+
Make sure to setup your OpenAI API key in `.env` file!
|
104 |
+
<br><br><br>
|
105 |
+
|
106 |
+
|
107 |
+
## Tool Selection and Initialization
|
108 |
+
|
109 |
+
MedRAX supports selective tool initialization, allowing you to use only the tools you need. Tools can be specified when initializing the agent (look at `main.py`):
|
110 |
+
|
111 |
+
```python
|
112 |
+
selected_tools = [
|
113 |
+
"ImageVisualizerTool",
|
114 |
+
"ChestXRayClassifierTool",
|
115 |
+
"ChestXRaySegmentationTool",
|
116 |
+
# Add or remove tools as needed
|
117 |
+
]
|
118 |
+
|
119 |
+
agent, tools_dict = initialize_agent(
|
120 |
+
"medrax/docs/system_prompts.txt",
|
121 |
+
tools_to_use=selected_tools,
|
122 |
+
model_dir="/model-weights"
|
123 |
+
)
|
124 |
+
```
|
125 |
+
|
126 |
+
<br><br>
|
127 |
+
## Automatically Downloaded Models
|
128 |
+
|
129 |
+
The following tools will automatically download their model weights when initialized:
|
130 |
+
|
131 |
+
### Classification Tool
|
132 |
+
```python
|
133 |
+
ChestXRayClassifierTool(device=device)
|
134 |
+
```
|
135 |
+
|
136 |
+
### Segmentation Tool
|
137 |
+
```python
|
138 |
+
ChestXRaySegmentationTool(device=device)
|
139 |
+
```
|
140 |
+
|
141 |
+
### Grounding Tool
|
142 |
+
```python
|
143 |
+
XRayPhraseGroundingTool(
|
144 |
+
cache_dir=model_dir,
|
145 |
+
temp_dir=temp_dir,
|
146 |
+
load_in_8bit=True,
|
147 |
+
device=device
|
148 |
+
)
|
149 |
+
```
|
150 |
+
- Maira-2 weights download to specified `cache_dir`
|
151 |
+
- 8-bit and 4-bit quantization available for reduced memory usage
|
152 |
+
|
153 |
+
### LLaVA-Med Tool
|
154 |
+
```python
|
155 |
+
LlavaMedTool(
|
156 |
+
cache_dir=model_dir,
|
157 |
+
device=device,
|
158 |
+
load_in_8bit=True
|
159 |
+
)
|
160 |
+
```
|
161 |
+
- Automatic weight download to `cache_dir`
|
162 |
+
- 8-bit and 4-bit quantization available for reduced memory usage
|
163 |
+
|
164 |
+
### Report Generation Tool
|
165 |
+
```python
|
166 |
+
ChestXRayReportGeneratorTool(
|
167 |
+
cache_dir=model_dir,
|
168 |
+
device=device
|
169 |
+
)
|
170 |
+
```
|
171 |
+
|
172 |
+
### Visual QA Tool
|
173 |
+
```python
|
174 |
+
XRayVQATool(
|
175 |
+
cache_dir=model_dir,
|
176 |
+
device=device
|
177 |
+
)
|
178 |
+
```
|
179 |
+
- CheXagent weights download automatically
|
180 |
+
|
181 |
+
### MedSAM Tool
|
182 |
+
```
|
183 |
+
Support for MedSAM segmentation will be added in a future update.
|
184 |
+
```
|
185 |
+
|
186 |
+
### Utility Tools
|
187 |
+
No additional model weights required:
|
188 |
+
```python
|
189 |
+
ImageVisualizerTool()
|
190 |
+
DicomProcessorTool(temp_dir=temp_dir)
|
191 |
+
```
|
192 |
+
<br>
|
193 |
+
|
194 |
+
## Manual Setup Required
|
195 |
+
|
196 |
+
### Image Generation Tool
|
197 |
+
```python
|
198 |
+
ChestXRayGeneratorTool(
|
199 |
+
model_path=f"{model_dir}/roentgen",
|
200 |
+
temp_dir=temp_dir,
|
201 |
+
device=device
|
202 |
+
)
|
203 |
+
```
|
204 |
+
- RoentGen weights require manual setup:
|
205 |
+
1. Contact authors: https://github.com/StanfordMIMI/RoentGen
|
206 |
+
2. Place weights in `{model_dir}/roentgen`
|
207 |
+
3. Optional tool, can be excluded if not needed
|
208 |
+
<br>
|
209 |
+
|
210 |
+
## Configuration Notes
|
211 |
+
|
212 |
+
### Required Parameters
|
213 |
+
- `model_dir` or `cache_dir`: Base directory for model weights that Hugging Face uses
|
214 |
+
- `temp_dir`: Directory for temporary files
|
215 |
+
- `device`: "cuda" for GPU, "cpu" for CPU-only
|
216 |
+
|
217 |
+
### Memory Management
|
218 |
+
- Consider selective tool initialization for resource constraints
|
219 |
+
- Use 8-bit quantization where available
|
220 |
+
- Some tools (LLaVA-Med, Grounding) are more resource-intensive
|
221 |
+
<br><br>
|
222 |
+
|
223 |
+
## Authors
|
224 |
+
- **Adibvafa Fallahpour**¹²³ * ([email protected])
|
225 |
+
- **Jun Ma**²³ *
|
226 |
+
- **Alif Munim**³⁴ *
|
227 |
+
- **Hongwei Lyu**³
|
228 |
+
- **Bo Wang**¹²³⁵
|
229 |
+
|
230 |
+
¹ Department of Computer Science, University of Toronto, Toronto, Canada
|
231 |
+
² Vector Institute, Toronto, Canada
|
232 |
+
³ University Health Network, Toronto, Canada
|
233 |
+
⁴ Cohere For AI, Toronto, Canada
|
234 |
+
⁵ Department of Laboratory Medicine and Pathobiology, University of Toronto, Toronto, Canada <br>
|
235 |
+
\* Equal contribution
|
236 |
+
<br><br>
|
237 |
+
|
238 |
+
|
239 |
+
## Citation
|
240 |
+
If you find this work useful, please cite our paper:
|
241 |
+
```bibtex
|
242 |
+
@misc{fallahpour2025medraxmedicalreasoningagent,
|
243 |
+
title={MedRAX: Medical Reasoning Agent for Chest X-ray},
|
244 |
+
author={Adibvafa Fallahpour and Jun Ma and Alif Munim and Hongwei Lyu and Bo Wang},
|
245 |
+
year={2025},
|
246 |
+
eprint={2502.02673},
|
247 |
+
archivePrefix={arXiv},
|
248 |
+
primaryClass={cs.LG},
|
249 |
+
url={https://arxiv.org/abs/2502.02673},
|
250 |
+
}
|
251 |
+
```
|
252 |
+
|
253 |
+
---
|
254 |
+
<p align="center">
|
255 |
+
Made with ❤️ at University of Toronto, Vector Institute, and University Health Network
|
256 |
+
</p>
|
assets/demo_fast.gif
ADDED
![]() |
Git LFS Details
|
assets/demo_fast.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f3bab0a7f21187f66bf9e7f5809489b6348d8b140b7e90d4fdf750ec989f92f
|
3 |
+
size 4109270
|
assets/medrax_logo.jpg
ADDED
![]() |
Git LFS Details
|
assets/medrax_logo.png
ADDED
![]() |
Git LFS Details
|
benchmark/__init__.py
ADDED
File without changes
|
benchmark/create_benchmark.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Medical X-ray Question Generation Benchmark aka ChestAgentBench
|
4 |
+
|
5 |
+
This script generates clinical questions from X-ray case data of Eurorad dataset using GPT-4o.
|
6 |
+
It structures questions across different analytical categories and saves them as JSON.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
import json
|
12 |
+
from typing import *
|
13 |
+
from pprint import pprint
|
14 |
+
|
15 |
+
import openai
|
16 |
+
import numpy as np
|
17 |
+
from scipy import stats
|
18 |
+
import plotly.graph_objects as go
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
from benchmark.utils import load_eurorad_dataset
|
22 |
+
from benchmark.llm import get_llm_response
|
23 |
+
|
24 |
+
# Constants
|
25 |
+
DATA_DIR = "set your data directory here, e.g. /home/MedRAX/data"
|
26 |
+
DATASET_PATH = os.path.join(DATA_DIR, "eurorad_metadata.json")
|
27 |
+
|
28 |
+
SYSTEM_PROMPT = """
|
29 |
+
You are an expert medical benchmark creation assistant.
|
30 |
+
Your goal is to generate questions that evaluate a multimodal medical AI agent's ability to interpret and reason about chest X-rays.
|
31 |
+
""".strip()
|
32 |
+
|
33 |
+
CATEGORIES_META = {
|
34 |
+
"detection": "Identify and locate specific findings in the chest X-ray.",
|
35 |
+
"classification": "Determine whether specific findings are present or absent in the chest X-ray.",
|
36 |
+
"enumeration": "Count the number of target findings in the chest X-ray.",
|
37 |
+
"localization": "Locate a given finding in the chest X-ray.",
|
38 |
+
"comparison": "Compare the size or position of a specific finding in the chest X-ray.",
|
39 |
+
"relationship": "Determine the relationship between two or more findings in the chest X-ray.",
|
40 |
+
"diagnosis": "Make a diagnosis or determine a treatment plan by interpreting the chest X-ray.",
|
41 |
+
"characterization": "Describe specific attributes (shape, density, margins, etc.) of findings.",
|
42 |
+
"reasoning": "Explain the medical rationale and thought process behind findings and conclusions.",
|
43 |
+
}
|
44 |
+
CATEGORIES = list(CATEGORIES_META.keys())
|
45 |
+
|
46 |
+
CATEGORY_COMBINATIONS = [
|
47 |
+
["detection", "localization", "characterization", "reasoning"], # Detailed Finding Analysis
|
48 |
+
["detection", "classification", "relationship", "reasoning"], # Pattern Recognition & Relations
|
49 |
+
["localization", "comparison", "relationship", "reasoning"], # Spatial Understanding
|
50 |
+
["classification", "comparison", "diagnosis", "reasoning"], # Clinical Decision Making
|
51 |
+
["classification", "characterization", "diagnosis", "reasoning"], # Diagnostic Characterization
|
52 |
+
]
|
53 |
+
|
54 |
+
DEFAULT_SECTIONS = [
|
55 |
+
"history",
|
56 |
+
"image_finding",
|
57 |
+
"discussion",
|
58 |
+
"differential_diagnosis",
|
59 |
+
"diagnosis",
|
60 |
+
"figures",
|
61 |
+
]
|
62 |
+
|
63 |
+
|
64 |
+
class Question:
|
65 |
+
"""A class to generate clinical questions from case data.
|
66 |
+
|
67 |
+
This class handles creating structured clinical questions by combining case data with
|
68 |
+
specified categories and difficulty levels.
|
69 |
+
|
70 |
+
Attributes:
|
71 |
+
type (str): The type of question (e.g. multiple choice)
|
72 |
+
difficulty (str): Difficulty level of the question
|
73 |
+
case_data (Dict[str, Any]): Dictionary containing the clinical case data
|
74 |
+
case_content (str): Formatted case data from selected sections
|
75 |
+
case_id (str): Unique identifier for the case
|
76 |
+
categories (List[str]): List of analytical categories this question tests
|
77 |
+
sections (List[str]): Case sections to include in question
|
78 |
+
raw_content (Optional[str]): Raw LLM response to the question prompt
|
79 |
+
content (Optional[Dict[str, str]]): Extracted content from the raw LLM response
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
type: str,
|
85 |
+
difficulty: str,
|
86 |
+
case_data: Dict[str, Any],
|
87 |
+
categories: List[str],
|
88 |
+
sections: List[str] = [
|
89 |
+
"history",
|
90 |
+
"image_finding",
|
91 |
+
"discussion",
|
92 |
+
"differential_diagnosis",
|
93 |
+
"diagnosis",
|
94 |
+
"figures",
|
95 |
+
],
|
96 |
+
system_prompt: str = "You are an expert medical benchmark creation assistant.",
|
97 |
+
) -> None:
|
98 |
+
self.type = type
|
99 |
+
self.difficulty = difficulty
|
100 |
+
self.case_data = case_data
|
101 |
+
self.case_id = case_data["case_id"]
|
102 |
+
self.categories = categories
|
103 |
+
self.sections = sections
|
104 |
+
self.system_prompt = system_prompt
|
105 |
+
self.case_content = self.select_case_sections()
|
106 |
+
self.raw_content: Optional[str] = None
|
107 |
+
self.content: Optional[Dict[str, str]] = None
|
108 |
+
|
109 |
+
def create_question_prompt(self) -> str:
|
110 |
+
"""Creates a formatted prompt for generating a clinical question.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
str: A structured prompt containing the question parameters and clinical data
|
114 |
+
"""
|
115 |
+
category_descriptions = "\n".join(
|
116 |
+
f"{category}: {desc}"
|
117 |
+
for category, desc in CATEGORIES_META.items()
|
118 |
+
if category in self.categories
|
119 |
+
)
|
120 |
+
|
121 |
+
return f"""
|
122 |
+
You must follow these guidelines:
|
123 |
+
1. Questions must be answerable using only context and chest X-rays.
|
124 |
+
- Questions must explicitly mention the referenced figures
|
125 |
+
- Questions can only reference the chest X-ray figures
|
126 |
+
|
127 |
+
2. Questions must have unambiguous, verifiable answers, and should:
|
128 |
+
- Challenge the agent's analytical capabilities
|
129 |
+
- Require multi-step reasoning
|
130 |
+
- Test ability to make precise observations
|
131 |
+
- Evaluate capability to derive insights and findings from the chest X-ray
|
132 |
+
|
133 |
+
3. The agent has access to tools like classification, report generation, segmentation, grounding, visual question answering, etc. Your question should be complex to require the use of such tools.
|
134 |
+
|
135 |
+
|
136 |
+
Create a {self.difficulty} {self.type} clinical question that integrates the following:
|
137 |
+
|
138 |
+
{category_descriptions}
|
139 |
+
|
140 |
+
based on the following clinical case:
|
141 |
+
|
142 |
+
{self.case_content}
|
143 |
+
|
144 |
+
Do not use any infomration derived from the CT and MRI images. Do not provide any information and findings about the chest X-rays.
|
145 |
+
Your question should require the agent to derive insights and findings from the chest X-ray by itself.
|
146 |
+
Your answer should be verifiable directly in the context of the case.
|
147 |
+
You can only use the image findings that come from the chest X-ray figures.
|
148 |
+
|
149 |
+
Your response must follow this exact format:
|
150 |
+
THOUGHTS: [Think about different reasoning steps and tools the agent should use to answer the question]
|
151 |
+
QUESTION: [complete question with relevant context. Incorrect choices should be very close to the correct answer.]
|
152 |
+
FIGURES: [list of required figures, e.g. ["Figure 1", "Figure 2a"]]
|
153 |
+
EXPLANATION: [short explanation of why your answer is verifiable in the case]
|
154 |
+
ANSWER: [correct answer e.g. "A"]
|
155 |
+
""".strip().replace(
|
156 |
+
" ", ""
|
157 |
+
) # remove tabs
|
158 |
+
|
159 |
+
def select_case_sections(self) -> str:
|
160 |
+
"""Extract and format selected sections from case data into paragraphs.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
str: Formatted string with case sections and content
|
164 |
+
"""
|
165 |
+
section_mapping = {
|
166 |
+
"history": ("history", "No history provided."),
|
167 |
+
"image_finding": ("image_finding", "No findings provided."),
|
168 |
+
"discussion": ("discussion", "No discussion provided."),
|
169 |
+
"differential_diagnosis": (
|
170 |
+
"differential_diagnosis",
|
171 |
+
"No differential diagnosis provided.",
|
172 |
+
),
|
173 |
+
"diagnosis": ("diagnosis", "No diagnosis provided."),
|
174 |
+
"figures": ("figures", "No figures provided."),
|
175 |
+
}
|
176 |
+
|
177 |
+
formatted = []
|
178 |
+
for section in self.sections:
|
179 |
+
if section in section_mapping:
|
180 |
+
key, default = section_mapping[section]
|
181 |
+
content = self.case_data.get(key, default)
|
182 |
+
|
183 |
+
if key == "figures":
|
184 |
+
figures_text = []
|
185 |
+
for figure in content:
|
186 |
+
for subfig in figure["subfigures"]:
|
187 |
+
figures_text.append(f"{subfig['number']}: {subfig['caption']}")
|
188 |
+
content = "\n".join(figures_text)
|
189 |
+
|
190 |
+
formatted.append(f"{section}:\n{content}")
|
191 |
+
|
192 |
+
return "\n\n".join(formatted)
|
193 |
+
|
194 |
+
def create_question(
|
195 |
+
self,
|
196 |
+
client: openai.OpenAI,
|
197 |
+
temperature: float = 0.7,
|
198 |
+
top_p: float = 0.95,
|
199 |
+
max_tokens: int = 500,
|
200 |
+
model: str = "gpt-4o",
|
201 |
+
) -> str:
|
202 |
+
"""Create a clinical question using LLM.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
client (openai.OpenAI): OpenAI client instance
|
206 |
+
temperature (float): Controls randomness in responses. Defaults to 0.7.
|
207 |
+
top_p (float): Controls diversity via nucleus sampling. Defaults to 0.95.
|
208 |
+
max_tokens (int): Max tokens in model response. Defaults to 500.
|
209 |
+
model (str): OpenAI model to use. Defaults to "gpt-4o".
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
str: LLM response containing formatted question components
|
213 |
+
"""
|
214 |
+
self.raw_content = get_llm_response(
|
215 |
+
client=client,
|
216 |
+
prompt=self.create_question_prompt(),
|
217 |
+
system_prompt=self.system_prompt,
|
218 |
+
temperature=temperature,
|
219 |
+
top_p=top_p,
|
220 |
+
max_tokens=max_tokens,
|
221 |
+
model=model,
|
222 |
+
)
|
223 |
+
self.content = self.extract_content()
|
224 |
+
|
225 |
+
return self.raw_content
|
226 |
+
|
227 |
+
def extract_content(self) -> Dict[str, str]:
|
228 |
+
"""Extract sections from raw LLM response using regex patterns.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Dict[str, str]: Extracted sections including thoughts, question, figures, explanation, and answer
|
232 |
+
"""
|
233 |
+
keywords = ["THOUGHTS", "QUESTION", "FIGURES", "EXPLANATION", "ANSWER"]
|
234 |
+
|
235 |
+
content = {}
|
236 |
+
for kw in keywords:
|
237 |
+
pattern = rf"{kw}:\s*(.*?)(?=\n[A-Z]+:|$)"
|
238 |
+
match = re.search(pattern, self.raw_content, re.DOTALL)
|
239 |
+
content[kw.lower()] = match.group(1).strip() if match else None
|
240 |
+
|
241 |
+
return content
|
242 |
+
|
243 |
+
def save(self, output_path: str) -> Dict[str, Any]:
|
244 |
+
"""Save question content and metadata as a JSON file.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
output_path (str): Directory path where the JSON file will be saved
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
Dict[str, Any]: Question data including content (thoughts, question, figures, options,
|
251 |
+
explanation, answer) and metadata (type, difficulty, categories, etc.)
|
252 |
+
"""
|
253 |
+
question_metadata = self.content.copy()
|
254 |
+
|
255 |
+
# Add metadata
|
256 |
+
question_metadata["metadata"] = {
|
257 |
+
"case_id": self.case_id,
|
258 |
+
"type": self.type,
|
259 |
+
"difficulty": self.difficulty,
|
260 |
+
"categories": self.categories,
|
261 |
+
"sections": self.sections,
|
262 |
+
}
|
263 |
+
|
264 |
+
# Create a directory for the case
|
265 |
+
case_dir = os.path.join(output_path, str(self.case_id))
|
266 |
+
os.makedirs(case_dir, exist_ok=True)
|
267 |
+
|
268 |
+
# Save the question metadata to a JSON file
|
269 |
+
output_file = os.path.join(case_dir, f"{self.case_id}_{self.__hash__()}.json")
|
270 |
+
with open(output_file, "w") as f:
|
271 |
+
json.dump(question_metadata, f, indent=2)
|
272 |
+
|
273 |
+
return question_metadata
|
274 |
+
|
275 |
+
|
276 |
+
def generate_questions(
|
277 |
+
dataset: Dict[str, Any],
|
278 |
+
client: openai.OpenAI,
|
279 |
+
output_dir: str,
|
280 |
+
skip_first: int = 100,
|
281 |
+
temperature: float = 0.7,
|
282 |
+
top_p: float = 0.95,
|
283 |
+
max_tokens: int = 1200,
|
284 |
+
model: str = "gpt-4o",
|
285 |
+
) -> None:
|
286 |
+
"""Generate questions for each case and category combination.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
dataset: Dictionary of case data
|
290 |
+
client: OpenAI client instance
|
291 |
+
output_dir: Directory to save generated questions
|
292 |
+
skip_first: Number of initial cases to skip
|
293 |
+
temperature: LLM temperature parameter
|
294 |
+
top_p: LLM top_p parameter
|
295 |
+
max_tokens: Maximum tokens for LLM response
|
296 |
+
model: LLM model name
|
297 |
+
"""
|
298 |
+
target_cases = sorted(list(dataset.keys()), key=int)[-len(dataset) : -skip_first]
|
299 |
+
|
300 |
+
for case_id in tqdm(target_cases, desc="Processing cases"):
|
301 |
+
case_data = dataset[case_id]
|
302 |
+
|
303 |
+
for category in tqdm(CATEGORY_COMBINATIONS, desc=f"Categories for case {case_id}"):
|
304 |
+
question = Question(
|
305 |
+
type="multiple choice (A/B/C/D/E/F)",
|
306 |
+
difficulty="complex",
|
307 |
+
case_data=case_data,
|
308 |
+
categories=category,
|
309 |
+
sections=DEFAULT_SECTIONS,
|
310 |
+
system_prompt=SYSTEM_PROMPT,
|
311 |
+
)
|
312 |
+
|
313 |
+
response = question.create_question(
|
314 |
+
client=client,
|
315 |
+
temperature=temperature,
|
316 |
+
top_p=top_p,
|
317 |
+
max_tokens=max_tokens,
|
318 |
+
model=model,
|
319 |
+
)
|
320 |
+
question.save(output_dir)
|
321 |
+
|
322 |
+
|
323 |
+
def main():
|
324 |
+
"""Main execution function."""
|
325 |
+
client = openai.OpenAI()
|
326 |
+
|
327 |
+
# Load and verify dataset
|
328 |
+
dataset = load_eurorad_dataset(
|
329 |
+
DATASET_PATH,
|
330 |
+
section="Chest Imaging",
|
331 |
+
as_dict=True,
|
332 |
+
filter_by_caption=[
|
333 |
+
"xray",
|
334 |
+
"x-ray",
|
335 |
+
"x ray",
|
336 |
+
"ray",
|
337 |
+
"xr",
|
338 |
+
"radiograph",
|
339 |
+
],
|
340 |
+
)
|
341 |
+
print(f"\n---\nFound {len(dataset)} cases with X-ray mentions\n---\n")
|
342 |
+
|
343 |
+
# Optional: Print sample case for verification
|
344 |
+
case_data = dataset["16798"]
|
345 |
+
pprint(case_data, sort_dicts=False)
|
346 |
+
|
347 |
+
# Generate questions
|
348 |
+
generate_questions(dataset=dataset, client=client, output_dir="benchmark/questions")
|
349 |
+
|
350 |
+
|
351 |
+
if __name__ == "__main__":
|
352 |
+
main()
|
benchmark/llm.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
def get_llm_response(
|
6 |
+
client: openai.OpenAI,
|
7 |
+
prompt: str,
|
8 |
+
system_prompt: str = "You are a helpful assistant.",
|
9 |
+
model: str = "gpt-4o-mini",
|
10 |
+
temperature: float = 0.7,
|
11 |
+
top_p: float = 0.95,
|
12 |
+
max_tokens: int = 500,
|
13 |
+
) -> str:
|
14 |
+
"""
|
15 |
+
Get response from OpenAI language model.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
client (openai.OpenAI): OpenAI client
|
19 |
+
prompt (str): The user prompt/question to send to the model
|
20 |
+
system_prompt (str, optional): System prompt to set model behavior.
|
21 |
+
model (str, optional): OpenAI model to use. Defaults to "gpt-4o-mini".
|
22 |
+
temperature (float, optional): Controls randomness in responses. Defaults to 0.7.
|
23 |
+
top_p (float, optional): Controls diversity via nucleus sampling. Defaults to 0.95.
|
24 |
+
max_tokens (int, optional): Max tokens in model response. Defaults to 200.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
str: The model's response text
|
28 |
+
"""
|
29 |
+
messages = [
|
30 |
+
{"role": "system", "content": system_prompt},
|
31 |
+
{"role": "user", "content": prompt},
|
32 |
+
]
|
33 |
+
|
34 |
+
response = client.chat.completions.create(
|
35 |
+
model=model,
|
36 |
+
messages=messages,
|
37 |
+
temperature=temperature,
|
38 |
+
top_p=top_p,
|
39 |
+
max_tokens=max_tokens,
|
40 |
+
)
|
41 |
+
|
42 |
+
return response.choices[0].message.content
|
benchmark/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from typing import Dict, List
|
4 |
+
|
5 |
+
|
6 |
+
def load_eurorad_dataset(
|
7 |
+
dataset_path: str,
|
8 |
+
section: str = "any",
|
9 |
+
as_dict: bool = False,
|
10 |
+
filter_by_caption: List[str] = [
|
11 |
+
"xray",
|
12 |
+
"x-ray",
|
13 |
+
"x ray",
|
14 |
+
"ray",
|
15 |
+
"xr",
|
16 |
+
"radiograph",
|
17 |
+
"radiogram",
|
18 |
+
"plain film",
|
19 |
+
],
|
20 |
+
) -> List[Dict] | Dict[str, Dict]:
|
21 |
+
"""
|
22 |
+
Load a dataset from a JSON file.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
dataset_path (str): Path to the JSON dataset file.
|
26 |
+
section (str, optional): Section of the dataset to load. Defaults to "any".
|
27 |
+
as_dict (bool, optional): Whether to return data as dict. Defaults to False.
|
28 |
+
filter_by_caption (List[str], optional): List of strings to filter cases by caption content. Defaults to [].
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
List[Dict] | Dict[str, Dict]: The loaded dataset as a list of dictionaries or dict if as_dict=True.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
FileNotFoundError: If dataset_path does not exist
|
35 |
+
json.JSONDecodeError: If file is not valid JSON
|
36 |
+
"""
|
37 |
+
|
38 |
+
with open(dataset_path, "r", encoding="utf-8") as file:
|
39 |
+
data = json.load(file)
|
40 |
+
|
41 |
+
if filter_by_caption:
|
42 |
+
filtered_data = {}
|
43 |
+
for case_id, case in data.items():
|
44 |
+
if any(
|
45 |
+
any(x in subfig["caption"].lower() for x in filter_by_caption)
|
46 |
+
for figure in case["figures"]
|
47 |
+
for subfig in figure["subfigures"]
|
48 |
+
) or any(x in case["image_finding"].lower() for x in filter_by_caption):
|
49 |
+
filtered_data[case_id] = case
|
50 |
+
data = filtered_data
|
51 |
+
|
52 |
+
if section != "any":
|
53 |
+
section = section.strip().lower()
|
54 |
+
if not as_dict:
|
55 |
+
data = [
|
56 |
+
item for item in data.values() if item.get("section", "").strip().lower() == section
|
57 |
+
]
|
58 |
+
else:
|
59 |
+
data = {
|
60 |
+
k: v for k, v in data.items() if v.get("section", "").strip().lower() == section
|
61 |
+
}
|
62 |
+
|
63 |
+
elif not as_dict:
|
64 |
+
data = list(data.values())
|
65 |
+
|
66 |
+
return data
|
67 |
+
|
68 |
+
|
69 |
+
def save_dataset(dataset: Dict | List[Dict], dataset_path: str):
|
70 |
+
"""
|
71 |
+
Save a dataset to a JSON file.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
dataset (Dict | List[Dict]): The dataset to save as a dictionary or list of dictionaries.
|
75 |
+
dataset_path (str): Path where the JSON dataset file will be saved.
|
76 |
+
"""
|
77 |
+
with open(dataset_path, "w", encoding="utf-8") as file:
|
78 |
+
json.dump(dataset, file)
|
data/eurorad_metadata.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/figures.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
import requests
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def download_eurorad_figures(metadata_path: str, output_dir: str) -> None:
|
9 |
+
"""
|
10 |
+
Download figures from Eurorad dataset and save them organized by case_id.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
metadata_path: Path to the eurorad_metadata.json file
|
14 |
+
output_dir: Base directory where figures will be saved
|
15 |
+
|
16 |
+
The figures will be saved as:
|
17 |
+
{output_dir}/{case_id}/{figure_number}.jpg
|
18 |
+
Example:
|
19 |
+
figures/189/Figure_1a.jpg
|
20 |
+
"""
|
21 |
+
# Create output directory if it doesn't exist
|
22 |
+
output_path = Path(output_dir)
|
23 |
+
output_path.mkdir(exist_ok=True)
|
24 |
+
|
25 |
+
# Load metadata
|
26 |
+
with open(metadata_path) as f:
|
27 |
+
metadata = json.load(f)
|
28 |
+
|
29 |
+
# Iterate through all cases with progress bar
|
30 |
+
for case_id in tqdm(metadata, desc="Downloading cases", unit="case"):
|
31 |
+
case = metadata[case_id]
|
32 |
+
case_dir = output_path / str(case["case_id"])
|
33 |
+
case_dir.mkdir(exist_ok=True)
|
34 |
+
|
35 |
+
# Process all figures and their subfigures
|
36 |
+
for figure in case["figures"]:
|
37 |
+
for subfig in figure["subfigures"]:
|
38 |
+
|
39 |
+
# Remove leading and trailing whitespace and convert to lowercase
|
40 |
+
subfig_name = f"{subfig['number'].strip().replace(' ', '_').lower()}.jpg"
|
41 |
+
subfig_path = Path(case_dir) / subfig_name
|
42 |
+
|
43 |
+
save_figure(
|
44 |
+
url=subfig["url"],
|
45 |
+
output_path=subfig_path,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def save_figure(url: str, output_path: Path) -> None:
|
50 |
+
"""
|
51 |
+
Download and save a single figure.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
url: URL of the figure to download
|
55 |
+
output_path: Path where the figure should be saved
|
56 |
+
"""
|
57 |
+
if output_path.exists():
|
58 |
+
return
|
59 |
+
|
60 |
+
try:
|
61 |
+
response = requests.get(url, timeout=10)
|
62 |
+
response.raise_for_status()
|
63 |
+
with open(output_path, "wb") as f:
|
64 |
+
f.write(response.content)
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Error downloading {url}: {e}")
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
root = os.path.dirname(os.path.abspath(__file__))
|
71 |
+
download_eurorad_figures(
|
72 |
+
metadata_path=os.path.join(root, "eurorad_metadata.json"),
|
73 |
+
output_dir=os.path.join(root, "figures"),
|
74 |
+
)
|
data/get_cases.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def get_response(url):
|
9 |
+
headers = {
|
10 |
+
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36 Edg/108.0.1462.54"
|
11 |
+
}
|
12 |
+
return requests.get(url, headers=headers)
|
13 |
+
|
14 |
+
def get_case_numbers_from_page(page):
|
15 |
+
url = f"https://www.eurorad.org/advanced-search?sort_by=published_at&sort_order=ASC&page={page}&filter%5B0%5D=section%3A40"
|
16 |
+
|
17 |
+
# Remove proxy usage since it's likely triggering the protection
|
18 |
+
response = get_response(url)
|
19 |
+
print(response.text)
|
20 |
+
|
21 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
22 |
+
spans = soup.find_all("span", class_="case__number small")
|
23 |
+
|
24 |
+
# Remove '#' from the span text and strip extra whitespace
|
25 |
+
numbers = [span.text.strip().replace("#", "").strip() for span in spans]
|
26 |
+
return numbers
|
27 |
+
|
28 |
+
|
29 |
+
def main():
|
30 |
+
total_pages = 107 # Pages 0 through 106
|
31 |
+
all_numbers = []
|
32 |
+
|
33 |
+
for page in tqdm(range(total_pages)):
|
34 |
+
numbers = get_case_numbers_from_page(page)
|
35 |
+
all_numbers.extend(numbers)
|
36 |
+
|
37 |
+
if page != total_pages - 1 and len(numbers) != 9:
|
38 |
+
print(f"Warning: Page {page} returned {len(numbers)} cases instead of 9")
|
39 |
+
|
40 |
+
# Be kind to the server – avoid hitting it too fast
|
41 |
+
time.sleep(1)
|
42 |
+
break
|
43 |
+
|
44 |
+
with open('case_numbers.json', 'w') as f:
|
45 |
+
json.dump(all_numbers, f)
|
46 |
+
|
47 |
+
print(f"Saved {len(all_numbers)} case numbers to case_numbers.json")
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
main()
|
data/stats/age_distribution.png
ADDED
![]() |
Git LFS Details
|
data/stats/area_of_interest_distribution.png
ADDED
![]() |
Git LFS Details
|
data/stats/gender_distribution.png
ADDED
![]() |
Git LFS Details
|
demo/chest/LIDC.dcm
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11d25b1d34dff083057de994fef7da3dcef75bd7b334823ec6cb9c16b3ba0338
|
3 |
+
size 17071804
|
demo/chest/Pseudo.dcm
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b35ae460fb5f62eb6d6c4c5117f6683100ad92c5fb6ba1a3c36da39703c4652
|
3 |
+
size 7535280
|
demo/chest/RIDER.dcm
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc15f7afa5434991e1359f596433870ad611b42227db87d484d31976545de7fd
|
3 |
+
size 7534066
|
demo/chest/TCGAA.dcm
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e8137290ac823d3da3c00ce3e18120123eaa62a786934c7afc52a989b0b64cf
|
3 |
+
size 7535274
|
demo/chest/__init__.py
ADDED
File without changes
|
demo/chest/effusion1.png
ADDED
![]() |
Git LFS Details
|
demo/chest/normal1.jpg
ADDED
![]() |
Git LFS Details
|
demo/chest/normal2.jpg
ADDED
![]() |
Git LFS Details
|
demo/chest/normal3.jpg
ADDED
![]() |
Git LFS Details
|
demo/chest/normal4.jpg
ADDED
![]() |
Git LFS Details
|
demo/chest/normal5.jpg
ADDED
![]() |
Git LFS Details
|
demo/chest/normal6.jpg
ADDED
![]() |
Git LFS Details
|
demo/chest/pneumonia1.jpg
ADDED
![]() |
demo/chest/pneumonia2.jpg
ADDED
![]() |
demo/chest/pneumonia3.jpg
ADDED
![]() |
demo/chest/pneumonia4.jpg
ADDED
![]() |
Git LFS Details
|
demo/chest/pneumonia5.jpg
ADDED
![]() |
Git LFS Details
|
experiments/README.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiments
|
2 |
+
Below are the instructions for running experiments using our novel ChestAgentBench and the previous SoTA CheXbench. ChestAgentBench is a comprehensive benchmark containing over 2,500 complex medical queries across 8 diverse categories.
|
3 |
+
|
4 |
+
### ChestAgentBench
|
5 |
+
|
6 |
+
To run gpt-4o on ChestAgentBench, enter the `experiments` directory and run the following script:
|
7 |
+
```bash
|
8 |
+
python benchmark_gpt4o.py
|
9 |
+
```
|
10 |
+
|
11 |
+
To run llama 3.2 vision 90B on ChestAgentBench, run the following:
|
12 |
+
```bash
|
13 |
+
python benchmark_llama.py
|
14 |
+
```
|
15 |
+
|
16 |
+
To run chexagent on ChestAgentBench, run the following:
|
17 |
+
```bash
|
18 |
+
python benchmark_chexagent.py
|
19 |
+
```
|
20 |
+
|
21 |
+
To run llava-med on ChestAgentBench, you'll need to clone their repo and copy the following script into it, after you follow their setup instructions.
|
22 |
+
```bash
|
23 |
+
mv benchmark_llavamed.py ~/LLaVA-Med/llava/serve
|
24 |
+
python -m llava.serve.benchmark_llavamed --model-name llava-med-v1.5-mistral-7b --controller http://localhost:10000
|
25 |
+
```
|
26 |
+
|
27 |
+
If you want to inspect the logs, you can run the following. It will select the most recent log file by default.
|
28 |
+
```bash
|
29 |
+
python inspect_logs.py [optional: log-file] -n [num-logs]
|
30 |
+
```
|
31 |
+
|
32 |
+
Finally, to analyze results, run:
|
33 |
+
```bash
|
34 |
+
python analyze_axes.py results/[logfile].json ../benchmark/questions/ --model [gpt4|llama|chexagent|llava-med] --max-questions [optional:int]
|
35 |
+
```
|
36 |
+
|
37 |
+
### CheXbench
|
38 |
+
|
39 |
+
To run the models on chexbench, you can use `chexbench_gpt4.py` as a reference. You'll need to download the dataset files locally, and upload them for each request. Rad-ReStruct and Open-I use the same set of images, so you can download the `NLMCXR.zip` file just once and copy the images to both directories.
|
40 |
+
|
41 |
+
You can find the datasets here:
|
42 |
+
1. [SLAKE: A Semantically-Labeled Knowledge-Enhanced Dataset for Medical Visual Question Answering](https://www.med-vqa.com/slake/). Save this to `MedMAX/data/slake`.
|
43 |
+
2. [Rad-ReStruct: A Novel VQA Benchmark and Method for Structured Radiology Reporting](https://github.com/ChantalMP/Rad-ReStruct). Save the images to `MedMAX/data/rad-restruct/images`.
|
44 |
+
3. [Open-I Service of the National Library of Medicine](https://openi.nlm.nih.gov/faq). Save the images to `MedMAX/data/openi/images`.
|
45 |
+
|
46 |
+
Once you're finished, you'll want to fix the paths in the `chexbench.json` file to your local paths using the `MedMax/data/fix_chexbench.py` script.
|
47 |
+
|
48 |
+
|
49 |
+
### Compare Runs
|
50 |
+
Analyze a single file based on overall accuracy and along different axes
|
51 |
+
```
|
52 |
+
python compare_runs.py results/medmax.json
|
53 |
+
```
|
54 |
+
|
55 |
+
For a direct evaluation comparing **2** models, on the exact same questions
|
56 |
+
```
|
57 |
+
python compare_runs.py results/medmax.json results/gpt4o.json
|
58 |
+
```
|
59 |
+
|
60 |
+
For a direct evaluation comparing **ALL** models, on the exact same questions (add as many model log files as you want).
|
61 |
+
```
|
62 |
+
python compare_runs.py results/medmax.json results/gpt4o.json results/llama.json results/chexagent.json results/llavamed.json
|
63 |
+
```
|
experiments/analyze_axes.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import argparse
|
6 |
+
from collections import defaultdict
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
QUESTION_TYPES = {
|
10 |
+
"Detailed Finding Analysis": ["detection", "localization", "characterization"],
|
11 |
+
"Pattern Recognition & Relations": ["detection", "classification", "relationship"],
|
12 |
+
"Spatial Understanding": ["localization", "comparison", "relationship"],
|
13 |
+
"Clinical Decision Making": ["classification", "comparison", "diagnosis"],
|
14 |
+
"Diagnostic Classification": ["classification", "characterization", "diagnosis"],
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def extract_answer_letter(answer: Optional[Union[str, Any]]) -> Optional[str]:
|
19 |
+
"""
|
20 |
+
Extract just the letter from various answer formats.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
answer: The answer text to extract letter from
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
Optional[str]: The extracted letter in uppercase, or None if no letter found
|
27 |
+
"""
|
28 |
+
if not answer:
|
29 |
+
return None
|
30 |
+
|
31 |
+
# Convert to string and clean
|
32 |
+
answer = str(answer).strip()
|
33 |
+
|
34 |
+
# If it's just a single letter, return it
|
35 |
+
if len(answer) == 1 and answer.isalpha():
|
36 |
+
return answer.upper()
|
37 |
+
|
38 |
+
# Try to extract letter from format like "A)" or "A."
|
39 |
+
if len(answer) >= 2 and answer[0].isalpha() and answer[1] in ").:- ":
|
40 |
+
return answer[0].upper()
|
41 |
+
|
42 |
+
# Try to extract letter from format like "A) Some text"
|
43 |
+
if answer.startswith(("A)", "B)", "C)", "D)", "E)", "F)")):
|
44 |
+
return answer[0].upper()
|
45 |
+
|
46 |
+
return None
|
47 |
+
|
48 |
+
|
49 |
+
def analyze_gpt4_results(
|
50 |
+
results_file: str, max_questions: Optional[int] = None
|
51 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
52 |
+
"""
|
53 |
+
Analyze results in GPT-4 format.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
results_file: Path to results file
|
57 |
+
max_questions: Maximum number of questions to analyze
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Tuple containing:
|
61 |
+
- overall_accuracy (float)
|
62 |
+
- category_accuracies (Dict)
|
63 |
+
- question_type_stats (Dict)
|
64 |
+
- correct_ids (List[str])
|
65 |
+
- incorrect_ids (List[str])
|
66 |
+
"""
|
67 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
68 |
+
all_questions = 0
|
69 |
+
all_correct = 0
|
70 |
+
correct_ids = []
|
71 |
+
incorrect_ids = []
|
72 |
+
|
73 |
+
with open(results_file, "r") as f:
|
74 |
+
lines = f.readlines()
|
75 |
+
|
76 |
+
processed_questions = 0
|
77 |
+
|
78 |
+
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
|
79 |
+
# Check if we've hit the maximum questions
|
80 |
+
if max_questions is not None and processed_questions >= max_questions:
|
81 |
+
break
|
82 |
+
if line.startswith("HTTP Request:"):
|
83 |
+
continue
|
84 |
+
|
85 |
+
try:
|
86 |
+
entry = json.loads(line)
|
87 |
+
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
|
88 |
+
question_id = entry.get("question_id")
|
89 |
+
|
90 |
+
model_letter = extract_answer_letter(entry.get("model_answer"))
|
91 |
+
correct_letter = extract_answer_letter(entry.get("correct_answer"))
|
92 |
+
|
93 |
+
if model_letter and correct_letter:
|
94 |
+
all_questions += 1
|
95 |
+
processed_questions += 1
|
96 |
+
is_correct = model_letter == correct_letter
|
97 |
+
|
98 |
+
if is_correct:
|
99 |
+
all_correct += 1
|
100 |
+
correct_ids.append(question_id)
|
101 |
+
else:
|
102 |
+
incorrect_ids.append(question_id)
|
103 |
+
|
104 |
+
for category in metadata.get("categories", []):
|
105 |
+
category_performance[category]["total"] += 1
|
106 |
+
if is_correct:
|
107 |
+
category_performance[category]["correct"] += 1
|
108 |
+
|
109 |
+
except json.JSONDecodeError:
|
110 |
+
continue
|
111 |
+
|
112 |
+
return process_results(
|
113 |
+
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def analyze_llama_results(
|
118 |
+
results_file: str, max_questions: Optional[int] = None
|
119 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
120 |
+
"""
|
121 |
+
Analyze results in Llama format.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
results_file: Path to results file
|
125 |
+
max_questions: Maximum number of questions to analyze
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Tuple containing:
|
129 |
+
- overall_accuracy (float)
|
130 |
+
- category_accuracies (Dict)
|
131 |
+
- question_type_stats (Dict)
|
132 |
+
- correct_ids (List[str])
|
133 |
+
- incorrect_ids (List[str])
|
134 |
+
"""
|
135 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
136 |
+
all_questions = 0
|
137 |
+
all_correct = 0
|
138 |
+
correct_ids = []
|
139 |
+
incorrect_ids = []
|
140 |
+
|
141 |
+
with open(results_file, "r") as f:
|
142 |
+
lines = f.readlines()
|
143 |
+
|
144 |
+
# If max_questions is set, limit the number of lines processed
|
145 |
+
if max_questions is not None:
|
146 |
+
lines = lines[:max_questions]
|
147 |
+
|
148 |
+
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
|
149 |
+
if line.startswith("HTTP Request:"):
|
150 |
+
continue
|
151 |
+
|
152 |
+
try:
|
153 |
+
entry = json.loads(line)
|
154 |
+
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
|
155 |
+
question_id = entry.get("question_id")
|
156 |
+
|
157 |
+
model_letter = extract_answer_letter(entry.get("model_answer"))
|
158 |
+
correct_letter = extract_answer_letter(entry.get("correct_answer"))
|
159 |
+
|
160 |
+
if model_letter and correct_letter:
|
161 |
+
all_questions += 1
|
162 |
+
is_correct = model_letter == correct_letter
|
163 |
+
|
164 |
+
if is_correct:
|
165 |
+
all_correct += 1
|
166 |
+
correct_ids.append(question_id)
|
167 |
+
else:
|
168 |
+
incorrect_ids.append(question_id)
|
169 |
+
|
170 |
+
for category in metadata.get("categories", []):
|
171 |
+
category_performance[category]["total"] += 1
|
172 |
+
if is_correct:
|
173 |
+
category_performance[category]["correct"] += 1
|
174 |
+
|
175 |
+
except json.JSONDecodeError:
|
176 |
+
continue
|
177 |
+
|
178 |
+
return process_results(
|
179 |
+
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
def analyze_chexagent_results(
|
184 |
+
results_file: str, max_questions: Optional[int] = None
|
185 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
186 |
+
"""
|
187 |
+
Analyze results in CheXagent format.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
results_file: Path to results file
|
191 |
+
max_questions: Maximum number of questions to analyze
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
Tuple containing:
|
195 |
+
- overall_accuracy (float)
|
196 |
+
- category_accuracies (Dict)
|
197 |
+
- question_type_stats (Dict)
|
198 |
+
- correct_ids (List[str])
|
199 |
+
- incorrect_ids (List[str])
|
200 |
+
"""
|
201 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
202 |
+
all_questions = 0
|
203 |
+
all_correct = 0
|
204 |
+
correct_ids = []
|
205 |
+
incorrect_ids = []
|
206 |
+
|
207 |
+
with open(results_file, "r") as f:
|
208 |
+
lines = f.readlines()
|
209 |
+
|
210 |
+
# If max_questions is set, limit the number of lines processed
|
211 |
+
if max_questions is not None:
|
212 |
+
lines = lines[:max_questions]
|
213 |
+
|
214 |
+
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
|
215 |
+
try:
|
216 |
+
entry = json.loads(line)
|
217 |
+
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
|
218 |
+
question_id = entry.get("question_id")
|
219 |
+
|
220 |
+
model_letter = extract_answer_letter(entry.get("model_answer"))
|
221 |
+
correct_letter = extract_answer_letter(entry.get("correct_answer"))
|
222 |
+
|
223 |
+
if model_letter and correct_letter:
|
224 |
+
all_questions += 1
|
225 |
+
is_correct = model_letter == correct_letter
|
226 |
+
|
227 |
+
if is_correct:
|
228 |
+
all_correct += 1
|
229 |
+
correct_ids.append(question_id)
|
230 |
+
else:
|
231 |
+
incorrect_ids.append(question_id)
|
232 |
+
|
233 |
+
for category in metadata.get("categories", []):
|
234 |
+
category_performance[category]["total"] += 1
|
235 |
+
if is_correct:
|
236 |
+
category_performance[category]["correct"] += 1
|
237 |
+
|
238 |
+
except json.JSONDecodeError:
|
239 |
+
continue
|
240 |
+
|
241 |
+
return process_results(
|
242 |
+
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
def process_results(
|
247 |
+
category_performance: Dict,
|
248 |
+
all_questions: int,
|
249 |
+
all_correct: int,
|
250 |
+
correct_ids: Optional[List[str]] = None,
|
251 |
+
incorrect_ids: Optional[List[str]] = None,
|
252 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
253 |
+
"""
|
254 |
+
Process raw results into final statistics.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
category_performance: Dict containing performance by category
|
258 |
+
all_questions: Total number of questions
|
259 |
+
all_correct: Total number of correct answers
|
260 |
+
correct_ids: List of IDs for correctly answered questions
|
261 |
+
incorrect_ids: List of IDs for incorrectly answered questions
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
Tuple containing:
|
265 |
+
- overall_accuracy (float)
|
266 |
+
- category_accuracies (Dict)
|
267 |
+
- question_type_stats (Dict)
|
268 |
+
- correct_ids (List[str])
|
269 |
+
- incorrect_ids (List[str])
|
270 |
+
"""
|
271 |
+
category_accuracies = {
|
272 |
+
category: {
|
273 |
+
"accuracy": stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0,
|
274 |
+
"total": stats["total"],
|
275 |
+
"correct": stats["correct"],
|
276 |
+
}
|
277 |
+
for category, stats in category_performance.items()
|
278 |
+
}
|
279 |
+
|
280 |
+
question_type_stats = {}
|
281 |
+
for qtype, categories in QUESTION_TYPES.items():
|
282 |
+
total = sum(
|
283 |
+
category_performance[cat]["total"] for cat in categories if cat in category_performance
|
284 |
+
)
|
285 |
+
correct = sum(
|
286 |
+
category_performance[cat]["correct"]
|
287 |
+
for cat in categories
|
288 |
+
if cat in category_performance
|
289 |
+
)
|
290 |
+
|
291 |
+
question_type_stats[qtype] = {
|
292 |
+
"accuracy": (correct / total * 100) if total > 0 else 0,
|
293 |
+
"total": total,
|
294 |
+
"correct": correct,
|
295 |
+
}
|
296 |
+
|
297 |
+
overall_accuracy = (all_correct / all_questions * 100) if all_questions > 0 else 0
|
298 |
+
|
299 |
+
return (
|
300 |
+
overall_accuracy,
|
301 |
+
category_accuracies,
|
302 |
+
question_type_stats,
|
303 |
+
correct_ids or [],
|
304 |
+
incorrect_ids or [],
|
305 |
+
)
|
306 |
+
|
307 |
+
|
308 |
+
def print_analysis(
|
309 |
+
overall_accuracy: float,
|
310 |
+
category_accuracies: Dict,
|
311 |
+
question_type_stats: Dict,
|
312 |
+
correct_ids: List[str],
|
313 |
+
incorrect_ids: List[str],
|
314 |
+
model_name: str,
|
315 |
+
) -> None:
|
316 |
+
"""
|
317 |
+
Print analysis results.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
overall_accuracy: Overall accuracy percentage
|
321 |
+
category_accuracies: Dict containing accuracy metrics by category
|
322 |
+
question_type_stats: Dict containing stats by question type
|
323 |
+
correct_ids: List of IDs for correctly answered questions
|
324 |
+
incorrect_ids: List of IDs for incorrectly answered questions
|
325 |
+
model_name: Name of the model being analyzed
|
326 |
+
"""
|
327 |
+
total_questions = len(correct_ids) + len(incorrect_ids)
|
328 |
+
print(
|
329 |
+
f"\nOverall Accuracy: {overall_accuracy:.2f}% ({len(correct_ids)} correct out of {total_questions} questions)"
|
330 |
+
)
|
331 |
+
|
332 |
+
print("\nCategory Performance:")
|
333 |
+
sorted_categories = sorted(
|
334 |
+
category_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True
|
335 |
+
)
|
336 |
+
for category, metrics in sorted_categories:
|
337 |
+
print(f"{category}:")
|
338 |
+
print(f" Accuracy: {metrics['accuracy']:.2f}%")
|
339 |
+
print(f" Total Questions: {metrics['total']}")
|
340 |
+
print(f" Correct Questions: {metrics['correct']}")
|
341 |
+
|
342 |
+
print("\nQuestion Type Performance:")
|
343 |
+
sorted_types = sorted(question_type_stats.items(), key=lambda x: x[1]["accuracy"], reverse=True)
|
344 |
+
for qtype, metrics in sorted_types:
|
345 |
+
print(f"\n{qtype}:")
|
346 |
+
print(f" Accuracy: {metrics['accuracy']:.2f}%")
|
347 |
+
print(f" Total Questions: {metrics['total']}")
|
348 |
+
print(f" Correct Questions: {metrics['correct']}")
|
349 |
+
print(f" Categories: {', '.join(QUESTION_TYPES[qtype])}")
|
350 |
+
|
351 |
+
# Save question IDs to JSON
|
352 |
+
question_ids = {"correct_ids": correct_ids, "incorrect_ids": incorrect_ids}
|
353 |
+
|
354 |
+
output_filename = f"{model_name}_question_ids.json"
|
355 |
+
with open(output_filename, "w") as f:
|
356 |
+
json.dump(question_ids, f, indent=2)
|
357 |
+
|
358 |
+
print(f"\nQuestion IDs have been saved to {output_filename}")
|
359 |
+
|
360 |
+
|
361 |
+
if __name__ == "__main__":
|
362 |
+
parser = argparse.ArgumentParser(description="Analyze benchmark results")
|
363 |
+
parser.add_argument("results_file", help="Path to results file")
|
364 |
+
parser.add_argument("benchmark_dir", nargs="?", help="Path to benchmark questions directory")
|
365 |
+
parser.add_argument(
|
366 |
+
"--model",
|
367 |
+
choices=["llava-med", "chexagent", "llama", "gpt4", "medrax"],
|
368 |
+
default="gpt4",
|
369 |
+
help="Specify model format (default: gpt4)",
|
370 |
+
)
|
371 |
+
parser.add_argument("--max-questions", type=int, help="Maximum number of questions to analyze")
|
372 |
+
args = parser.parse_args()
|
373 |
+
|
374 |
+
if args.model == "gpt4":
|
375 |
+
results = analyze_gpt4_results(args.results_file, args.max_questions)
|
376 |
+
elif args.model == "llama":
|
377 |
+
results = analyze_llama_results(args.results_file, args.max_questions)
|
378 |
+
elif args.model == "chexagent":
|
379 |
+
results = analyze_chexagent_results(args.results_file, args.max_questions)
|
380 |
+
elif args.model == "medrax":
|
381 |
+
results = analyze_gpt4_results(args.results_file, args.max_questions)
|
382 |
+
else:
|
383 |
+
parser.error(f"Unsupported model: {args.model}")
|
384 |
+
|
385 |
+
print_analysis(*results, args.model)
|
experiments/benchmark_chexagent.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import time
|
6 |
+
import logging
|
7 |
+
from datetime import datetime
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
# Configure model settings
|
14 |
+
MODEL_NAME = "StanfordAIMI/CheXagent-2-3b"
|
15 |
+
DTYPE = torch.bfloat16
|
16 |
+
DEVICE = "cuda"
|
17 |
+
|
18 |
+
# Configure logging
|
19 |
+
log_filename = f"model_inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
20 |
+
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
|
21 |
+
|
22 |
+
|
23 |
+
def initialize_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
24 |
+
"""Initialize the CheXagent model and tokenizer.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
tuple containing:
|
28 |
+
- AutoModelForCausalLM: The initialized CheXagent model
|
29 |
+
- AutoTokenizer: The initialized tokenizer
|
30 |
+
"""
|
31 |
+
print("Loading model and tokenizer...")
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
33 |
+
model = AutoModelForCausalLM.from_pretrained(
|
34 |
+
MODEL_NAME, device_map="auto", trust_remote_code=True
|
35 |
+
)
|
36 |
+
model = model.to(DTYPE)
|
37 |
+
model.eval()
|
38 |
+
return model, tokenizer
|
39 |
+
|
40 |
+
|
41 |
+
def create_inference_request(
|
42 |
+
question_data: dict,
|
43 |
+
case_details: dict,
|
44 |
+
case_id: str,
|
45 |
+
question_id: str,
|
46 |
+
model: AutoModelForCausalLM,
|
47 |
+
tokenizer: AutoTokenizer,
|
48 |
+
) -> str | None:
|
49 |
+
"""Create and execute an inference request for the CheXagent model.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
question_data: Dictionary containing question details and metadata
|
53 |
+
case_details: Dictionary containing case information and image paths
|
54 |
+
case_id: Unique identifier for the medical case
|
55 |
+
question_id: Unique identifier for the question
|
56 |
+
model: The initialized CheXagent model
|
57 |
+
tokenizer: The initialized tokenizer
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
str | None: Single letter answer (A-F) if successful, None if failed
|
61 |
+
"""
|
62 |
+
system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
|
63 |
+
Rules:
|
64 |
+
1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
|
65 |
+
2. Do not add periods, explanations, or any other text
|
66 |
+
3. Do not use markdown or formatting
|
67 |
+
4. Do not restate the question
|
68 |
+
5. Do not explain your reasoning
|
69 |
+
|
70 |
+
Examples of valid responses:
|
71 |
+
A
|
72 |
+
B
|
73 |
+
C
|
74 |
+
|
75 |
+
Examples of invalid responses:
|
76 |
+
"A."
|
77 |
+
"Answer: B"
|
78 |
+
"C) This shows..."
|
79 |
+
"The answer is D"
|
80 |
+
"""
|
81 |
+
|
82 |
+
prompt = f"""Given the following medical case:
|
83 |
+
Please answer this multiple choice question:
|
84 |
+
{question_data['question']}
|
85 |
+
Base your answer only on the provided images and case information."""
|
86 |
+
|
87 |
+
# Parse required figures
|
88 |
+
try:
|
89 |
+
if isinstance(question_data["figures"], str):
|
90 |
+
try:
|
91 |
+
required_figures = json.loads(question_data["figures"])
|
92 |
+
except json.JSONDecodeError:
|
93 |
+
required_figures = [question_data["figures"]]
|
94 |
+
elif isinstance(question_data["figures"], list):
|
95 |
+
required_figures = question_data["figures"]
|
96 |
+
else:
|
97 |
+
required_figures = [str(question_data["figures"])]
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Error parsing figures: {e}")
|
100 |
+
required_figures = []
|
101 |
+
|
102 |
+
required_figures = [
|
103 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
104 |
+
]
|
105 |
+
|
106 |
+
# Get image paths
|
107 |
+
image_paths = []
|
108 |
+
for figure in required_figures:
|
109 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
110 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
111 |
+
|
112 |
+
matching_figures = [
|
113 |
+
case_figure
|
114 |
+
for case_figure in case_details.get("figures", [])
|
115 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
116 |
+
]
|
117 |
+
|
118 |
+
for case_figure in matching_figures:
|
119 |
+
subfigures = []
|
120 |
+
if figure_letter:
|
121 |
+
subfigures = [
|
122 |
+
subfig
|
123 |
+
for subfig in case_figure.get("subfigures", [])
|
124 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
125 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
126 |
+
]
|
127 |
+
else:
|
128 |
+
subfigures = case_figure.get("subfigures", [])
|
129 |
+
|
130 |
+
for subfig in subfigures:
|
131 |
+
if "local_path" in subfig:
|
132 |
+
image_paths.append("medrax/data/" + subfig["local_path"])
|
133 |
+
|
134 |
+
if not image_paths:
|
135 |
+
print(f"No local images found for case {case_id}, question {question_id}")
|
136 |
+
return None
|
137 |
+
|
138 |
+
try:
|
139 |
+
start_time = time.time()
|
140 |
+
|
141 |
+
# Prepare input for the model
|
142 |
+
query = tokenizer.from_list_format(
|
143 |
+
[*[{"image": path} for path in image_paths], {"text": prompt}]
|
144 |
+
)
|
145 |
+
conv = [{"from": "system", "value": system_prompt}, {"from": "human", "value": query}]
|
146 |
+
input_ids = tokenizer.apply_chat_template(
|
147 |
+
conv, add_generation_prompt=True, return_tensors="pt"
|
148 |
+
)
|
149 |
+
|
150 |
+
# Generate response
|
151 |
+
with torch.no_grad():
|
152 |
+
output = model.generate(
|
153 |
+
input_ids.to(DEVICE),
|
154 |
+
do_sample=False,
|
155 |
+
num_beams=1,
|
156 |
+
temperature=1.0,
|
157 |
+
top_p=1.0,
|
158 |
+
use_cache=True,
|
159 |
+
max_new_tokens=512,
|
160 |
+
)[0]
|
161 |
+
|
162 |
+
response = tokenizer.decode(output[input_ids.size(1) : -1])
|
163 |
+
duration = time.time() - start_time
|
164 |
+
|
165 |
+
# Clean response
|
166 |
+
clean_answer = validate_answer(response)
|
167 |
+
|
168 |
+
# Log response
|
169 |
+
log_entry = {
|
170 |
+
"case_id": case_id,
|
171 |
+
"question_id": question_id,
|
172 |
+
"timestamp": datetime.now().isoformat(),
|
173 |
+
"model": MODEL_NAME,
|
174 |
+
"duration": round(duration, 2),
|
175 |
+
"model_answer": clean_answer,
|
176 |
+
"correct_answer": question_data["answer"],
|
177 |
+
"input": {
|
178 |
+
"question_data": {
|
179 |
+
"question": question_data["question"],
|
180 |
+
"explanation": question_data["explanation"],
|
181 |
+
"metadata": question_data.get("metadata", {}),
|
182 |
+
"figures": question_data["figures"],
|
183 |
+
},
|
184 |
+
"image_paths": image_paths,
|
185 |
+
},
|
186 |
+
}
|
187 |
+
logging.info(json.dumps(log_entry))
|
188 |
+
return clean_answer
|
189 |
+
|
190 |
+
except Exception as e:
|
191 |
+
print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
|
192 |
+
log_entry = {
|
193 |
+
"case_id": case_id,
|
194 |
+
"question_id": question_id,
|
195 |
+
"timestamp": datetime.now().isoformat(),
|
196 |
+
"model": MODEL_NAME,
|
197 |
+
"status": "error",
|
198 |
+
"error": str(e),
|
199 |
+
"input": {
|
200 |
+
"question_data": {
|
201 |
+
"question": question_data["question"],
|
202 |
+
"explanation": question_data["explanation"],
|
203 |
+
"metadata": question_data.get("metadata", {}),
|
204 |
+
"figures": question_data["figures"],
|
205 |
+
},
|
206 |
+
"image_paths": image_paths,
|
207 |
+
},
|
208 |
+
}
|
209 |
+
logging.info(json.dumps(log_entry))
|
210 |
+
return None
|
211 |
+
|
212 |
+
|
213 |
+
def validate_answer(response_text: str) -> str | None:
|
214 |
+
"""Enforce strict single-letter response format.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
response_text: Raw response text from the model
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
str | None: Single uppercase letter (A-F) if valid, None if invalid
|
221 |
+
"""
|
222 |
+
if not response_text:
|
223 |
+
return None
|
224 |
+
|
225 |
+
# Remove all whitespace and convert to uppercase
|
226 |
+
cleaned = response_text.strip().upper()
|
227 |
+
|
228 |
+
# Check if it's exactly one valid letter
|
229 |
+
if len(cleaned) == 1 and cleaned in "ABCDEF":
|
230 |
+
return cleaned
|
231 |
+
|
232 |
+
# If not, try to extract just the letter
|
233 |
+
match = re.search(r"([A-F])", cleaned)
|
234 |
+
return match.group(1) if match else None
|
235 |
+
|
236 |
+
|
237 |
+
def load_benchmark_questions(case_id: str) -> list[str]:
|
238 |
+
"""Find all question files for a given case ID.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
case_id: Unique identifier for the medical case
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
list[str]: List of paths to question JSON files
|
245 |
+
"""
|
246 |
+
benchmark_dir = "../benchmark/questions"
|
247 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
248 |
+
|
249 |
+
|
250 |
+
def count_total_questions() -> tuple[int, int]:
|
251 |
+
"""Count total number of cases and questions in benchmark.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
tuple containing:
|
255 |
+
- int: Total number of cases
|
256 |
+
- int: Total number of questions
|
257 |
+
"""
|
258 |
+
total_cases = len(glob.glob("../benchmark/questions/*"))
|
259 |
+
total_questions = sum(
|
260 |
+
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
|
261 |
+
for case_id in os.listdir("../benchmark/questions")
|
262 |
+
)
|
263 |
+
return total_cases, total_questions
|
264 |
+
|
265 |
+
|
266 |
+
def main():
|
267 |
+
# Load the cases with local paths
|
268 |
+
with open("medrax/data/updated_cases.json", "r") as file:
|
269 |
+
data = json.load(file)
|
270 |
+
|
271 |
+
# Initialize model and tokenizer
|
272 |
+
model, tokenizer = initialize_model()
|
273 |
+
|
274 |
+
total_cases, total_questions = count_total_questions()
|
275 |
+
cases_processed = 0
|
276 |
+
questions_processed = 0
|
277 |
+
skipped_questions = 0
|
278 |
+
|
279 |
+
print(f"\nBeginning inference with {MODEL_NAME}")
|
280 |
+
print(f"Found {total_cases} cases with {total_questions} total questions")
|
281 |
+
|
282 |
+
# Process each case with progress bar
|
283 |
+
for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
|
284 |
+
question_files = load_benchmark_questions(case_id)
|
285 |
+
if not question_files:
|
286 |
+
continue
|
287 |
+
|
288 |
+
cases_processed += 1
|
289 |
+
for question_file in tqdm(
|
290 |
+
question_files, desc=f"Processing questions for case {case_id}", leave=False
|
291 |
+
):
|
292 |
+
with open(question_file, "r") as file:
|
293 |
+
question_data = json.load(file)
|
294 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
295 |
+
|
296 |
+
questions_processed += 1
|
297 |
+
answer = create_inference_request(
|
298 |
+
question_data, case_details, case_id, question_id, model, tokenizer
|
299 |
+
)
|
300 |
+
|
301 |
+
if answer is None:
|
302 |
+
skipped_questions += 1
|
303 |
+
continue
|
304 |
+
|
305 |
+
print(f"\nCase {case_id}, Question {question_id}")
|
306 |
+
print(f"Model Answer: {answer}")
|
307 |
+
print(f"Correct Answer: {question_data['answer']}")
|
308 |
+
|
309 |
+
print(f"\nInference Summary:")
|
310 |
+
print(f"Total Cases Processed: {cases_processed}")
|
311 |
+
print(f"Total Questions Processed: {questions_processed}")
|
312 |
+
print(f"Total Questions Skipped: {skipped_questions}")
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == "__main__":
|
316 |
+
main()
|
experiments/benchmark_gpt4o.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import openai
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import time
|
6 |
+
import logging
|
7 |
+
from datetime import datetime
|
8 |
+
from tenacity import retry, wait_exponential, stop_after_attempt
|
9 |
+
|
10 |
+
model_name = "chatgpt-4o-latest"
|
11 |
+
temperature = 0.2
|
12 |
+
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
13 |
+
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
|
14 |
+
|
15 |
+
|
16 |
+
def calculate_cost(
|
17 |
+
prompt_tokens: int, completion_tokens: int, model: str = "chatgpt-4o-latest"
|
18 |
+
) -> float:
|
19 |
+
"""Calculate the cost of API usage based on token counts.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
prompt_tokens: Number of tokens in the prompt
|
23 |
+
completion_tokens: Number of tokens in the completion
|
24 |
+
model: Model name to use for pricing, defaults to chatgpt-4o-latest
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
float: Cost in USD
|
28 |
+
"""
|
29 |
+
pricing = {"chatgpt-4o-latest": {"prompt": 5.0, "completion": 15.0}}
|
30 |
+
rates = pricing.get(model, {"prompt": 5.0, "completion": 15.0})
|
31 |
+
return (prompt_tokens * rates["prompt"] + completion_tokens * rates["completion"]) / 1000000
|
32 |
+
|
33 |
+
|
34 |
+
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
35 |
+
def create_multimodal_request(
|
36 |
+
question_data: dict, case_details: dict, case_id: str, question_id: str, client: openai.OpenAI
|
37 |
+
) -> openai.types.chat.ChatCompletion:
|
38 |
+
"""Create and send a multimodal request to the OpenAI API.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
question_data: Dictionary containing question details and figures
|
42 |
+
case_details: Dictionary containing case information and figures
|
43 |
+
case_id: Identifier for the medical case
|
44 |
+
question_id: Identifier for the specific question
|
45 |
+
client: OpenAI client instance
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
openai.types.chat.ChatCompletion: API response object, or None if request fails
|
49 |
+
"""
|
50 |
+
prompt = f"""Given the following medical case:
|
51 |
+
Please answer this multiple choice question:
|
52 |
+
{question_data['question']}
|
53 |
+
Base your answer only on the provided images and case information."""
|
54 |
+
|
55 |
+
content = [{"type": "text", "text": prompt}]
|
56 |
+
|
57 |
+
# Parse required figures
|
58 |
+
try:
|
59 |
+
# Try multiple ways of parsing figures
|
60 |
+
if isinstance(question_data["figures"], str):
|
61 |
+
try:
|
62 |
+
required_figures = json.loads(question_data["figures"])
|
63 |
+
except json.JSONDecodeError:
|
64 |
+
required_figures = [question_data["figures"]]
|
65 |
+
elif isinstance(question_data["figures"], list):
|
66 |
+
required_figures = question_data["figures"]
|
67 |
+
else:
|
68 |
+
required_figures = [str(question_data["figures"])]
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error parsing figures: {e}")
|
71 |
+
required_figures = []
|
72 |
+
|
73 |
+
# Ensure each figure starts with "Figure "
|
74 |
+
required_figures = [
|
75 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
76 |
+
]
|
77 |
+
|
78 |
+
subfigures = []
|
79 |
+
for figure in required_figures:
|
80 |
+
# Handle both regular figures and those with letter suffixes
|
81 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
82 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
83 |
+
|
84 |
+
# Find matching figures in case details
|
85 |
+
matching_figures = [
|
86 |
+
case_figure
|
87 |
+
for case_figure in case_details.get("figures", [])
|
88 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
89 |
+
]
|
90 |
+
|
91 |
+
if not matching_figures:
|
92 |
+
print(f"No matching figure found for {figure} in case {case_id}")
|
93 |
+
continue
|
94 |
+
|
95 |
+
for case_figure in matching_figures:
|
96 |
+
# If a specific letter is specified, filter subfigures
|
97 |
+
if figure_letter:
|
98 |
+
matching_subfigures = [
|
99 |
+
subfig
|
100 |
+
for subfig in case_figure.get("subfigures", [])
|
101 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
102 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
103 |
+
]
|
104 |
+
subfigures.extend(matching_subfigures)
|
105 |
+
else:
|
106 |
+
# If no letter specified, add all subfigures
|
107 |
+
subfigures.extend(case_figure.get("subfigures", []))
|
108 |
+
|
109 |
+
# Add images to content
|
110 |
+
for subfig in subfigures:
|
111 |
+
if "url" in subfig:
|
112 |
+
content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
|
113 |
+
else:
|
114 |
+
print(f"Subfigure missing URL: {subfig}")
|
115 |
+
|
116 |
+
# If no images found, log and return None
|
117 |
+
if len(content) == 1: # Only the text prompt exists
|
118 |
+
print(f"No images found for case {case_id}, question {question_id}")
|
119 |
+
return None
|
120 |
+
|
121 |
+
messages = [
|
122 |
+
{
|
123 |
+
"role": "system",
|
124 |
+
"content": "You are a medical imaging expert. Provide only the letter corresponding to your answer choice (A/B/C/D/E/F).",
|
125 |
+
},
|
126 |
+
{"role": "user", "content": content},
|
127 |
+
]
|
128 |
+
|
129 |
+
if len(content) == 1: # Only the text prompt exists
|
130 |
+
print(f"No images found for case {case_id}, question {question_id}")
|
131 |
+
log_entry = {
|
132 |
+
"case_id": case_id,
|
133 |
+
"question_id": question_id,
|
134 |
+
"timestamp": datetime.now().isoformat(),
|
135 |
+
"model": model_name,
|
136 |
+
"temperature": temperature,
|
137 |
+
"status": "skipped",
|
138 |
+
"reason": "no_images",
|
139 |
+
"cost": 0,
|
140 |
+
"input": {
|
141 |
+
"messages": messages,
|
142 |
+
"question_data": {
|
143 |
+
"question": question_data["question"],
|
144 |
+
"explanation": question_data["explanation"],
|
145 |
+
"metadata": question_data.get("metadata", {}),
|
146 |
+
"figures": question_data["figures"],
|
147 |
+
},
|
148 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
149 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
150 |
+
},
|
151 |
+
}
|
152 |
+
logging.info(json.dumps(log_entry))
|
153 |
+
return None
|
154 |
+
|
155 |
+
try:
|
156 |
+
start_time = time.time()
|
157 |
+
|
158 |
+
response = client.chat.completions.create(
|
159 |
+
model=model_name, messages=messages, max_tokens=50, temperature=temperature
|
160 |
+
)
|
161 |
+
duration = time.time() - start_time
|
162 |
+
|
163 |
+
log_entry = {
|
164 |
+
"case_id": case_id,
|
165 |
+
"question_id": question_id,
|
166 |
+
"timestamp": datetime.now().isoformat(),
|
167 |
+
"model": model_name,
|
168 |
+
"temperature": temperature,
|
169 |
+
"duration": round(duration, 2),
|
170 |
+
"usage": {
|
171 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
172 |
+
"completion_tokens": response.usage.completion_tokens,
|
173 |
+
"total_tokens": response.usage.total_tokens,
|
174 |
+
},
|
175 |
+
"cost": calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens),
|
176 |
+
"model_answer": response.choices[0].message.content,
|
177 |
+
"correct_answer": question_data["answer"],
|
178 |
+
"input": {
|
179 |
+
"messages": messages,
|
180 |
+
"question_data": {
|
181 |
+
"question": question_data["question"],
|
182 |
+
"explanation": question_data["explanation"],
|
183 |
+
"metadata": question_data.get("metadata", {}),
|
184 |
+
"figures": question_data["figures"],
|
185 |
+
},
|
186 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
187 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
188 |
+
},
|
189 |
+
}
|
190 |
+
logging.info(json.dumps(log_entry))
|
191 |
+
return response
|
192 |
+
|
193 |
+
except openai.RateLimitError:
|
194 |
+
log_entry = {
|
195 |
+
"case_id": case_id,
|
196 |
+
"question_id": question_id,
|
197 |
+
"timestamp": datetime.now().isoformat(),
|
198 |
+
"model": model_name,
|
199 |
+
"temperature": temperature,
|
200 |
+
"status": "error",
|
201 |
+
"reason": "rate_limit",
|
202 |
+
"cost": 0,
|
203 |
+
"input": {
|
204 |
+
"messages": messages,
|
205 |
+
"question_data": {
|
206 |
+
"question": question_data["question"],
|
207 |
+
"explanation": question_data["explanation"],
|
208 |
+
"metadata": question_data.get("metadata", {}),
|
209 |
+
"figures": question_data["figures"],
|
210 |
+
},
|
211 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
212 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
213 |
+
},
|
214 |
+
}
|
215 |
+
logging.info(json.dumps(log_entry))
|
216 |
+
print(
|
217 |
+
f"\nRate limit hit for case {case_id}, question {question_id}. Waiting 20s...",
|
218 |
+
flush=True,
|
219 |
+
)
|
220 |
+
time.sleep(20)
|
221 |
+
raise
|
222 |
+
except Exception as e:
|
223 |
+
log_entry = {
|
224 |
+
"case_id": case_id,
|
225 |
+
"question_id": question_id,
|
226 |
+
"timestamp": datetime.now().isoformat(),
|
227 |
+
"model": model_name,
|
228 |
+
"temperature": temperature,
|
229 |
+
"status": "error",
|
230 |
+
"error": str(e),
|
231 |
+
"cost": 0,
|
232 |
+
"input": {
|
233 |
+
"messages": messages,
|
234 |
+
"question_data": {
|
235 |
+
"question": question_data["question"],
|
236 |
+
"explanation": question_data["explanation"],
|
237 |
+
"metadata": question_data.get("metadata", {}),
|
238 |
+
"figures": question_data["figures"],
|
239 |
+
},
|
240 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
241 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
242 |
+
},
|
243 |
+
}
|
244 |
+
logging.info(json.dumps(log_entry))
|
245 |
+
print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
|
246 |
+
raise
|
247 |
+
|
248 |
+
|
249 |
+
def load_benchmark_questions(case_id: str) -> list:
|
250 |
+
"""Load benchmark questions for a given case.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
case_id: Identifier for the medical case
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
list: List of paths to question files
|
257 |
+
"""
|
258 |
+
benchmark_dir = "../benchmark/questions"
|
259 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
260 |
+
|
261 |
+
|
262 |
+
def count_total_questions() -> tuple[int, int]:
|
263 |
+
"""Count total number of cases and questions in benchmark.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
tuple: (total_cases, total_questions)
|
267 |
+
"""
|
268 |
+
total_cases = len(glob.glob("../benchmark/questions/*"))
|
269 |
+
total_questions = sum(
|
270 |
+
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
|
271 |
+
for case_id in os.listdir("../benchmark/questions")
|
272 |
+
)
|
273 |
+
return total_cases, total_questions
|
274 |
+
|
275 |
+
|
276 |
+
def main() -> None:
|
277 |
+
"""Main function to run the benchmark evaluation."""
|
278 |
+
with open("../data/eurorad_metadata.json", "r") as file:
|
279 |
+
data = json.load(file)
|
280 |
+
|
281 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
282 |
+
if not api_key:
|
283 |
+
raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
284 |
+
global client
|
285 |
+
client = openai.OpenAI(api_key=api_key)
|
286 |
+
|
287 |
+
total_cases, total_questions = count_total_questions()
|
288 |
+
cases_processed = 0
|
289 |
+
questions_processed = 0
|
290 |
+
skipped_questions = 0
|
291 |
+
|
292 |
+
print(f"Beginning benchmark evaluation for model {model_name} with temperature {temperature}")
|
293 |
+
|
294 |
+
for case_id, case_details in data.items():
|
295 |
+
question_files = load_benchmark_questions(case_id)
|
296 |
+
if not question_files:
|
297 |
+
continue
|
298 |
+
|
299 |
+
cases_processed += 1
|
300 |
+
for question_file in question_files:
|
301 |
+
with open(question_file, "r") as file:
|
302 |
+
question_data = json.load(file)
|
303 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
304 |
+
|
305 |
+
questions_processed += 1
|
306 |
+
response = create_multimodal_request(
|
307 |
+
question_data, case_details, case_id, question_id, client
|
308 |
+
)
|
309 |
+
|
310 |
+
# Handle cases where response is None
|
311 |
+
if response is None:
|
312 |
+
skipped_questions += 1
|
313 |
+
print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
|
314 |
+
continue
|
315 |
+
|
316 |
+
print(
|
317 |
+
f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
|
318 |
+
)
|
319 |
+
print(f"Case ID: {case_id}")
|
320 |
+
print(f"Question ID: {question_id}")
|
321 |
+
print(f"Model Answer: {response.choices[0].message.content}")
|
322 |
+
print(f"Correct Answer: {question_data['answer']}\n")
|
323 |
+
|
324 |
+
print(f"\nBenchmark Summary:")
|
325 |
+
print(f"Total Cases Processed: {cases_processed}")
|
326 |
+
print(f"Total Questions Processed: {questions_processed}")
|
327 |
+
print(f"Total Questions Skipped: {skipped_questions}")
|
328 |
+
|
329 |
+
|
330 |
+
if __name__ == "__main__":
|
331 |
+
main()
|
experiments/benchmark_llama.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Any, Union
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import glob
|
6 |
+
import time
|
7 |
+
import logging
|
8 |
+
import socket
|
9 |
+
import requests
|
10 |
+
import httpx
|
11 |
+
import backoff
|
12 |
+
from datetime import datetime
|
13 |
+
from tenacity import retry, wait_exponential, stop_after_attempt
|
14 |
+
from openai import OpenAI
|
15 |
+
|
16 |
+
# Configure model settings
|
17 |
+
MODEL_NAME = "meta-llama/llama-3.2-90b-vision-instruct"
|
18 |
+
temperature = 0.2
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
22 |
+
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
|
23 |
+
|
24 |
+
|
25 |
+
def verify_dns() -> bool:
|
26 |
+
"""Verify DNS resolution and connectivity.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
bool: True if DNS resolution succeeds, False otherwise
|
30 |
+
"""
|
31 |
+
try:
|
32 |
+
# Try to resolve openrouter.ai
|
33 |
+
socket.gethostbyname("openrouter.ai")
|
34 |
+
return True
|
35 |
+
except socket.gaierror:
|
36 |
+
print("DNS resolution failed. Trying to use Google DNS (8.8.8.8)...")
|
37 |
+
# Modify resolv.conf to use Google DNS
|
38 |
+
try:
|
39 |
+
with open("/etc/resolv.conf", "w") as f:
|
40 |
+
f.write("nameserver 8.8.8.8\n")
|
41 |
+
return True
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Failed to update DNS settings: {e}")
|
44 |
+
return False
|
45 |
+
|
46 |
+
|
47 |
+
def verify_connection() -> bool:
|
48 |
+
"""Verify connection to OpenRouter API.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
bool: True if connection succeeds, False otherwise
|
52 |
+
"""
|
53 |
+
try:
|
54 |
+
response = requests.get("https://openrouter.ai/api/v1/status", timeout=10)
|
55 |
+
return response.status_code == 200
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Connection test failed: {e}")
|
58 |
+
return False
|
59 |
+
|
60 |
+
|
61 |
+
def initialize_client() -> OpenAI:
|
62 |
+
"""Initialize the OpenRouter client with proper timeout settings and connection verification.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
OpenAI: Configured OpenAI client for OpenRouter
|
66 |
+
|
67 |
+
Raises:
|
68 |
+
ValueError: If OPENROUTER_API_KEY environment variable is not set
|
69 |
+
ConnectionError: If DNS verification or connection test fails
|
70 |
+
"""
|
71 |
+
api_key = os.getenv("OPENROUTER_API_KEY")
|
72 |
+
if not api_key:
|
73 |
+
raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
|
74 |
+
|
75 |
+
# Configure timeout settings for the client
|
76 |
+
timeout_settings = 120 # Increased timeout for large images/responses
|
77 |
+
|
78 |
+
# Verify DNS and connection
|
79 |
+
if not verify_dns():
|
80 |
+
raise ConnectionError("DNS verification failed. Please check your network settings.")
|
81 |
+
|
82 |
+
if not verify_connection():
|
83 |
+
raise ConnectionError(
|
84 |
+
"Cannot connect to OpenRouter. Please check your internet connection."
|
85 |
+
)
|
86 |
+
|
87 |
+
# Set up client with retry and timeout settings
|
88 |
+
return OpenAI(
|
89 |
+
base_url="https://openrouter.ai/api/v1",
|
90 |
+
api_key=api_key,
|
91 |
+
timeout=timeout_settings,
|
92 |
+
http_client=httpx.Client(
|
93 |
+
timeout=timeout_settings, transport=httpx.HTTPTransport(retries=3)
|
94 |
+
),
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
@backoff.on_exception(
|
99 |
+
backoff.expo,
|
100 |
+
(ConnectionError, TimeoutError, socket.gaierror, httpx.ConnectError),
|
101 |
+
max_tries=5,
|
102 |
+
max_time=300, # Maximum total time to try in seconds
|
103 |
+
)
|
104 |
+
def create_multimodal_request(
|
105 |
+
question_data: Dict[str, Any],
|
106 |
+
case_details: Dict[str, Any],
|
107 |
+
case_id: str,
|
108 |
+
question_id: str,
|
109 |
+
client: OpenAI,
|
110 |
+
) -> Optional[Any]:
|
111 |
+
"""Create and send a multimodal request to the model.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
question_data: Dictionary containing question details
|
115 |
+
case_details: Dictionary containing case information
|
116 |
+
case_id: ID of the medical case
|
117 |
+
question_id: ID of the specific question
|
118 |
+
client: OpenAI client instance
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Optional[Any]: Model response if successful, None if skipped
|
122 |
+
|
123 |
+
Raises:
|
124 |
+
ConnectionError: If connection fails
|
125 |
+
TimeoutError: If request times out
|
126 |
+
Exception: For other errors
|
127 |
+
"""
|
128 |
+
|
129 |
+
system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
|
130 |
+
Rules:
|
131 |
+
1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
|
132 |
+
2. Do not add periods, explanations, or any other text
|
133 |
+
3. Do not use markdown or formatting
|
134 |
+
4. Do not restate the question
|
135 |
+
5. Do not explain your reasoning
|
136 |
+
|
137 |
+
Examples of valid responses:
|
138 |
+
A
|
139 |
+
B
|
140 |
+
C
|
141 |
+
|
142 |
+
Examples of invalid responses:
|
143 |
+
"A."
|
144 |
+
"Answer: B"
|
145 |
+
"C) This shows..."
|
146 |
+
"The answer is D"
|
147 |
+
"""
|
148 |
+
|
149 |
+
prompt = f"""Given the following medical case:
|
150 |
+
Please answer this multiple choice question:
|
151 |
+
{question_data['question']}
|
152 |
+
Base your answer only on the provided images and case information."""
|
153 |
+
|
154 |
+
# Parse required figures
|
155 |
+
try:
|
156 |
+
if isinstance(question_data["figures"], str):
|
157 |
+
try:
|
158 |
+
required_figures = json.loads(question_data["figures"])
|
159 |
+
except json.JSONDecodeError:
|
160 |
+
required_figures = [question_data["figures"]]
|
161 |
+
elif isinstance(question_data["figures"], list):
|
162 |
+
required_figures = question_data["figures"]
|
163 |
+
else:
|
164 |
+
required_figures = [str(question_data["figures"])]
|
165 |
+
except Exception as e:
|
166 |
+
print(f"Error parsing figures: {e}")
|
167 |
+
required_figures = []
|
168 |
+
|
169 |
+
required_figures = [
|
170 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
171 |
+
]
|
172 |
+
|
173 |
+
# Process subfigures and prepare content
|
174 |
+
content = [{"type": "text", "text": prompt}]
|
175 |
+
image_urls = []
|
176 |
+
image_captions = []
|
177 |
+
|
178 |
+
for figure in required_figures:
|
179 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
180 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
181 |
+
|
182 |
+
matching_figures = [
|
183 |
+
case_figure
|
184 |
+
for case_figure in case_details.get("figures", [])
|
185 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
186 |
+
]
|
187 |
+
|
188 |
+
for case_figure in matching_figures:
|
189 |
+
subfigures = []
|
190 |
+
if figure_letter:
|
191 |
+
subfigures = [
|
192 |
+
subfig
|
193 |
+
for subfig in case_figure.get("subfigures", [])
|
194 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
195 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
196 |
+
]
|
197 |
+
else:
|
198 |
+
subfigures = case_figure.get("subfigures", [])
|
199 |
+
|
200 |
+
for subfig in subfigures:
|
201 |
+
if "url" in subfig:
|
202 |
+
content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
|
203 |
+
image_urls.append(subfig["url"])
|
204 |
+
image_captions.append(subfig.get("caption", ""))
|
205 |
+
|
206 |
+
if len(content) == 1: # Only the text prompt exists
|
207 |
+
print(f"No images found for case {case_id}, question {question_id}")
|
208 |
+
# Log the skipped question
|
209 |
+
log_entry = {
|
210 |
+
"case_id": case_id,
|
211 |
+
"question_id": question_id,
|
212 |
+
"timestamp": datetime.now().isoformat(),
|
213 |
+
"model": MODEL_NAME,
|
214 |
+
"status": "skipped",
|
215 |
+
"reason": "no_images",
|
216 |
+
"input": {
|
217 |
+
"question_data": {
|
218 |
+
"question": question_data["question"],
|
219 |
+
"explanation": question_data["explanation"],
|
220 |
+
"metadata": question_data.get("metadata", {}),
|
221 |
+
"figures": question_data["figures"],
|
222 |
+
},
|
223 |
+
"image_urls": image_urls,
|
224 |
+
},
|
225 |
+
}
|
226 |
+
logging.info(json.dumps(log_entry))
|
227 |
+
return None
|
228 |
+
|
229 |
+
try:
|
230 |
+
start_time = time.time()
|
231 |
+
|
232 |
+
response = client.chat.completions.create(
|
233 |
+
model=MODEL_NAME,
|
234 |
+
temperature=temperature,
|
235 |
+
messages=[
|
236 |
+
{"role": "system", "content": system_prompt},
|
237 |
+
{"role": "user", "content": content},
|
238 |
+
],
|
239 |
+
)
|
240 |
+
duration = time.time() - start_time
|
241 |
+
|
242 |
+
# Get raw response
|
243 |
+
raw_answer = response.choices[0].message.content
|
244 |
+
|
245 |
+
# Validate and clean
|
246 |
+
clean_answer = validate_answer(raw_answer)
|
247 |
+
|
248 |
+
if not clean_answer:
|
249 |
+
print(f"Warning: Invalid response format for case {case_id}, question {question_id}")
|
250 |
+
print(f"Raw response: {raw_answer}")
|
251 |
+
|
252 |
+
# Update response object with cleaned answer
|
253 |
+
response.choices[0].message.content = clean_answer
|
254 |
+
|
255 |
+
# Log response
|
256 |
+
log_entry = {
|
257 |
+
"case_id": case_id,
|
258 |
+
"question_id": question_id,
|
259 |
+
"timestamp": datetime.now().isoformat(),
|
260 |
+
"model": MODEL_NAME,
|
261 |
+
"temperature": temperature,
|
262 |
+
"duration": round(duration, 2),
|
263 |
+
"usage": {
|
264 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
265 |
+
"completion_tokens": response.usage.completion_tokens,
|
266 |
+
"total_tokens": response.usage.total_tokens,
|
267 |
+
},
|
268 |
+
"model_answer": response.choices[0].message.content,
|
269 |
+
"correct_answer": question_data["answer"],
|
270 |
+
"input": {
|
271 |
+
"question_data": {
|
272 |
+
"question": question_data["question"],
|
273 |
+
"explanation": question_data["explanation"],
|
274 |
+
"metadata": question_data.get("metadata", {}),
|
275 |
+
"figures": question_data["figures"],
|
276 |
+
},
|
277 |
+
"image_urls": image_urls,
|
278 |
+
},
|
279 |
+
}
|
280 |
+
logging.info(json.dumps(log_entry))
|
281 |
+
return response
|
282 |
+
|
283 |
+
except ConnectionError as e:
|
284 |
+
print(f"Connection error for case {case_id}, question {question_id}: {str(e)}")
|
285 |
+
print("Retrying after a longer delay...")
|
286 |
+
time.sleep(30) # Add a longer delay before retry
|
287 |
+
raise
|
288 |
+
except TimeoutError as e:
|
289 |
+
print(f"Timeout error for case {case_id}, question {question_id}: {str(e)}")
|
290 |
+
print("Retrying with increased timeout...")
|
291 |
+
raise
|
292 |
+
except Exception as e:
|
293 |
+
# Log failed requests too
|
294 |
+
log_entry = {
|
295 |
+
"case_id": case_id,
|
296 |
+
"question_id": question_id,
|
297 |
+
"timestamp": datetime.now().isoformat(),
|
298 |
+
"model": MODEL_NAME,
|
299 |
+
"temperature": temperature,
|
300 |
+
"status": "error",
|
301 |
+
"error": str(e),
|
302 |
+
"input": {
|
303 |
+
"question_data": {
|
304 |
+
"question": question_data["question"],
|
305 |
+
"explanation": question_data["explanation"],
|
306 |
+
"metadata": question_data.get("metadata", {}),
|
307 |
+
"figures": question_data["figures"],
|
308 |
+
},
|
309 |
+
"image_urls": image_urls,
|
310 |
+
},
|
311 |
+
}
|
312 |
+
logging.info(json.dumps(log_entry))
|
313 |
+
raise
|
314 |
+
|
315 |
+
|
316 |
+
def extract_answer(response_text: str) -> Optional[str]:
|
317 |
+
"""Extract single letter answer from model response.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
response_text: Raw text response from model
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
Optional[str]: Single letter answer if found, None otherwise
|
324 |
+
"""
|
325 |
+
# Convert to uppercase and remove periods
|
326 |
+
text = response_text.upper().replace(".", "")
|
327 |
+
|
328 |
+
# Look for common patterns
|
329 |
+
patterns = [
|
330 |
+
r"ANSWER:\s*([A-F])", # Matches "ANSWER: X"
|
331 |
+
r"OPTION\s*([A-F])", # Matches "OPTION X"
|
332 |
+
r"([A-F])\)", # Matches "X)"
|
333 |
+
r"\b([A-F])\b", # Matches single letter
|
334 |
+
]
|
335 |
+
|
336 |
+
for pattern in patterns:
|
337 |
+
matches = re.findall(pattern, text)
|
338 |
+
if matches:
|
339 |
+
return matches[0]
|
340 |
+
|
341 |
+
return None
|
342 |
+
|
343 |
+
|
344 |
+
def validate_answer(response_text: str) -> Optional[str]:
|
345 |
+
"""Enforce strict single-letter response format.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
response_text: Raw text response from model
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
Optional[str]: Valid single letter answer if found, None otherwise
|
352 |
+
"""
|
353 |
+
if not response_text:
|
354 |
+
return None
|
355 |
+
|
356 |
+
# Remove all whitespace and convert to uppercase
|
357 |
+
cleaned = response_text.strip().upper()
|
358 |
+
|
359 |
+
# Check if it's exactly one valid letter
|
360 |
+
if len(cleaned) == 1 and cleaned in "ABCDEF":
|
361 |
+
return cleaned
|
362 |
+
|
363 |
+
# If not, try to extract just the letter
|
364 |
+
match = re.search(r"([A-F])", cleaned)
|
365 |
+
return match.group(1) if match else None
|
366 |
+
|
367 |
+
|
368 |
+
def load_benchmark_questions(case_id: str) -> List[str]:
|
369 |
+
"""Find all question files for a given case ID.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
case_id: ID of the medical case
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
List[str]: List of paths to question files
|
376 |
+
"""
|
377 |
+
benchmark_dir = "../benchmark/questions"
|
378 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
379 |
+
|
380 |
+
|
381 |
+
def count_total_questions() -> Tuple[int, int]:
|
382 |
+
"""Count total number of cases and questions.
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
Tuple[int, int]: (total_cases, total_questions)
|
386 |
+
"""
|
387 |
+
total_cases = len(glob.glob("../benchmark/questions/*"))
|
388 |
+
total_questions = sum(
|
389 |
+
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
|
390 |
+
for case_id in os.listdir("../benchmark/questions")
|
391 |
+
)
|
392 |
+
return total_cases, total_questions
|
393 |
+
|
394 |
+
|
395 |
+
def main():
|
396 |
+
with open("../data/eurorad_metadata.json", "r") as file:
|
397 |
+
data = json.load(file)
|
398 |
+
|
399 |
+
client = initialize_client()
|
400 |
+
total_cases, total_questions = count_total_questions()
|
401 |
+
cases_processed = 0
|
402 |
+
questions_processed = 0
|
403 |
+
skipped_questions = 0
|
404 |
+
|
405 |
+
print(f"Beginning benchmark evaluation for {MODEL_NAME} with temperature {temperature}")
|
406 |
+
|
407 |
+
for case_id, case_details in data.items():
|
408 |
+
question_files = load_benchmark_questions(case_id)
|
409 |
+
if not question_files:
|
410 |
+
continue
|
411 |
+
|
412 |
+
cases_processed += 1
|
413 |
+
for question_file in question_files:
|
414 |
+
with open(question_file, "r") as file:
|
415 |
+
question_data = json.load(file)
|
416 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
417 |
+
|
418 |
+
questions_processed += 1
|
419 |
+
response = create_multimodal_request(
|
420 |
+
question_data, case_details, case_id, question_id, client
|
421 |
+
)
|
422 |
+
|
423 |
+
if response is None:
|
424 |
+
skipped_questions += 1
|
425 |
+
print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
|
426 |
+
continue
|
427 |
+
|
428 |
+
print(
|
429 |
+
f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
|
430 |
+
)
|
431 |
+
print(f"Case ID: {case_id}")
|
432 |
+
print(f"Question ID: {question_id}")
|
433 |
+
print(f"Model Answer: {response.choices[0].message.content}")
|
434 |
+
print(f"Correct Answer: {question_data['answer']}\n")
|
435 |
+
|
436 |
+
print(f"\nBenchmark Summary:")
|
437 |
+
print(f"Total Cases Processed: {cases_processed}")
|
438 |
+
print(f"Total Questions Processed: {questions_processed}")
|
439 |
+
print(f"Total Questions Skipped: {skipped_questions}")
|
440 |
+
|
441 |
+
|
442 |
+
if __name__ == "__main__":
|
443 |
+
main()
|
experiments/benchmark_llavamed.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import base64
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
from llava.conversation import conv_templates
|
8 |
+
import time
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import logging
|
12 |
+
from datetime import datetime
|
13 |
+
from tqdm import tqdm
|
14 |
+
import re
|
15 |
+
from typing import Dict, List, Optional, Union, Any, Tuple
|
16 |
+
|
17 |
+
|
18 |
+
def process_image(image_path: str, target_size: int = 640) -> Image.Image:
|
19 |
+
"""Process and resize an image to match model requirements.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
image_path: Path to the input image file
|
23 |
+
target_size: Target size for both width and height in pixels
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
PIL.Image: Processed and padded image with dimensions (target_size, target_size)
|
27 |
+
"""
|
28 |
+
image = Image.open(image_path)
|
29 |
+
if image.mode != "RGB":
|
30 |
+
image = image.convert("RGB")
|
31 |
+
|
32 |
+
# Calculate scaling to maintain aspect ratio
|
33 |
+
ratio = min(target_size / image.width, target_size / image.height)
|
34 |
+
new_size = (int(image.width * ratio), int(image.height * ratio))
|
35 |
+
|
36 |
+
# Resize image
|
37 |
+
image = image.resize(new_size, Image.LANCZOS)
|
38 |
+
|
39 |
+
# Create new image with padding
|
40 |
+
new_image = Image.new("RGB", (target_size, target_size), (0, 0, 0))
|
41 |
+
# Paste resized image in center
|
42 |
+
offset = ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2)
|
43 |
+
new_image.paste(image, offset)
|
44 |
+
|
45 |
+
return new_image
|
46 |
+
|
47 |
+
|
48 |
+
def validate_answer(response_text: str) -> Optional[str]:
|
49 |
+
"""Extract and validate a single-letter response from the model's output.
|
50 |
+
Handles multiple response formats and edge cases.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
response_text: The full text output from the model
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
A single letter answer (A-F) or None if no valid answer found
|
57 |
+
"""
|
58 |
+
if not response_text:
|
59 |
+
return None
|
60 |
+
|
61 |
+
# Clean the response text
|
62 |
+
cleaned = response_text.strip()
|
63 |
+
|
64 |
+
# Comprehensive set of patterns to extract the answer
|
65 |
+
extraction_patterns = [
|
66 |
+
# Strict format with explicit letter answer
|
67 |
+
r"(?:THE\s*)?(?:SINGLE\s*)?LETTER\s*(?:ANSWER\s*)?(?:IS:?)\s*([A-F])\b",
|
68 |
+
# Patterns for extracting from longer descriptions
|
69 |
+
r"(?:correct\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
|
70 |
+
r"\b(?:answer|option)\s*([A-F])[):]\s*",
|
71 |
+
# Patterns for extracting from descriptive sentences
|
72 |
+
r"(?:most\s+likely\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
|
73 |
+
r"suggest[s]?\s+(?:that\s+)?(?:the\s+)?(?:answer\s+)?(?:is\s*)?([A-F])\b",
|
74 |
+
# Patterns with contextual words
|
75 |
+
r"characteriz[e]?d?\s+by\s+([A-F])\b",
|
76 |
+
r"indicat[e]?s?\s+([A-F])\b",
|
77 |
+
# Fallback to Option X or Letterr X formats
|
78 |
+
r"Option\s*([A-F])\b",
|
79 |
+
r"\b([A-F])\)\s*",
|
80 |
+
# Fallback to standalone letter
|
81 |
+
r"^\s*([A-F])\s*$",
|
82 |
+
]
|
83 |
+
|
84 |
+
# Try each pattern
|
85 |
+
for pattern in extraction_patterns:
|
86 |
+
matches = re.findall(pattern, cleaned, re.IGNORECASE)
|
87 |
+
for match in matches:
|
88 |
+
# Ensure match is a single valid letter
|
89 |
+
if isinstance(match, tuple):
|
90 |
+
match = match[0] if match[0] in "ABCDEF" else None
|
91 |
+
if match and match.upper() in "ABCDEF":
|
92 |
+
return match.upper()
|
93 |
+
|
94 |
+
# Final fallback: look for standalone letters in context
|
95 |
+
context_matches = re.findall(r"\b([A-F])\b", cleaned.upper())
|
96 |
+
context_letters = [m for m in context_matches if m in "ABCDEF"]
|
97 |
+
if context_letters:
|
98 |
+
return context_letters[0]
|
99 |
+
|
100 |
+
# No valid answer found
|
101 |
+
return None
|
102 |
+
|
103 |
+
|
104 |
+
def load_benchmark_questions(case_id: str) -> List[str]:
|
105 |
+
"""Find all question files for a given case ID.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
case_id: The ID of the medical case
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
List of paths to question JSON files
|
112 |
+
"""
|
113 |
+
benchmark_dir = "MedMAX/benchmark/questions"
|
114 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
115 |
+
|
116 |
+
|
117 |
+
def count_total_questions() -> Tuple[int, int]:
|
118 |
+
"""Count total number of cases and questions in benchmark.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Tuple containing (total_cases, total_questions)
|
122 |
+
"""
|
123 |
+
total_cases = len(glob.glob("MedMAX/benchmark/questions/*"))
|
124 |
+
total_questions = sum(
|
125 |
+
len(glob.glob(f"MedMAX/benchmark/questions/{case_id}/*.json"))
|
126 |
+
for case_id in os.listdir("MedMAX/benchmark/questions")
|
127 |
+
)
|
128 |
+
return total_cases, total_questions
|
129 |
+
|
130 |
+
|
131 |
+
def create_inference_request(
|
132 |
+
question_data: Dict[str, Any],
|
133 |
+
case_details: Dict[str, Any],
|
134 |
+
case_id: str,
|
135 |
+
question_id: str,
|
136 |
+
worker_addr: str,
|
137 |
+
model_name: str,
|
138 |
+
raw_output: bool = False,
|
139 |
+
) -> Union[Tuple[Optional[str], Optional[float]], Dict[str, Any]]:
|
140 |
+
"""Create and send inference request to worker.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
question_data: Dictionary containing question details and figures
|
144 |
+
case_details: Dictionary containing case information and figures
|
145 |
+
case_id: Identifier for the medical case
|
146 |
+
question_id: Identifier for the specific question
|
147 |
+
worker_addr: Address of the worker endpoint
|
148 |
+
model_name: Name of the model to use
|
149 |
+
raw_output: Whether to return raw model output
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
If raw_output is False: Tuple of (validated_answer, duration)
|
153 |
+
If raw_output is True: Dictionary with full inference details
|
154 |
+
"""
|
155 |
+
system_prompt = """You are a medical imaging expert. Your answer MUST be a SINGLE LETTER (A/B/C/D/E/F), provided in this format: 'The SINGLE LETTER answer is: X'.
|
156 |
+
"""
|
157 |
+
|
158 |
+
prompt = f"""Given the following medical case:
|
159 |
+
Please answer this multiple choice question:
|
160 |
+
{question_data['question']}
|
161 |
+
Base your answer only on the provided images and case information. Respond with your SINGLE LETTER answer: """
|
162 |
+
|
163 |
+
try:
|
164 |
+
# Parse required figures
|
165 |
+
if isinstance(question_data["figures"], str):
|
166 |
+
try:
|
167 |
+
required_figures = json.loads(question_data["figures"])
|
168 |
+
except json.JSONDecodeError:
|
169 |
+
required_figures = [question_data["figures"]]
|
170 |
+
elif isinstance(question_data["figures"], list):
|
171 |
+
required_figures = question_data["figures"]
|
172 |
+
else:
|
173 |
+
required_figures = [str(question_data["figures"])]
|
174 |
+
except Exception as e:
|
175 |
+
print(f"Error parsing figures: {e}")
|
176 |
+
required_figures = []
|
177 |
+
|
178 |
+
required_figures = [
|
179 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
180 |
+
]
|
181 |
+
|
182 |
+
# Get image paths
|
183 |
+
image_paths = []
|
184 |
+
for figure in required_figures:
|
185 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
186 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
187 |
+
|
188 |
+
matching_figures = [
|
189 |
+
case_figure
|
190 |
+
for case_figure in case_details.get("figures", [])
|
191 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
192 |
+
]
|
193 |
+
|
194 |
+
for case_figure in matching_figures:
|
195 |
+
subfigures = []
|
196 |
+
if figure_letter:
|
197 |
+
subfigures = [
|
198 |
+
subfig
|
199 |
+
for subfig in case_figure.get("subfigures", [])
|
200 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
201 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
202 |
+
]
|
203 |
+
else:
|
204 |
+
subfigures = case_figure.get("subfigures", [])
|
205 |
+
|
206 |
+
for subfig in subfigures:
|
207 |
+
if "local_path" in subfig:
|
208 |
+
image_paths.append("MedMAX/data/" + subfig["local_path"])
|
209 |
+
|
210 |
+
if not image_paths:
|
211 |
+
print(f"No local images found for case {case_id}, question {question_id}")
|
212 |
+
return "skipped", 0.0 # Return a special 'skipped' marker
|
213 |
+
|
214 |
+
try:
|
215 |
+
start_time = time.time()
|
216 |
+
|
217 |
+
# Process each image
|
218 |
+
processed_images = [process_image(path) for path in image_paths]
|
219 |
+
|
220 |
+
# Create conversation
|
221 |
+
conv = conv_templates["mistral_instruct"].copy()
|
222 |
+
|
223 |
+
# Add image and message
|
224 |
+
if "<image>" not in prompt:
|
225 |
+
text = prompt + "\n<image>"
|
226 |
+
else:
|
227 |
+
text = prompt
|
228 |
+
|
229 |
+
message = (text, processed_images[0], "Default") # Currently handling first image
|
230 |
+
conv.append_message(conv.roles[0], message)
|
231 |
+
conv.append_message(conv.roles[1], None)
|
232 |
+
|
233 |
+
prompt = conv.get_prompt()
|
234 |
+
headers = {"User-Agent": "LLaVA-Med Client"}
|
235 |
+
pload = {
|
236 |
+
"model": model_name,
|
237 |
+
"prompt": prompt,
|
238 |
+
"max_new_tokens": 150, # Reduce this since we only need one letter
|
239 |
+
"temperature": 0.5, # Lower temperature for more focused responses
|
240 |
+
"stop": conv.sep2,
|
241 |
+
"images": conv.get_images(),
|
242 |
+
"top_p": 1, # Lower top_p for more focused sampling
|
243 |
+
"frequency_penalty": 0.0,
|
244 |
+
"presence_penalty": 0.0,
|
245 |
+
}
|
246 |
+
|
247 |
+
max_retries = 3
|
248 |
+
retry_delay = 5
|
249 |
+
response_text = None
|
250 |
+
|
251 |
+
for attempt in range(max_retries):
|
252 |
+
try:
|
253 |
+
response = requests.post(
|
254 |
+
worker_addr + "/worker_generate_stream",
|
255 |
+
headers=headers,
|
256 |
+
json=pload,
|
257 |
+
stream=True,
|
258 |
+
timeout=30,
|
259 |
+
)
|
260 |
+
|
261 |
+
complete_output = ""
|
262 |
+
for chunk in response.iter_lines(
|
263 |
+
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
264 |
+
):
|
265 |
+
if chunk:
|
266 |
+
data = json.loads(chunk.decode("utf-8"))
|
267 |
+
if data["error_code"] == 0:
|
268 |
+
output = data["text"].split("[/INST]")[-1]
|
269 |
+
complete_output = output
|
270 |
+
else:
|
271 |
+
print(f"\nError: {data['text']} (error_code: {data['error_code']})")
|
272 |
+
if attempt < max_retries - 1:
|
273 |
+
time.sleep(retry_delay)
|
274 |
+
break
|
275 |
+
return None, None
|
276 |
+
|
277 |
+
if complete_output:
|
278 |
+
response_text = complete_output
|
279 |
+
break
|
280 |
+
|
281 |
+
except (requests.exceptions.RequestException, json.JSONDecodeError) as e:
|
282 |
+
if attempt < max_retries - 1:
|
283 |
+
print(f"\nNetwork error: {str(e)}. Retrying in {retry_delay} seconds...")
|
284 |
+
time.sleep(retry_delay)
|
285 |
+
else:
|
286 |
+
print(f"\nFailed after {max_retries} attempts: {str(e)}")
|
287 |
+
return None, None
|
288 |
+
|
289 |
+
duration = time.time() - start_time
|
290 |
+
|
291 |
+
if raw_output:
|
292 |
+
inference_details = {
|
293 |
+
"raw_output": response_text,
|
294 |
+
"validated_answer": validate_answer(response_text),
|
295 |
+
"duration": duration,
|
296 |
+
"prompt": prompt,
|
297 |
+
"system_prompt": system_prompt,
|
298 |
+
"image_paths": image_paths,
|
299 |
+
"payload": pload,
|
300 |
+
}
|
301 |
+
return inference_details
|
302 |
+
|
303 |
+
return validate_answer(response_text), duration
|
304 |
+
|
305 |
+
except Exception as e:
|
306 |
+
print(f"Error in inference request: {str(e)}")
|
307 |
+
return None, None
|
308 |
+
|
309 |
+
|
310 |
+
def clean_payload(payload: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
311 |
+
"""Remove image-related and large data from the payload to keep the log lean.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
payload: Original request payload dictionary
|
315 |
+
|
316 |
+
Returns:
|
317 |
+
Cleaned payload dictionary with large data removed
|
318 |
+
"""
|
319 |
+
if not payload:
|
320 |
+
return None
|
321 |
+
|
322 |
+
# Create a copy of the payload to avoid modifying the original
|
323 |
+
cleaned_payload = payload.copy()
|
324 |
+
|
325 |
+
# Remove large or sensitive data
|
326 |
+
if "images" in cleaned_payload:
|
327 |
+
del cleaned_payload["images"]
|
328 |
+
|
329 |
+
return cleaned_payload
|
330 |
+
|
331 |
+
|
332 |
+
def main():
|
333 |
+
parser = argparse.ArgumentParser()
|
334 |
+
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
|
335 |
+
parser.add_argument("--worker-address", type=str)
|
336 |
+
parser.add_argument("--model-name", type=str, default="llava-med-v1.5-mistral-7b")
|
337 |
+
parser.add_argument("--output-dir", type=str, default="benchmark_results")
|
338 |
+
parser.add_argument(
|
339 |
+
"--raw-output", action="store_true", help="Return raw model output without validation"
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--num-cases",
|
343 |
+
type=int,
|
344 |
+
help="Number of cases to process if looking at raw outputs",
|
345 |
+
default=2,
|
346 |
+
)
|
347 |
+
args = parser.parse_args()
|
348 |
+
|
349 |
+
# Setup output directory
|
350 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
351 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
352 |
+
|
353 |
+
# Setup live logging files
|
354 |
+
live_log_filename = os.path.join(args.output_dir, f"live_benchmark_log_{timestamp}.json")
|
355 |
+
final_results_filename = os.path.join(args.output_dir, f"final_results_{timestamp}.json")
|
356 |
+
|
357 |
+
# Initialize live log file
|
358 |
+
with open(live_log_filename, "w") as live_log_file:
|
359 |
+
live_log_file.write("[\n") # Start of JSON array
|
360 |
+
|
361 |
+
# Setup logging
|
362 |
+
logging.basicConfig(
|
363 |
+
filename=os.path.join(args.output_dir, f"benchmark_{timestamp}.log"),
|
364 |
+
level=logging.INFO,
|
365 |
+
format="%(message)s",
|
366 |
+
)
|
367 |
+
|
368 |
+
# Get worker address
|
369 |
+
if args.worker_address:
|
370 |
+
worker_addr = args.worker_address
|
371 |
+
else:
|
372 |
+
try:
|
373 |
+
requests.post(args.controller_address + "/refresh_all_workers")
|
374 |
+
ret = requests.post(args.controller_address + "/list_models")
|
375 |
+
models = ret.json()["models"]
|
376 |
+
ret = requests.post(
|
377 |
+
args.controller_address + "/get_worker_address", json={"model": args.model_name}
|
378 |
+
)
|
379 |
+
worker_addr = ret.json()["address"]
|
380 |
+
print(f"Worker address: {worker_addr}")
|
381 |
+
except requests.exceptions.RequestException as e:
|
382 |
+
print(f"Failed to connect to controller: {e}")
|
383 |
+
return
|
384 |
+
|
385 |
+
if worker_addr == "":
|
386 |
+
print("No available worker")
|
387 |
+
return
|
388 |
+
|
389 |
+
# Load cases with local paths
|
390 |
+
with open("MedMAX/data/updated_cases.json", "r") as file:
|
391 |
+
data = json.load(file)
|
392 |
+
|
393 |
+
total_cases, total_questions = count_total_questions()
|
394 |
+
print(f"\nStarting benchmark with {args.model_name}")
|
395 |
+
print(f"Found {total_cases} cases with {total_questions} total questions")
|
396 |
+
|
397 |
+
results = {
|
398 |
+
"model": args.model_name,
|
399 |
+
"timestamp": datetime.now().isoformat(),
|
400 |
+
"total_cases": total_cases,
|
401 |
+
"total_questions": total_questions,
|
402 |
+
"results": [],
|
403 |
+
}
|
404 |
+
|
405 |
+
cases_processed = 0
|
406 |
+
questions_processed = 0
|
407 |
+
correct_answers = 0
|
408 |
+
skipped_questions = 0
|
409 |
+
total_processed_entries = 0
|
410 |
+
|
411 |
+
# Process each case
|
412 |
+
for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
|
413 |
+
question_files = load_benchmark_questions(case_id)
|
414 |
+
if not question_files:
|
415 |
+
continue
|
416 |
+
|
417 |
+
cases_processed += 1
|
418 |
+
for question_file in tqdm(
|
419 |
+
question_files, desc=f"Processing questions for case {case_id}", leave=False
|
420 |
+
):
|
421 |
+
with open(question_file, "r") as file:
|
422 |
+
question_data = json.load(file)
|
423 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
424 |
+
|
425 |
+
questions_processed += 1
|
426 |
+
|
427 |
+
# Get model's answer
|
428 |
+
inference_result = create_inference_request(
|
429 |
+
question_data,
|
430 |
+
case_details,
|
431 |
+
case_id,
|
432 |
+
question_id,
|
433 |
+
worker_addr,
|
434 |
+
args.model_name,
|
435 |
+
raw_output=True, # Always use raw output for detailed logging
|
436 |
+
)
|
437 |
+
|
438 |
+
# Handle skipped questions
|
439 |
+
if inference_result == ("skipped", 0.0):
|
440 |
+
skipped_questions += 1
|
441 |
+
print(f"\nCase {case_id}, Question {question_id}: Skipped (No images)")
|
442 |
+
|
443 |
+
# Log skipped question
|
444 |
+
skipped_entry = {
|
445 |
+
"case_id": case_id,
|
446 |
+
"question_id": question_id,
|
447 |
+
"status": "skipped",
|
448 |
+
"reason": "No images found",
|
449 |
+
}
|
450 |
+
with open(live_log_filename, "a") as live_log_file:
|
451 |
+
json.dump(skipped_entry, live_log_file, indent=2)
|
452 |
+
live_log_file.write(",\n") # Add comma for next entry
|
453 |
+
|
454 |
+
continue
|
455 |
+
|
456 |
+
# Extract information
|
457 |
+
answer = inference_result["validated_answer"]
|
458 |
+
duration = inference_result["duration"]
|
459 |
+
|
460 |
+
# Prepare detailed logging entry
|
461 |
+
log_entry = {
|
462 |
+
"case_id": case_id,
|
463 |
+
"question_id": question_id,
|
464 |
+
"question": question_data["question"],
|
465 |
+
"correct_answer": question_data["answer"],
|
466 |
+
"raw_output": inference_result["raw_output"],
|
467 |
+
"validated_answer": answer,
|
468 |
+
"model_answer": answer,
|
469 |
+
"is_correct": answer == question_data["answer"] if answer else False,
|
470 |
+
"duration": duration,
|
471 |
+
"system_prompt": inference_result["system_prompt"],
|
472 |
+
"input_prompt": inference_result["prompt"],
|
473 |
+
"image_paths": inference_result["image_paths"],
|
474 |
+
"payload": clean_payload(inference_result["payload"]),
|
475 |
+
}
|
476 |
+
|
477 |
+
# Write to live log file
|
478 |
+
with open(live_log_filename, "a") as live_log_file:
|
479 |
+
json.dump(log_entry, live_log_file, indent=2)
|
480 |
+
live_log_file.write(",\n") # Add comma for next entry
|
481 |
+
|
482 |
+
# Print to console
|
483 |
+
print(f"\nCase {case_id}, Question {question_id}")
|
484 |
+
print(f"Model Answer: {answer}")
|
485 |
+
print(f"Correct Answer: {question_data['answer']}")
|
486 |
+
print(f"Time taken: {duration:.2f}s")
|
487 |
+
|
488 |
+
# Track correct answers
|
489 |
+
if answer == question_data["answer"]:
|
490 |
+
correct_answers += 1
|
491 |
+
|
492 |
+
# Append to results
|
493 |
+
results["results"].append(log_entry)
|
494 |
+
total_processed_entries += 1
|
495 |
+
|
496 |
+
# Optional: break if reached specified number of cases
|
497 |
+
if args.raw_output and cases_processed == args.num_cases:
|
498 |
+
break
|
499 |
+
|
500 |
+
# Optional: break if reached specified number of cases
|
501 |
+
if args.raw_output and cases_processed == args.num_cases:
|
502 |
+
break
|
503 |
+
|
504 |
+
# Close live log file
|
505 |
+
with open(live_log_filename, "a") as live_log_file:
|
506 |
+
# Remove trailing comma and close JSON array
|
507 |
+
live_log_file.seek(live_log_file.tell() - 2, 0) # Go back 2 chars to remove ',\n'
|
508 |
+
live_log_file.write("\n]")
|
509 |
+
|
510 |
+
# Calculate final statistics
|
511 |
+
results["summary"] = {
|
512 |
+
"cases_processed": cases_processed,
|
513 |
+
"questions_processed": questions_processed,
|
514 |
+
"total_processed_entries": total_processed_entries,
|
515 |
+
"correct_answers": correct_answers,
|
516 |
+
"skipped_questions": skipped_questions,
|
517 |
+
"accuracy": (
|
518 |
+
correct_answers / (questions_processed - skipped_questions)
|
519 |
+
if (questions_processed - skipped_questions) > 0
|
520 |
+
else 0
|
521 |
+
),
|
522 |
+
}
|
523 |
+
|
524 |
+
# Save final results
|
525 |
+
with open(final_results_filename, "w") as f:
|
526 |
+
json.dump(results, f, indent=2)
|
527 |
+
|
528 |
+
print(f"\nBenchmark Summary:")
|
529 |
+
print(f"Total Cases Processed: {cases_processed}")
|
530 |
+
print(f"Total Questions Processed: {questions_processed}")
|
531 |
+
print(f"Total Processed Entries: {total_processed_entries}")
|
532 |
+
print(f"Correct Answers: {correct_answers}")
|
533 |
+
print(f"Skipped Questions: {skipped_questions}")
|
534 |
+
print(f"Accuracy: {(correct_answers / (questions_processed - skipped_questions) * 100):.2f}%")
|
535 |
+
print(f"\nResults saved to {args.output_dir}")
|
536 |
+
print(f"Live log: {live_log_filename}")
|
537 |
+
print(f"Final results: {final_results_filename}")
|
538 |
+
|
539 |
+
|
540 |
+
if __name__ == "__main__":
|
541 |
+
main()
|
experiments/benchmark_medrax.ipynb
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import operator\n",
|
10 |
+
"import warnings\n",
|
11 |
+
"from typing import *\n",
|
12 |
+
"import traceback\n",
|
13 |
+
"\n",
|
14 |
+
"import os\n",
|
15 |
+
"import torch\n",
|
16 |
+
"from dotenv import load_dotenv\n",
|
17 |
+
"from IPython.display import Image\n",
|
18 |
+
"from langgraph.checkpoint.memory import MemorySaver\n",
|
19 |
+
"from langgraph.graph import END, StateGraph\n",
|
20 |
+
"from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage\n",
|
21 |
+
"from langchain_openai import ChatOpenAI\n",
|
22 |
+
"from transformers import logging\n",
|
23 |
+
"import matplotlib.pyplot as plt\n",
|
24 |
+
"import numpy as np\n",
|
25 |
+
"import re\n",
|
26 |
+
"\n",
|
27 |
+
"from medrax.agent import *\n",
|
28 |
+
"from medrax.tools import *\n",
|
29 |
+
"from medrax.utils import *\n",
|
30 |
+
"\n",
|
31 |
+
"import json\n",
|
32 |
+
"import openai\n",
|
33 |
+
"import os\n",
|
34 |
+
"import glob\n",
|
35 |
+
"import time\n",
|
36 |
+
"import logging\n",
|
37 |
+
"from datetime import datetime\n",
|
38 |
+
"from tenacity import retry, wait_exponential, stop_after_attempt\n",
|
39 |
+
"\n",
|
40 |
+
"warnings.filterwarnings(\"ignore\")\n",
|
41 |
+
"_ = load_dotenv()\n",
|
42 |
+
"\n",
|
43 |
+
"\n",
|
44 |
+
"# Setup directory paths\n",
|
45 |
+
"ROOT = \"set this directory to where MedRAX is, .e.g /home/MedRAX\"\n",
|
46 |
+
"PROMPT_FILE = f\"{ROOT}/medrax/docs/system_prompts.txt\"\n",
|
47 |
+
"BENCHMARK_FILE = f\"{ROOT}/benchmark/questions\"\n",
|
48 |
+
"MODEL_DIR = f\"set this to where the tool models are, e.g /home/models\"\n",
|
49 |
+
"FIGURES_DIR = f\"{ROOT}/benchmark/figures\"\n",
|
50 |
+
"\n",
|
51 |
+
"model_name = \"medrax\"\n",
|
52 |
+
"temperature = 0.2\n",
|
53 |
+
"medrax_logs = f\"{ROOT}/experiments/medrax_logs\"\n",
|
54 |
+
"log_filename = f\"{medrax_logs}/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json\"\n",
|
55 |
+
"logging.basicConfig(filename=log_filename, level=logging.INFO, format=\"%(message)s\", force=True)\n",
|
56 |
+
"device = \"cuda\""
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 2,
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"def get_tools():\n",
|
66 |
+
" report_tool = ChestXRayReportGeneratorTool(cache_dir=MODEL_DIR, device=device)\n",
|
67 |
+
" xray_classification_tool = ChestXRayClassifierTool(device=device)\n",
|
68 |
+
" segmentation_tool = ChestXRaySegmentationTool(device=device)\n",
|
69 |
+
" grounding_tool = XRayPhraseGroundingTool(\n",
|
70 |
+
" cache_dir=MODEL_DIR, temp_dir=\"temp\", device=device, load_in_8bit=True\n",
|
71 |
+
" )\n",
|
72 |
+
" xray_vqa_tool = XRayVQATool(cache_dir=MODEL_DIR, device=device)\n",
|
73 |
+
" llava_med_tool = LlavaMedTool(cache_dir=MODEL_DIR, device=device, load_in_8bit=True)\n",
|
74 |
+
"\n",
|
75 |
+
" return [\n",
|
76 |
+
" report_tool,\n",
|
77 |
+
" xray_classification_tool,\n",
|
78 |
+
" segmentation_tool,\n",
|
79 |
+
" grounding_tool,\n",
|
80 |
+
" xray_vqa_tool,\n",
|
81 |
+
" llava_med_tool,\n",
|
82 |
+
" ]\n",
|
83 |
+
"\n",
|
84 |
+
"\n",
|
85 |
+
"def get_agent(tools):\n",
|
86 |
+
" prompts = load_prompts_from_file(PROMPT_FILE)\n",
|
87 |
+
" prompt = prompts[\"MEDICAL_ASSISTANT\"]\n",
|
88 |
+
"\n",
|
89 |
+
" checkpointer = MemorySaver()\n",
|
90 |
+
" model = ChatOpenAI(model=\"gpt-4o\", temperature=temperature, top_p=0.95)\n",
|
91 |
+
" agent = Agent(\n",
|
92 |
+
" model,\n",
|
93 |
+
" tools=tools,\n",
|
94 |
+
" log_tools=True,\n",
|
95 |
+
" log_dir=\"logs\",\n",
|
96 |
+
" system_prompt=prompt,\n",
|
97 |
+
" checkpointer=checkpointer,\n",
|
98 |
+
" )\n",
|
99 |
+
" thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
|
100 |
+
" return agent, thread\n",
|
101 |
+
"\n",
|
102 |
+
"\n",
|
103 |
+
"def run_medrax(agent, thread, prompt, image_urls=[]):\n",
|
104 |
+
" messages = [\n",
|
105 |
+
" HumanMessage(\n",
|
106 |
+
" content=[\n",
|
107 |
+
" {\"type\": \"text\", \"text\": prompt},\n",
|
108 |
+
" ]\n",
|
109 |
+
" + [{\"type\": \"image_url\", \"image_url\": {\"url\": image_url}} for image_url in image_urls]\n",
|
110 |
+
" )\n",
|
111 |
+
" ]\n",
|
112 |
+
"\n",
|
113 |
+
" final_response = None\n",
|
114 |
+
" for event in agent.workflow.stream({\"messages\": messages}, thread):\n",
|
115 |
+
" for v in event.values():\n",
|
116 |
+
" final_response = v\n",
|
117 |
+
"\n",
|
118 |
+
" final_response = final_response[\"messages\"][-1].content.strip()\n",
|
119 |
+
" agent_state = agent.workflow.get_state(thread)\n",
|
120 |
+
"\n",
|
121 |
+
" return final_response, str(agent_state)"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": 3,
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [],
|
129 |
+
"source": [
|
130 |
+
"def create_multimodal_request(question_data, case_details, case_id, question_id, agent, thread):\n",
|
131 |
+
" # Parse required figures\n",
|
132 |
+
" try:\n",
|
133 |
+
" # Try multiple ways of parsing figures\n",
|
134 |
+
" if isinstance(question_data[\"figures\"], str):\n",
|
135 |
+
" try:\n",
|
136 |
+
" required_figures = json.loads(question_data[\"figures\"])\n",
|
137 |
+
" except json.JSONDecodeError:\n",
|
138 |
+
" required_figures = [question_data[\"figures\"]]\n",
|
139 |
+
" elif isinstance(question_data[\"figures\"], list):\n",
|
140 |
+
" required_figures = question_data[\"figures\"]\n",
|
141 |
+
" else:\n",
|
142 |
+
" required_figures = [str(question_data[\"figures\"])]\n",
|
143 |
+
" except Exception as e:\n",
|
144 |
+
" print(f\"Error parsing figures: {e}\")\n",
|
145 |
+
" required_figures = []\n",
|
146 |
+
"\n",
|
147 |
+
" # Ensure each figure starts with \"Figure \"\n",
|
148 |
+
" required_figures = [\n",
|
149 |
+
" fig if fig.startswith(\"Figure \") else f\"Figure {fig}\" for fig in required_figures\n",
|
150 |
+
" ]\n",
|
151 |
+
"\n",
|
152 |
+
" subfigures = []\n",
|
153 |
+
" for figure in required_figures:\n",
|
154 |
+
" # Handle both regular figures and those with letter suffixes\n",
|
155 |
+
" base_figure_num = \"\".join(filter(str.isdigit, figure))\n",
|
156 |
+
" figure_letter = \"\".join(filter(str.isalpha, figure.split()[-1])) or None\n",
|
157 |
+
"\n",
|
158 |
+
" # Find matching figures in case details\n",
|
159 |
+
" matching_figures = [\n",
|
160 |
+
" case_figure\n",
|
161 |
+
" for case_figure in case_details.get(\"figures\", [])\n",
|
162 |
+
" if case_figure[\"number\"] == f\"Figure {base_figure_num}\"\n",
|
163 |
+
" ]\n",
|
164 |
+
"\n",
|
165 |
+
" if not matching_figures:\n",
|
166 |
+
" print(f\"No matching figure found for {figure} in case {case_id}\")\n",
|
167 |
+
" continue\n",
|
168 |
+
"\n",
|
169 |
+
" for case_figure in matching_figures:\n",
|
170 |
+
" # If a specific letter is specified, filter subfigures\n",
|
171 |
+
" if figure_letter:\n",
|
172 |
+
" matching_subfigures = [\n",
|
173 |
+
" subfig\n",
|
174 |
+
" for subfig in case_figure.get(\"subfigures\", [])\n",
|
175 |
+
" if subfig.get(\"number\", \"\").lower().endswith(figure_letter.lower())\n",
|
176 |
+
" or subfig.get(\"label\", \"\").lower() == figure_letter.lower()\n",
|
177 |
+
" ]\n",
|
178 |
+
" subfigures.extend(matching_subfigures)\n",
|
179 |
+
" else:\n",
|
180 |
+
" # If no letter specified, add all subfigures\n",
|
181 |
+
" subfigures.extend(case_figure.get(\"subfigures\", []))\n",
|
182 |
+
"\n",
|
183 |
+
" # Add images to content\n",
|
184 |
+
" figure_prompt = \"\"\n",
|
185 |
+
" image_urls = []\n",
|
186 |
+
"\n",
|
187 |
+
" for subfig in subfigures:\n",
|
188 |
+
" if \"number\" in subfig:\n",
|
189 |
+
" subfig_number = subfig[\"number\"].lower().strip().replace(\" \", \"_\") + \".jpg\"\n",
|
190 |
+
" subfig_path = os.path.join(FIGURES_DIR, case_id, subfig_number)\n",
|
191 |
+
" figure_prompt += f\"{subfig_number} located at {subfig_path}\\n\"\n",
|
192 |
+
" if \"url\" in subfig:\n",
|
193 |
+
" image_urls.append(subfig[\"url\"])\n",
|
194 |
+
" else:\n",
|
195 |
+
" print(f\"Subfigure missing URL: {subfig}\")\n",
|
196 |
+
"\n",
|
197 |
+
" prompt = (\n",
|
198 |
+
" f\"Answer this question correctly using chain of thought reasoning and \"\n",
|
199 |
+
" \"carefully evaluating choices. Solve using our own vision and reasoning and then\"\n",
|
200 |
+
" \"use tools to complement your reasoning. Trust your own judgement over any tools.\\n\"\n",
|
201 |
+
" f\"{question_data['question']}\\n{figure_prompt}\"\n",
|
202 |
+
" )\n",
|
203 |
+
"\n",
|
204 |
+
" try:\n",
|
205 |
+
" start_time = time.time()\n",
|
206 |
+
"\n",
|
207 |
+
" final_response, agent_state = run_medrax(\n",
|
208 |
+
" agent=agent, thread=thread, prompt=prompt, image_urls=image_urls\n",
|
209 |
+
" )\n",
|
210 |
+
" model_answer, agent_state = run_medrax(\n",
|
211 |
+
" agent=agent,\n",
|
212 |
+
" thread=thread,\n",
|
213 |
+
" prompt=\"If you had to choose the best option, only respond with the letter of choice (only one of A, B, C, D, E, F)\",\n",
|
214 |
+
" )\n",
|
215 |
+
" duration = time.time() - start_time\n",
|
216 |
+
"\n",
|
217 |
+
" log_entry = {\n",
|
218 |
+
" \"case_id\": case_id,\n",
|
219 |
+
" \"question_id\": question_id,\n",
|
220 |
+
" \"timestamp\": datetime.now().isoformat(),\n",
|
221 |
+
" \"model\": model_name,\n",
|
222 |
+
" \"temperature\": temperature,\n",
|
223 |
+
" \"duration\": round(duration, 2),\n",
|
224 |
+
" \"usage\": \"\",\n",
|
225 |
+
" \"cost\": 0,\n",
|
226 |
+
" \"raw_response\": final_response,\n",
|
227 |
+
" \"model_answer\": model_answer.strip(),\n",
|
228 |
+
" \"correct_answer\": question_data[\"answer\"][0],\n",
|
229 |
+
" \"input\": {\n",
|
230 |
+
" \"messages\": prompt,\n",
|
231 |
+
" \"question_data\": {\n",
|
232 |
+
" \"question\": question_data[\"question\"],\n",
|
233 |
+
" \"explanation\": question_data[\"explanation\"],\n",
|
234 |
+
" \"metadata\": question_data.get(\"metadata\", {}),\n",
|
235 |
+
" \"figures\": question_data[\"figures\"],\n",
|
236 |
+
" },\n",
|
237 |
+
" \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
|
238 |
+
" \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
|
239 |
+
" },\n",
|
240 |
+
" \"agent_state\": agent_state,\n",
|
241 |
+
" }\n",
|
242 |
+
" logging.info(json.dumps(log_entry))\n",
|
243 |
+
" return final_response, model_answer.strip()\n",
|
244 |
+
"\n",
|
245 |
+
" except Exception as e:\n",
|
246 |
+
" log_entry = {\n",
|
247 |
+
" \"case_id\": case_id,\n",
|
248 |
+
" \"question_id\": question_id,\n",
|
249 |
+
" \"timestamp\": datetime.now().isoformat(),\n",
|
250 |
+
" \"model\": model_name,\n",
|
251 |
+
" \"temperature\": temperature,\n",
|
252 |
+
" \"status\": \"error\",\n",
|
253 |
+
" \"error\": str(e),\n",
|
254 |
+
" \"cost\": 0,\n",
|
255 |
+
" \"input\": {\n",
|
256 |
+
" \"messages\": prompt,\n",
|
257 |
+
" \"question_data\": {\n",
|
258 |
+
" \"question\": question_data[\"question\"],\n",
|
259 |
+
" \"explanation\": question_data[\"explanation\"],\n",
|
260 |
+
" \"metadata\": question_data.get(\"metadata\", {}),\n",
|
261 |
+
" \"figures\": question_data[\"figures\"],\n",
|
262 |
+
" },\n",
|
263 |
+
" \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
|
264 |
+
" \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
|
265 |
+
" },\n",
|
266 |
+
" }\n",
|
267 |
+
" logging.info(json.dumps(log_entry))\n",
|
268 |
+
" print(f\"Error processing case {case_id}, question {question_id}: {str(e)}\")\n",
|
269 |
+
" return \"\", \"\"\n",
|
270 |
+
"\n",
|
271 |
+
"\n",
|
272 |
+
"def load_benchmark_questions(case_id):\n",
|
273 |
+
" benchmark_dir = \"../benchmark/questions\"\n",
|
274 |
+
" return glob.glob(f\"{benchmark_dir}/{case_id}/{case_id}_*.json\")\n",
|
275 |
+
"\n",
|
276 |
+
"\n",
|
277 |
+
"def count_total_questions():\n",
|
278 |
+
" total_cases = len(glob.glob(\"../benchmark/questions/*\"))\n",
|
279 |
+
" total_questions = sum(\n",
|
280 |
+
" len(glob.glob(f\"../benchmark/questions/{case_id}/*.json\"))\n",
|
281 |
+
" for case_id in os.listdir(\"../benchmark/questions\")\n",
|
282 |
+
" )\n",
|
283 |
+
" return total_cases, total_questions\n",
|
284 |
+
"\n",
|
285 |
+
"\n",
|
286 |
+
"def main(tools):\n",
|
287 |
+
" with open(\"../data/eurorad_metadata.json\", \"r\") as file:\n",
|
288 |
+
" data = json.load(file)\n",
|
289 |
+
"\n",
|
290 |
+
" total_cases, total_questions = count_total_questions()\n",
|
291 |
+
" cases_processed = 0\n",
|
292 |
+
" questions_processed = 0\n",
|
293 |
+
" skipped_questions = 0\n",
|
294 |
+
"\n",
|
295 |
+
" print(f\"Beginning benchmark evaluation for model {model_name} with temperature {temperature}\\n\")\n",
|
296 |
+
"\n",
|
297 |
+
" for case_id, case_details in data.items():\n",
|
298 |
+
" if int(case_details[\"case_id\"]) <= 17158:\n",
|
299 |
+
" continue\n",
|
300 |
+
"\n",
|
301 |
+
" print(f\"----------------------------------------------------------------\")\n",
|
302 |
+
" agent, thread = get_agent(tools)\n",
|
303 |
+
"\n",
|
304 |
+
" question_files = load_benchmark_questions(case_id)\n",
|
305 |
+
" if not question_files:\n",
|
306 |
+
" continue\n",
|
307 |
+
"\n",
|
308 |
+
" cases_processed += 1\n",
|
309 |
+
" for question_file in question_files:\n",
|
310 |
+
" with open(question_file, \"r\") as file:\n",
|
311 |
+
" question_data = json.load(file)\n",
|
312 |
+
" question_id = os.path.basename(question_file).split(\".\")[0]\n",
|
313 |
+
"\n",
|
314 |
+
" # agent, thread = get_agent(tools)\n",
|
315 |
+
" questions_processed += 1\n",
|
316 |
+
" final_response, model_answer = create_multimodal_request(\n",
|
317 |
+
" question_data, case_details, case_id, question_id, agent, thread\n",
|
318 |
+
" )\n",
|
319 |
+
"\n",
|
320 |
+
" # Handle cases where response is None\n",
|
321 |
+
" if final_response is None:\n",
|
322 |
+
" skipped_questions += 1\n",
|
323 |
+
" print(f\"Skipped question: Case ID {case_id}, Question ID {question_id}\")\n",
|
324 |
+
" continue\n",
|
325 |
+
"\n",
|
326 |
+
" print(\n",
|
327 |
+
" f\"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}\"\n",
|
328 |
+
" )\n",
|
329 |
+
" print(f\"Case ID: {case_id}\")\n",
|
330 |
+
" print(f\"Question ID: {question_id}\")\n",
|
331 |
+
" print(f\"Final Response: {final_response}\")\n",
|
332 |
+
" print(f\"Model Answer: {model_answer}\")\n",
|
333 |
+
" print(f\"Correct Answer: {question_data['answer']}\")\n",
|
334 |
+
" print(f\"----------------------------------------------------------------\\n\")\n",
|
335 |
+
"\n",
|
336 |
+
" print(f\"\\nBenchmark Summary:\")\n",
|
337 |
+
" print(f\"Total Cases Processed: {cases_processed}\")\n",
|
338 |
+
" print(f\"Total Questions Processed: {questions_processed}\")\n",
|
339 |
+
" print(f\"Total Questions Skipped: {skipped_questions}\")"
|
340 |
+
]
|
341 |
+
},
|
342 |
+
{
|
343 |
+
"cell_type": "code",
|
344 |
+
"execution_count": null,
|
345 |
+
"metadata": {},
|
346 |
+
"outputs": [],
|
347 |
+
"source": [
|
348 |
+
"tools = get_tools()\n",
|
349 |
+
"main(tools)"
|
350 |
+
]
|
351 |
+
}
|
352 |
+
],
|
353 |
+
"metadata": {
|
354 |
+
"kernelspec": {
|
355 |
+
"display_name": "medmax",
|
356 |
+
"language": "python",
|
357 |
+
"name": "python3"
|
358 |
+
},
|
359 |
+
"language_info": {
|
360 |
+
"codemirror_mode": {
|
361 |
+
"name": "ipython",
|
362 |
+
"version": 3
|
363 |
+
},
|
364 |
+
"file_extension": ".py",
|
365 |
+
"mimetype": "text/x-python",
|
366 |
+
"name": "python",
|
367 |
+
"nbconvert_exporter": "python",
|
368 |
+
"pygments_lexer": "ipython3",
|
369 |
+
"version": "3.10.16"
|
370 |
+
}
|
371 |
+
},
|
372 |
+
"nbformat": 4,
|
373 |
+
"nbformat_minor": 2
|
374 |
+
}
|
experiments/chexbench_gpt4.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import openai
|
3 |
+
import os
|
4 |
+
from datetime import datetime
|
5 |
+
import base64
|
6 |
+
import logging
|
7 |
+
from pathlib import Path
|
8 |
+
import time
|
9 |
+
from tqdm import tqdm
|
10 |
+
from typing import Dict, List, Optional, Union, Any
|
11 |
+
|
12 |
+
# Configuration constants
|
13 |
+
DEBUG_MODE = False
|
14 |
+
OUTPUT_DIR = "results"
|
15 |
+
MODEL_NAME = "gpt-4o-2024-05-13"
|
16 |
+
TEMPERATURE = 0.2
|
17 |
+
SUBSET = "Visual Question Answering"
|
18 |
+
|
19 |
+
# Set up logging configuration
|
20 |
+
logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO
|
21 |
+
logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s")
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def get_mime_type(file_path: str) -> str:
|
26 |
+
"""
|
27 |
+
Determine MIME type based on file extension.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
file_path (str): Path to the file
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: MIME type string for the file
|
34 |
+
"""
|
35 |
+
extension = os.path.splitext(file_path)[1].lower()
|
36 |
+
mime_types = {
|
37 |
+
".png": "image/png",
|
38 |
+
".jpg": "image/jpeg",
|
39 |
+
".jpeg": "image/jpeg",
|
40 |
+
".gif": "image/gif",
|
41 |
+
}
|
42 |
+
return mime_types.get(extension, "application/octet-stream")
|
43 |
+
|
44 |
+
|
45 |
+
def encode_image(image_path: str) -> str:
|
46 |
+
"""
|
47 |
+
Encode image to base64 with extensive error checking.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
image_path (str): Path to the image file
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
str: Base64 encoded image string
|
54 |
+
|
55 |
+
Raises:
|
56 |
+
FileNotFoundError: If image file does not exist
|
57 |
+
ValueError: If image file is empty or too large
|
58 |
+
Exception: For other image processing errors
|
59 |
+
"""
|
60 |
+
logger.debug(f"Attempting to read image from: {image_path}")
|
61 |
+
if not os.path.exists(image_path):
|
62 |
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
63 |
+
|
64 |
+
# Add check for file size
|
65 |
+
file_size = os.path.getsize(image_path)
|
66 |
+
if file_size > 20 * 1024 * 1024: # 20MB limit
|
67 |
+
raise ValueError("Image file size exceeds 20MB limit")
|
68 |
+
if file_size == 0:
|
69 |
+
raise ValueError("Image file is empty")
|
70 |
+
logger.debug(f"Image file size: {file_size / 1024:.2f} KB")
|
71 |
+
|
72 |
+
try:
|
73 |
+
from PIL import Image
|
74 |
+
|
75 |
+
# Try to open and verify the image
|
76 |
+
with Image.open(image_path) as img:
|
77 |
+
# Get image details
|
78 |
+
width, height = img.size
|
79 |
+
format = img.format
|
80 |
+
mode = img.mode
|
81 |
+
logger.debug(
|
82 |
+
f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}"
|
83 |
+
)
|
84 |
+
|
85 |
+
if format not in ["PNG", "JPEG", "GIF"]:
|
86 |
+
raise ValueError(f"Unsupported image format: {format}")
|
87 |
+
|
88 |
+
with open(image_path, "rb") as image_file:
|
89 |
+
# Read the first few bytes to verify it's a valid PNG
|
90 |
+
header = image_file.read(8)
|
91 |
+
# if header != b'\x89PNG\r\n\x1a\n':
|
92 |
+
# logger.warning("File does not have a valid PNG signature")
|
93 |
+
|
94 |
+
# Reset file pointer and read entire file
|
95 |
+
image_file.seek(0)
|
96 |
+
encoded = base64.b64encode(image_file.read()).decode("utf-8")
|
97 |
+
encoded_length = len(encoded)
|
98 |
+
logger.debug(f"Base64 encoded length: {encoded_length} characters")
|
99 |
+
|
100 |
+
# Verify the encoded string is not empty and starts correctly
|
101 |
+
if encoded_length == 0:
|
102 |
+
raise ValueError("Base64 encoding produced empty string")
|
103 |
+
if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"):
|
104 |
+
logger.warning("Base64 string doesn't start with expected JPEG or PNG header")
|
105 |
+
|
106 |
+
return encoded
|
107 |
+
except Exception as e:
|
108 |
+
logger.error(f"Error reading/encoding image: {str(e)}")
|
109 |
+
raise
|
110 |
+
|
111 |
+
|
112 |
+
def create_single_request(
|
113 |
+
image_path: str, question: str, options: Dict[str, str]
|
114 |
+
) -> List[Dict[str, Any]]:
|
115 |
+
"""
|
116 |
+
Create a single API request with image and question.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
image_path (str): Path to the image file
|
120 |
+
question (str): Question text
|
121 |
+
options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1'
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
List[Dict[str, Any]]: List of message dictionaries for the API request
|
125 |
+
|
126 |
+
Raises:
|
127 |
+
Exception: For errors in request creation
|
128 |
+
"""
|
129 |
+
if DEBUG_MODE:
|
130 |
+
logger.debug("Creating API request...")
|
131 |
+
|
132 |
+
prompt = f"""Given the following medical examination question:
|
133 |
+
Please answer this multiple choice question:
|
134 |
+
|
135 |
+
Question: {question}
|
136 |
+
|
137 |
+
Options:
|
138 |
+
A) {options['option_0']}
|
139 |
+
B) {options['option_1']}
|
140 |
+
|
141 |
+
Base your answer only on the provided image and select either A or B."""
|
142 |
+
|
143 |
+
try:
|
144 |
+
encoded_image = encode_image(image_path)
|
145 |
+
mime_type = get_mime_type(image_path)
|
146 |
+
|
147 |
+
if DEBUG_MODE:
|
148 |
+
logger.debug(f"Image encoded with MIME type: {mime_type}")
|
149 |
+
|
150 |
+
messages = [
|
151 |
+
{
|
152 |
+
"role": "system",
|
153 |
+
"content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.",
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"role": "user",
|
157 |
+
"content": [
|
158 |
+
{"type": "text", "text": prompt},
|
159 |
+
{
|
160 |
+
"type": "image_url",
|
161 |
+
"image_url": {"url": f"data:{mime_type};base64,{encoded_image}"},
|
162 |
+
},
|
163 |
+
],
|
164 |
+
},
|
165 |
+
]
|
166 |
+
|
167 |
+
if DEBUG_MODE:
|
168 |
+
log_messages = json.loads(json.dumps(messages))
|
169 |
+
log_messages[1]["content"][1]["image_url"][
|
170 |
+
"url"
|
171 |
+
] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]"
|
172 |
+
logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}")
|
173 |
+
|
174 |
+
return messages
|
175 |
+
|
176 |
+
except Exception as e:
|
177 |
+
logger.error(f"Error creating request: {str(e)}")
|
178 |
+
raise
|
179 |
+
|
180 |
+
|
181 |
+
def check_answer(model_answer: str, correct_answer: int) -> bool:
|
182 |
+
"""
|
183 |
+
Check if the model's answer matches the correct answer.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
model_answer (str): The model's answer (A or B)
|
187 |
+
correct_answer (int): The correct answer index (0 for A, 1 for B)
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
bool: True if answer is correct, False otherwise
|
191 |
+
"""
|
192 |
+
if not isinstance(model_answer, str):
|
193 |
+
return False
|
194 |
+
|
195 |
+
# Clean the model answer to get just the letter
|
196 |
+
model_letter = model_answer.strip().upper()
|
197 |
+
if model_letter.startswith("A"):
|
198 |
+
model_index = 0
|
199 |
+
elif model_letter.startswith("B"):
|
200 |
+
model_index = 1
|
201 |
+
else:
|
202 |
+
return False
|
203 |
+
|
204 |
+
return model_index == correct_answer
|
205 |
+
|
206 |
+
|
207 |
+
def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str:
|
208 |
+
"""
|
209 |
+
Save results to a JSON file with timestamp.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
results (List[Dict[str, Any]]): List of result dictionaries
|
213 |
+
output_dir (str): Directory to save results
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
str: Path to the saved file
|
217 |
+
"""
|
218 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
219 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
220 |
+
output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json")
|
221 |
+
|
222 |
+
with open(output_file, "w") as f:
|
223 |
+
json.dump(results, f, indent=2)
|
224 |
+
|
225 |
+
logger.info(f"Batch results saved to {output_file}")
|
226 |
+
return output_file
|
227 |
+
|
228 |
+
|
229 |
+
def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]:
|
230 |
+
"""
|
231 |
+
Calculate accuracy from results, handling error cases.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
results (List[Dict[str, Any]]): List of result dictionaries
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total)
|
238 |
+
"""
|
239 |
+
if not results:
|
240 |
+
return 0.0, 0, 0
|
241 |
+
|
242 |
+
total = len(results)
|
243 |
+
valid_results = [r for r in results if "output" in r]
|
244 |
+
correct = sum(
|
245 |
+
1 for result in valid_results if result.get("output", {}).get("is_correct", False)
|
246 |
+
)
|
247 |
+
|
248 |
+
accuracy = (correct / total * 100) if total > 0 else 0
|
249 |
+
return accuracy, correct, total
|
250 |
+
|
251 |
+
|
252 |
+
def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float:
|
253 |
+
"""
|
254 |
+
Calculate accuracy for the current batch.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
results (List[Dict[str, Any]]): List of result dictionaries
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
float: Accuracy percentage for the batch
|
261 |
+
"""
|
262 |
+
valid_results = [r for r in results if "output" in r]
|
263 |
+
if not valid_results:
|
264 |
+
return 0.0
|
265 |
+
return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100
|
266 |
+
|
267 |
+
|
268 |
+
def process_batch(
|
269 |
+
data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50
|
270 |
+
) -> List[Dict[str, Any]]:
|
271 |
+
"""
|
272 |
+
Process a batch of examples and return results.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
data (List[Dict[str, Any]]): List of data items to process
|
276 |
+
client (openai.OpenAI): OpenAI client instance
|
277 |
+
start_idx (int, optional): Starting index for batch. Defaults to 0
|
278 |
+
batch_size (int, optional): Size of batch to process. Defaults to 50
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
List[Dict[str, Any]]: List of processed results
|
282 |
+
"""
|
283 |
+
batch_results = []
|
284 |
+
end_idx = min(start_idx + batch_size, len(data))
|
285 |
+
|
286 |
+
pbar = tqdm(
|
287 |
+
range(start_idx, end_idx),
|
288 |
+
desc=f"Processing batch {start_idx//batch_size + 1}",
|
289 |
+
unit="example",
|
290 |
+
)
|
291 |
+
|
292 |
+
for index in pbar:
|
293 |
+
vqa_item = data[index]
|
294 |
+
options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]}
|
295 |
+
|
296 |
+
try:
|
297 |
+
messages = create_single_request(
|
298 |
+
image_path=vqa_item["image_path"], question=vqa_item["question"], options=options
|
299 |
+
)
|
300 |
+
|
301 |
+
response = client.chat.completions.create(
|
302 |
+
model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE
|
303 |
+
)
|
304 |
+
|
305 |
+
model_answer = response.choices[0].message.content.strip()
|
306 |
+
is_correct = check_answer(model_answer, vqa_item["answer"])
|
307 |
+
|
308 |
+
result = {
|
309 |
+
"timestamp": datetime.now().isoformat(),
|
310 |
+
"example_index": index,
|
311 |
+
"input": {
|
312 |
+
"question": vqa_item["question"],
|
313 |
+
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
|
314 |
+
"image_path": vqa_item["image_path"],
|
315 |
+
},
|
316 |
+
"output": {
|
317 |
+
"model_answer": model_answer,
|
318 |
+
"correct_answer": "A" if vqa_item["answer"] == 0 else "B",
|
319 |
+
"is_correct": is_correct,
|
320 |
+
"usage": {
|
321 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
322 |
+
"completion_tokens": response.usage.completion_tokens,
|
323 |
+
"total_tokens": response.usage.total_tokens,
|
324 |
+
},
|
325 |
+
},
|
326 |
+
}
|
327 |
+
batch_results.append(result)
|
328 |
+
|
329 |
+
# Update progress bar with current accuracy
|
330 |
+
current_accuracy = calculate_batch_accuracy(batch_results)
|
331 |
+
pbar.set_description(
|
332 |
+
f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% "
|
333 |
+
f"({len(batch_results)}/{index-start_idx+1} examples)"
|
334 |
+
)
|
335 |
+
|
336 |
+
except Exception as e:
|
337 |
+
error_result = {
|
338 |
+
"timestamp": datetime.now().isoformat(),
|
339 |
+
"example_index": index,
|
340 |
+
"error": str(e),
|
341 |
+
"input": {
|
342 |
+
"question": vqa_item["question"],
|
343 |
+
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
|
344 |
+
"image_path": vqa_item["image_path"],
|
345 |
+
},
|
346 |
+
}
|
347 |
+
batch_results.append(error_result)
|
348 |
+
if DEBUG_MODE:
|
349 |
+
pbar.write(f"Error processing example {index}: {str(e)}")
|
350 |
+
|
351 |
+
time.sleep(1) # Rate limiting
|
352 |
+
|
353 |
+
return batch_results
|
354 |
+
|
355 |
+
|
356 |
+
def main() -> None:
|
357 |
+
"""
|
358 |
+
Main function to process the entire dataset.
|
359 |
+
|
360 |
+
Raises:
|
361 |
+
ValueError: If OPENAI_API_KEY is not set
|
362 |
+
Exception: For other processing errors
|
363 |
+
"""
|
364 |
+
logger.info("Starting full dataset processing...")
|
365 |
+
json_path = "../data/chexbench_updated.json"
|
366 |
+
|
367 |
+
try:
|
368 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
369 |
+
if not api_key:
|
370 |
+
raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
371 |
+
client = openai.OpenAI(api_key=api_key)
|
372 |
+
|
373 |
+
with open(json_path, "r") as f:
|
374 |
+
data = json.load(f)
|
375 |
+
|
376 |
+
subset_data = data[SUBSET]
|
377 |
+
total_examples = len(subset_data)
|
378 |
+
logger.info(f"Found {total_examples} examples in {SUBSET} subset")
|
379 |
+
|
380 |
+
all_results = []
|
381 |
+
batch_size = 50 # Process in batches of 50 examples
|
382 |
+
|
383 |
+
# Process all examples in batches
|
384 |
+
for start_idx in range(0, total_examples, batch_size):
|
385 |
+
batch_results = process_batch(subset_data, client, start_idx, batch_size)
|
386 |
+
all_results.extend(batch_results)
|
387 |
+
|
388 |
+
# Save intermediate results after each batch
|
389 |
+
output_file = save_results_to_json(all_results, OUTPUT_DIR)
|
390 |
+
|
391 |
+
# Calculate and log overall progress
|
392 |
+
overall_accuracy, correct, total = calculate_accuracy(all_results)
|
393 |
+
logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed")
|
394 |
+
logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)")
|
395 |
+
|
396 |
+
logger.info("Processing completed!")
|
397 |
+
logger.info(f"Final results saved to: {output_file}")
|
398 |
+
|
399 |
+
except Exception as e:
|
400 |
+
logger.error(f"Fatal error: {str(e)}")
|
401 |
+
raise
|
402 |
+
|
403 |
+
|
404 |
+
if __name__ == "__main__":
|
405 |
+
main()
|
experiments/compare_runs.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
from typing import List, Dict, Any, Tuple
|
5 |
+
import re
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
# Define category order
|
9 |
+
CATEGORY_ORDER = [
|
10 |
+
"detection",
|
11 |
+
"classification",
|
12 |
+
"localization",
|
13 |
+
"comparison",
|
14 |
+
"relationship",
|
15 |
+
"diagnosis",
|
16 |
+
"characterization",
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def extract_letter_answer(answer: str) -> str:
|
21 |
+
"""Extract just the letter answer from various answer formats.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
answer: The answer string to extract a letter from
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
str: The extracted letter in uppercase, or empty string if no letter found
|
28 |
+
"""
|
29 |
+
if not answer:
|
30 |
+
return ""
|
31 |
+
|
32 |
+
# Convert to string and clean
|
33 |
+
answer = str(answer).strip()
|
34 |
+
|
35 |
+
# If it's just a single letter A-F, return it
|
36 |
+
if len(answer) == 1 and answer.upper() in "ABCDEF":
|
37 |
+
return answer.upper()
|
38 |
+
|
39 |
+
# Try to match patterns like "A)", "A.", "A ", etc.
|
40 |
+
match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE)
|
41 |
+
if match:
|
42 |
+
return match.group(1).upper()
|
43 |
+
|
44 |
+
# Try to find any standalone A-F letters preceded by space or start of string
|
45 |
+
# and followed by space, period, parenthesis or end of string
|
46 |
+
matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE)
|
47 |
+
if matches:
|
48 |
+
return matches[0].upper()
|
49 |
+
|
50 |
+
# Last resort: just find any A-F letter
|
51 |
+
letters = re.findall(r"[A-F]", answer, re.IGNORECASE)
|
52 |
+
if letters:
|
53 |
+
return letters[0].upper()
|
54 |
+
|
55 |
+
# If no letter found, return original (cleaned)
|
56 |
+
return answer.strip().upper()
|
57 |
+
|
58 |
+
|
59 |
+
def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]:
|
60 |
+
"""Parse JSON Lines file and extract valid predictions.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
file_path: Path to the JSON Lines file to parse
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Tuple containing:
|
67 |
+
- str: Model name or file path if model name not found
|
68 |
+
- List[Dict[str, Any]]: List of valid prediction entries
|
69 |
+
"""
|
70 |
+
valid_predictions = []
|
71 |
+
model_name = None
|
72 |
+
|
73 |
+
# First try to parse as LLaVA format
|
74 |
+
try:
|
75 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
76 |
+
data = json.load(f)
|
77 |
+
if data.get("model") == "llava-med-v1.5-mistral-7b":
|
78 |
+
model_name = data["model"]
|
79 |
+
for result in data.get("results", []):
|
80 |
+
if all(k in result for k in ["case_id", "question_id", "correct_answer"]):
|
81 |
+
# Extract answer with priority: model_answer > validated_answer > raw_output
|
82 |
+
model_answer = (
|
83 |
+
result.get("model_answer")
|
84 |
+
or result.get("validated_answer")
|
85 |
+
or result.get("raw_output", "")
|
86 |
+
)
|
87 |
+
|
88 |
+
# Add default categories for LLaVA results
|
89 |
+
prediction = {
|
90 |
+
"case_id": result["case_id"],
|
91 |
+
"question_id": result["question_id"],
|
92 |
+
"model_answer": model_answer,
|
93 |
+
"correct_answer": result["correct_answer"],
|
94 |
+
"input": {
|
95 |
+
"question_data": {
|
96 |
+
"metadata": {
|
97 |
+
"categories": [
|
98 |
+
"detection",
|
99 |
+
"classification",
|
100 |
+
"localization",
|
101 |
+
"comparison",
|
102 |
+
"relationship",
|
103 |
+
"diagnosis",
|
104 |
+
"characterization",
|
105 |
+
]
|
106 |
+
}
|
107 |
+
}
|
108 |
+
},
|
109 |
+
}
|
110 |
+
valid_predictions.append(prediction)
|
111 |
+
return model_name, valid_predictions
|
112 |
+
except (json.JSONDecodeError, KeyError):
|
113 |
+
pass
|
114 |
+
|
115 |
+
# If not LLaVA format, process as original format
|
116 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
117 |
+
for line in f:
|
118 |
+
if line.startswith("HTTP Request:"):
|
119 |
+
continue
|
120 |
+
try:
|
121 |
+
data = json.loads(line.strip())
|
122 |
+
if "model" in data:
|
123 |
+
model_name = data["model"]
|
124 |
+
if all(
|
125 |
+
k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"]
|
126 |
+
):
|
127 |
+
valid_predictions.append(data)
|
128 |
+
except json.JSONDecodeError:
|
129 |
+
continue
|
130 |
+
|
131 |
+
return model_name if model_name else file_path, valid_predictions
|
132 |
+
|
133 |
+
|
134 |
+
def filter_common_questions(
|
135 |
+
predictions_list: List[List[Dict[str, Any]]]
|
136 |
+
) -> List[List[Dict[str, Any]]]:
|
137 |
+
"""Ensure only questions that exist across all models are evaluated.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
predictions_list: List of prediction lists from different models
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
List[List[Dict[str, Any]]]: Filtered predictions containing only common questions
|
144 |
+
"""
|
145 |
+
question_sets = [
|
146 |
+
set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list
|
147 |
+
]
|
148 |
+
common_questions = set.intersection(*question_sets)
|
149 |
+
|
150 |
+
return [
|
151 |
+
[p for p in preds if (p["case_id"], p["question_id"]) in common_questions]
|
152 |
+
for preds in predictions_list
|
153 |
+
]
|
154 |
+
|
155 |
+
|
156 |
+
def calculate_accuracy(
|
157 |
+
predictions: List[Dict[str, Any]]
|
158 |
+
) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]:
|
159 |
+
"""Compute overall and category-level accuracy.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
predictions: List of prediction entries to analyze
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
Tuple containing:
|
166 |
+
- float: Overall accuracy percentage
|
167 |
+
- int: Number of correct predictions
|
168 |
+
- int: Total number of predictions
|
169 |
+
- Dict[str, Dict[str, float]]: Category-level accuracy statistics
|
170 |
+
"""
|
171 |
+
if not predictions:
|
172 |
+
return 0.0, 0, 0, {}
|
173 |
+
|
174 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
175 |
+
correct = 0
|
176 |
+
total = 0
|
177 |
+
sample_size = min(5, len(predictions))
|
178 |
+
sampled_indices = random.sample(range(len(predictions)), sample_size)
|
179 |
+
|
180 |
+
print("\nSample extracted answers:")
|
181 |
+
for i in sampled_indices:
|
182 |
+
pred = predictions[i]
|
183 |
+
model_ans = extract_letter_answer(pred["model_answer"])
|
184 |
+
correct_ans = extract_letter_answer(pred["correct_answer"])
|
185 |
+
print(f"QID: {pred['question_id']}")
|
186 |
+
print(f" Raw Model Answer: {pred['model_answer']}")
|
187 |
+
print(f" Extracted Model Answer: {model_ans}")
|
188 |
+
print(f" Raw Correct Answer: {pred['correct_answer']}")
|
189 |
+
print(f" Extracted Correct Answer: {correct_ans}")
|
190 |
+
print("-" * 80)
|
191 |
+
|
192 |
+
for pred in predictions:
|
193 |
+
try:
|
194 |
+
model_ans = extract_letter_answer(pred["model_answer"])
|
195 |
+
correct_ans = extract_letter_answer(pred["correct_answer"])
|
196 |
+
categories = (
|
197 |
+
pred.get("input", {})
|
198 |
+
.get("question_data", {})
|
199 |
+
.get("metadata", {})
|
200 |
+
.get("categories", [])
|
201 |
+
)
|
202 |
+
|
203 |
+
if model_ans and correct_ans:
|
204 |
+
total += 1
|
205 |
+
is_correct = model_ans == correct_ans
|
206 |
+
if is_correct:
|
207 |
+
correct += 1
|
208 |
+
|
209 |
+
for category in categories:
|
210 |
+
category_performance[category]["total"] += 1
|
211 |
+
if is_correct:
|
212 |
+
category_performance[category]["correct"] += 1
|
213 |
+
|
214 |
+
except KeyError:
|
215 |
+
continue
|
216 |
+
|
217 |
+
category_accuracies = {
|
218 |
+
category: {
|
219 |
+
"accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0,
|
220 |
+
"total": stats["total"],
|
221 |
+
"correct": stats["correct"],
|
222 |
+
}
|
223 |
+
for category, stats in category_performance.items()
|
224 |
+
}
|
225 |
+
|
226 |
+
return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies)
|
227 |
+
|
228 |
+
|
229 |
+
def compare_models(file_paths: List[str]) -> None:
|
230 |
+
"""Compare accuracy between multiple model prediction files.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
file_paths: List of paths to model prediction files to compare
|
234 |
+
"""
|
235 |
+
# Parse all files
|
236 |
+
parsed_results = [parse_json_lines(file_path) for file_path in file_paths]
|
237 |
+
model_names, predictions_list = zip(*parsed_results)
|
238 |
+
|
239 |
+
# Get initial stats
|
240 |
+
print(f"\n📊 **Initial Accuracy**:")
|
241 |
+
results = []
|
242 |
+
category_results = []
|
243 |
+
|
244 |
+
for preds, name in zip(predictions_list, model_names):
|
245 |
+
acc, correct, total, category_acc = calculate_accuracy(preds)
|
246 |
+
results.append((acc, correct, total, name))
|
247 |
+
category_results.append(category_acc)
|
248 |
+
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
|
249 |
+
|
250 |
+
# Get common questions across all models
|
251 |
+
filtered_predictions = filter_common_questions(predictions_list)
|
252 |
+
print(
|
253 |
+
f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}"
|
254 |
+
)
|
255 |
+
|
256 |
+
# Compute accuracy on common questions
|
257 |
+
print(f"\n📊 **Accuracy on Common Questions**:")
|
258 |
+
filtered_results = []
|
259 |
+
filtered_category_results = []
|
260 |
+
|
261 |
+
for preds, name in zip(filtered_predictions, model_names):
|
262 |
+
acc, correct, total, category_acc = calculate_accuracy(preds)
|
263 |
+
filtered_results.append((acc, correct, total, name))
|
264 |
+
filtered_category_results.append(category_acc)
|
265 |
+
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
|
266 |
+
|
267 |
+
# Print category-wise accuracy
|
268 |
+
print("\nCategory Performance (Common Questions):")
|
269 |
+
for category in CATEGORY_ORDER:
|
270 |
+
print(f"\n{category.capitalize()}:")
|
271 |
+
for model_name, category_acc in zip(model_names, filtered_category_results):
|
272 |
+
stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0})
|
273 |
+
print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})")
|
274 |
+
|
275 |
+
|
276 |
+
def main():
|
277 |
+
parser = argparse.ArgumentParser(
|
278 |
+
description="Compare accuracy across multiple model prediction files"
|
279 |
+
)
|
280 |
+
parser.add_argument("files", nargs="+", help="Paths to model prediction files")
|
281 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
|
282 |
+
|
283 |
+
args = parser.parse_args()
|
284 |
+
random.seed(args.seed)
|
285 |
+
|
286 |
+
compare_models(args.files)
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == "__main__":
|
290 |
+
main()
|
experiments/inspect_logs.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
from pathlib import Path
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
|
9 |
+
def get_latest_log() -> str:
|
10 |
+
"""Find the most recently modified log file in the current directory.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
str: Path to the most recently modified log file
|
14 |
+
|
15 |
+
Raises:
|
16 |
+
FileNotFoundError: If no log files are found in the current directory
|
17 |
+
"""
|
18 |
+
logs = list(Path(".").glob("api_usage_*.json"))
|
19 |
+
if not logs:
|
20 |
+
raise FileNotFoundError("No log files found in the current directory.")
|
21 |
+
return str(max(logs, key=lambda p: p.stat().st_mtime))
|
22 |
+
|
23 |
+
|
24 |
+
def format_cost(entry: dict) -> str:
|
25 |
+
"""Format cost if available, otherwise return 'N/A'
|
26 |
+
|
27 |
+
Args:
|
28 |
+
entry: Log entry dictionary containing cost information
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
str: Formatted cost string with $ and 4 decimal places, or 'N/A' if cost not found
|
32 |
+
"""
|
33 |
+
return f"${entry.get('cost', 'N/A'):.4f}" if "cost" in entry else "N/A"
|
34 |
+
|
35 |
+
|
36 |
+
def print_gpt4_entry(entry: dict) -> None:
|
37 |
+
"""Print entry for GPT-4 format
|
38 |
+
|
39 |
+
Args:
|
40 |
+
entry: Log entry dictionary in GPT-4 format containing model info, inputs and outputs
|
41 |
+
"""
|
42 |
+
print("\n=== Log Entry ===")
|
43 |
+
print(f"Model: {entry['model']}")
|
44 |
+
print(f"Case ID: {entry['case_id']}")
|
45 |
+
print(f"Question ID: {entry['question_id']}")
|
46 |
+
|
47 |
+
print("\n=== Model Input ===")
|
48 |
+
messages = entry["input"]["messages"]
|
49 |
+
print("System message:", messages[0]["content"])
|
50 |
+
user_content = messages[1]["content"]
|
51 |
+
print("\nUser prompt:", user_content[0]["text"])
|
52 |
+
print("\nImages provided:")
|
53 |
+
for content in user_content[1:]:
|
54 |
+
print(f" - {content['image_url']['url']}")
|
55 |
+
|
56 |
+
print("\n=== Model Output ===")
|
57 |
+
print(f"Answer: {entry['model_answer']}")
|
58 |
+
print(f"Correct: {entry['correct_answer']}")
|
59 |
+
|
60 |
+
print("\n=== Usage Stats ===")
|
61 |
+
print(f"Duration: {entry['duration']}s")
|
62 |
+
print(f"Cost: {format_cost(entry)}")
|
63 |
+
print(
|
64 |
+
f"Tokens: {entry['usage']['total_tokens']}",
|
65 |
+
f"(prompt: {entry['usage']['prompt_tokens']},",
|
66 |
+
f"completion: {entry['usage']['completion_tokens']})",
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def print_llama_entry(entry: dict) -> None:
|
71 |
+
"""Print entry for Llama-3.2 format
|
72 |
+
|
73 |
+
Args:
|
74 |
+
entry: Log entry dictionary in Llama format containing model info, inputs and outputs
|
75 |
+
"""
|
76 |
+
print("\n=== Log Entry ===")
|
77 |
+
print(f"Model: {entry['model']}")
|
78 |
+
print(f"Case ID: {entry['case_id']}")
|
79 |
+
print(f"Question ID: {entry['question_id']}")
|
80 |
+
|
81 |
+
print("\n=== Model Input ===")
|
82 |
+
print(f"Question: {entry['input']['question_data']['question']}")
|
83 |
+
print("\nImages provided:")
|
84 |
+
for url in entry["input"]["image_urls"]:
|
85 |
+
print(f" - {url}")
|
86 |
+
if entry["input"]["image_captions"]:
|
87 |
+
print("\nImage captions:")
|
88 |
+
for caption in entry["input"]["image_captions"]:
|
89 |
+
if caption:
|
90 |
+
print(f" - {caption}")
|
91 |
+
|
92 |
+
print("\n=== Model Output ===")
|
93 |
+
print(f"Answer: {entry['model_answer']}")
|
94 |
+
print(f"Correct: {entry['correct_answer']}")
|
95 |
+
|
96 |
+
print("\n=== Usage Stats ===")
|
97 |
+
print(f"Duration: {entry['duration']}s")
|
98 |
+
if "usage" in entry:
|
99 |
+
print(
|
100 |
+
f"Tokens: {entry['usage']['total_tokens']}",
|
101 |
+
f"(prompt: {entry['usage']['prompt_tokens']},",
|
102 |
+
f"completion: {entry['usage']['completion_tokens']})",
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
def determine_model_type(entry: dict) -> str:
|
107 |
+
"""Determine the model type from the entry
|
108 |
+
|
109 |
+
Args:
|
110 |
+
entry: Log entry dictionary containing model information
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
str: Model type - 'gpt4', 'llama', or 'unknown'
|
114 |
+
"""
|
115 |
+
model = entry.get("model", "").lower()
|
116 |
+
if "gpt-4" in model:
|
117 |
+
return "gpt4"
|
118 |
+
elif "llama" in model:
|
119 |
+
return "llama"
|
120 |
+
elif "chexagent" in model:
|
121 |
+
return "chexagent"
|
122 |
+
elif "medrax" in model:
|
123 |
+
return "medrax"
|
124 |
+
else:
|
125 |
+
return "unknown"
|
126 |
+
|
127 |
+
|
128 |
+
def print_log_entry(
|
129 |
+
log_file: Optional[str] = None,
|
130 |
+
num_entries: Optional[int] = None,
|
131 |
+
model_filter: Optional[str] = None,
|
132 |
+
) -> None:
|
133 |
+
"""Print log entries from the specified log file or the latest log file.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
log_file: Path to the log file. If None, uses the latest log file.
|
137 |
+
num_entries: Number of entries to print. If None, prints all entries.
|
138 |
+
model_filter: Filter entries by model type ('gpt4' or 'llama'). If None, prints all.
|
139 |
+
"""
|
140 |
+
if log_file is None:
|
141 |
+
log_file = get_latest_log()
|
142 |
+
print(f"Using latest log file: {log_file}")
|
143 |
+
|
144 |
+
entries_printed = 0
|
145 |
+
total_entries = 0
|
146 |
+
filtered_entries = 0
|
147 |
+
|
148 |
+
with open(log_file, "r") as f:
|
149 |
+
for line in f:
|
150 |
+
if line.startswith("HTTP"):
|
151 |
+
continue
|
152 |
+
try:
|
153 |
+
total_entries += 1
|
154 |
+
entry = json.loads(line)
|
155 |
+
|
156 |
+
# Apply model filter if specified
|
157 |
+
model_type = determine_model_type(entry)
|
158 |
+
if model_filter and model_type != model_filter:
|
159 |
+
filtered_entries += 1
|
160 |
+
continue
|
161 |
+
|
162 |
+
if model_type == "gpt4":
|
163 |
+
print_gpt4_entry(entry)
|
164 |
+
elif model_type == "llama":
|
165 |
+
print_llama_entry(entry)
|
166 |
+
else:
|
167 |
+
print(f"Unknown model type in entry: {entry['model']}")
|
168 |
+
continue
|
169 |
+
|
170 |
+
print("=" * 50)
|
171 |
+
entries_printed += 1
|
172 |
+
if num_entries and entries_printed >= num_entries:
|
173 |
+
break
|
174 |
+
|
175 |
+
except (json.JSONDecodeError, KeyError) as e:
|
176 |
+
print(f"Error processing entry: {e}")
|
177 |
+
continue
|
178 |
+
|
179 |
+
print(f"\nSummary:")
|
180 |
+
print(f"Total entries: {total_entries}")
|
181 |
+
print(f"Entries printed: {entries_printed}")
|
182 |
+
if model_filter:
|
183 |
+
print(f"Entries filtered: {filtered_entries}")
|
184 |
+
|
185 |
+
|
186 |
+
def main() -> None:
|
187 |
+
"""Main entry point for the script"""
|
188 |
+
parser = argparse.ArgumentParser(
|
189 |
+
description="Parse and display log entries from API usage logs."
|
190 |
+
)
|
191 |
+
parser.add_argument("-l", "--log_file", nargs="?", help="Path to the log file (optional)")
|
192 |
+
parser.add_argument("-n", "--num_entries", type=int, help="Number of entries to display")
|
193 |
+
parser.add_argument(
|
194 |
+
"-m",
|
195 |
+
"--model",
|
196 |
+
choices=["gpt4", "llama"],
|
197 |
+
default="gpt4",
|
198 |
+
help="Model type to display (default: gpt4)",
|
199 |
+
)
|
200 |
+
args = parser.parse_args()
|
201 |
+
|
202 |
+
try:
|
203 |
+
print_log_entry(args.log_file, args.num_entries, args.model)
|
204 |
+
except FileNotFoundError as e:
|
205 |
+
print(f"Error: {e}")
|
206 |
+
exit(1)
|
207 |
+
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
main()
|
experiments/validate_logs.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Optional
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import glob
|
5 |
+
from pathlib import Path
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
|
9 |
+
def get_latest_log() -> str:
|
10 |
+
"""Find the most recently modified log file in the current directory.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
str: Path to the most recently modified log file
|
14 |
+
|
15 |
+
Raises:
|
16 |
+
SystemExit: If no log files are found in current directory
|
17 |
+
"""
|
18 |
+
log_pattern = "api_usage_*.json"
|
19 |
+
logs = list(Path(".").glob(log_pattern))
|
20 |
+
if not logs:
|
21 |
+
print(f"No files matching pattern '{log_pattern}' found in current directory")
|
22 |
+
sys.exit(1)
|
23 |
+
return str(max(logs, key=lambda p: p.stat().st_mtime))
|
24 |
+
|
25 |
+
|
26 |
+
def analyze_log_file(filename: str) -> Tuple[List[Dict], List[Dict], Dict[str, List[str]]]:
|
27 |
+
"""Analyze a log file for entries missing images and errors.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
filename: Path to the log file to analyze
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Tuple containing:
|
34 |
+
- List of entries with no images
|
35 |
+
- List of skipped/error entries
|
36 |
+
- Dict of processing errors by type
|
37 |
+
|
38 |
+
Raises:
|
39 |
+
SystemExit: If file cannot be found or read
|
40 |
+
"""
|
41 |
+
no_images = []
|
42 |
+
errors = defaultdict(list)
|
43 |
+
skipped = []
|
44 |
+
|
45 |
+
try:
|
46 |
+
with open(filename, "r") as f:
|
47 |
+
for line_num, line in enumerate(f, 1):
|
48 |
+
# Skip HTTP request logs
|
49 |
+
if line.startswith("HTTP Request:") or line.strip() == "":
|
50 |
+
continue
|
51 |
+
try:
|
52 |
+
# Try to parse the JSON line
|
53 |
+
if not line.strip().startswith("{"):
|
54 |
+
continue
|
55 |
+
entry = json.loads(line.strip())
|
56 |
+
case_id = entry.get("case_id")
|
57 |
+
question_id = entry.get("question_id")
|
58 |
+
|
59 |
+
# Skip if we can't identify the question
|
60 |
+
if not case_id or not question_id:
|
61 |
+
continue
|
62 |
+
|
63 |
+
# Check for explicit skip/error status
|
64 |
+
if entry.get("status") in ["skipped", "error"]:
|
65 |
+
skipped.append(
|
66 |
+
{
|
67 |
+
"case_id": case_id,
|
68 |
+
"question_id": question_id,
|
69 |
+
"reason": entry.get("reason"),
|
70 |
+
"status": entry.get("status"),
|
71 |
+
}
|
72 |
+
)
|
73 |
+
continue
|
74 |
+
|
75 |
+
# Check user content for images
|
76 |
+
messages = entry.get("input", {}).get("messages", [])
|
77 |
+
has_image = False
|
78 |
+
for msg in messages:
|
79 |
+
content = msg.get("content", [])
|
80 |
+
if isinstance(content, list):
|
81 |
+
for item in content:
|
82 |
+
if isinstance(item, dict) and item.get("type") == "image_url":
|
83 |
+
has_image = True
|
84 |
+
break
|
85 |
+
if not has_image:
|
86 |
+
no_images.append(
|
87 |
+
{
|
88 |
+
"case_id": case_id,
|
89 |
+
"question_id": question_id,
|
90 |
+
"question": entry.get("input", {})
|
91 |
+
.get("question_data", {})
|
92 |
+
.get("question", "")[:100]
|
93 |
+
+ "...", # First 100 chars of question
|
94 |
+
}
|
95 |
+
)
|
96 |
+
except json.JSONDecodeError:
|
97 |
+
errors["json_decode"].append(f"Line {line_num}: Invalid JSON")
|
98 |
+
continue
|
99 |
+
except Exception as e:
|
100 |
+
errors["other"].append(f"Line {line_num}: Error processing entry: {str(e)}")
|
101 |
+
except FileNotFoundError:
|
102 |
+
print(f"Error: Could not find log file: {filename}")
|
103 |
+
sys.exit(1)
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error reading file {filename}: {str(e)}")
|
106 |
+
sys.exit(1)
|
107 |
+
|
108 |
+
return no_images, skipped, errors
|
109 |
+
|
110 |
+
|
111 |
+
def print_results(
|
112 |
+
filename: str, no_images: List[Dict], skipped: List[Dict], errors: Dict[str, List[str]]
|
113 |
+
) -> None:
|
114 |
+
"""Print analysis results.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
filename: Name of the analyzed log file
|
118 |
+
no_images: List of entries with no images
|
119 |
+
skipped: List of skipped/error entries
|
120 |
+
errors: Dict of processing errors by type
|
121 |
+
"""
|
122 |
+
print(f"\nAnalyzing log file: {filename}")
|
123 |
+
print("\n=== Questions with No Images ===")
|
124 |
+
if no_images:
|
125 |
+
for entry in no_images:
|
126 |
+
print(f"\nCase ID: {entry['case_id']}")
|
127 |
+
print(f"Question ID: {entry['question_id']}")
|
128 |
+
print(f"Question Preview: {entry['question']}")
|
129 |
+
print(f"\nTotal questions without images: {len(no_images)}")
|
130 |
+
|
131 |
+
print("\n=== Skipped/Error Questions ===")
|
132 |
+
if skipped:
|
133 |
+
for entry in skipped:
|
134 |
+
print(f"\nCase ID: {entry['case_id']}")
|
135 |
+
print(f"Question ID: {entry['question_id']}")
|
136 |
+
print(f"Status: {entry['status']}")
|
137 |
+
print(f"Reason: {entry.get('reason', 'unknown')}")
|
138 |
+
print(f"\nTotal skipped/error questions: {len(skipped)}")
|
139 |
+
|
140 |
+
if errors:
|
141 |
+
print("\n=== Processing Errors ===")
|
142 |
+
for error_type, messages in errors.items():
|
143 |
+
if messages:
|
144 |
+
print(f"\n{error_type}:")
|
145 |
+
for msg in messages:
|
146 |
+
print(f" {msg}")
|
147 |
+
|
148 |
+
|
149 |
+
def main() -> None:
|
150 |
+
"""Main entry point for log validation script."""
|
151 |
+
# If a file is specified as an argument, use it; otherwise find the latest log
|
152 |
+
if len(sys.argv) > 1:
|
153 |
+
log_file = sys.argv[1]
|
154 |
+
else:
|
155 |
+
log_file = get_latest_log()
|
156 |
+
|
157 |
+
no_images, skipped, errors = analyze_log_file(log_file)
|
158 |
+
print_results(log_file, no_images, skipped, errors)
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == "__main__":
|
162 |
+
main()
|
handler.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import runpod
|
3 |
+
from main import initialize_agent
|
4 |
+
|
5 |
+
# 初始化agent(可以根据需要选择要使用的工具)
|
6 |
+
selected_tools = [
|
7 |
+
"ImageVisualizerTool",
|
8 |
+
"DicomProcessorTool",
|
9 |
+
"ChestXRayClassifierTool",
|
10 |
+
"ChestXRaySegmentationTool",
|
11 |
+
"ChestXRayReportGeneratorTool",
|
12 |
+
"XRayVQATool",
|
13 |
+
]
|
14 |
+
|
15 |
+
agent, tools_dict = initialize_agent(
|
16 |
+
"medrax/docs/system_prompts.txt",
|
17 |
+
tools_to_use=selected_tools,
|
18 |
+
model_dir="/model-weights"
|
19 |
+
)
|
20 |
+
|
21 |
+
def handler(event):
|
22 |
+
"""
|
23 |
+
处理RunPod API请求的主函数
|
24 |
+
"""
|
25 |
+
try:
|
26 |
+
# 获取请求参数
|
27 |
+
job_input = event["input"]
|
28 |
+
|
29 |
+
# 验证必需的参数
|
30 |
+
if "image" not in job_input:
|
31 |
+
return {"error": "Missing required parameter: image"}
|
32 |
+
|
33 |
+
if "task" not in job_input:
|
34 |
+
return {"error": "Missing required parameter: task"}
|
35 |
+
|
36 |
+
image_data = job_input["image"] # 这里假设是base64编码的图像
|
37 |
+
task = job_input["task"] # 任务类型
|
38 |
+
|
39 |
+
# 根据任务类型调用相应的工具
|
40 |
+
if task == "classification":
|
41 |
+
result = tools_dict["ChestXRayClassifierTool"].run(image_data)
|
42 |
+
elif task == "segmentation":
|
43 |
+
result = tools_dict["ChestXRaySegmentationTool"].run(image_data)
|
44 |
+
elif task == "report":
|
45 |
+
result = tools_dict["ChestXRayReportGeneratorTool"].run(image_data)
|
46 |
+
else:
|
47 |
+
return {"error": f"Unsupported task type: {task}"}
|
48 |
+
|
49 |
+
return {
|
50 |
+
"status": "success",
|
51 |
+
"result": result
|
52 |
+
}
|
53 |
+
|
54 |
+
except Exception as e:
|
55 |
+
return {"error": str(e)}
|
56 |
+
|
57 |
+
runpod.serverless.start({"handler": handler})
|
interface.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import gradio as gr
|
3 |
+
from pathlib import Path
|
4 |
+
import time
|
5 |
+
import shutil
|
6 |
+
from typing import AsyncGenerator, List, Optional, Tuple
|
7 |
+
from gradio import ChatMessage
|
8 |
+
|
9 |
+
|
10 |
+
class ChatInterface:
|
11 |
+
"""
|
12 |
+
A chat interface for interacting with a medical AI agent through Gradio.
|
13 |
+
|
14 |
+
Handles file uploads, message processing, and chat history management.
|
15 |
+
Supports both regular image files and DICOM medical imaging files.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, agent, tools_dict):
|
19 |
+
"""
|
20 |
+
Initialize the chat interface.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
agent: The medical AI agent to handle requests
|
24 |
+
tools_dict (dict): Dictionary of available tools for image processing
|
25 |
+
"""
|
26 |
+
self.agent = agent
|
27 |
+
self.tools_dict = tools_dict
|
28 |
+
self.upload_dir = Path("temp")
|
29 |
+
self.upload_dir.mkdir(exist_ok=True)
|
30 |
+
self.current_thread_id = None
|
31 |
+
# Separate storage for original and display paths
|
32 |
+
self.original_file_path = None # For LLM (.dcm or other)
|
33 |
+
self.display_file_path = None # For UI (always viewable format)
|
34 |
+
|
35 |
+
def handle_upload(self, file_path: str) -> str:
|
36 |
+
"""
|
37 |
+
Handle new file upload and set appropriate paths.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
file_path (str): Path to the uploaded file
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: Display path for UI, or None if no file uploaded
|
44 |
+
"""
|
45 |
+
if not file_path:
|
46 |
+
return None
|
47 |
+
|
48 |
+
source = Path(file_path)
|
49 |
+
timestamp = int(time.time())
|
50 |
+
|
51 |
+
# Save original file with proper suffix
|
52 |
+
suffix = source.suffix.lower()
|
53 |
+
saved_path = self.upload_dir / f"upload_{timestamp}{suffix}"
|
54 |
+
shutil.copy2(file_path, saved_path) # Use file_path directly instead of source
|
55 |
+
self.original_file_path = str(saved_path)
|
56 |
+
|
57 |
+
# Handle DICOM conversion for display only
|
58 |
+
if suffix == ".dcm":
|
59 |
+
output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path))
|
60 |
+
self.display_file_path = output["image_path"]
|
61 |
+
else:
|
62 |
+
self.display_file_path = str(saved_path)
|
63 |
+
|
64 |
+
return self.display_file_path
|
65 |
+
|
66 |
+
def add_message(
|
67 |
+
self, message: str, display_image: str, history: List[dict]
|
68 |
+
) -> Tuple[List[dict], gr.Textbox]:
|
69 |
+
"""
|
70 |
+
Add a new message to the chat history.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
message (str): Text message to add
|
74 |
+
display_image (str): Path to image being displayed
|
75 |
+
history (List[dict]): Current chat history
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
Tuple[List[dict], gr.Textbox]: Updated history and textbox component
|
79 |
+
"""
|
80 |
+
image_path = self.original_file_path or display_image
|
81 |
+
if image_path is not None:
|
82 |
+
history.append({"role": "user", "content": {"path": image_path}})
|
83 |
+
if message is not None:
|
84 |
+
history.append({"role": "user", "content": message})
|
85 |
+
return history, gr.Textbox(value=message, interactive=False)
|
86 |
+
|
87 |
+
async def process_message(
|
88 |
+
self, message: str, display_image: Optional[str], chat_history: List[ChatMessage]
|
89 |
+
) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]:
|
90 |
+
"""
|
91 |
+
Process a message and generate responses.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
message (str): User message to process
|
95 |
+
display_image (Optional[str]): Path to currently displayed image
|
96 |
+
chat_history (List[ChatMessage]): Current chat history
|
97 |
+
|
98 |
+
Yields:
|
99 |
+
Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string
|
100 |
+
"""
|
101 |
+
chat_history = chat_history or []
|
102 |
+
|
103 |
+
# Initialize thread if needed
|
104 |
+
if not self.current_thread_id:
|
105 |
+
self.current_thread_id = str(time.time())
|
106 |
+
|
107 |
+
messages = []
|
108 |
+
image_path = self.original_file_path or display_image
|
109 |
+
if image_path is not None:
|
110 |
+
messages.append({"role": "user", "content": f"path: {image_path}"})
|
111 |
+
if message is not None:
|
112 |
+
messages.append({"role": "user", "content": message})
|
113 |
+
|
114 |
+
try:
|
115 |
+
for event in self.agent.workflow.stream(
|
116 |
+
{"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}}
|
117 |
+
):
|
118 |
+
if isinstance(event, dict):
|
119 |
+
if "process" in event:
|
120 |
+
content = event["process"]["messages"][-1].content
|
121 |
+
if content:
|
122 |
+
content = re.sub(r"temp/[^\s]*", "", content)
|
123 |
+
chat_history.append(ChatMessage(role="assistant", content=content))
|
124 |
+
yield chat_history, self.display_file_path, ""
|
125 |
+
|
126 |
+
elif "execute" in event:
|
127 |
+
for message in event["execute"]["messages"]:
|
128 |
+
tool_name = message.name
|
129 |
+
tool_result = eval(message.content)[0]
|
130 |
+
|
131 |
+
if tool_result:
|
132 |
+
metadata = {"title": f"🖼️ Image from tool: {tool_name}"}
|
133 |
+
formatted_result = " ".join(
|
134 |
+
line.strip() for line in str(tool_result).splitlines()
|
135 |
+
).strip()
|
136 |
+
metadata["description"] = formatted_result
|
137 |
+
chat_history.append(
|
138 |
+
ChatMessage(
|
139 |
+
role="assistant",
|
140 |
+
content=formatted_result,
|
141 |
+
metadata=metadata,
|
142 |
+
)
|
143 |
+
)
|
144 |
+
|
145 |
+
# For image_visualizer, use display path
|
146 |
+
if tool_name == "image_visualizer":
|
147 |
+
self.display_file_path = tool_result["image_path"]
|
148 |
+
chat_history.append(
|
149 |
+
ChatMessage(
|
150 |
+
role="assistant",
|
151 |
+
# content=gr.Image(value=self.display_file_path),
|
152 |
+
content={"path": self.display_file_path},
|
153 |
+
)
|
154 |
+
)
|
155 |
+
|
156 |
+
yield chat_history, self.display_file_path, ""
|
157 |
+
|
158 |
+
except Exception as e:
|
159 |
+
chat_history.append(
|
160 |
+
ChatMessage(
|
161 |
+
role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
|
162 |
+
)
|
163 |
+
)
|
164 |
+
yield chat_history, self.display_file_path
|
165 |
+
|
166 |
+
|
167 |
+
def create_demo(agent, tools_dict):
|
168 |
+
"""
|
169 |
+
Create a Gradio demo interface for the medical AI agent.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
agent: The medical AI agent to handle requests
|
173 |
+
tools_dict (dict): Dictionary of available tools for image processing
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
gr.Blocks: Gradio Blocks interface
|
177 |
+
"""
|
178 |
+
interface = ChatInterface(agent, tools_dict)
|
179 |
+
|
180 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
181 |
+
with gr.Column():
|
182 |
+
gr.Markdown(
|
183 |
+
"""
|
184 |
+
# 🏥 MedRAX
|
185 |
+
Medical Reasoning Agent for Chest X-ray
|
186 |
+
"""
|
187 |
+
)
|
188 |
+
|
189 |
+
with gr.Row():
|
190 |
+
with gr.Column(scale=3):
|
191 |
+
chatbot = gr.Chatbot(
|
192 |
+
[],
|
193 |
+
height=800,
|
194 |
+
container=True,
|
195 |
+
show_label=True,
|
196 |
+
elem_classes="chat-box",
|
197 |
+
type="messages",
|
198 |
+
label="Agent",
|
199 |
+
avatar_images=(
|
200 |
+
None,
|
201 |
+
"assets/medrax_logo.jpg",
|
202 |
+
),
|
203 |
+
)
|
204 |
+
with gr.Row():
|
205 |
+
with gr.Column(scale=3):
|
206 |
+
txt = gr.Textbox(
|
207 |
+
show_label=False,
|
208 |
+
placeholder="Ask about the X-ray...",
|
209 |
+
container=False,
|
210 |
+
)
|
211 |
+
|
212 |
+
with gr.Column(scale=3):
|
213 |
+
image_display = gr.Image(
|
214 |
+
label="Image", type="filepath", height=700, container=True
|
215 |
+
)
|
216 |
+
with gr.Row():
|
217 |
+
upload_button = gr.UploadButton(
|
218 |
+
"📎 Upload X-Ray",
|
219 |
+
file_types=["image"],
|
220 |
+
)
|
221 |
+
dicom_upload = gr.UploadButton(
|
222 |
+
"📄 Upload DICOM",
|
223 |
+
file_types=["file"],
|
224 |
+
)
|
225 |
+
with gr.Row():
|
226 |
+
clear_btn = gr.Button("Clear Chat")
|
227 |
+
new_thread_btn = gr.Button("New Thread")
|
228 |
+
|
229 |
+
# Event handlers
|
230 |
+
def clear_chat():
|
231 |
+
interface.original_file_path = None
|
232 |
+
interface.display_file_path = None
|
233 |
+
return [], None
|
234 |
+
|
235 |
+
def new_thread():
|
236 |
+
interface.current_thread_id = str(time.time())
|
237 |
+
return [], interface.display_file_path
|
238 |
+
|
239 |
+
def handle_file_upload(file):
|
240 |
+
return interface.handle_upload(file.name)
|
241 |
+
|
242 |
+
chat_msg = txt.submit(
|
243 |
+
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
244 |
+
)
|
245 |
+
bot_msg = chat_msg.then(
|
246 |
+
interface.process_message,
|
247 |
+
inputs=[txt, image_display, chatbot],
|
248 |
+
outputs=[chatbot, image_display, txt],
|
249 |
+
)
|
250 |
+
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
|
251 |
+
|
252 |
+
upload_button.upload(handle_file_upload, inputs=upload_button, outputs=image_display)
|
253 |
+
|
254 |
+
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
|
255 |
+
|
256 |
+
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
|
257 |
+
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
258 |
+
|
259 |
+
return demo
|