AImused commited on
Commit
4ccc3f1
·
verified ·
1 Parent(s): 9cc28c3

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +18 -0
  2. Dockerfile +46 -0
  3. LICENSE +201 -0
  4. README.md +13 -0
  5. chat.py +311 -0
  6. eval_mm/README.md +362 -0
  7. eval_mm/README_zh.md +358 -0
  8. eval_mm/vlmevalkit/requirements.txt +30 -0
  9. eval_mm/vlmevalkit/requirements/docs.txt +11 -0
  10. eval_mm/vlmevalkit/run.py +223 -0
  11. eval_mm/vlmevalkit/scripts/run_inference.sh +31 -0
  12. eval_mm/vlmevalkit/setup.py +122 -0
  13. eval_mm/vlmevalkit/vlmeval/__init__.py +16 -0
  14. eval_mm/vlmevalkit/vlmeval/api/__init__.py +5 -0
  15. eval_mm/vlmevalkit/vlmeval/api/base.py +265 -0
  16. eval_mm/vlmevalkit/vlmeval/api/gpt.py +248 -0
  17. eval_mm/vlmevalkit/vlmeval/config.py +19 -0
  18. eval_mm/vlmevalkit/vlmeval/dataset/__init__.py +186 -0
  19. eval_mm/vlmevalkit/vlmeval/dataset/dude.py +210 -0
  20. eval_mm/vlmevalkit/vlmeval/dataset/image_base.py +165 -0
  21. eval_mm/vlmevalkit/vlmeval/dataset/image_caption.py +75 -0
  22. eval_mm/vlmevalkit/vlmeval/dataset/image_mcq.py +484 -0
  23. eval_mm/vlmevalkit/vlmeval/dataset/image_mt.py +128 -0
  24. eval_mm/vlmevalkit/vlmeval/dataset/image_vqa.py +433 -0
  25. eval_mm/vlmevalkit/vlmeval/dataset/image_yorn.py +88 -0
  26. eval_mm/vlmevalkit/vlmeval/dataset/mmbench_video.py +252 -0
  27. eval_mm/vlmevalkit/vlmeval/dataset/mmlongbench.py +582 -0
  28. eval_mm/vlmevalkit/vlmeval/dataset/mvbench.py +577 -0
  29. eval_mm/vlmevalkit/vlmeval/dataset/slidevqa.py +189 -0
  30. eval_mm/vlmevalkit/vlmeval/dataset/text_base.py +88 -0
  31. eval_mm/vlmevalkit/vlmeval/dataset/text_mcq.py +123 -0
  32. eval_mm/vlmevalkit/vlmeval/dataset/utils/__init__.py +9 -0
  33. eval_mm/vlmevalkit/vlmeval/dataset/utils/judge_util.py +41 -0
  34. eval_mm/vlmevalkit/vlmeval/dataset/utils/llavabench.py +65 -0
  35. eval_mm/vlmevalkit/vlmeval/dataset/utils/mathv.py +170 -0
  36. eval_mm/vlmevalkit/vlmeval/dataset/utils/mathvista.py +164 -0
  37. eval_mm/vlmevalkit/vlmeval/dataset/utils/mmbench_video.py +70 -0
  38. eval_mm/vlmevalkit/vlmeval/dataset/utils/mmdu.py +126 -0
  39. eval_mm/vlmevalkit/vlmeval/dataset/utils/mmvet.py +106 -0
  40. eval_mm/vlmevalkit/vlmeval/dataset/utils/multiple_choice.py +442 -0
  41. eval_mm/vlmevalkit/vlmeval/dataset/utils/mvbench.py +450 -0
  42. eval_mm/vlmevalkit/vlmeval/dataset/utils/ocrbench.py +65 -0
  43. eval_mm/vlmevalkit/vlmeval/dataset/utils/videomme.py +140 -0
  44. eval_mm/vlmevalkit/vlmeval/dataset/utils/vqa_eval.py +285 -0
  45. eval_mm/vlmevalkit/vlmeval/dataset/utils/yorn.py +203 -0
  46. eval_mm/vlmevalkit/vlmeval/dataset/vcr.py +332 -0
  47. eval_mm/vlmevalkit/vlmeval/dataset/video_base.py +87 -0
  48. eval_mm/vlmevalkit/vlmeval/dataset/videomme.py +250 -0
  49. eval_mm/vlmevalkit/vlmeval/inference.py +171 -0
  50. eval_mm/vlmevalkit/vlmeval/inference_mt.py +180 -0
.gitattributes CHANGED
@@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/checkpoint/assets/Skiing.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ models/checkpoint/assets/demo.wav filter=lfs diff=lfs merge=lfs -text
38
+ models/checkpoint/assets/input_examples/Trump_WEF_2018_10s.mp3 filter=lfs diff=lfs merge=lfs -text
39
+ models/checkpoint/assets/input_examples/assistant_default_female_voice.wav filter=lfs diff=lfs merge=lfs -text
40
+ models/checkpoint/assets/input_examples/assistant_male_voice.wav filter=lfs diff=lfs merge=lfs -text
41
+ models/checkpoint/assets/input_examples/audio_understanding.mp3 filter=lfs diff=lfs merge=lfs -text
42
+ models/checkpoint/assets/input_examples/chi-english-1.wav filter=lfs diff=lfs merge=lfs -text
43
+ models/checkpoint/assets/input_examples/cxk_original.wav filter=lfs diff=lfs merge=lfs -text
44
+ models/checkpoint/assets/input_examples/exciting-emotion.wav filter=lfs diff=lfs merge=lfs -text
45
+ models/checkpoint/assets/input_examples/fast-pace.wav filter=lfs diff=lfs merge=lfs -text
46
+ models/checkpoint/assets/input_examples/icl_20.wav filter=lfs diff=lfs merge=lfs -text
47
+ models/checkpoint/assets/input_examples/indian-accent.wav filter=lfs diff=lfs merge=lfs -text
48
+ models/checkpoint/assets/mimick.wav filter=lfs diff=lfs merge=lfs -text
49
+ models/checkpoint/assets/qa.wav filter=lfs diff=lfs merge=lfs -text
50
+ ref_audios/default.wav filter=lfs diff=lfs merge=lfs -text
51
+ ref_audios/female_example.wav filter=lfs diff=lfs merge=lfs -text
52
+ ref_audios/male_example.wav filter=lfs diff=lfs merge=lfs -text
53
+ ref_audios/video_default.wav filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ DEBIAN_FRONTEND=noninteractive \
6
+ CUDA_HOME=/usr/local/cuda \
7
+ PATH=/usr/local/cuda/bin:$PATH \
8
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
9
+ NVIDIA_VISIBLE_DEVICES=all \
10
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility \
11
+ HF_HOME=/app/models \
12
+ NUMBA_CACHE_DIR=/tmp/numba_cache
13
+
14
+ # Install system dependencies
15
+ RUN apt-get update && apt-get install -y --no-install-recommends \
16
+ python3 \
17
+ python3-pip \
18
+ python3-dev \
19
+ build-essential \
20
+ git \
21
+ ffmpeg \
22
+ libsndfile1 \
23
+ curl \
24
+ && rm -rf /var/lib/apt/lists/*
25
+
26
+ # Upgrade pip and install build tools
27
+ RUN python3 -m pip install --upgrade pip setuptools wheel uv
28
+
29
+ WORKDIR /app
30
+
31
+ # Create Numba cache directory
32
+ RUN mkdir -p /tmp/numba_cache && \
33
+ chown nobody:nogroup /tmp/numba_cache && \
34
+ chmod 700 /tmp/numba_cache
35
+
36
+ COPY requirements.txt .
37
+
38
+ # Install other requirements
39
+ RUN python3 -m uv pip install --no-cache-dir -r requirements.txt --prerelease=allow
40
+ RUN python3 -m uv pip install --no-build-isolation flash-attn
41
+
42
+ COPY . .
43
+
44
+ EXPOSE 8000
45
+
46
+ CMD ["python3", "server.py"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 OpenBMB
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - any-to-any
5
+ - omega
6
+ - omegalabs
7
+ - bittensor
8
+ - agi
9
+ ---
10
+
11
+ This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
12
+
13
+ Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
chat.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # import os
4
+ import torch
5
+
6
+ # import torch
7
+ import json
8
+ from PIL import Image
9
+ # from PIL import Image
10
+
11
+ # from PIL import Image
12
+ import base64
13
+ import io
14
+
15
+ # import io
16
+ from accelerate import load_checkpoint_and_dispatch, init_empty_weights
17
+
18
+ # from accelerate import load_checkpoint_and_dispatch, init_empty_weights
19
+
20
+ # from accelerate import load_checkpoint_and_dispatch, init_empty_weights
21
+ from transformers import AutoTokenizer, AutoModel
22
+
23
+ # from transformers import AutoTokenizer, AutoModel
24
+
25
+ from omnilmm.utils import disable_torch_init
26
+ from omnilmm.model.omnilmm import OmniLMMForCausalLM
27
+ from omnilmm.model.utils import build_transform
28
+
29
+ # from omnilmm.model.utils import build_transform
30
+ from omnilmm.train.train_utils import omni_preprocess
31
+ # from omnilmm.train.train_utils import omni_preprocess
32
+
33
+ DEFAULT_IMAGE_TOKEN = "<image>"
34
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
35
+ DEFAULT_IM_START_TOKEN = "<im_start>"
36
+ DEFAULT_IM_END_TOKEN = "<im_end>"
37
+
38
+
39
+
40
+ def init_omni_lmm(model_path):
41
+ torch.backends.cuda.matmul.allow_tf32 = True
42
+ disable_torch_init()
43
+ model_name = os.path.expanduser(model_path)
44
+ print(f'Load omni_lmm model and tokenizer from {model_name}')
45
+ tokenizer = AutoTokenizer.from_pretrained(
46
+ model_name, model_max_length=2048)
47
+
48
+ if False:
49
+ # model on multiple devices for small size gpu memory (Nvidia 3090 24G x2)
50
+ with init_empty_weights():
51
+ model = OmniLMMForCausalLM.from_pretrained(model_name, tune_clip=True, torch_dtype=torch.bfloat16)
52
+ model = load_checkpoint_and_dispatch(model, model_name, dtype=torch.bfloat16,
53
+ device_map="auto", no_split_module_classes=['Eva','MistralDecoderLayer', 'ModuleList', 'Resampler']
54
+ )
55
+ else:
56
+ model = OmniLMMForCausalLM.from_pretrained(
57
+ model_name, tune_clip=True, torch_dtype=torch.bfloat16
58
+ ).to(device='cuda', dtype=torch.bfloat16)
59
+
60
+ image_processor = build_transform(
61
+ is_train=False, input_size=model.model.config.image_size, std_mode='OPENAI_CLIP')
62
+
63
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
64
+ assert mm_use_im_start_end
65
+
66
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
67
+ DEFAULT_IM_END_TOKEN], special_tokens=True)
68
+
69
+
70
+ vision_config = model.model.vision_config
71
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
72
+ [DEFAULT_IMAGE_PATCH_TOKEN])[0]
73
+ vision_config.use_im_start_end = mm_use_im_start_end
74
+ vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
75
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
76
+ image_token_len = model.model.config.num_query
77
+
78
+ return model, image_processor, image_token_len, tokenizer
79
+
80
+ def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token):
81
+ if '<image>' in question_text[0]['content']:
82
+ question_text[0]['content'] = question_text[0]['content'].replace(
83
+ '<image>', im_st_token + im_patch_token * image_token_len + im_ed_token)
84
+ else:
85
+ question_text[0]['content'] = im_st_token + im_patch_token * \
86
+ image_token_len + im_ed_token + '\n' + question_text[0]['content']
87
+ return question_text
88
+
89
+ def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
90
+ question = expand_question_into_multimodal(
91
+ question, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN)
92
+
93
+ conversation = question
94
+ data_dict = omni_preprocess(sources=[conversation],
95
+ tokenizer=tokenizer,
96
+ generation=True)
97
+
98
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
99
+ labels=data_dict["labels"][0])
100
+ return data_dict
101
+
102
+
103
+
104
+ class OmniLMM12B:
105
+ def __init__(self, model_path) -> None:
106
+ model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)
107
+ self.model = model
108
+ self.image_token_len = image_token_len
109
+ self.image_transform = img_processor
110
+ self.tokenizer = tokenizer
111
+ self.model.eval()
112
+
113
+ def decode(self, image, input_ids):
114
+ with torch.inference_mode():
115
+ output = self.model.generate_vllm(
116
+ input_ids=input_ids.unsqueeze(0).cuda(),
117
+ images=image.unsqueeze(0).half().cuda(),
118
+ temperature=0.6,
119
+ max_new_tokens=1024,
120
+ # num_beams=num_beams,
121
+ do_sample=True,
122
+ output_scores=True,
123
+ return_dict_in_generate=True,
124
+ repetition_penalty=1.1,
125
+ top_k=30,
126
+ top_p=0.9,
127
+ )
128
+
129
+ response = self.tokenizer.decode(
130
+ output.sequences[0], skip_special_tokens=True)
131
+ response = response.strip()
132
+ return response
133
+
134
+ def chat(self, input):
135
+ try:
136
+ image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
137
+ except Exception as e:
138
+ return "Image decode error"
139
+
140
+ msgs = json.loads(input['question'])
141
+ input_ids = wrap_question_for_omni_lmm(
142
+ msgs, self.image_token_len, self.tokenizer)['input_ids']
143
+ input_ids = torch.as_tensor(input_ids)
144
+ #print('input_ids', input_ids)
145
+ image = self.image_transform(image)
146
+
147
+ out = self.decode(image, input_ids)
148
+
149
+ return out
150
+
151
+
152
+ def img2base64(file_name):
153
+ with open(file_name, 'rb') as f:
154
+ encoded_string = base64.b64encode(f.read())
155
+ return encoded_string
156
+
157
+ class MiniCPMV:
158
+ def __init__(self, model_path) -> None:
159
+ self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
160
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
161
+ self.model.eval().cuda()
162
+
163
+ def chat(self, input):
164
+ try:
165
+ image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
166
+ except Exception as e:
167
+ return "Image decode error"
168
+
169
+ msgs = json.loads(input['question'])
170
+
171
+ answer, context, _ = self.model.chat(
172
+ image=image,
173
+ msgs=msgs,
174
+ context=None,
175
+ tokenizer=self.tokenizer,
176
+ sampling=True,
177
+ temperature=0.7
178
+ )
179
+ return answer
180
+
181
+ class MiniCPMV2_5:
182
+ def __init__(self, model_path) -> None:
183
+ self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16)
184
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
185
+ self.model.eval().cuda()
186
+
187
+ def chat(self, input):
188
+ try:
189
+ image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
190
+ except Exception as e:
191
+ return "Image decode error"
192
+
193
+ msgs = json.loads(input['question'])
194
+
195
+ answer = self.model.chat(
196
+ image=image,
197
+ msgs=msgs,
198
+ tokenizer=self.tokenizer,
199
+ sampling=True,
200
+ temperature=0.7
201
+ )
202
+ return answer
203
+
204
+ class MiniCPMV2_6:
205
+ def __init__(self, model_path, multi_gpus=False) -> None:
206
+
207
+ print('torch_version:', torch.__version__)
208
+ if multi_gpus: # inference on multi-gpus
209
+ from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
210
+ with init_empty_weights():
211
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
212
+ attn_implementation='sdpa', torch_dtype=torch.bfloat16)
213
+
214
+ device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
215
+ no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
216
+ device_id = device_map["llm.model.embed_tokens"]
217
+ device_map["llm.lm_head"] = device_id # first and last layer of llm should be in the same device
218
+ device_map["vpm"] = device_id
219
+ device_map["resampler"] = device_id
220
+ device_id2 = device_map["llm.model.layers.26"]
221
+ device_map["llm.model.layers.8"] = device_id2
222
+ device_map["llm.model.layers.9"] = device_id2
223
+ device_map["llm.model.layers.10"] = device_id2
224
+ device_map["llm.model.layers.11"] = device_id2
225
+ device_map["llm.model.layers.12"] = device_id2
226
+ device_map["llm.model.layers.13"] = device_id2
227
+ device_map["llm.model.layers.14"] = device_id2
228
+ device_map["llm.model.layers.15"] = device_id2
229
+ device_map["llm.model.layers.16"] = device_id2
230
+ print(device_map)
231
+
232
+ self.model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
233
+ self.model.eval()
234
+ else:
235
+ self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
236
+ attn_implementation='sdpa', torch_dtype=torch.bfloat16)
237
+ self.model.eval().cuda()
238
+
239
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
240
+
241
+ def chat(self, input):
242
+ image = None
243
+ if "image" in input and len(input["image"]) > 10: # legacy API
244
+ try:
245
+ image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
246
+ except Exception as e:
247
+ return "Image decode error"
248
+
249
+ msgs = json.loads(input["question"])
250
+
251
+ for msg in msgs:
252
+ contents = msg.pop('content') # support str or List[Dict]
253
+ if isinstance(contents, str):
254
+ contents = [contents]
255
+
256
+ new_cnts = []
257
+ for c in contents:
258
+ if isinstance(c, dict):
259
+ if c['type'] == 'text':
260
+ c = c['pairs']
261
+ elif c['type'] == 'image':
262
+ c = Image.open(io.BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
263
+ else:
264
+ raise ValueError("content type only support text and image.")
265
+ new_cnts.append(c)
266
+ msg['content'] = new_cnts
267
+ print(f'msgs: {str(msgs)}')
268
+
269
+ answer = self.model.chat(
270
+ image=image,
271
+ msgs=msgs,
272
+ tokenizer=self.tokenizer,
273
+ )
274
+ return answer
275
+
276
+
277
+ class MiniCPMVChat:
278
+ def __init__(self, model_path, multi_gpus=False) -> None:
279
+ if '12B' in model_path:
280
+ self.model = OmniLMM12B(model_path)
281
+ elif 'MiniCPM-Llama3-V' in model_path:
282
+ self.model = MiniCPMV2_5(model_path)
283
+ elif 'MiniCPM-V-2_6' in model_path:
284
+ self.model = MiniCPMV2_6(model_path, multi_gpus)
285
+ else:
286
+ self.model = MiniCPMV(model_path)
287
+
288
+ def chat(self, input):
289
+ return self.model.chat(input)
290
+
291
+
292
+ if __name__ == '__main__':
293
+
294
+ model_path = 'openbmb/OmniLMM-12B'
295
+ chat_model = MiniCPMVChat(model_path)
296
+
297
+ im_64 = img2base64('./assets/worldmap_ck.jpg')
298
+
299
+ # first round chat
300
+ msgs = [{"role": "user", "content": "What is interesting about this image?"}]
301
+ input = {"image": im_64, "question": json.dumps(msgs, ensure_ascii=True)}
302
+ answer = chat_model.chat(input)
303
+ print(msgs[-1]["content"]+'\n', answer)
304
+
305
+ # second round chat
306
+ msgs.append({"role": "assistant", "content": answer})
307
+ msgs.append({"role": "user", "content": "Where is China in the image"})
308
+ input = {"image": im_64,"question": json.dumps(msgs, ensure_ascii=True)}
309
+ answer = chat_model.chat(input)
310
+ print(msgs[-1]["content"]+'\n', answer)
311
+
eval_mm/README.md ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ ## MiniCPM-V 2.6
4
+
5
+ ### opencompass
6
+ First, enter the `vlmevalkit` directory and install all dependencies:
7
+ ```bash
8
+ cd vlmevalkit
9
+ pip install --upgrade pip
10
+ pip install -e .
11
+ wget https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=4377e0a7fe8ff8ffc4f7c9c6130c1dcd3874050ae4fc28b7ff1d35234fbca423
12
+ wget https://download.pytorch.org/whl/cu118/torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=2e63d62e09d9b48b407d3e1b30eb8ae4e3abad6968e8d33093b60d0657542428
13
+ wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
14
+ pip install torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl
15
+ pip install torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl
16
+ pip install flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
17
+ ```
18
+ <br />
19
+
20
+ Then, run `scripts/run_inference.sh`, which receives three input parameters in sequence: `MODELNAME`, `DATALIST`, and `MODE`. `MODELNAME` represents the name of the model, `DATALIST` represents the datasets used for inference, and `MODE` represents evaluation mode:
21
+ ```bash
22
+ chmod +x ./scripts/run_inference.sh
23
+ ./scripts/run_inference.sh $MODELNAME $DATALIST $MODE
24
+ ```
25
+ <br />
26
+
27
+ The four available choices for `MODELNAME` are listed in `vlmeval/config.py`:
28
+ ```bash
29
+ minicpm_series = {
30
+ 'MiniCPM-V': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
31
+ 'MiniCPM-V-2': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
32
+ 'MiniCPM-Llama3-V-2_5': partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
33
+ 'MiniCPM-V-2_6': partial(MiniCPM_V_2_6, model_path='openbmb/MiniCPM-V-2_6'),
34
+ }
35
+ ```
36
+ <br />
37
+
38
+ All available choices for `DATALIST` are listed in `vlmeval/utils/dataset_config.py`. Separate the names of different datasets with spaces and add quotation marks at both ends:
39
+ ```bash
40
+ $DATALIST="MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST"
41
+ ```
42
+ <br />
43
+
44
+ While scoring on each benchmark directly, set `MODE=all`. If only inference results are required, set `MODE=infer`. In order to reproduce the results in the table displayed on the homepage (columns between MME and HallusionBench), you need to run the script according to the following settings:
45
+ ```bash
46
+ # without CoT
47
+ ./scripts/run_inference.sh MiniCPM-V-2_6 "MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST" all
48
+ ./scripts/run_inference.sh MiniCPM-V-2_6 MME all
49
+ # with CoT
50
+ # While running the CoT version of MME, you need to modify the 'use_cot' function in vlmeval/vlm/minicpm_v.py and add MME to the branch that returns True.
51
+ ./scripts/run_inference/sh MiniCPM-V-2_6 "MMMU_DEV_VAL MMVet MMStar HallusionBench OCRBench" all
52
+ ./scripts/run_inference.sh MiniCPM-V-2_6 MME all
53
+ ```
54
+ <br />
55
+
56
+ ### vqadataset
57
+ First, enter the `vqaeval` directory and install all dependencies. Then, create `downloads` subdirectory to store the downloaded dataset for all tasks:
58
+ ```bash
59
+ cd vqaeval
60
+ pip install -r requirements.txt
61
+ mkdir downloads
62
+ ```
63
+ <br />
64
+
65
+ Download the datasets from the following links and place it in the specified directories:
66
+ ###### TextVQA
67
+ ```bash
68
+ cd downloads
69
+ mkdir TextVQA && cd TextVQA
70
+ wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
71
+ unzip train_val_images.zip && rm train_val_images.zip
72
+ mv train_val_images/train_images . && rm -rf train_val_images
73
+ wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
74
+ cd ../..
75
+ ```
76
+
77
+ ###### DocVQA / DocVQATest
78
+
79
+ ```bash
80
+ cd downloads
81
+ mkdir DocVQA && cd DocVQA && mkdir spdocvqa_images
82
+ # Download Images and Annotations from Task 1 - Single Page Document Visual Question Answering at https://rrc.cvc.uab.es/?ch=17&com=downloads
83
+ # Move the spdocvqa_images.tar.gz and spdocvqa_qas.zip to DocVQA directory
84
+ tar -zxvf spdocvqa_images.tar.gz -C spdocvqa_images && rm spdocvqa_images.tar.gz
85
+ unzip spdocvqa_qas.zip && rm spdocvqa_qas.zip
86
+ cp spdocvqa_qas/val_v1.0_withQT.json . && cp spdocvqa_qas/test_v1.0.json . && rm -rf spdocvqa_qas
87
+ cd ../..
88
+ ```
89
+ <br />
90
+
91
+ The `downloads` directory should be organized according to the following structure:
92
+ ```bash
93
+ downloads
94
+ ├── TextVQA
95
+ │ ├── train_images
96
+ │ │ ├── ...
97
+ │ ├── TextVQA_0.5.1_val.json
98
+ ├── DocVQA
99
+ │ ├── spdocvqa_images
100
+ │ │ ├── ...
101
+ │ ├── val_v1.0_withQT.json
102
+ │ ├── test_v1.0.json
103
+ ```
104
+ <br />
105
+
106
+ Modify the parameters in `shell/run_inference.sh` and run inference:
107
+
108
+ ```bash
109
+ chmod +x ./shell/run_inference.sh
110
+ ./shell/run_inference.sh
111
+ ```
112
+ <br />
113
+
114
+ All optional parameters are listed in `eval_utils/getargs.py`. The meanings of some major parameters are listed as follows.
115
+ For `MiniCPM-V-2_6`, set `model_name` to `minicpmv26`:
116
+ ```bash
117
+ # path to images and their corresponding questions
118
+ # TextVQA
119
+ --textVQA_image_dir
120
+ --textVQA_ann_path
121
+ # DocVQA
122
+ --docVQA_image_dir
123
+ --docVQA_ann_path
124
+ # DocVQATest
125
+ --docVQATest_image_dir
126
+ --docVQATest_ann_path
127
+
128
+ # whether to eval on certain task
129
+ --eval_textVQA
130
+ --eval_docVQA
131
+ --eval_docVQATest
132
+ --eval_all
133
+
134
+ # model name and model path
135
+ --model_name
136
+ --model_path
137
+ # load model from ckpt
138
+ --ckpt
139
+ # the way the model processes input data, "interleave" represents interleaved image-text form, while "old" represents non-interleaved.
140
+ --generate_method
141
+
142
+ --batchsize
143
+
144
+ # path to save the outputs
145
+ --answer_path
146
+ ```
147
+ <br />
148
+
149
+ While evaluating on different tasks, parameters need to be set as follows:
150
+ ###### TextVQA
151
+ ```bash
152
+ --eval_textVQA
153
+ --textVQA_image_dir ./downloads/TextVQA/train_images
154
+ --textVQA_ann_path ./downloads/TextVQA/TextVQA_0.5.1_val.json
155
+ ```
156
+
157
+ ###### DocVQA
158
+ ```bash
159
+ --eval_docVQA
160
+ --docVQA_image_dir ./downloads/DocVQA/spdocvqa_images
161
+ --docVQA_ann_path ./downloads/DocVQA/val_v1.0_withQT.json
162
+ ```
163
+
164
+ ###### DocVQATest
165
+ ```bash
166
+ --eval_docVQATest
167
+ --docVQATest_image_dir ./downloads/DocVQA/spdocvqa_images
168
+ --docVQATest_ann_path ./downloads/DocVQA/test_v1.0.json
169
+ ```
170
+
171
+ <br />
172
+
173
+ For the DocVQATest task, in order to upload the inference results to the [official website](https://rrc.cvc.uab.es/?ch=17) for evaluation, run `shell/run_transform.sh` for format transformation after inference. `input_file_path` represents the path to the original output json, `output_file_path` represents the path to the transformed json:
174
+ ```bash
175
+ chmod +x ./shell/run_transform.sh
176
+ ./shell/run_transform.sh
177
+ ```
178
+ <br />
179
+
180
+ ## MiniCPM-Llama3-V-2_5
181
+
182
+ <details>
183
+ <summary>Expand</summary>
184
+
185
+ ### opencompass
186
+ First, enter the `vlmevalkit` directory and install all dependencies:
187
+ ```bash
188
+ cd vlmevalkit
189
+ pip install -r requirements.txt
190
+ ```
191
+ <br />
192
+
193
+ Then, run `scripts/run_inference.sh`, which receives three input parameters in sequence: `MODELNAME`, `DATALIST`, and `MODE`. `MODELNAME` represents the name of the model, `DATALIST` represents the datasets used for inference, and `MODE` represents evaluation mode:
194
+ ```bash
195
+ chmod +x ./scripts/run_inference.sh
196
+ ./scripts/run_inference.sh $MODELNAME $DATALIST $MODE
197
+ ```
198
+ <br />
199
+
200
+ The three available choices for `MODELNAME` are listed in `vlmeval/config.py`:
201
+ ```bash
202
+ ungrouped = {
203
+ 'MiniCPM-V':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
204
+ 'MiniCPM-V-2':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
205
+ 'MiniCPM-Llama3-V-2_5':partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
206
+ }
207
+ ```
208
+ <br />
209
+
210
+ All available choices for `DATALIST` are listed in `vlmeval/utils/dataset_config.py`. While evaluating on a single dataset, call the dataset name directly without quotation marks; while evaluating on multiple datasets, separate the names of different datasets with spaces and add quotation marks at both ends:
211
+ ```bash
212
+ $DATALIST="POPE ScienceQA_TEST ChartQA_TEST"
213
+ ```
214
+ <br />
215
+
216
+ While scoring on each benchmark directly, set `MODE=all`. If only inference results are required, set `MODE=infer`. In order to reproduce the results in the table displayed on the homepage (columns between MME and RealWorldQA), you need to run the script according to the following settings:
217
+ ```bash
218
+ # run on all 7 datasets
219
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 "MME MMBench_TEST_EN MMBench_TEST_CN MMMU_DEV_VAL MathVista_MINI LLaVABench RealWorldQA" all
220
+
221
+ # The following are instructions for running on a single dataset
222
+ # MME
223
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MME all
224
+ # MMBench_TEST_EN
225
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_EN all
226
+ # MMBench_TEST_CN
227
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_CN all
228
+ # MMMU_DEV_VAL
229
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMMU_DEV_VAL all
230
+ # MathVista_MINI
231
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MathVista_MINI all
232
+ # LLaVABench
233
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 LLaVABench all
234
+ # RealWorldQA
235
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 RealWorldQA all
236
+ ```
237
+ <br />
238
+
239
+ ### vqadataset
240
+ First, enter the `vqaeval` directory and install all dependencies. Then, create `downloads` subdirectory to store the downloaded dataset for all tasks:
241
+ ```bash
242
+ cd vqaeval
243
+ pip install -r requirements.txt
244
+ mkdir downloads
245
+ ```
246
+ <br />
247
+
248
+ Download the datasets from the following links and place it in the specified directories:
249
+ ###### TextVQA
250
+ ```bash
251
+ cd downloads
252
+ mkdir TextVQA && cd TextVQA
253
+ wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
254
+ unzip train_val_images.zip && rm train_val_images.zip
255
+ mv train_val_images/train_images . && rm -rf train_val_images
256
+ wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
257
+ cd ../..
258
+ ```
259
+
260
+ ###### DocVQA / DocVQATest
261
+
262
+ ```bash
263
+ cd downloads
264
+ mkdir DocVQA && cd DocVQA && mkdir spdocvqa_images
265
+ # Download Images and Annotations from Task 1 - Single Page Document Visual Question Answering at https://rrc.cvc.uab.es/?ch=17&com=downloads
266
+ # Move the spdocvqa_images.tar.gz and spdocvqa_qas.zip to DocVQA directory
267
+ tar -zxvf spdocvqa_images.tar.gz -C spdocvqa_images && rm spdocvqa_images.tar.gz
268
+ unzip spdocvqa_qas.zip && rm spdocvqa_qas.zip
269
+ cp spdocvqa_qas/val_v1.0_withQT.json . && cp spdocvqa_qas/test_v1.0.json . && rm -rf spdocvqa_qas
270
+ cd ../..
271
+ ```
272
+ <br />
273
+
274
+ The `downloads` directory should be organized according to the following structure:
275
+ ```bash
276
+ downloads
277
+ ├── TextVQA
278
+ │ ├── train_images
279
+ │ │ ├── ...
280
+ │ ├── TextVQA_0.5.1_val.json
281
+ ├── DocVQA
282
+ │ ├── spdocvqa_images
283
+ │ │ ├── ...
284
+ │ ├── val_v1.0_withQT.json
285
+ │ ├── test_v1.0.json
286
+ ```
287
+ <br />
288
+
289
+ Modify the parameters in `shell/run_inference.sh` and run inference:
290
+
291
+ ```bash
292
+ chmod +x ./shell/run_inference.sh
293
+ ./shell/run_inference.sh
294
+ ```
295
+ <br />
296
+
297
+ All optional parameters are listed in `eval_utils/getargs.py`. The meanings of some major parameters are listed as follows.
298
+ For `MiniCPM-Llama3-V-2_5`, set `model_name` to `minicpmv`:
299
+ ```bash
300
+ # path to images and their corresponding questions
301
+ # TextVQA
302
+ --textVQA_image_dir
303
+ --textVQA_ann_path
304
+ # DocVQA
305
+ --docVQA_image_dir
306
+ --docVQA_ann_path
307
+ # DocVQATest
308
+ --docVQATest_image_dir
309
+ --docVQATest_ann_path
310
+
311
+ # whether to eval on certain task
312
+ --eval_textVQA
313
+ --eval_docVQA
314
+ --eval_docVQATest
315
+ --eval_all
316
+
317
+ # model name and model path
318
+ --model_name
319
+ --model_path
320
+ # load model from ckpt
321
+ --ckpt
322
+ # the way the model processes input data, "interleave" represents interleaved image-text form, while "old" represents non-interleaved.
323
+ --generate_method
324
+
325
+ --batchsize
326
+
327
+ # path to save the outputs
328
+ --answer_path
329
+ ```
330
+ <br />
331
+
332
+ While evaluating on different tasks, parameters need to be set as follows:
333
+ ###### TextVQA
334
+ ```bash
335
+ --eval_textVQA
336
+ --textVQA_image_dir ./downloads/TextVQA/train_images
337
+ --textVQA_ann_path ./downloads/TextVQA/TextVQA_0.5.1_val.json
338
+ ```
339
+
340
+ ###### DocVQA
341
+ ```bash
342
+ --eval_docVQA
343
+ --docVQA_image_dir ./downloads/DocVQA/spdocvqa_images
344
+ --docVQA_ann_path ./downloads/DocVQA/val_v1.0_withQT.json
345
+ ```
346
+
347
+ ###### DocVQATest
348
+ ```bash
349
+ --eval_docVQATest
350
+ --docVQATest_image_dir ./downloads/DocVQA/spdocvqa_images
351
+ --docVQATest_ann_path ./downloads/DocVQA/test_v1.0.json
352
+ ```
353
+
354
+ <br />
355
+
356
+ For the DocVQATest task, in order to upload the inference results to the [official website](https://rrc.cvc.uab.es/?ch=17) for evaluation, run `shell/run_transform.sh` for format transformation after inference. `input_file_path` represents the path to the original output json, `output_file_path` represents the path to the transformed json:
357
+ ```bash
358
+ chmod +x ./shell/run_transform.sh
359
+ ./shell/run_transform.sh
360
+ ```
361
+
362
+ </details>
eval_mm/README_zh.md ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ ## MiniCPM-V 2.6
4
+
5
+ ### opencompass
6
+ 首先,进入 `vlmevalkit` 目录下,安装必要的依赖:
7
+ ```bash
8
+ cd vlmevalkit
9
+ pip install --upgrade pip
10
+ pip install -e .
11
+ wget https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=4377e0a7fe8ff8ffc4f7c9c6130c1dcd3874050ae4fc28b7ff1d35234fbca423
12
+ wget https://download.pytorch.org/whl/cu118/torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=2e63d62e09d9b48b407d3e1b30eb8ae4e3abad6968e8d33093b60d0657542428
13
+ wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
14
+ pip install torch-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl
15
+ pip install torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl
16
+ pip install flash_attn-2.6.3+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
17
+ rm *.whl
18
+ ```
19
+ <br />
20
+
21
+ 然后,运行 `scripts/run_inference.sh`,该脚本依次接收三个输入参数:`MODELNAME`, `DATALIST`, `MODE`。`MODELNAME` 为模型名称,`DATALIST` 为目标数据集,`MODE` 为评测模式。
22
+ ```bash
23
+ chmod +x ./scripts/run_inference.sh
24
+ ./scripts/run_inference.sh $MODELNAME $DATALIST $MODE
25
+ ```
26
+ <br />
27
+
28
+ `MODELNAME` 有四种选择,位于 `vlmeval/config.py` 中:
29
+ ```bash
30
+ minicpm_series = {
31
+ 'MiniCPM-V': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
32
+ 'MiniCPM-V-2': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
33
+ 'MiniCPM-Llama3-V-2_5': partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
34
+ 'MiniCPM-V-2_6': partial(MiniCPM_V_2_6, model_path='openbmb/MiniCPM-V-2_6'),
35
+ }
36
+ ```
37
+ <br />
38
+
39
+ 可选的所有 `DATALIST` 位于 `vlmeval/utils/dataset_config.py` 中。将不同数据集名称以空格隔开,两端加引号:
40
+ ```bash
41
+ $DATALIST="MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST"
42
+ ```
43
+ <br />
44
+
45
+ 直接对各 benchmark 进行评分时,设置 `MODE=all`。如果仅需要推理结果,则设置 `MODE=infer`。
46
+ 为了复现出首页展示的表格中的各项结果(MME 到 HallusionBench 之间的列),需要按照如下设置运行:
47
+ ```bash
48
+ # without CoT
49
+ ./scripts/run_inference.sh MiniCPM-V-2_6 "MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST" all
50
+ ./scripts/run_inference.sh MiniCPM-V-2_6 MME all
51
+ # with CoT,运行 CoT 版本的 MME 时,需要改写 vlmeval/vlm/minicpm_v.py 中的 'use_cot' 函数,将 MME 添加到 return True 的分支中
52
+ ./scripts/run_inference/sh MiniCPM-V-2_6 "MMMU_DEV_VAL MMVet MMStar HallusionBench OCRBench" all
53
+ ./scripts/run_inference.sh MiniCPM-V-2_6 MME all
54
+ ```
55
+ <br />
56
+
57
+ ### vqadataset
58
+ 首先,进入 `vqaeval` 目录下,安装必要的依赖,并创建 `downloads` 子目录,用于存储下载的数据集:
59
+ ```bash
60
+ cd vqaeval
61
+ pip install -r requirements.txt
62
+ mkdir downloads
63
+ ```
64
+ <br />
65
+
66
+ 然后,从下列各地址下载数据集并置于指定目录下:
67
+ ###### TextVQA
68
+ ```bash
69
+ cd downloads
70
+ mkdir TextVQA && cd TextVQA
71
+ wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
72
+ unzip train_val_images.zip && rm train_val_images.zip
73
+ mv train_val_images/train_images . && rm -rf train_val_images
74
+ wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
75
+ cd ../..
76
+ ```
77
+
78
+ ###### DocVQA / DocVQATest
79
+ ```bash
80
+ cd downloads
81
+ mkdir DocVQA && cd DocVQA && mkdir spdocvqa_images
82
+ # 在 https://rrc.cvc.uab.es/?ch=17&com=downloads 下载 Task 1 - Single Page Document Visual Question Answering 下的 Images 和 Annotations
83
+ # 将下载得到的 spdocvqa_images.tar.gz 以及 spdocvqa_qas.zip 置于 DocVQA 目录下
84
+ tar -zxvf spdocvqa_images.tar.gz -C spdocvqa_images && rm spdocvqa_images.tar.gz
85
+ unzip spdocvqa_qas.zip && rm spdocvqa_qas.zip
86
+ cp spdocvqa_qas/val_v1.0_withQT.json . && cp spdocvqa_qas/test_v1.0.json . && rm -rf spdocvqa_qas
87
+ cd ../..
88
+ ```
89
+ <br />
90
+
91
+ `downloads` 目录应当按照下列结构组织:
92
+ ```bash
93
+ downloads
94
+ ├── TextVQA
95
+ │ ├── train_images
96
+ │ │ ├── ...
97
+ │ ├── TextVQA_0.5.1_val.json
98
+ ├── DocVQA
99
+ │ ├── spdocvqa_images
100
+ │ │ ├── ...
101
+ │ ├── val_v1.0_withQT.json
102
+ │ ├── test_v1.0.json
103
+ ```
104
+ <br />
105
+
106
+ 准备好相应的数据集之后,修改 `shell/run_inference.sh` 的参数,运行推理:
107
+
108
+ ```bash
109
+ chmod +x ./shell/run_inference.sh
110
+ ./shell/run_inference.sh
111
+ ```
112
+ <br />
113
+
114
+ 可以传入的参数位于 `eval_utils/getargs.py` 中,各主要参数的含义如下。
115
+ 对于 `MiniCPM-V-2_6`,需要将 `model_name`设置为 `minicpmv26`:
116
+ ```bash
117
+ # 指定 TextVQA 评测所有图片和问题的路径
118
+ --textVQA_image_dir
119
+ --textVQA_ann_path
120
+ # 指定 DocVQA 评测所有图片和问题的路径
121
+ --docVQA_image_dir
122
+ --docVQA_ann_path
123
+ # 指定 DocVQATest 评测所有图片和问题的路径
124
+ --docVQATest_image_dir
125
+ --docVQATest_ann_path
126
+
127
+ # 决定是否评测某个任务,eval_all 设置为 True 表示所有任务都评测
128
+ --eval_textVQA
129
+ --eval_docVQA
130
+ --eval_docVQATest
131
+ --eval_all
132
+
133
+ # 模型名称、模型路径(从指定路径加载模型)
134
+ --model_name
135
+ --model_path
136
+ # 从 checkpoint 加载模型
137
+ --ckpt
138
+ # 模型处理输入数据的方式,interleave 表示图文交错式,old 表示非交错式
139
+ --generate_method
140
+ # 推理时的批处理规模,建议推理时设置为 1
141
+ --batchsize
142
+
143
+ # 输出内容保存的路径
144
+ --answer_path
145
+ ```
146
+ <br />
147
+
148
+ 评测三个任务需要设置的参数如下:
149
+ ###### TextVQA
150
+ ```bash
151
+ --eval_textVQA
152
+ --textVQA_image_dir ./downloads/TextVQA/train_images
153
+ --textVQA_ann_path ./downloads/TextVQA/TextVQA_0.5.1_val.json
154
+ ```
155
+
156
+ ###### DocVQA
157
+ ```bash
158
+ --eval_docVQA
159
+ --docVQA_image_dir ./downloads/DocVQA/spdocvqa_images
160
+ --docVQA_ann_path ./downloads/DocVQA/val_v1.0_withQT.json
161
+ ```
162
+
163
+ ###### DocVQATest
164
+ ```bash
165
+ --eval_docVQATest
166
+ --docVQATest_image_dir ./downloads/DocVQA/spdocvqa_images
167
+ --docVQATest_ann_path ./downloads/DocVQA/test_v1.0.json
168
+ ```
169
+ <br />
170
+
171
+ 对于 DocVQATest 任务,为了将推理结果上传到[官方网站](https://rrc.cvc.uab.es/?ch=17)进行评测,还需要运行 `shell/run_transform.sh` 进行格式转换。其中,`input_file_path` 对应原始输出的 json 的路径,`output_file_path` 为自定义的转换后的 json 的路径:
172
+ ```bash
173
+ chmod +x ./shell/run_transform.sh
174
+ ./shell/run_transform.sh
175
+ ```
176
+ <br />
177
+
178
+ ## MiniCPM-Llama3-V-2_5
179
+
180
+ <details>
181
+ <summary>展开</summary>
182
+
183
+ ### opencompass
184
+ 首先,进入 `vlmevalkit` 目录下,安装必要的依赖:
185
+ ```bash
186
+ cd vlmevalkit
187
+ pip install -r requirements.txt
188
+ ```
189
+ <br />
190
+
191
+ 然后,运行 `scripts/run_inference.sh`,该脚本依次接收三个输入参数:`MODELNAME`, `DATALIST`, `MODE`。`MODELNAME` 为模型名称,`DATALIST` 为目标数据集,`MODE` 为评测模式。
192
+ ```bash
193
+ chmod +x ./scripts/run_inference.sh
194
+ ./scripts/run_inference.sh $MODELNAME $DATALIST $MODE
195
+ ```
196
+ <br />
197
+
198
+ `MODELNAME` 有三种选择,位于 `vlmeval/config.py` 中:
199
+ ```bash
200
+ ungrouped = {
201
+ 'MiniCPM-V':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
202
+ 'MiniCPM-V-2':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
203
+ 'MiniCPM-Llama3-V-2_5':partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
204
+ }
205
+ ```
206
+ <br />
207
+
208
+ 可选的所有 `DATALIST` 位于 `vlmeval/utils/dataset_config.py` 中,评测单个数据集时,直接调用数据集名称,不加引号;评测多个数据集时,将不同数据集名称以空格隔开,两端加引号:
209
+ ```bash
210
+ $DATALIST="POPE ScienceQA_TEST ChartQA_TEST"
211
+ ```
212
+ <br />
213
+
214
+ 直接对各 benchmark 进行评分时,设置 `MODE=all`。如果仅需要推理结果,则设置 `MODE=infer`
215
+ 为了复现出首页展示的表格中的各项结果(MME 到 RealWorldQA 之间的列),需要按照如下设置运行:
216
+ ```bash
217
+ # 一次性运行 7 个数据集
218
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 "MME MMBench_TEST_EN MMBench_TEST_CN MMMU_DEV_VAL MathVista_MINI LLaVABench RealWorldQA" all
219
+
220
+ # 以下是单独运行 1 个数据集的指令
221
+ # MME
222
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MME all
223
+ # MMBench_TEST_EN
224
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_EN all
225
+ # MMBench_TEST_CN
226
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMBench_TEST_CN all
227
+ # MMMU_DEV_VAL
228
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MMMU_DEV_VAL all
229
+ # MathVista_MINI
230
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 MathVista_MINI all
231
+ # LLaVABench
232
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 LLaVABench all
233
+ # RealWorldQA
234
+ ./scripts/run_inference.sh MiniCPM-Llama3-V-2_5 RealWorldQA all
235
+ ```
236
+ <br />
237
+
238
+ ### vqadataset
239
+ 首先,进入 `vqaeval` 目录下,安装必要的依赖,并创建 `downloads` 子目录,用于存储下载的数据集:
240
+ ```bash
241
+ cd vqaeval
242
+ pip install -r requirements.txt
243
+ mkdir downloads
244
+ ```
245
+ <br />
246
+
247
+ 然后,从下列各地址下载数据集并置于指定目录下:
248
+ ###### TextVQA
249
+ ```bash
250
+ cd downloads
251
+ mkdir TextVQA && cd TextVQA
252
+ wget https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
253
+ unzip train_val_images.zip && rm train_val_images.zip
254
+ mv train_val_images/train_images . && rm -rf train_val_images
255
+ wget https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
256
+ cd ../..
257
+ ```
258
+
259
+ ###### DocVQA / DocVQATest
260
+ ```bash
261
+ cd downloads
262
+ mkdir DocVQA && cd DocVQA && mkdir spdocvqa_images
263
+ # 在 https://rrc.cvc.uab.es/?ch=17&com=downloads 下载 Task 1 - Single Page Document Visual Question Answering 下的 Images 和 Annotations
264
+ # 将下载得到的 spdocvqa_images.tar.gz 以及 spdocvqa_qas.zip 置于 DocVQA 目录下
265
+ tar -zxvf spdocvqa_images.tar.gz -C spdocvqa_images && rm spdocvqa_images.tar.gz
266
+ unzip spdocvqa_qas.zip && rm spdocvqa_qas.zip
267
+ cp spdocvqa_qas/val_v1.0_withQT.json . && cp spdocvqa_qas/test_v1.0.json . && rm -rf spdocvqa_qas
268
+ cd ../..
269
+ ```
270
+ <br />
271
+
272
+ `downloads` 目录应当按照下列结构组织:
273
+ ```bash
274
+ downloads
275
+ ├── TextVQA
276
+ │ ├── train_images
277
+ │ │ ├── ...
278
+ │ ├── TextVQA_0.5.1_val.json
279
+ ├── DocVQA
280
+ │ ├─�� spdocvqa_images
281
+ │ │ ├── ...
282
+ │ ├── val_v1.0_withQT.json
283
+ │ ├── test_v1.0.json
284
+ ```
285
+ <br />
286
+
287
+ 准备好相应的数据集之后,修改 `shell/run_inference.sh` 的参数,运行推理:
288
+
289
+ ```bash
290
+ chmod +x ./shell/run_inference.sh
291
+ ./shell/run_inference.sh
292
+ ```
293
+ <br />
294
+
295
+ 可以传入的参数位于 `eval_utils/getargs.py` 中,各主要参数的含义如下。
296
+ 对于 `MiniCPM-Llama3-V-2_5`,需要将 `model_name` 设置为 `minicpmv`:
297
+ ```bash
298
+ # 指定 TextVQA 评测所有图片和问题的路径
299
+ --textVQA_image_dir
300
+ --textVQA_ann_path
301
+ # 指定 DocVQA 评测所有图片和问题的路径
302
+ --docVQA_image_dir
303
+ --docVQA_ann_path
304
+ # 指定 DocVQATest 评测所有图片和问题的路径
305
+ --docVQATest_image_dir
306
+ --docVQATest_ann_path
307
+
308
+ # 决定是否评测某个任务,eval_all 设置为 True 表示所有任务都评测
309
+ --eval_textVQA
310
+ --eval_docVQA
311
+ --eval_docVQATest
312
+ --eval_all
313
+
314
+ # 模型名称、模型路径(从指定路径加载模型)
315
+ --model_name
316
+ --model_path
317
+ # 从 checkpoint 加载模型
318
+ --ckpt
319
+ # 模型处理输入数据的方式,interleave 表示图文交错式,old 表示非交错式
320
+ --generate_method
321
+ # 推理时的批处理规模,建议推理时设置为 1
322
+ --batchsize
323
+
324
+ # 输出内容保存的路径
325
+ --answer_path
326
+ ```
327
+ <br />
328
+
329
+ 评测三个任务需要设置的参数如下:
330
+ ###### TextVQA
331
+ ```bash
332
+ --eval_textVQA
333
+ --textVQA_image_dir ./downloads/TextVQA/train_images
334
+ --textVQA_ann_path ./downloads/TextVQA/TextVQA_0.5.1_val.json
335
+ ```
336
+
337
+ ###### DocVQA
338
+ ```bash
339
+ --eval_docVQA
340
+ --docVQA_image_dir ./downloads/DocVQA/spdocvqa_images
341
+ --docVQA_ann_path ./downloads/DocVQA/val_v1.0_withQT.json
342
+ ```
343
+
344
+ ###### DocVQATest
345
+ ```bash
346
+ --eval_docVQATest
347
+ --docVQATest_image_dir ./downloads/DocVQA/spdocvqa_images
348
+ --docVQATest_ann_path ./downloads/DocVQA/test_v1.0.json
349
+ ```
350
+ <br />
351
+
352
+ 对于 DocVQATest 任务,为了将推理结果上传到[官方网站](https://rrc.cvc.uab.es/?ch=17)进行评测,还需要运行 `shell/run_transform.sh` 进行格式转换。其中,`input_file_path` 对应原始输出的 json 的路径,`output_file_path` 为自定义的转换后的 json 的路径:
353
+ ```bash
354
+ chmod +x ./shell/run_transform.sh
355
+ ./shell/run_transform.sh
356
+ ```
357
+
358
+ </details>
eval_mm/vlmevalkit/requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord
2
+ gradio
3
+ huggingface_hub
4
+ imageio
5
+ matplotlib
6
+ moviepy
7
+ numpy>=1.23.4
8
+ omegaconf
9
+ openai==1.3.5
10
+ opencv-python>=4.4.0.46
11
+ openpyxl
12
+ pandas
13
+ peft
14
+ pillow
15
+ portalocker
16
+ python-dotenv
17
+ requests
18
+ rich
19
+ sentencepiece
20
+ setuptools
21
+ sty
22
+ tabulate
23
+ tiktoken
24
+ timeout-decorator
25
+ torch>=2.0.1
26
+ tqdm
27
+ transformers
28
+ typing_extensions==4.7.1
29
+ validators
30
+ xlsxwriter
eval_mm/vlmevalkit/requirements/docs.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ docutils==0.18.1
2
+ modelindex
3
+ myst-parser
4
+ -e git+https://github.com/open-compass/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
5
+ sphinx==6.1.3
6
+ sphinx-copybutton
7
+ sphinx-design
8
+ sphinx-notfound-page
9
+ sphinx-tabs
10
+ sphinxcontrib-jquery
11
+ tabulate
eval_mm/vlmevalkit/run.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ from vlmeval.config import supported_VLM
5
+ from vlmeval.dataset import build_dataset
6
+ from vlmeval.inference import infer_data_job
7
+ from vlmeval.inference_video import infer_data_job_video
8
+ from vlmeval.inference_mt import infer_data_job_mt
9
+ from vlmeval.smp import *
10
+ from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser()
15
+ # Essential Args
16
+ parser.add_argument('--data', type=str, nargs='+', required=True)
17
+ parser.add_argument('--model', type=str, nargs='+', required=True)
18
+ # Args that only apply to Video Dataset
19
+ parser.add_argument('--nframe', type=int, default=8)
20
+ parser.add_argument('--pack', action='store_true')
21
+ parser.add_argument('--use-subtitle', action='store_true')
22
+ # Work Dir
23
+ parser.add_argument('--work-dir', type=str, default='./outputs', help='select the output directory')
24
+ # Infer + Eval or Infer Only
25
+ parser.add_argument('--mode', type=str, default='all', choices=['all', 'infer'])
26
+ # API Kwargs, Apply to API VLMs and Judge API LLMs
27
+ parser.add_argument('--nproc', type=int, default=4, help='Parallel API calling')
28
+ parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
29
+ # Explicitly Set the Judge Model
30
+ parser.add_argument('--judge', type=str, default=None)
31
+ # Logging Utils
32
+ parser.add_argument('--verbose', action='store_true')
33
+ # Configuration for Resume
34
+ # Ignore: will not rerun failed VLM inference
35
+ parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ')
36
+ # Rerun: will remove all evaluation temp files
37
+ parser.add_argument('--rerun', action='store_true')
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def main():
43
+ logger = get_logger('RUN')
44
+
45
+ args = parse_args()
46
+ assert len(args.data), '--data should be a list of data files'
47
+
48
+ if args.retry is not None:
49
+ for k, v in supported_VLM.items():
50
+ if hasattr(v, 'keywords') and 'retry' in v.keywords:
51
+ v.keywords['retry'] = args.retry
52
+ supported_VLM[k] = v
53
+ if hasattr(v, 'keywords') and 'verbose' in v.keywords:
54
+ v.keywords['verbose'] = args.verbose
55
+ supported_VLM[k] = v
56
+
57
+ rank, world_size = get_rank_and_world_size()
58
+ if world_size > 1:
59
+ local_rank = os.environ.get('LOCAL_RANK', 0)
60
+ torch.cuda.set_device(int(local_rank))
61
+ dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=10800))
62
+
63
+ for _, model_name in enumerate(args.model):
64
+ model = None
65
+
66
+ pred_root = osp.join(args.work_dir, model_name)
67
+ os.makedirs(pred_root, exist_ok=True)
68
+
69
+ for _, dataset_name in enumerate(args.data):
70
+ dataset_kwargs = {}
71
+ if dataset_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']:
72
+ dataset_kwargs['model'] = model_name
73
+ if dataset_name == 'MMBench-Video':
74
+ dataset_kwargs['pack'] = args.pack
75
+ if dataset_name == 'Video-MME':
76
+ dataset_kwargs['use_subtitle'] = args.use_subtitle
77
+
78
+ # If distributed, first build the dataset on the main process for doing preparation works
79
+ if world_size > 1:
80
+ dataset = build_dataset(dataset_name, **dataset_kwargs) if rank == 0 else None
81
+ dist.barrier()
82
+ dataset_list = [dataset]
83
+ dist.broadcast_object_list(dataset_list, src=0)
84
+ dataset = dataset_list[0]
85
+ else:
86
+ dataset = build_dataset(dataset_name, **dataset_kwargs)
87
+ if dataset is None:
88
+ logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
89
+ continue
90
+
91
+ result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx'
92
+ if dataset_name in ['MMBench-Video']:
93
+ packstr = 'pack' if args.pack else 'nopack'
94
+ result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx'
95
+ elif dataset.MODALITY == 'VIDEO':
96
+ if args.pack:
97
+ logger.info(f'{dataset_name} not support Pack Mode, directly change to unpack')
98
+ args.pack = False
99
+ packstr = 'pack' if args.pack else 'nopack'
100
+ result_file = f'{pred_root}/{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx'
101
+ if dataset_name in ['Video-MME']:
102
+ subtitlestr = 'subs' if args.use_subtitle else 'nosubs'
103
+ result_file = result_file.replace('.xlsx', f'_{subtitlestr}.xlsx')
104
+
105
+ if dataset.TYPE == 'MT':
106
+ result_file = result_file.replace('.xlsx', '.tsv')
107
+
108
+ if osp.exists(result_file) and args.rerun:
109
+ for keyword in ['openai', 'gpt', 'auxmatch']:
110
+ os.system(f'rm {pred_root}/{model_name}_{dataset_name}_{keyword}*')
111
+
112
+ if model is None:
113
+ model = model_name # which is only a name
114
+
115
+ # Perform the Inference
116
+ if dataset.MODALITY == 'VIDEO':
117
+ model = infer_data_job_video(
118
+ model,
119
+ work_dir=pred_root,
120
+ model_name=model_name,
121
+ dataset=dataset,
122
+ nframe=args.nframe,
123
+ pack=args.pack,
124
+ verbose=args.verbose,
125
+ subtitle=args.use_subtitle,
126
+ api_nproc=args.nproc)
127
+ elif dataset.TYPE == 'MT':
128
+ model = infer_data_job_mt(
129
+ model,
130
+ work_dir=pred_root,
131
+ model_name=model_name,
132
+ dataset=dataset,
133
+ verbose=args.verbose,
134
+ api_nproc=args.nproc,
135
+ ignore_failed=args.ignore)
136
+ else:
137
+ model = infer_data_job(
138
+ model,
139
+ work_dir=pred_root,
140
+ model_name=model_name,
141
+ dataset=dataset,
142
+ verbose=args.verbose,
143
+ api_nproc=args.nproc,
144
+ ignore_failed=args.ignore)
145
+
146
+ # Set the judge kwargs first before evaluation or dumping
147
+ judge_kwargs = {
148
+ 'nproc': args.nproc,
149
+ 'verbose': args.verbose,
150
+ }
151
+ if args.retry is not None:
152
+ judge_kwargs['retry'] = args.retry
153
+ if args.judge is not None:
154
+ judge_kwargs['model'] = args.judge
155
+ else:
156
+ if dataset.TYPE in ['MCQ', 'Y/N']:
157
+ judge_kwargs['model'] = 'chatgpt-0125'
158
+ elif listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video', 'MathVision'], dataset_name):
159
+ judge_kwargs['model'] = 'gpt-4-turbo'
160
+ elif listinstr(['MMLongBench', 'MMDU', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI'], dataset_name):
161
+ judge_kwargs['model'] = 'gpt-4o'
162
+ if 'OPENAI_API_KEY_JUDGE' in os.environ and len(os.environ['OPENAI_API_KEY_JUDGE']):
163
+ judge_kwargs['key'] = os.environ['OPENAI_API_KEY_JUDGE']
164
+ if 'OPENAI_API_BASE_JUDGE' in os.environ and len(os.environ['OPENAI_API_BASE_JUDGE']):
165
+ judge_kwargs['api_base'] = os.environ['OPENAI_API_BASE_JUDGE']
166
+
167
+ if rank == 0:
168
+ if dataset_name in ['MMMU_TEST']:
169
+ result_json = MMMU_result_transfer(result_file)
170
+ logger.info(f'Transfer MMMU_TEST result to json for official evaluation, '
171
+ f'json file saved in {result_json}') # noqa: E501
172
+ continue
173
+ elif 'MMT-Bench_ALL' in dataset_name:
174
+ submission_file = MMTBench_result_transfer(result_file, **judge_kwargs)
175
+ logger.info(f'Extract options from prediction of MMT-Bench FULL split for official evaluation '
176
+ f'(https://eval.ai/web/challenges/challenge-page/2328/overview), '
177
+ f'submission file saved in {submission_file}') # noqa: E501
178
+ continue
179
+ elif 'MLLMGuard_DS' in dataset_name:
180
+ logger.info('The evaluation of MLLMGuard_DS is not supported yet. ') # noqa: E501
181
+ continue
182
+ elif 'AesBench_TEST' == dataset_name:
183
+ logger.info(f'The results are saved in {result_file}. '
184
+ f'Please send it to the AesBench Team via [email protected].') # noqa: E501
185
+ continue
186
+
187
+ if dataset_name in [
188
+ 'MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMBench', 'MMBench_CN',
189
+ 'MMBench_TEST_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_V11', 'MMBench_CN_V11'
190
+ ]:
191
+ if not MMBenchOfficialServer(dataset_name):
192
+ logger.error(
193
+ f'Can not evaluate {dataset_name} on non-official servers, '
194
+ 'will skip the evaluation. '
195
+ )
196
+ continue
197
+
198
+ eval_proxy = os.environ.get('EVAL_PROXY', None)
199
+ old_proxy = os.environ.get('HTTP_PROXY', '')
200
+
201
+ if rank == 0 and args.mode == 'all':
202
+ if eval_proxy is not None:
203
+ proxy_set(eval_proxy)
204
+
205
+ eval_results = dataset.evaluate(result_file, **judge_kwargs)
206
+ if eval_results is not None:
207
+ assert isinstance(eval_results, dict) or isinstance(eval_results, pd.DataFrame)
208
+ logger.info(f'The evaluation of model {model_name} x dataset {dataset_name} has finished! ')
209
+ logger.info('Evaluation Results:')
210
+ if isinstance(eval_results, dict):
211
+ logger.info('\n' + json.dumps(eval_results, indent=4))
212
+ elif isinstance(eval_results, pd.DataFrame):
213
+ if len(eval_results) < len(eval_results.columns):
214
+ eval_results = eval_results.T
215
+ logger.info('\n' + tabulate(eval_results))
216
+
217
+ if eval_proxy is not None:
218
+ proxy_set(old_proxy)
219
+
220
+
221
+ if __name__ == '__main__':
222
+ load_env()
223
+ main()
eval_mm/vlmevalkit/scripts/run_inference.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export PATH=/usr/local/cuda/bin:$PATH
2
+
3
+ export HF_ENDPOINT=https://hf-mirror.com
4
+ export OMP_NUM_THREADS=1
5
+ export timestamp=`date +"%Y%m%d%H%M%S"`
6
+ export OLD_VERSION='False'
7
+ export PYTHONPATH=$(dirname $SELF_DIR):$PYTHONPATH
8
+
9
+ # gpu consumed
10
+ # fp16 17-18G
11
+ # int4 7-8G
12
+
13
+ # model to be used
14
+ # Example: MODELNAME=MiniCPM_V_2_6
15
+ MODELNAME=$1
16
+ # datasets to be tested
17
+ # Example: DATALIST="MMMU_DEV_VAL MathVista_MINI MMVet MMBench_DEV_EN_V11 MMBench_DEV_CN_V11 MMStar HallusionBench AI2D_TEST"
18
+ DATALIST=$2
19
+ # test mode, all or infer
20
+ MODE=$3
21
+
22
+ echo "Starting inference with model $MODELNAME on datasets $DATALIST"
23
+ # run on multi gpus with torchrun command
24
+ # remember to run twice, the first run may fail
25
+ torchrun --nproc_per_node=8 run.py --data $DATALIST --model $MODELNAME --mode $MODE
26
+ torchrun --nproc_per_node=8 run.py --data $DATALIST --model $MODELNAME --mode $MODE
27
+ # run on single gpu with python command
28
+ # python run.py --data $DATALIST --model $MODELNAME --verbose --mode $MODE
29
+ # python run.py --data $DATALIST --model $MODELNAME --verbose --mode $MODE
30
+
31
+ ls
eval_mm/vlmevalkit/setup.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ from os.path import exists
4
+ from setuptools import find_packages, setup
5
+
6
+
7
+ def parse_requirements(fname='requirements.txt', with_version=True):
8
+ """Parse the package dependencies listed in a requirements file but strips
9
+ specific versioning information.
10
+
11
+ Args:
12
+ fname (str): path to requirements file
13
+ with_version (bool, default=False): if True include version specs
14
+
15
+ Returns:
16
+ List[str]: list of requirements items
17
+
18
+ CommandLine:
19
+ python -c "import setup; print(setup.parse_requirements())"
20
+ """
21
+
22
+ require_fpath = fname
23
+
24
+ def parse_line(line):
25
+ """Parse information from a line in a requirements text file."""
26
+ if line.startswith('-r '):
27
+ # Allow specifying requirements in other files
28
+ target = line.split(' ')[1]
29
+ for info in parse_require_file(target):
30
+ yield info
31
+ else:
32
+ info = {'line': line}
33
+ if line.startswith('-e '):
34
+ info['package'] = line.split('#egg=')[1]
35
+ elif '@git+' in line:
36
+ info['package'] = line
37
+ else:
38
+ # Remove versioning from the package
39
+ pat = '(' + '|'.join(['>=', '==', '>']) + ')'
40
+ parts = re.split(pat, line, maxsplit=1)
41
+ parts = [p.strip() for p in parts]
42
+
43
+ info['package'] = parts[0]
44
+ if len(parts) > 1:
45
+ op, rest = parts[1:]
46
+ if ';' in rest:
47
+ # Handle platform specific dependencies
48
+ # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
49
+ version, platform_deps = map(str.strip,
50
+ rest.split(';'))
51
+ info['platform_deps'] = platform_deps
52
+ else:
53
+ version = rest # NOQA
54
+ info['version'] = (op, version)
55
+ yield info
56
+
57
+ def parse_require_file(fpath):
58
+ with open(fpath, 'r') as f:
59
+ for line in f.readlines():
60
+ line = line.strip()
61
+ if line and not line.startswith('#'):
62
+ for info in parse_line(line):
63
+ yield info
64
+
65
+ def gen_packages_items():
66
+ if exists(require_fpath):
67
+ for info in parse_require_file(require_fpath):
68
+ parts = [info['package']]
69
+ if with_version and 'version' in info:
70
+ parts.extend(info['version'])
71
+ if not sys.version.startswith('3.4'):
72
+ # apparently package_deps are broken in 3.4
73
+ platform_deps = info.get('platform_deps')
74
+ if platform_deps is not None:
75
+ parts.append(';' + platform_deps)
76
+ item = ''.join(parts)
77
+ yield item
78
+
79
+ packages = list(gen_packages_items())
80
+ return packages
81
+
82
+
83
+ with open('README.md') as f:
84
+ readme = f.read()
85
+
86
+
87
+ def do_setup():
88
+ setup(
89
+ name='vlmeval',
90
+ version='0.1.0',
91
+ description='OpenCompass VLM Evaluation Kit',
92
+ author='Haodong Duan',
93
+ author_email='[email protected]',
94
+ maintainer='Haodong Duan',
95
+ maintainer_email='[email protected]',
96
+ long_description=readme,
97
+ long_description_content_type='text/markdown',
98
+ cmdclass={},
99
+ install_requires=parse_requirements('requirements.txt'),
100
+ setup_requires=[],
101
+ python_requires='>=3.7.0',
102
+ packages=find_packages(exclude=[
103
+ 'test*',
104
+ 'paper_test*',
105
+ ]),
106
+ keywords=['AI', 'NLP', 'in-context learning'],
107
+ entry_points={
108
+ 'console_scripts': ['vlmutil = vlmeval:cli']
109
+ },
110
+ classifiers=[
111
+ 'Programming Language :: Python :: 3.7',
112
+ 'Programming Language :: Python :: 3.8',
113
+ 'Programming Language :: Python :: 3.9',
114
+ 'Programming Language :: Python :: 3.10',
115
+ 'Intended Audience :: Developers',
116
+ 'Intended Audience :: Education',
117
+ 'Intended Audience :: Science/Research',
118
+ ])
119
+
120
+
121
+ if __name__ == '__main__':
122
+ do_setup()
eval_mm/vlmevalkit/vlmeval/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import torch
3
+ except ImportError:
4
+ pass
5
+
6
+ from .smp import *
7
+ from .api import *
8
+ from .dataset import *
9
+ from .utils import *
10
+ from .vlm import *
11
+ from .config import *
12
+ from .tools import cli
13
+
14
+ load_env()
15
+
16
+ __version__ = '0.2rc1'
eval_mm/vlmevalkit/vlmeval/api/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .gpt import OpenAIWrapper, GPT4V
2
+
3
+ __all__ = [
4
+ 'OpenAIWrapper', 'GPT4V'
5
+ ]
eval_mm/vlmevalkit/vlmeval/api/base.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import random as rd
3
+ from abc import abstractmethod
4
+ import os.path as osp
5
+ import copy as cp
6
+ from ..smp import get_logger, parse_file, concat_images_vlmeval
7
+
8
+
9
+ class BaseAPI:
10
+
11
+ allowed_types = ['text', 'image']
12
+ INTERLEAVE = True
13
+ INSTALL_REQ = False
14
+
15
+ def __init__(self,
16
+ retry=10,
17
+ wait=3,
18
+ system_prompt=None,
19
+ verbose=True,
20
+ fail_msg='Failed to obtain answer via API.',
21
+ **kwargs):
22
+ """Base Class for all APIs.
23
+
24
+ Args:
25
+ retry (int, optional): The retry times for `generate_inner`. Defaults to 10.
26
+ wait (int, optional): The wait time after each failed retry of `generate_inner`. Defaults to 3.
27
+ system_prompt (str, optional): Defaults to None.
28
+ verbose (bool, optional): Defaults to True.
29
+ fail_msg (str, optional): The message to return when failed to obtain answer.
30
+ Defaults to 'Failed to obtain answer via API.'.
31
+ **kwargs: Other kwargs for `generate_inner`.
32
+ """
33
+
34
+ self.wait = wait
35
+ self.retry = retry
36
+ self.system_prompt = system_prompt
37
+ self.verbose = verbose
38
+ self.fail_msg = fail_msg
39
+ self.logger = get_logger('ChatAPI')
40
+
41
+ if len(kwargs):
42
+ self.logger.info(f'BaseAPI received the following kwargs: {kwargs}')
43
+ self.logger.info('Will try to use them as kwargs for `generate`. ')
44
+ self.default_kwargs = kwargs
45
+
46
+ @abstractmethod
47
+ def generate_inner(self, inputs, **kwargs):
48
+ """The inner function to generate the answer.
49
+
50
+ Returns:
51
+ tuple(int, str, str): ret_code, response, log
52
+ """
53
+ self.logger.warning('For APIBase, generate_inner is an abstract method. ')
54
+ assert 0, 'generate_inner not defined'
55
+ ret_code, answer, log = None, None, None
56
+ # if ret_code is 0, means succeed
57
+ return ret_code, answer, log
58
+
59
+ def working(self):
60
+ """If the API model is working, return True, else return False.
61
+
62
+ Returns:
63
+ bool: If the API model is working, return True, else return False.
64
+ """
65
+ self.old_timeout = None
66
+ if hasattr(self, 'timeout'):
67
+ self.old_timeout = self.timeout
68
+ self.timeout = 120
69
+
70
+ retry = 5
71
+ while retry > 0:
72
+ ret = self.generate('hello')
73
+ if ret is not None and ret != '' and self.fail_msg not in ret:
74
+ if self.old_timeout is not None:
75
+ self.timeout = self.old_timeout
76
+ return True
77
+ retry -= 1
78
+
79
+ if self.old_timeout is not None:
80
+ self.timeout = self.old_timeout
81
+ return False
82
+
83
+ def check_content(self, msgs):
84
+ """Check the content type of the input. Four types are allowed: str, dict, liststr, listdict.
85
+
86
+ Args:
87
+ msgs: Raw input messages.
88
+
89
+ Returns:
90
+ str: The message type.
91
+ """
92
+ if isinstance(msgs, str):
93
+ return 'str'
94
+ if isinstance(msgs, dict):
95
+ return 'dict'
96
+ if isinstance(msgs, list):
97
+ types = [self.check_content(m) for m in msgs]
98
+ if all(t == 'str' for t in types):
99
+ return 'liststr'
100
+ if all(t == 'dict' for t in types):
101
+ return 'listdict'
102
+ return 'unknown'
103
+
104
+ def preproc_content(self, inputs):
105
+ """Convert the raw input messages to a list of dicts.
106
+
107
+ Args:
108
+ inputs: raw input messages.
109
+
110
+ Returns:
111
+ list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
112
+ """
113
+ if self.check_content(inputs) == 'str':
114
+ return [dict(type='text', value=inputs)]
115
+ elif self.check_content(inputs) == 'dict':
116
+ assert 'type' in inputs and 'value' in inputs
117
+ return [inputs]
118
+ elif self.check_content(inputs) == 'liststr':
119
+ res = []
120
+ for s in inputs:
121
+ mime, pth = parse_file(s)
122
+ if mime is None or mime == 'unknown':
123
+ res.append(dict(type='text', value=s))
124
+ else:
125
+ res.append(dict(type=mime.split('/')[0], value=pth))
126
+ return res
127
+ elif self.check_content(inputs) == 'listdict':
128
+ for item in inputs:
129
+ assert 'type' in item and 'value' in item
130
+ mime, s = parse_file(item['value'])
131
+ if mime is None:
132
+ assert item['type'] == 'text', item['value']
133
+ else:
134
+ assert mime.split('/')[0] == item['type']
135
+ item['value'] = s
136
+ return inputs
137
+ else:
138
+ return None
139
+
140
+ # May exceed the context windows size, so try with different turn numbers.
141
+ def chat_inner(self, inputs, **kwargs):
142
+ _ = kwargs.pop('dataset', None)
143
+ while len(inputs):
144
+ try:
145
+ return self.generate_inner(inputs, **kwargs)
146
+ except:
147
+ inputs = inputs[1:]
148
+ while len(inputs) and inputs[0]['role'] != 'user':
149
+ inputs = inputs[1:]
150
+ continue
151
+ return -1, self.fail_msg + ': ' + 'Failed with all possible conversation turns.', None
152
+
153
+ def chat(self, messages, **kwargs1):
154
+ """The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
155
+ assert hasattr(self, 'chat_inner'), 'The API model should has the `chat_inner` method. '
156
+ for msg in messages:
157
+ assert isinstance(msg, dict) and 'role' in msg and 'content' in msg, msg
158
+ assert self.check_content(msg['content']) in ['str', 'dict', 'liststr', 'listdict'], msg
159
+ msg['content'] = self.preproc_content(msg['content'])
160
+ # merge kwargs
161
+ kwargs = cp.deepcopy(self.default_kwargs)
162
+ kwargs.update(kwargs1)
163
+
164
+ answer = None
165
+ # a very small random delay [0s - 0.5s]
166
+ T = rd.random() * 0.5
167
+ time.sleep(T)
168
+
169
+ assert messages[-1]['role'] == 'user'
170
+
171
+ for i in range(self.retry):
172
+ try:
173
+ ret_code, answer, log = self.chat_inner(messages, **kwargs)
174
+ if ret_code == 0 and self.fail_msg not in answer and answer != '':
175
+ if self.verbose:
176
+ print(answer)
177
+ return answer
178
+ elif self.verbose:
179
+ if not isinstance(log, str):
180
+ try:
181
+ log = log.text
182
+ except:
183
+ self.logger.warning(f'Failed to parse {log} as an http response. ')
184
+ self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
185
+ except Exception as err:
186
+ if self.verbose:
187
+ self.logger.error(f'An error occured during try {i}:')
188
+ self.logger.error(err)
189
+ # delay before each retry
190
+ T = rd.random() * self.wait * 2
191
+ time.sleep(T)
192
+
193
+ return self.fail_msg if answer in ['', None] else answer
194
+
195
+ def generate(self, message, **kwargs1):
196
+ """The main function to generate the answer. Will call `generate_inner` with the preprocessed input messages.
197
+
198
+ Args:
199
+ message: raw input messages.
200
+
201
+ Returns:
202
+ str: The generated answer of the Failed Message if failed to obtain answer.
203
+ """
204
+ assert self.check_content(message) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {message}'
205
+ message = self.preproc_content(message)
206
+ assert message is not None and self.check_content(message) == 'listdict'
207
+ for item in message:
208
+ assert item['type'] in self.allowed_types, f'Invalid input type: {item["type"]}'
209
+
210
+ # merge kwargs
211
+ kwargs = cp.deepcopy(self.default_kwargs)
212
+ kwargs.update(kwargs1)
213
+
214
+ answer = None
215
+ # a very small random delay [0s - 0.5s]
216
+ T = rd.random() * 0.5
217
+ time.sleep(T)
218
+
219
+ for i in range(self.retry):
220
+ try:
221
+ ret_code, answer, log = self.generate_inner(message, **kwargs)
222
+ if ret_code == 0 and self.fail_msg not in answer and answer != '':
223
+ if self.verbose:
224
+ print(answer)
225
+ return answer
226
+ elif self.verbose:
227
+ if not isinstance(log, str):
228
+ try:
229
+ log = log.text
230
+ except:
231
+ self.logger.warning(f'Failed to parse {log} as an http response. ')
232
+ self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
233
+ except Exception as err:
234
+ if self.verbose:
235
+ self.logger.error(f'An error occured during try {i}:')
236
+ self.logger.error(err)
237
+ # delay before each retry
238
+ T = rd.random() * self.wait * 2
239
+ time.sleep(T)
240
+
241
+ return self.fail_msg if answer in ['', None] else answer
242
+
243
+ def message_to_promptimg(self, message, dataset=None):
244
+ assert not self.INTERLEAVE
245
+ model_name = self.__class__.__name__
246
+ import warnings
247
+ warnings.warn(
248
+ f'Model {model_name} does not support interleaved input. '
249
+ 'Will use the first image and aggregated texts as prompt. ')
250
+ num_images = len([x for x in message if x['type'] == 'image'])
251
+ if num_images == 0:
252
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
253
+ image = None
254
+ elif num_images == 1:
255
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
256
+ image = [x['value'] for x in message if x['type'] == 'image'][0]
257
+ else:
258
+ prompt = '\n'.join([x['value'] if x['type'] == 'text' else '<image>' for x in message])
259
+ if dataset == 'BLINK':
260
+ image = concat_images_vlmeval(
261
+ [x['value'] for x in message if x['type'] == 'image'],
262
+ target_size=512)
263
+ else:
264
+ image = [x['value'] for x in message if x['type'] == 'image'][0]
265
+ return prompt, image
eval_mm/vlmevalkit/vlmeval/api/gpt.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..smp import *
2
+ import os
3
+ import sys
4
+ from .base import BaseAPI
5
+
6
+ APIBASES = {
7
+ 'OFFICIAL': 'https://api.openai.com/v1/chat/completions',
8
+ }
9
+
10
+
11
+ def GPT_context_window(model):
12
+ length_map = {
13
+ 'gpt-4': 8192,
14
+ 'gpt-4-0613': 8192,
15
+ 'gpt-4-turbo-preview': 128000,
16
+ 'gpt-4-1106-preview': 128000,
17
+ 'gpt-4-0125-preview': 128000,
18
+ 'gpt-4-vision-preview': 128000,
19
+ 'gpt-4-turbo': 128000,
20
+ 'gpt-4-turbo-2024-04-09': 128000,
21
+ 'gpt-3.5-turbo': 16385,
22
+ 'gpt-3.5-turbo-0125': 16385,
23
+ 'gpt-3.5-turbo-1106': 16385,
24
+ 'gpt-3.5-turbo-instruct': 4096,
25
+ }
26
+ if model in length_map:
27
+ return length_map[model]
28
+ else:
29
+ return 128000
30
+
31
+
32
+ class OpenAIWrapper(BaseAPI):
33
+
34
+ is_api: bool = True
35
+
36
+ def __init__(self,
37
+ model: str = 'gpt-3.5-turbo-0613',
38
+ retry: int = 5,
39
+ wait: int = 5,
40
+ key: str = None,
41
+ verbose: bool = True,
42
+ system_prompt: str = None,
43
+ temperature: float = 0,
44
+ timeout: int = 60,
45
+ api_base: str = None,
46
+ max_tokens: int = 1024,
47
+ img_size: int = 512,
48
+ img_detail: str = 'low',
49
+ use_azure: bool = False,
50
+ **kwargs):
51
+
52
+ self.model = model
53
+ self.cur_idx = 0
54
+ self.fail_msg = 'Failed to obtain answer via API. '
55
+ self.max_tokens = max_tokens
56
+ self.temperature = temperature
57
+ self.use_azure = use_azure
58
+
59
+ if 'step-1v' in model:
60
+ env_key = os.environ.get('STEPAI_API_KEY', '')
61
+ if key is None:
62
+ key = env_key
63
+ elif 'yi-vision' in model:
64
+ env_key = os.environ.get('YI_API_KEY', '')
65
+ if key is None:
66
+ key = env_key
67
+ else:
68
+ if use_azure:
69
+ env_key = os.environ.get('AZURE_OPENAI_API_KEY', None)
70
+ assert env_key is not None, 'Please set the environment variable AZURE_OPENAI_API_KEY. '
71
+
72
+ if key is None:
73
+ key = env_key
74
+ assert isinstance(key, str), (
75
+ 'Please set the environment variable AZURE_OPENAI_API_KEY to your openai key. '
76
+ )
77
+ else:
78
+ env_key = os.environ.get('OPENAI_API_KEY', '')
79
+ if key is None:
80
+ key = env_key
81
+ assert isinstance(key, str) and key.startswith('sk-'), (
82
+ f'Illegal openai_key {key}. '
83
+ 'Please set the environment variable OPENAI_API_KEY to your openai key. '
84
+ )
85
+
86
+ self.key = key
87
+ assert img_size > 0 or img_size == -1
88
+ self.img_size = img_size
89
+ assert img_detail in ['high', 'low']
90
+ self.img_detail = img_detail
91
+ self.timeout = timeout
92
+
93
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
94
+
95
+ if use_azure:
96
+ api_base_template = (
97
+ '{endpoint}openai/deployments/{deployment_name}/chat/completions?api-version={api_version}'
98
+ )
99
+ endpoint = os.getenv('AZURE_OPENAI_ENDPOINT', None)
100
+ assert endpoint is not None, 'Please set the environment variable AZURE_OPENAI_ENDPOINT. '
101
+ deployment_name = os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME', None)
102
+ assert deployment_name is not None, 'Please set the environment variable AZURE_OPENAI_DEPLOYMENT_NAME. '
103
+ api_version = os.getenv('OPENAI_API_VERSION', None)
104
+ assert api_version is not None, 'Please set the environment variable OPENAI_API_VERSION. '
105
+
106
+ self.api_base = api_base_template.format(
107
+ endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'),
108
+ deployment_name=os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME'),
109
+ api_version=os.getenv('OPENAI_API_VERSION')
110
+ )
111
+ else:
112
+ if api_base is None:
113
+ if 'OPENAI_API_BASE' in os.environ and os.environ['OPENAI_API_BASE'] != '':
114
+ self.logger.info('Environment variable OPENAI_API_BASE is set. Will use it as api_base. ')
115
+ api_base = os.environ['OPENAI_API_BASE']
116
+ else:
117
+ api_base = 'OFFICIAL'
118
+
119
+ assert api_base is not None
120
+
121
+ if api_base in APIBASES:
122
+ self.api_base = APIBASES[api_base]
123
+ elif api_base.startswith('http'):
124
+ self.api_base = api_base
125
+ else:
126
+ self.logger.error('Unknown API Base. ')
127
+ sys.exit(-1)
128
+
129
+ self.logger.info(f'Using API Base: {self.api_base}; API Key: {self.key}')
130
+
131
+ # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
132
+ # content can be a string or a list of image & text
133
+ def prepare_itlist(self, inputs):
134
+ assert np.all([isinstance(x, dict) for x in inputs])
135
+ has_images = np.sum([x['type'] == 'image' for x in inputs])
136
+ if has_images:
137
+ content_list = []
138
+ for msg in inputs:
139
+ if msg['type'] == 'text':
140
+ content_list.append(dict(type='text', text=msg['value']))
141
+ elif msg['type'] == 'image':
142
+ from PIL import Image
143
+ img = Image.open(msg['value'])
144
+ b64 = encode_image_to_base64(img, target_size=self.img_size)
145
+ img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
146
+ content_list.append(dict(type='image_url', image_url=img_struct))
147
+ else:
148
+ assert all([x['type'] == 'text' for x in inputs])
149
+ text = '\n'.join([x['value'] for x in inputs])
150
+ content_list = [dict(type='text', text=text)]
151
+ return content_list
152
+
153
+ def prepare_inputs(self, inputs):
154
+ input_msgs = []
155
+ if self.system_prompt is not None:
156
+ input_msgs.append(dict(role='system', content=self.system_prompt))
157
+ assert isinstance(inputs, list) and isinstance(inputs[0], dict)
158
+ assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
159
+ if 'role' in inputs[0]:
160
+ assert inputs[-1]['role'] == 'user', inputs[-1]
161
+ for item in inputs:
162
+ input_msgs.append(dict(role=item['role'], content=self.prepare_itlist(item['content'])))
163
+ else:
164
+ input_msgs.append(dict(role='user', content=self.prepare_itlist(inputs)))
165
+ return input_msgs
166
+
167
+ def generate_inner(self, inputs, **kwargs) -> str:
168
+ input_msgs = self.prepare_inputs(inputs)
169
+ temperature = kwargs.pop('temperature', self.temperature)
170
+ max_tokens = kwargs.pop('max_tokens', self.max_tokens)
171
+
172
+ context_window = GPT_context_window(self.model)
173
+ max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
174
+ if 0 < max_tokens <= 100:
175
+ self.logger.warning(
176
+ 'Less than 100 tokens left, '
177
+ 'may exceed the context window with some additional meta symbols. '
178
+ )
179
+ if max_tokens <= 0:
180
+ return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
181
+
182
+ # Will send request if use Azure, dk how to use openai client for it
183
+ if self.use_azure:
184
+ headers = {'Content-Type': 'application/json', 'api-key': self.key}
185
+ else:
186
+ headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.key}'}
187
+ payload = dict(
188
+ model=self.model,
189
+ messages=input_msgs,
190
+ max_tokens=max_tokens,
191
+ n=1,
192
+ temperature=temperature,
193
+ **kwargs)
194
+ response = requests.post(
195
+ self.api_base,
196
+ headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
197
+ ret_code = response.status_code
198
+ ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
199
+ answer = self.fail_msg
200
+ try:
201
+ resp_struct = json.loads(response.text)
202
+ answer = resp_struct['choices'][0]['message']['content'].strip()
203
+ except:
204
+ pass
205
+ return ret_code, answer, response
206
+
207
+ def get_image_token_len(self, img_path, detail='low'):
208
+ import math
209
+ if detail == 'low':
210
+ return 85
211
+
212
+ im = Image.open(img_path)
213
+ height, width = im.size
214
+ if width > 1024 or height > 1024:
215
+ if width > height:
216
+ height = int(height * 1024 / width)
217
+ width = 1024
218
+ else:
219
+ width = int(width * 1024 / height)
220
+ height = 1024
221
+
222
+ h = math.ceil(height / 512)
223
+ w = math.ceil(width / 512)
224
+ total = 85 + 170 * h * w
225
+ return total
226
+
227
+ def get_token_len(self, inputs) -> int:
228
+ import tiktoken
229
+ try:
230
+ enc = tiktoken.encoding_for_model(self.model)
231
+ except:
232
+ enc = tiktoken.encoding_for_model('gpt-4')
233
+ assert isinstance(inputs, list)
234
+ tot = 0
235
+ for item in inputs:
236
+ if 'role' in item:
237
+ tot += self.get_token_len(item['content'])
238
+ elif item['type'] == 'text':
239
+ tot += len(enc.encode(item['value']))
240
+ elif item['type'] == 'image':
241
+ tot += self.get_image_token_len(item['value'], detail=self.img_detail)
242
+ return tot
243
+
244
+
245
+ class GPT4V(OpenAIWrapper):
246
+
247
+ def generate(self, message, dataset=None):
248
+ return super(GPT4V, self).generate(message)
eval_mm/vlmevalkit/vlmeval/config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vlmeval.vlm import *
2
+ from vlmeval.api import *
3
+ from functools import partial
4
+
5
+ minicpm_series = {
6
+ 'MiniCPM-V': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'),
7
+ 'MiniCPM-V-2': partial(MiniCPM_V, model_path='openbmb/MiniCPM-V-2'),
8
+ 'MiniCPM-Llama3-V-2_5': partial(MiniCPM_Llama3_V, model_path='openbmb/MiniCPM-Llama3-V-2_5'),
9
+ 'MiniCPM-V-2_6': partial(MiniCPM_V_2_6, model_path='openbmb/MiniCPM-V-2_6'),
10
+ }
11
+
12
+ supported_VLM = {}
13
+
14
+ model_groups = [
15
+ minicpm_series
16
+ ]
17
+
18
+ for grp in model_groups:
19
+ supported_VLM.update(grp)
eval_mm/vlmevalkit/vlmeval/dataset/__init__.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from .image_base import img_root_map, ImageBaseDataset
4
+ from .image_caption import ImageCaptionDataset
5
+ from .image_yorn import ImageYORNDataset
6
+ from .image_mcq import ImageMCQDataset, MMMUDataset, CustomMCQDataset, MUIRDataset, GMAIMMBenchDataset
7
+ from .image_mt import MMDUDataset
8
+ from .image_vqa import (
9
+ ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset, CustomVQADataset
10
+ )
11
+
12
+ from .vcr import VCRDataset
13
+ from .mmlongbench import MMLongBench
14
+ from .dude import DUDE
15
+ from .slidevqa import SlideVQA
16
+
17
+ from .mmbench_video import MMBenchVideo
18
+ from .text_mcq import CustomTextMCQDataset, TextMCQDataset
19
+ from .videomme import VideoMME
20
+ from .mvbench import MVBench, MVBench_MP4
21
+ from .utils import *
22
+ from ..smp import *
23
+
24
+
25
+ class ConcatDataset(ImageBaseDataset):
26
+ # This dataset takes multiple dataset names as input and aggregate them into a single dataset.
27
+ # Each single dataset should not have a field named `SUB_DATASET`
28
+
29
+ DATASET_SETS = {
30
+ 'MMMB': ['MMMB_ar', 'MMMB_cn', 'MMMB_en', 'MMMB_pt', 'MMMB_ru', 'MMMB_tr'],
31
+ 'MTL_MMBench_DEV': [
32
+ 'MMBench_dev_ar', 'MMBench_dev_cn', 'MMBench_dev_en',
33
+ 'MMBench_dev_pt', 'MMBench_dev_ru', 'MMBench_dev_tr'
34
+ ]
35
+ }
36
+
37
+ def __init__(self, dataset):
38
+ datasets = self.DATASET_SETS[dataset]
39
+ self.dataset_map = {}
40
+ # The name of the compliation
41
+ self.dataset_name = dataset
42
+ self.datasets = datasets
43
+ for dname in datasets:
44
+ dataset = build_dataset(dname)
45
+ assert dataset is not None, dataset
46
+ self.dataset_map[dname] = dataset
47
+ TYPES = [x.TYPE for x in self.dataset_map.values()]
48
+ MODALITIES = [x.MODALITY for x in self.dataset_map.values()]
49
+ assert np.all([x == TYPES[0] for x in TYPES]), (datasets, TYPES)
50
+ assert np.all([x == MODALITIES[0] for x in MODALITIES]), (datasets, MODALITIES)
51
+ self.TYPE = TYPES[0]
52
+ self.MODALITY = MODALITIES[0]
53
+ data_all = []
54
+ for dname in datasets:
55
+ data = self.dataset_map[dname].data
56
+ data['SUB_DATASET'] = [dname] * len(data)
57
+ data_new = localize_df(data, dname, nproc=16)
58
+ data_all.append(data_new)
59
+
60
+ data = pd.concat(data_all)
61
+ data['original_index'] = data.pop('index')
62
+ data['index'] = np.arange(len(data))
63
+ self.data = data
64
+
65
+ def build_prompt(self, line):
66
+ if isinstance(line, int):
67
+ line = self.data.iloc[line]
68
+ idx = line['original_index']
69
+ dname = line['SUB_DATASET']
70
+ org_data = self.dataset_map[dname].data
71
+ org_line = cp.deepcopy(org_data[org_data['index'] == idx]).iloc[0]
72
+ return self.dataset_map[dname].build_prompt(org_line)
73
+
74
+ def dump_image(self, line):
75
+ # Assert all images are pre-dumped
76
+ assert 'image' not in line
77
+ assert 'image_path' in line
78
+ tgt_path = toliststr(line['image_path'])
79
+ return tgt_path
80
+
81
+ @classmethod
82
+ def supported_datasets(cls):
83
+ return list(cls.DATASET_SETS)
84
+
85
+ def evaluate(self, eval_file, **judge_kwargs):
86
+ suffix = eval_file.split('.')[-1]
87
+ # First, split the eval_file by dataset
88
+ data_all = load(eval_file)
89
+ for dname in self.datasets:
90
+ tgt = eval_file.replace(self.dataset_name, dname)
91
+ data_sub = data_all[data_all['SUB_DATASET'] == dname]
92
+ data_sub.pop('index')
93
+ data_sub['index'] = data_sub.pop('original_index')
94
+ data_sub.pop('SUB_DATASET')
95
+ dump(data_sub, tgt)
96
+ # Then, evaluate each dataset separately
97
+ results_all = []
98
+ for dname in self.datasets:
99
+ tgt = eval_file.replace(self.dataset_name, dname)
100
+ res = self.dataset_map[dname].evaluate(tgt, **judge_kwargs)
101
+ assert isinstance(res, pd.DataFrame)
102
+ res['DATASET'] = [dname] * len(res)
103
+ results_all.append(res)
104
+ result = pd.concat(results_all)
105
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
106
+ dump(result, score_file)
107
+ return result
108
+
109
+
110
+ # Add new supported dataset class here
111
+ IMAGE_DATASET = [
112
+ ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset, MathVision,
113
+ MMMUDataset, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset,
114
+ MMLongBench, VCRDataset, MMDUDataset, DUDE, SlideVQA, MUIRDataset, GMAIMMBenchDataset
115
+ ]
116
+
117
+ VIDEO_DATASET = [
118
+ MMBenchVideo, VideoMME, MVBench, MVBench_MP4
119
+ ]
120
+
121
+ TEXT_DATASET = [
122
+ TextMCQDataset
123
+ ]
124
+
125
+ CUSTOM_DATASET = [
126
+ CustomMCQDataset, CustomVQADataset, CustomTextMCQDataset
127
+ ]
128
+
129
+ DATASET_COLLECTION = [ConcatDataset]
130
+
131
+ DATASET_CLASSES = IMAGE_DATASET + VIDEO_DATASET + TEXT_DATASET + CUSTOM_DATASET + DATASET_COLLECTION
132
+ SUPPORTED_DATASETS = []
133
+ for DATASET_CLS in DATASET_CLASSES:
134
+ SUPPORTED_DATASETS.extend(DATASET_CLS.supported_datasets())
135
+
136
+
137
+ def DATASET_TYPE(dataset):
138
+ for cls in DATASET_CLASSES:
139
+ if dataset in cls.supported_datasets():
140
+ if hasattr(cls, 'TYPE'):
141
+ return cls.TYPE
142
+ # Have to add specific routine to handle ConcatDataset
143
+ if dataset in ConcatDataset.DATASET_SETS:
144
+ dataset_list = ConcatDataset.DATASET_SETS[dataset]
145
+ TYPES = [DATASET_TYPE(dname) for dname in dataset_list]
146
+ assert np.all([x == TYPES[0] for x in TYPES]), (dataset_list, TYPES)
147
+ return TYPES[0]
148
+
149
+ if 'openended' in dataset.lower():
150
+ return 'VQA'
151
+ warnings.warn(f'Dataset {dataset} is a custom one and not annotated as `openended`, will treat as MCQ. ')
152
+ return 'MCQ'
153
+
154
+
155
+ def build_dataset(dataset_name, **kwargs):
156
+ for cls in DATASET_CLASSES:
157
+ if dataset_name in cls.supported_datasets():
158
+ return cls(dataset=dataset_name, **kwargs)
159
+
160
+ warnings.warn(f'Dataset {dataset_name} is not officially supported. ')
161
+
162
+ data_file = osp.join(LMUDataRoot(), f'{dataset_name}.tsv')
163
+ if not osp.exists(data_file):
164
+ warnings.warn(f'Data file {data_file} does not exist. Dataset building failed. ')
165
+ return None
166
+
167
+ data = load(data_file)
168
+ if 'question' not in [x.lower() for x in data.columns]:
169
+ warnings.warn(f'Data file {data_file} does not have a `question` column. Dataset building failed. ')
170
+ return None
171
+
172
+ if 'A' in data and 'B' in data:
173
+ if 'image' in data or 'image_path' in data:
174
+ warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom MCQ dataset. ')
175
+ return CustomMCQDataset(dataset=dataset_name, **kwargs)
176
+ else:
177
+ warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom Text MCQ dataset. ')
178
+ return CustomTextMCQDataset(dataset=dataset_name, **kwargs)
179
+ else:
180
+ warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom VQA dataset. ')
181
+ return CustomVQADataset(dataset=dataset_name, **kwargs)
182
+
183
+
184
+ __all__ = [
185
+ 'build_dataset', 'img_root_map', 'build_judge', 'extract_answer_from_item', 'prefetch_answer', 'DEBUG_MESSAGE'
186
+ ] + [cls.__name__ for cls in DATASET_CLASSES]
eval_mm/vlmevalkit/vlmeval/dataset/dude.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ from .utils.judge_util import build_judge
5
+ from .image_base import ImageBaseDataset
6
+ from .mmlongbench import concat_images, MMLongBench_auxeval, anls_compute
7
+ from ..smp import *
8
+
9
+
10
+ FAIL_MSG = 'Failed to obtain answer via API.'
11
+
12
+
13
+ def DUDE_acc(result_file):
14
+ data = load(result_file)
15
+ overall_score = 0.0
16
+ score_list = list()
17
+ for i in range(len(data)):
18
+ item = data.iloc[i]
19
+ if isinstance(item['answer'], float) and math.isnan(item['answer']):
20
+ item['answer'] = 'Not answerable'
21
+
22
+ item['answer'] = item['answer'].lower()
23
+ item['pred'] = item['pred'].lower()
24
+ score = anls_compute(item['answer'], item['pred'])
25
+ score_list.append(score)
26
+ overall_score += score
27
+
28
+ data['score'] = score_list
29
+ dump(data, result_file)
30
+
31
+ res = dict()
32
+ res['category'], res['num'], res['avg_score'] = ['anls'], [len(data)], [overall_score / len(data)]
33
+ res = pd.DataFrame(res)
34
+ return res
35
+
36
+
37
+ class DUDE(ImageBaseDataset):
38
+
39
+ TYPE = 'VQA'
40
+
41
+ DATASET_URL = {
42
+ 'DUDE': 'https://opencompass.openxlab.space/utils/VLMEval/DUDE.tsv',
43
+ 'DUDE_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/DUDE_MINI.tsv',
44
+ }
45
+ DATASET_MD5 = {
46
+ 'DUDE': '130d860d08206e1e407cd77150c10d88',
47
+ 'DUDE_MINI': 'e0c0d998114f0cca7516d12039d2b538',
48
+ }
49
+
50
+ SUPPORTED_MODELS = {
51
+ 'GPT4': (1, 1),
52
+ 'GPT4V': (1, 1),
53
+ 'GPT4V_HIGH': (1, 1),
54
+ 'GPT4o': (1, 1),
55
+ 'GPT4o_HIGH': (1, 1),
56
+ 'GPT4o_MINI': (1, 1),
57
+ 'XComposer2d5': (1, -1),
58
+ 'XComposer2_4KHD': (1, -1),
59
+ 'MiniCPM-Llama3-V-2_5': (1, 5),
60
+ 'InternVL-Chat-V1-5': (5, 2),
61
+ }
62
+
63
+ def __init__(self, dataset, **kwargs):
64
+ self.model_list = list(self.SUPPORTED_MODELS.keys())
65
+ model_name = kwargs['model']
66
+ if not listinstr(self.model_list, model_name):
67
+ raise AssertionError("{} doesn't support the evaluation on DUDE.".format(model_name))
68
+ super(DUDE, self).__init__(dataset)
69
+
70
+ self.is_api = True if listinstr(['GPT4'], model_name) else False
71
+ self.max_pages = 120
72
+ concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
73
+ self.concat_num = concat_num
74
+ self.column_num = column_num
75
+
76
+ def prepare_tsv(self, url, file_md5=None):
77
+ data_root = LMUDataRoot()
78
+ os.makedirs(data_root, exist_ok=True)
79
+ file_name = url.split('/')[-1]
80
+ data_path = osp.join(data_root, file_name)
81
+ if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
82
+ pass
83
+ else:
84
+ warnings.warn('The dataset tsv is not downloaded')
85
+ download_file(url, data_path)
86
+ return load(data_path)
87
+
88
+ def dump_image(self, origin_line):
89
+ os.makedirs(self.img_root, exist_ok=True)
90
+ try:
91
+ import fitz
92
+ except:
93
+ warnings.warn('Please use `pip install pymupdf` to parse PDF files.')
94
+
95
+ line = origin_line.copy()
96
+ if not isinstance(line['image_path'], List):
97
+ line['image_path'] = [line['image_path']]
98
+ line['image_path'] = line['image_path'][:self.max_pages]
99
+ skip_pdf_parse = True
100
+ for im_name in line['image_path']:
101
+ path = osp.join(self.img_root, im_name)
102
+ if not read_ok(path):
103
+ skip_pdf_parse = False
104
+ break
105
+
106
+ # Just for being compatible with the zooped loop: zip(line['image'], line['image_path'])
107
+ if skip_pdf_parse:
108
+ line['image'] = line['image_path']
109
+ else:
110
+ pdf_data = base64.b64decode(line['image'])
111
+ pdf_file = io.BytesIO(pdf_data)
112
+ encoded_images = []
113
+ with fitz.open(stream=pdf_file, filetype='pdf') as doc:
114
+ doc = doc[:self.max_pages]
115
+ for page in doc:
116
+ image = page.get_pixmap(dpi=144)
117
+ image_file = io.BytesIO(image.tobytes(output='png'))
118
+ image = Image.open(image_file)
119
+ encoded_image = encode_image_to_base64(image)
120
+ encoded_images.append(encoded_image)
121
+ line['image'] = encoded_images
122
+ print('process {}'.format(line['doc_id']))
123
+
124
+ if 'image' in line:
125
+ if isinstance(line['image'], list):
126
+ tgt_path = []
127
+ assert 'image_path' in line
128
+ for img, im_name in zip(line['image'], line['image_path']):
129
+ path = osp.join(self.img_root, im_name)
130
+ if not read_ok(path):
131
+ decode_base64_to_image_file(img, path)
132
+ tgt_path.append(path)
133
+ else:
134
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
135
+ if not read_ok(tgt_path):
136
+ decode_base64_to_image_file(line['image'], tgt_path)
137
+ tgt_path = [tgt_path]
138
+ else:
139
+ assert 'image_path' in line
140
+ tgt_path = toliststr(line['image_path'])
141
+
142
+ if self.concat_num > 0 and not self.is_api:
143
+ concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
144
+
145
+ old_tgt_path = tgt_path
146
+ assert isinstance(old_tgt_path, list)
147
+ if self.column_num != -1:
148
+ tgt_path = [
149
+ '_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
150
+ for i in range(len(concatenated_images))
151
+ ]
152
+ else:
153
+ tgt_path = ['_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all.jpg']
154
+
155
+ for path, concatenated_image in zip(tgt_path, concatenated_images):
156
+ if not read_ok(path):
157
+ decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
158
+ num_images, image_size = len(old_tgt_path), concatenated_image.size
159
+ print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
160
+ return tgt_path
161
+
162
+ @classmethod
163
+ def evaluate(self, eval_file, **judge_kwargs):
164
+ logger = get_logger('Evaluation')
165
+ model = judge_kwargs['model']
166
+
167
+ suffix = eval_file.split('.')[-1]
168
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
169
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
170
+
171
+ if osp.exists(storage):
172
+ logger.warning(f'GPT scoring file {storage} already exists, will reuse it in DUDE_eval. ')
173
+ else:
174
+ data = load(eval_file)
175
+ model = build_judge(max_tokens=128, **judge_kwargs)
176
+ lt = len(data)
177
+ lines = [data.iloc[i] for i in range(lt)]
178
+ tups = [(model, line) for line in lines]
179
+ indices = [line['index'] for line in lines]
180
+
181
+ ans = {}
182
+ if osp.exists(tmp_file):
183
+ ans = load(tmp_file)
184
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
185
+ indices = [i for i in indices if i not in ans]
186
+
187
+ if len(indices):
188
+ new_results = list()
189
+ for model, line in tqdm(tups):
190
+ res = MMLongBench_auxeval(model, line)
191
+ new_results.append(res)
192
+
193
+ log_map, res_map, pred_map = {}, {}, {}
194
+ all_inds = [line['index'] for line in lines]
195
+ for k, v in zip(all_inds, new_results):
196
+ log_map[k] = v['log']
197
+ res_map[k] = v['res']
198
+ pred_map[k] = v['pred']
199
+ data['res'] = [res_map[idx] for idx in data['index']]
200
+ data['log'] = [log_map[idx] for idx in data['index']]
201
+ data['pred'] = [pred_map[idx] for idx in data['index']]
202
+ dump(data, storage)
203
+
204
+ score = DUDE_acc(storage)
205
+ score_pth = storage.replace('.xlsx', '_score.csv')
206
+
207
+ dump(score, score_pth)
208
+ logger.info(f'DUDE successfully finished evaluating {eval_file}, results saved in {score_pth}')
209
+ logger.info('Score: ')
210
+ logger.info(score)
eval_mm/vlmevalkit/vlmeval/dataset/image_base.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from abc import abstractmethod
3
+ from ..smp import *
4
+
5
+
6
+ def img_root_map(dataset):
7
+ if 'OCRVQA' in dataset:
8
+ return 'OCRVQA'
9
+ if 'COCO_VAL' == dataset:
10
+ return 'COCO'
11
+ if 'MMMU' in dataset:
12
+ return 'MMMU'
13
+ mmbench_root_map = {
14
+ 'MMBench_DEV_EN': 'MMBench', 'MMBench_TEST_EN': 'MMBench',
15
+ 'MMBench_DEV_CN': 'MMBench', 'MMBench_TEST_CN': 'MMBench',
16
+ 'MMBench': 'MMBench', 'MMBench_CN': 'MMBench',
17
+ 'MMBench_DEV_EN_V11': 'MMBench_V11', 'MMBench_TEST_EN_V11': 'MMBench_V11',
18
+ 'MMBench_DEV_CN_V11': 'MMBench_V11', 'MMBench_TEST_CN_V11': 'MMBench_V11',
19
+ 'MMBench_V11': 'MMBench', 'MMBench_CN_V11': 'MMBench',
20
+ }
21
+ if dataset in mmbench_root_map:
22
+ return mmbench_root_map[dataset]
23
+ return dataset
24
+
25
+
26
+ class ImageBaseDataset:
27
+
28
+ MODALITY = 'IMAGE'
29
+ DATASET_URL = {}
30
+ DATASET_MD5 = {}
31
+
32
+ def __init__(self, dataset='MMBench', skip_noimg=True):
33
+ ROOT = LMUDataRoot()
34
+ # You can override this variable to save image files to a different directory
35
+ self.dataset_name = dataset
36
+ self.img_root = osp.join(ROOT, 'images', img_root_map(dataset))
37
+
38
+ data = self.load_data(dataset)
39
+ self.skip_noimg = skip_noimg
40
+ if skip_noimg and 'image' in data:
41
+ data = data[~pd.isna(data['image'])]
42
+
43
+ data['index'] = [str(x) for x in data['index']]
44
+
45
+ self.meta_only = True
46
+
47
+ # The image field can store the base64 encoded image or another question index (for saving space)
48
+ if 'image' in data:
49
+ data['image'] = [str(x) for x in data['image']]
50
+ image_map = {x: y for x, y in zip(data['index'], data['image'])}
51
+ for k in image_map:
52
+ if len(image_map[k]) <= 64:
53
+ idx = image_map[k]
54
+ assert idx in image_map and len(image_map[idx]) > 64
55
+ image_map[k] = image_map[idx]
56
+
57
+ images = [toliststr(image_map[k]) for k in data['index']]
58
+ data['image'] = [x[0] if len(x) == 1 else x for x in images]
59
+ self.meta_only = False
60
+
61
+ if 'image_path' in data:
62
+ paths = [toliststr(x) for x in data['image_path']]
63
+ data['image_path'] = [x[0] if len(x) == 1 else x for x in paths]
64
+
65
+ if np.all([istype(x, int) for x in data['index']]):
66
+ data['index'] = [int(x) for x in data['index']]
67
+
68
+ self.data = data
69
+ self.post_build(dataset)
70
+
71
+ def __len__(self):
72
+ return len(self.data)
73
+
74
+ def __getitem__(self, idx):
75
+ return dict(self.data.iloc[idx])
76
+
77
+ def prepare_tsv(self, url, file_md5=None):
78
+ data_root = LMUDataRoot()
79
+ os.makedirs(data_root, exist_ok=True)
80
+ update_flag = False
81
+ file_name = url.split('/')[-1]
82
+ data_path = osp.join(data_root, file_name)
83
+ if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
84
+ pass
85
+ else:
86
+ warnings.warn('The dataset tsv is not downloaded')
87
+ download_file(url, data_path)
88
+ update_flag = True
89
+
90
+ if file_size(data_path, 'GB') > 1:
91
+ local_path = data_path.replace('.tsv', '_local.tsv')
92
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
93
+ from ..tools import LOCALIZE
94
+ LOCALIZE(data_path, local_path)
95
+ data_path = local_path
96
+ return load(data_path)
97
+
98
+ def dump_image(self, line):
99
+ os.makedirs(self.img_root, exist_ok=True)
100
+
101
+ if 'image' in line:
102
+ if isinstance(line['image'], list):
103
+ tgt_path = []
104
+ assert 'image_path' in line
105
+ for img, im_name in zip(line['image'], line['image_path']):
106
+ path = osp.join(self.img_root, im_name)
107
+ if not read_ok(path):
108
+ decode_base64_to_image_file(img, path)
109
+ tgt_path.append(path)
110
+ else:
111
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
112
+ if not read_ok(tgt_path):
113
+ decode_base64_to_image_file(line['image'], tgt_path)
114
+ tgt_path = [tgt_path]
115
+ else:
116
+ assert 'image_path' in line
117
+ tgt_path = toliststr(line['image_path'])
118
+
119
+ return tgt_path
120
+
121
+ def display(self, line):
122
+ if isinstance(line, int):
123
+ line = self.data.iloc[line]
124
+ assert isinstance(line, pd.Series) or isinstance(line, dict)
125
+ mmqa_display(line)
126
+
127
+ # Return a list of dataset names that are supported by this class, can override
128
+ @classmethod
129
+ def supported_datasets(cls):
130
+ return list(cls.DATASET_URL)
131
+
132
+ # Given the dataset name, return the dataset as a pandas dataframe, can override
133
+ def load_data(self, dataset):
134
+ url = self.DATASET_URL[dataset]
135
+ file_md5 = self.DATASET_MD5[dataset] if dataset in self.DATASET_MD5 else None
136
+ return self.prepare_tsv(url, file_md5)
137
+
138
+ # Post built hook, will be called after the dataset is built, can override
139
+ def post_build(self, dataset):
140
+ pass
141
+
142
+ # Given one data record, return the built prompt (a multi-modal message), can override
143
+ def build_prompt(self, line):
144
+ if isinstance(line, int):
145
+ line = self.data.iloc[line]
146
+
147
+ if self.meta_only:
148
+ tgt_path = toliststr(line['image_path'])
149
+ else:
150
+ tgt_path = self.dump_image(line)
151
+
152
+ question = line['question']
153
+
154
+ msgs = []
155
+ if isinstance(tgt_path, list):
156
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
157
+ else:
158
+ msgs = [dict(type='image', value=tgt_path)]
159
+ msgs.append(dict(type='text', value=question))
160
+ return msgs
161
+
162
+ # Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
163
+ @abstractmethod
164
+ def evaluate(self, eval_file, **judge_kwargs):
165
+ pass
eval_mm/vlmevalkit/vlmeval/dataset/image_caption.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_base import ImageBaseDataset
2
+ from ..smp import *
3
+
4
+
5
+ class COCO_Caption_Scorer():
6
+ def __init__(self, ref, gt):
7
+ from pycocoevalcap.bleu.bleu import Bleu
8
+ from pycocoevalcap.rouge.rouge import Rouge
9
+ from pycocoevalcap.cider.cider import Cider
10
+
11
+ self.ref = ref
12
+ self.gt = gt
13
+ print('setting up scorers...')
14
+ self.scorers = [
15
+ (Bleu(4), ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4']),
16
+ (Rouge(), 'ROUGE_L'),
17
+ (Cider(), 'CIDEr'),
18
+ ]
19
+
20
+ def compute_scores(self):
21
+ total_scores = {}
22
+ for scorer, method in self.scorers:
23
+ print('computing %s score...' % (scorer.method()))
24
+ score, scores = scorer.compute_score(self.gt, self.ref)
25
+ if isinstance(method, list):
26
+ for sc, scs, m in zip(score, scores, method):
27
+ print('%s: %0.3f' % (m, sc * 100))
28
+ total_scores['Bleu'] = [x * 100 for x in score]
29
+ else:
30
+ print('%s: %0.3f' % (method, score * 100))
31
+ total_scores[method] = score * 100
32
+
33
+ print('*****DONE*****')
34
+ for key, value in total_scores.items():
35
+ print('{}:{}'.format(key, value))
36
+ return total_scores
37
+
38
+
39
+ class ImageCaptionDataset(ImageBaseDataset):
40
+
41
+ TYPE = 'Caption'
42
+
43
+ DATASET_URL = {
44
+ 'COCO_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv',
45
+ }
46
+
47
+ DATASET_MD5 = {
48
+ 'COCO_VAL': '72a5079dead060269ac222c5aa5128af',
49
+ }
50
+
51
+ def load_data(self, dataset):
52
+ data = super().load_data(dataset)
53
+ if 'question' not in data:
54
+ data['question'] = [(
55
+ 'Please describe this image in general. Directly provide the description, '
56
+ 'do not include prefix like "This image depicts". '
57
+ )] * len(data)
58
+ return data
59
+
60
+ # It returns a dictionary of scores
61
+ @classmethod
62
+ def evaluate(self, eval_file, **kwargs):
63
+ data = load(eval_file)
64
+ lt = len(data)
65
+ lines = [data.iloc[i] for i in range(lt)]
66
+ ref, gt = {}, {}
67
+ for i, line in enumerate(lines):
68
+ ref[str(i)] = [str(line['prediction'])]
69
+ gt[str(i)] = eval(line['answer'])
70
+
71
+ scorer = COCO_Caption_Scorer(ref, gt)
72
+ coco_caption_score_dict = scorer.compute_scores()
73
+ score_pth = eval_file.replace('.xlsx', '_score.json')
74
+ dump(coco_caption_score_dict, score_pth)
75
+ return coco_caption_score_dict
eval_mm/vlmevalkit/vlmeval/dataset/image_mcq.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from .image_base import ImageBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+ from ..smp import *
6
+
7
+
8
+ MMMB_URLS = {
9
+ 'MMMB_ar': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_ar.tsv',
10
+ 'MMMB_cn': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_cn.tsv',
11
+ 'MMMB_en': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_en.tsv',
12
+ 'MMMB_pt': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_pt.tsv',
13
+ 'MMMB_ru': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_ru.tsv',
14
+ 'MMMB_tr': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_tr.tsv',
15
+ }
16
+
17
+ MTL_MMBench_URLS = {
18
+ 'MMBench_dev_ar': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_ar.tsv',
19
+ 'MMBench_dev_cn': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_cn.tsv',
20
+ 'MMBench_dev_en': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_en.tsv',
21
+ 'MMBench_dev_pt': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_pt.tsv',
22
+ 'MMBench_dev_tr': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_tr.tsv',
23
+ 'MMBench_dev_ru': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_ru.tsv',
24
+ }
25
+
26
+ MMMB_MD5 = {
27
+ 'MMMB_ar': 'f3a18b6385f1d9701840aa42de27aead', 'MMMB_cn': '13ed82fa89730037292fcaa27f08f430',
28
+ 'MMMB_en': '1cd781a71ec5a2983c090b84105d6a01', 'MMMB_pt': '548ea2b3bb2da991790386f0015d30d1',
29
+ 'MMMB_ru': 'ce1cc8a0533425ab0d86b326ebfc2984', 'MMMB_tr': '0733739d43090327975294292bc5cd67'
30
+ }
31
+
32
+ MTL_MMBench_MD5 = {
33
+ 'MMBench_dev_ar': '4271b4a0d0200e1a86380a878e0d64a4', 'MMBench_dev_cn': '2ed5135326fed02c8e51ea50dda8222f',
34
+ 'MMBench_dev_en': 'd9ab776fc018b3d45785e9a5c23431c2', 'MMBench_dev_pt': '4ddfbcd27ef12444b908c03831cd0295',
35
+ 'MMBench_dev_tr': '4fab39d501389d3d6cc90264bb708f11', 'MMBench_dev_ru': '5ba1171ff2e68f80637bf78349e402a5'
36
+ }
37
+
38
+
39
+ class ImageMCQDataset(ImageBaseDataset):
40
+
41
+ TYPE = 'MCQ'
42
+
43
+ DATASET_URL = {
44
+ # MMBench v1.0
45
+ 'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv',
46
+ 'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv',
47
+ 'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv',
48
+ 'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv',
49
+ 'MMBench': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv', # Internal Only
50
+ 'MMBench_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv', # Internal Only
51
+ # MMBench v1.1
52
+ 'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN_V11.tsv',
53
+ 'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN_V11.tsv',
54
+ 'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN_V11.tsv',
55
+ 'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN_V11.tsv',
56
+ 'MMBench_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_V11.tsv', # Internal Only
57
+ 'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN_V11.tsv', # Internal Only
58
+ # SEEDBench Series
59
+ 'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv',
60
+ 'SEEDBench2': 'https://huggingface.co/datasets/VLMEval/SEEDBench2/resolve/main/SEEDBench2.tsv',
61
+ 'SEEDBench2_Plus': 'https://opencompass.openxlab.space/utils/VLMEval/SEEDBench2_Plus.tsv',
62
+ # ScienceQA Series
63
+ 'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv',
64
+ 'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv',
65
+ # MMT-Bench
66
+ 'MMT-Bench_ALL_MI': 'https://opencompass.openxlab.space/utils/VLMEval/MMT-Bench_ALL_MI.tsv',
67
+ 'MMT-Bench_ALL': 'https://opencompass.openxlab.space/utils/VLMEval/MMT-Bench_ALL.tsv',
68
+ 'MMT-Bench_VAL_MI': 'https://opencompass.openxlab.space/utils/VLMEval/MMT-Bench_VAL_MI.tsv',
69
+ 'MMT-Bench_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMT-Bench_VAL.tsv',
70
+ # AesBench
71
+ 'AesBench_VAL': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_VAL.tsv',
72
+ 'AesBench_TEST': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_TEST.tsv',
73
+ # Q-Bench1
74
+ 'Q-Bench1_VAL': 'https://huggingface.co/datasets/zhangzicheng/qbench_tsv/resolve/main/Q-Bench1_VAL.tsv',
75
+ 'Q-Bench1_TEST': 'https://huggingface.co/datasets/zhangzicheng/qbench_tsv/resolve/main/Q-Bench1_TEST.tsv',
76
+ # A-Bench
77
+ 'A-Bench_VAL': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_VAL.tsv',
78
+ 'A-Bench_TEST': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_TEST.tsv',
79
+ # Other Benchmarks
80
+ 'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
81
+ 'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
82
+ 'AI2D_TEST_NO_MASK': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST_NO_MASK.tsv',
83
+ 'MMStar': 'https://opencompass.openxlab.space/utils/VLMEval/MMStar.tsv',
84
+ 'RealWorldQA': 'https://opencompass.openxlab.space/utils/VLMEval/RealWorldQA.tsv',
85
+ 'MLLMGuard_DS': 'https://opencompass.openxlab.space/utils/VLMEval/MLLMGuard_DS.tsv',
86
+ 'BLINK': 'https://opencompass.openxlab.space/utils/VLMEval/BLINK.tsv',
87
+ 'TaskMeAnything_v1_imageqa_random': (
88
+ 'https://huggingface.co/datasets/weikaih/TaskMeAnything-v1-imageqa-random/'
89
+ 'resolve/main/TaskMeAnything-v1-imageqa-random.tsv'
90
+ ),
91
+ 'A-OKVQA': 'https://huggingface.co/datasets/Allen8/A-OKVQA/resolve/main/a-okvqa.tsv'
92
+ }
93
+
94
+ DATASET_MD5 = {
95
+ # MMBench v1.0
96
+ 'MMBench_DEV_EN': 'b6caf1133a01c6bb705cf753bb527ed8',
97
+ 'MMBench_TEST_EN': '6939fadb0ce626fefc0bdc9c64efc528',
98
+ 'MMBench_DEV_CN': '08b8fc3324a5ed74155350f57be69fbd',
99
+ 'MMBench_TEST_CN': '7e1239baf0ee4c8b513e19705a0f317e',
100
+ 'MMBench': '4115aea3383f3dd0083be6a633e0f820', # Internal Only
101
+ 'MMBench_CN': '2e053ffc90ea598b1feae13c36dc13ee', # Internal Only
102
+ # MMBench v1.1
103
+ 'MMBench_DEV_EN_V11': '30c05be8f2f347a50be25aa067248184',
104
+ 'MMBench_TEST_EN_V11': '26f0f15381a21720255091d3e0316ce6',
105
+ 'MMBench_DEV_CN_V11': '593f9b5f6bea453d870a798b34ae4f37',
106
+ 'MMBench_TEST_CN_V11': '74bbe4556dac745613c7cbe5ad787050',
107
+ 'MMBench_V11': 'b9276414f57af1308dcc4d0cd9b42e7c', # Internal Only
108
+ 'MMBench_CN_V11': '95f6980dd1b4de38e3cbffe0305a3f25', # Internal Only
109
+ # SEEDBench
110
+ 'SEEDBench_IMG': '68017231464752261a2526d6ca3a10c0',
111
+ 'SEEDBench2': '4ec15cf864c4f16274112284f531813e',
112
+ 'SEEDBench2_Plus': 'e32d3216dc4f452b0fe497a52015d1fd',
113
+ # ScienceQA
114
+ 'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
115
+ 'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
116
+ # MMT-Bench
117
+ 'MMT-Bench_ALL_MI': '5272157097e19cdd7cb41e412ab3b7c7',
118
+ 'MMT-Bench_ALL': 'b273a2f4c596fe4f2605de0494cd632f',
119
+ 'MMT-Bench_VAL_MI': 'c7d7b998eb5cd9aa36c7d4f721472462',
120
+ 'MMT-Bench_VAL': '8dd4b730f53dbf9c3aed90ca31c928e0',
121
+ # AesBench
122
+ 'AesBench_VAL': '3edb0c319e9187aa0b97fe7a11700a8c',
123
+ 'AesBench_TEST': '58b1f7ba2cc32e1d68896d6ee716bbf8',
124
+ # Q-Bench1
125
+ 'Q-Bench1_VAL': '837bdb6cd2da571713543462815187b7',
126
+ 'Q-Bench1_TEST': '15e759bfd58c9d5f30b23a317d347153',
127
+ # A-Bench
128
+ 'A-Bench_VAL': '218563ec50d34bb336c814143a5bb9c1',
129
+ 'A-Bench_TEST': '567013fb033a20cf23f51d8e865bd16c',
130
+ # Other Benchmarks
131
+ 'CCBench': 'f5dde47f24dc5a6fb6e595b409b466ac',
132
+ 'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
133
+ 'AI2D_TEST_NO_MASK': 'fd8f463634d4fe9fbd23b876e8eea5be',
134
+ 'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
135
+ 'RealWorldQA': '92321028d2bc29040284b6674721e48f',
136
+ 'MLLMGuard_DS': '975fc0dd7119386e198c37d71e274b3f',
137
+ 'BLINK': '3b6649b6a662184ea046908e5506260e',
138
+ 'TaskMeAnything_v1_imageqa_random': '023fef69e2ca21827afb77c5ec3bc889'
139
+ }
140
+
141
+ DATASET_URL.update(MMMB_URLS)
142
+ DATASET_URL.update(MTL_MMBench_URLS)
143
+ DATASET_MD5.update(MMMB_MD5)
144
+ DATASET_MD5.update(MTL_MMBench_MD5)
145
+
146
+ def build_prompt(self, line):
147
+
148
+ if isinstance(line, int):
149
+ line = self.data.iloc[line]
150
+
151
+ if self.meta_only:
152
+ tgt_path = toliststr(line['image_path'])
153
+ else:
154
+ tgt_path = self.dump_image(line)
155
+
156
+ question = line['question']
157
+ options = {
158
+ cand: line[cand]
159
+ for cand in string.ascii_uppercase
160
+ if cand in line and not pd.isna(line[cand])
161
+ }
162
+ options_prompt = 'Options:\n'
163
+ for key, item in options.items():
164
+ options_prompt += f'{key}. {item}\n'
165
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
166
+ prompt = ''
167
+ if hint is not None:
168
+ prompt += f'Hint: {hint}\n'
169
+ prompt += f'Question: {question}\n'
170
+ if len(options):
171
+ prompt += options_prompt
172
+ prompt += 'Please select the correct answer from the options above. \n'
173
+
174
+ msgs = []
175
+ if isinstance(tgt_path, list):
176
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
177
+ else:
178
+ msgs = [dict(type='image', value=tgt_path)]
179
+ msgs.append(dict(type='text', value=prompt))
180
+
181
+ return msgs
182
+
183
+ def evaluate(self, eval_file, **judge_kwargs):
184
+ from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
185
+ # assert dataset is not None
186
+ dataset_map = {
187
+ 'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
188
+ 'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
189
+ }
190
+ dataset = self.dataset_name
191
+ if dataset in dataset_map:
192
+ dataset = dataset_map[dataset]
193
+ nproc = judge_kwargs.pop('nproc', 4)
194
+
195
+ circular = False
196
+ if listinstr(['mmbench', 'ccbench'], dataset.lower()):
197
+ data = load(eval_file)
198
+ data['index'] = [int(x) for x in data['index']]
199
+ dump(data, eval_file)
200
+ circular = True
201
+
202
+ suffix = eval_file.split('.')[-1]
203
+ model = judge_kwargs.get('model', 'exact_matching')
204
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
205
+ name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
206
+ name_str = name_str_map[model] if model in name_str_map else model
207
+
208
+ if model == 'exact_matching':
209
+ model = None
210
+ elif gpt_key_set():
211
+ model = build_judge(**judge_kwargs)
212
+ if not model.working():
213
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
214
+ warnings.warn(DEBUG_MESSAGE)
215
+ model = None
216
+ else:
217
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
218
+ model = None
219
+
220
+ result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
221
+
222
+ data = load(eval_file)
223
+ data = data.sort_values(by='index')
224
+ data['prediction'] = [str(x) for x in data['prediction']]
225
+ # If not choice label, then use lower case
226
+ for k in data.keys():
227
+ data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
228
+
229
+ meta = self.data
230
+ meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
231
+ data_map = {x: y for x, y in zip(data['index'], data['question'])}
232
+ for k in data_map:
233
+ assert k in meta_q_map, (
234
+ f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
235
+ )
236
+
237
+ if circular:
238
+ data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
239
+ else:
240
+ data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
241
+
242
+ # load split
243
+ dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
244
+ data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
245
+
246
+ # May have different report acc functions for different datasets
247
+ if 'MMT' in dataset:
248
+ acc = report_acc_MMT(data)
249
+ else:
250
+ acc = report_acc(data)
251
+
252
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
253
+ dump(acc, score_file)
254
+
255
+ if dataset == 'AesBench_VAL':
256
+ warnings.warn('Note that AesBench VAL is just a toy version of AesBench TEST. For full results, \
257
+ please evaluate on AesBench TEST. The AesBench TEST dataset is more than 20 times \
258
+ larger than the VAL dataset and the leaderboard results are based on AesBench TEST.')
259
+ return acc
260
+
261
+
262
+ class MMMUDataset(ImageMCQDataset):
263
+
264
+ DATASET_URL = {
265
+ 'MMMU_DEV_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv',
266
+ 'MMMU_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
267
+ }
268
+
269
+ DATASET_MD5 = {
270
+ 'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
271
+ 'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
272
+ }
273
+
274
+ @staticmethod
275
+ def split_MMMU(msgs):
276
+ text, images = None, []
277
+ for s in msgs:
278
+ if s['type'] == 'image':
279
+ images.append(s['value'])
280
+ elif s['type'] == 'text':
281
+ assert text is None
282
+ text = s['value']
283
+ text_segs = text.split('<image ')
284
+ if len(text_segs) == 1:
285
+ return msgs
286
+
287
+ segs = [dict(type='text', value=text_segs[0])]
288
+ for i, seg in enumerate(text_segs):
289
+ if i == 0:
290
+ continue
291
+ assert istype(seg[0], int) and seg[1] == '>'
292
+ image_idx = int(seg[0]) - 1
293
+ segs.append(dict(type='image', value=images[image_idx]))
294
+ segs.append(dict(type='text', value=seg[2:]))
295
+ return segs
296
+
297
+ def build_prompt(self, line):
298
+ msgs = super().build_prompt(line)
299
+ msgs = self.split_MMMU(msgs)
300
+ return msgs
301
+
302
+
303
+ class MUIRDataset(ImageMCQDataset):
304
+
305
+ DATASET_URL = {
306
+ 'MUIRBench': 'http://opencompass.openxxlab.com/utils/VLMEval/MUIRBench.tsv'
307
+ }
308
+
309
+ DATASET_MD5 = {
310
+ 'MUIRBench': '2e5e6fd7699761b08a7cb3ab8c0c2ec8'
311
+ }
312
+
313
+ @staticmethod
314
+ def split_MUIR(msgs):
315
+ text, images = None, []
316
+
317
+ # Separate images and text from msgs
318
+ for s in msgs:
319
+ if s['type'] == 'image':
320
+ images.append(s['value'])
321
+ elif s['type'] == 'text':
322
+ assert text is None # Ensure only one text entry is expected
323
+ text = s['value']
324
+
325
+ # Split text by <image> tags
326
+ text_segs = text.split('<image>')
327
+
328
+ # Initialize the segments list
329
+ segs = []
330
+
331
+ # Iterate through the text segments and images
332
+ for i, seg in enumerate(text_segs):
333
+ # Append the image if this is not the first segment and there are still images left
334
+ if i > 0 and i - 1 < len(images):
335
+ segs.append(dict(type='image', value=images[i - 1]))
336
+ # Append the text segment (if it's non-empty)
337
+ if len(seg) > 0:
338
+ segs.append(dict(type='text', value=seg))
339
+
340
+ return segs
341
+
342
+ def build_prompt(self, line):
343
+
344
+ if isinstance(line, int):
345
+ line = self.data.iloc[line]
346
+
347
+ if self.meta_only:
348
+ tgt_path = toliststr(line['image_path'])
349
+ else:
350
+ tgt_path = self.dump_image(line)
351
+
352
+ question = line['question']
353
+ options = {
354
+ cand: line[cand]
355
+ for cand in string.ascii_uppercase
356
+ if cand in line and not pd.isna(line[cand])
357
+ }
358
+ # options_prompt = ''
359
+ options_prompt = '\n'.join([f'{key}. {item}' for key, item in options.items()])
360
+ # for key, item in options.items():
361
+ # options_prompt += f'{key}. {item}\n'
362
+
363
+ prompt = ''
364
+
365
+ prompt += f'{question}\n'
366
+ if len(options):
367
+ prompt += options_prompt
368
+ prompt += "\nAnswer with the option's letter from the given choices directly."
369
+
370
+ msgs = []
371
+ if isinstance(tgt_path, list):
372
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
373
+ else:
374
+ msgs = [dict(type='image', value=tgt_path)]
375
+ msgs.append(dict(type='text', value=prompt))
376
+
377
+ msgs = self.split_MUIR(msgs)
378
+ return msgs
379
+
380
+
381
+ class GMAIMMBenchDataset(ImageMCQDataset):
382
+
383
+ DATASET_URL = {
384
+ 'GMAI-MMBench_VAL': 'https://huggingface.co/datasets/VLMEval/GMAI-MMBench/resolve/main/GMAI-MMBench_VAL.tsv'
385
+ }
386
+
387
+ DATASET_MD5 = {
388
+ 'GMAI-MMBench_VAL': '254bd581627866f1c499d3d6b4422324'
389
+ }
390
+
391
+ def report_acc_by_groups(self, df, group_column):
392
+ res = defaultdict(list)
393
+
394
+ # Check for the 'split' column
395
+ if 'split' in df:
396
+ splits = list(set(df['split']))
397
+ res['split'] = splits
398
+ else:
399
+ df['split'] = ['none'] * len(df)
400
+ res['split'] = ['none']
401
+
402
+ res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
403
+
404
+ if group_column not in df:
405
+ raise ValueError(f"Column '{group_column}' not found in dataframe.")
406
+
407
+ abilities = list(set(df[group_column]))
408
+ abilities = ['None' if isinstance(ab, float) and pd.isna(ab) else ab for ab in abilities]
409
+ abilities.sort()
410
+
411
+ for ab in abilities:
412
+ ab_name = ab
413
+ sub_df = df[df[group_column] == ab]
414
+ res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
415
+
416
+ return pd.DataFrame(res)
417
+
418
+ def evaluate(self, eval_file, **judge_kwargs):
419
+ from .utils.multiple_choice import report_acc, mcq_vanilla_eval
420
+ nproc = judge_kwargs.pop('nproc', 4)
421
+
422
+ suffix = eval_file.split('.')[-1]
423
+ model = judge_kwargs.get('model', 'exact_matching')
424
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
425
+ name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
426
+ name_str = name_str_map[model] if model in name_str_map else model
427
+
428
+ if model == 'exact_matching':
429
+ model = None
430
+ elif gpt_key_set():
431
+ model = build_judge(**judge_kwargs)
432
+ if not model.working():
433
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
434
+ warnings.warn(DEBUG_MESSAGE)
435
+ model = None
436
+ else:
437
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
438
+ model = None
439
+
440
+ result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
441
+
442
+ data = load(eval_file)
443
+ data = data.sort_values(by='index')
444
+ data['prediction'] = [str(x) for x in data['prediction']]
445
+ # If not choice label, then use lower case
446
+ for k in data.keys():
447
+ data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
448
+
449
+ meta = self.data
450
+ meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
451
+ data_map = {x: y for x, y in zip(data['index'], data['question'])}
452
+ for k in data_map:
453
+ assert k in meta_q_map, (
454
+ f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
455
+ )
456
+
457
+ data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
458
+
459
+ # load split
460
+ dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
461
+ data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
462
+
463
+ acc = report_acc(data)
464
+
465
+ for group_col in ['clinical vqa task', 'department', 'perceptual granularity']:
466
+ acc_grouped = self.report_acc_by_groups(data, group_col)
467
+ score_file_grouped = eval_file.replace(f'.{suffix}', f'_{group_col}_acc.csv')
468
+ dump(acc_grouped, score_file_grouped)
469
+
470
+ return acc
471
+
472
+
473
+ class CustomMCQDataset(ImageMCQDataset):
474
+
475
+ def load_data(self, dataset):
476
+ data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
477
+
478
+ if file_size(data_path, 'GB') > 1:
479
+ local_path = data_path.replace('.tsv', '_local.tsv')
480
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
481
+ from ..tools import LOCALIZE
482
+ LOCALIZE(data_path, local_path)
483
+ data_path = local_path
484
+ return load(data_path)
eval_mm/vlmevalkit/vlmeval/dataset/image_mt.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_base import ImageBaseDataset
2
+ from .utils.judge_util import build_judge
3
+ from ..smp import *
4
+ from ..utils import track_progress_rich
5
+
6
+
7
+ class ImageMTDataset(ImageBaseDataset):
8
+
9
+ TYPE = 'MT'
10
+
11
+ def build_prompt(self, line):
12
+ if isinstance(line, int):
13
+ line = self.data.iloc[line]
14
+
15
+ if self.meta_only:
16
+ tgt_path = toliststr(line['image_path'])
17
+ else:
18
+ tgt_path = self.dump_image(line)
19
+
20
+ questions = toliststr(line['question'])
21
+ if 'answer' in line:
22
+ answers = toliststr(line['answer'])
23
+ else:
24
+ answers = [''] * len(questions)
25
+ assert len(questions) == len(answers)
26
+
27
+ dlgs, pics_number = [], 0
28
+ for i in range(len(questions)):
29
+ q, a = questions[i], answers[i]
30
+ if '<ImageHere>' in q:
31
+ content = []
32
+ tag_number = q.count('<ImageHere>')
33
+ images = tgt_path[pics_number: pics_number + tag_number]
34
+ pics_number += tag_number
35
+ q_split = q.split('<ImageHere>')
36
+ for i in range(tag_number):
37
+ qsp, im = q_split[i], images[i]
38
+ if qsp != '':
39
+ content.append(dict(type='text', value=qsp))
40
+ content.append(dict(type='image', value=im))
41
+ if q_split[-1] != '':
42
+ content.append(dict(type='text', value=q_split[-1]))
43
+ else:
44
+ content = [dict(type='text', value=q)]
45
+ dlgs.append(dict(role='user', content=content))
46
+ assert '<ImageHere>' not in a, 'We currently do not support images in the answer. '
47
+ content = [dict(type='text', value=a)]
48
+ dlgs.append(dict(role='assistant', content=content))
49
+ return dlgs
50
+
51
+
52
+ class MMDUDataset(ImageMTDataset):
53
+
54
+ DATASET_URL = {'MMDU': 'https://opencompass.openxlab.space/utils/VLMEval/MMDU.tsv'}
55
+ DATASET_MD5 = {'MMDU': '848b635a88a078f49aebcc6e39792061'}
56
+ DIMS = [
57
+ 'Creativity', 'Richness', 'Visual Perception', 'Logical Coherence',
58
+ 'Answer Accuracy', 'Image Relationship Understanding', 'Overall Score'
59
+ ]
60
+
61
+ def calculat_metric(self, ans):
62
+ all = defaultdict(lambda: 0)
63
+ tot = defaultdict(lambda: 0)
64
+ valid = defaultdict(lambda: 0)
65
+ for k in ans:
66
+ res = ans[k]['res']
67
+ assert isinstance(res, pd.DataFrame)
68
+ lt = len(res)
69
+ for i in range(lt):
70
+ line = res.iloc[i]
71
+ for k in self.DIMS:
72
+ tot[k] += 1
73
+ if k in line and line[k] is not None:
74
+ try:
75
+ score = int(line[k])
76
+ score = np.clip(score, 0, 10)
77
+ all[k] += score
78
+ valid[k] += 1
79
+ except Exception as e:
80
+ print(f'Failed to parse the score: {str(e)}')
81
+ sp1 = {'set': 'all'}
82
+ sp1.update({k: all[k] / tot[k] * 10 for k in self.DIMS})
83
+ sp2 = {'set': 'valid'}
84
+ sp2.update({k: all[k] / valid[k] * 10 for k in self.DIMS})
85
+
86
+ return pd.DataFrame([sp1, sp2])
87
+
88
+ def evaluate(self, eval_file, **judge_kwargs):
89
+ suffix = eval_file.split('.')[-1]
90
+ model = judge_kwargs['model']
91
+
92
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
93
+ score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
94
+ nproc = judge_kwargs.pop('nproc', 4)
95
+
96
+ data = load(eval_file)
97
+ model = judge_kwargs.pop('model', 'gpt-4o')
98
+ judge_model = build_judge(model=model, **judge_kwargs)
99
+
100
+ lt = len(data)
101
+ lines = [data.iloc[i] for i in range(lt)]
102
+ tups = [(judge_model, line) for line in lines]
103
+ indices = [line['index'] for line in lines]
104
+
105
+ ans = {}
106
+ if osp.exists(tmp_file):
107
+ ans = load(tmp_file)
108
+
109
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
110
+ indices = [i for i in indices if i not in ans]
111
+
112
+ from .utils.mmdu import mmdu_score
113
+
114
+ if len(indices):
115
+ new_results = track_progress_rich(
116
+ mmdu_score,
117
+ tups,
118
+ nproc=nproc,
119
+ chunksize=nproc,
120
+ keys=indices,
121
+ save=tmp_file,)
122
+ ans = load(tmp_file)
123
+ for k, v in zip(indices, new_results):
124
+ assert k in ans
125
+
126
+ metric = self.calculat_metric(ans)
127
+ dump(metric, score_file)
128
+ return metric
eval_mm/vlmevalkit/vlmeval/dataset/image_vqa.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ from .image_base import ImageBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+ from ..smp import *
6
+ from ..utils import track_progress_rich
7
+
8
+
9
+ class ImageVQADataset(ImageBaseDataset):
10
+ TYPE = 'VQA'
11
+
12
+ DATASET_URL = {
13
+ 'OCRVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv',
14
+ 'OCRVQA_TESTCORE': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv',
15
+ 'TextVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv',
16
+ 'DocVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv',
17
+ 'DocVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_TEST.tsv',
18
+ 'InfoVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_VAL.tsv',
19
+ 'InfoVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_TEST.tsv',
20
+ 'ChartQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ChartQA_TEST.tsv',
21
+ }
22
+
23
+ DATASET_MD5 = {
24
+ 'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
25
+ 'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
26
+ 'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
27
+ 'DocVQA_VAL': 'd5ee77e1926ff10690d469c56b73eabf',
28
+ 'DocVQA_TEST': '6a2f28cac26ef2d3447374e8c6f6c8e9',
29
+ 'InfoVQA_VAL': '2342e9c225222f0ef4dec545ebb126fe',
30
+ 'InfoVQA_TEST': 'df535bf51b88dc9718252c34131a6227',
31
+ 'ChartQA_TEST': 'c902e0aa9be5582a7aad6dcf52734b42',
32
+ }
33
+
34
+ def build_prompt(self, line):
35
+ msgs = super().build_prompt(line)
36
+ assert msgs[-1]['type'] == 'text'
37
+ msgs[-1]['value'] += '\nAnswer the question using a single word or phrase.'
38
+ return msgs
39
+
40
+ # It returns a DataFrame
41
+ def evaluate(self, eval_file, **judge_kwargs):
42
+ from .utils.vqa_eval import hit_calculate, process_line
43
+
44
+ data = load(eval_file)
45
+ dataset = self.dataset_name
46
+ assert 'answer' in data and 'prediction' in data
47
+ data['prediction'] = [str(x) for x in data['prediction']]
48
+ data['answer'] = [str(x) for x in data['answer']]
49
+ lt = len(data)
50
+ pool = mp.Pool(16)
51
+ lines = [data.iloc[i] for i in range(lt)]
52
+ if listinstr(['TextVQA'], dataset):
53
+ res = pool.map(partial(process_line, method='vqa_score'), lines)
54
+ elif listinstr(['ChartQA'], dataset):
55
+ res = pool.map(partial(process_line, method='relaxed_accuracy'), lines)
56
+ elif listinstr(['OCRVQA'], dataset):
57
+ res = pool.map(partial(process_line, method='accuracy'), lines)
58
+ elif listinstr(['DocVQA', 'InfoVQA'], dataset):
59
+ res = pool.map(partial(process_line, method='anls'), lines)
60
+ else: # default using vqa_score to calculate score
61
+ res = pool.map(process_line, lines)
62
+ hit = hit_calculate(res, dataset)
63
+ ret = dict()
64
+ if 'split' in data:
65
+ splits = set(data['split'])
66
+ for sp in splits:
67
+ sub = [r for l, r in zip(lines, res) if l['split'] == sp]
68
+ # [np.mean(x['match']) >= full_score_weight for x in sub]
69
+ hit = hit_calculate(sub, dataset)
70
+ ret[sp] = np.mean(hit) * 100
71
+ sub = [r for l, r in zip(lines, res)]
72
+ hit = hit_calculate(sub, dataset)
73
+ ret['Overall'] = np.mean(hit) * 100
74
+ else:
75
+ ret['Overall'] = np.mean(hit) * 100
76
+ if 'category' in data:
77
+ cates = list(set(data['category']))
78
+ cates.sort()
79
+ for c in cates:
80
+ sub = [r for l, r in zip(lines, res) if l['category'] == c]
81
+ # [np.mean(x['match']) >= full_score_weight for x in sub]
82
+ hit = hit_calculate(sub, dataset)
83
+ ret[c] = np.mean(hit) * 100
84
+ ret = d2df(ret)
85
+ ret.round(2)
86
+
87
+ suffix = eval_file.split('.')[-1]
88
+ result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
89
+ dump(ret, result_file)
90
+ return ret
91
+
92
+
93
+ class OCRBench(ImageBaseDataset):
94
+ TYPE = 'VQA'
95
+ DATASET_URL = {
96
+ 'OCRBench': 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv'
97
+ }
98
+ DATASET_MD5 = {'OCRBench': 'e953d98a987cc6e26ef717b61260b778'}
99
+
100
+ # It returns a dictionary
101
+ @classmethod
102
+ def evaluate(self, eval_file, **judge_kwargs):
103
+ OCRBench_score = {
104
+ 'Regular Text Recognition': 0,
105
+ 'Irregular Text Recognition': 0,
106
+ 'Artistic Text Recognition': 0,
107
+ 'Handwriting Recognition': 0,
108
+ 'Digit String Recognition': 0,
109
+ 'Non-Semantic Text Recognition': 0,
110
+ 'Scene Text-centric VQA': 0,
111
+ 'Doc-oriented VQA': 0,
112
+ 'Key Information Extraction': 0,
113
+ 'Handwritten Mathematical Expression Recognition': 0,
114
+ }
115
+
116
+ data = load(eval_file)
117
+ lt = len(data)
118
+ lines = [data.iloc[i] for i in range(lt)]
119
+ for i in tqdm(range(len(lines))):
120
+ line = lines[i]
121
+ predict = str(line['prediction'])
122
+ answers = eval(line['answer'])
123
+ category = line['category']
124
+ if category == 'Handwritten Mathematical Expression Recognition':
125
+ for j in range(len(answers)):
126
+ answer = answers[j].strip().replace('\n', ' ').replace(' ', '')
127
+ predict = predict.strip().replace('\n', ' ').replace(' ', '')
128
+ if answer in predict:
129
+ OCRBench_score[category] += 1
130
+ break
131
+ else:
132
+ for j in range(len(answers)):
133
+ answer = answers[j].lower().strip().replace('\n', ' ')
134
+ predict = predict.lower().strip().replace('\n', ' ')
135
+ if answer in predict:
136
+ OCRBench_score[category] += 1
137
+ break
138
+
139
+ final_score_dict = {}
140
+ final_score_dict['Text Recognition'] = \
141
+ (OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition']
142
+ + OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition']
143
+ + OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition'])
144
+ final_score_dict['Scene Text-centric VQA'] = OCRBench_score['Scene Text-centric VQA']
145
+ final_score_dict['Doc-oriented VQA'] = OCRBench_score['Doc-oriented VQA']
146
+ final_score_dict['Key Information Extraction'] = OCRBench_score['Key Information Extraction']
147
+ final_score_dict['Handwritten Mathematical Expression Recognition'] = \
148
+ (OCRBench_score['Handwritten Mathematical Expression Recognition'])
149
+ final_score_dict['Final Score'] = \
150
+ (final_score_dict['Text Recognition'] + final_score_dict['Scene Text-centric VQA']
151
+ + final_score_dict['Doc-oriented VQA'] + final_score_dict['Key Information Extraction']
152
+ + final_score_dict['Handwritten Mathematical Expression Recognition'])
153
+ final_score_dict['Final Score Norm'] = (float(final_score_dict['Final Score']) / 10)
154
+ score_pth = eval_file.replace('.xlsx', '_score.json')
155
+ dump(final_score_dict, score_pth)
156
+ return final_score_dict
157
+
158
+
159
+ class MathVista(ImageBaseDataset):
160
+ TYPE = 'VQA'
161
+ DATASET_URL = {
162
+ 'MathVista_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv'
163
+ }
164
+ DATASET_MD5 = {'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464'}
165
+
166
+ # It returns a DataFrame
167
+ @classmethod
168
+ def evaluate(self, eval_file, **judge_kwargs):
169
+ from .utils.mathvista import MathVista_auxeval, MathVista_acc
170
+
171
+ model = judge_kwargs['model']
172
+ suffix = eval_file.split('.')[-1]
173
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
174
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
175
+ nproc = judge_kwargs.pop('nproc', 4)
176
+
177
+ if not osp.exists(storage):
178
+ data = load(eval_file)
179
+ model = build_judge(max_tokens=128, **judge_kwargs)
180
+ assert model.working(), ('MathVista evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
181
+ lt = len(data)
182
+ lines = [data.iloc[i] for i in range(lt)]
183
+ tups = [(model, line) for line in lines]
184
+ indices = [line['index'] for line in lines]
185
+
186
+ ans = {}
187
+ if osp.exists(tmp_file):
188
+ ans = load(tmp_file)
189
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
190
+ indices = [i for i in indices if i not in ans]
191
+
192
+ if len(indices):
193
+ new_results = track_progress_rich(
194
+ MathVista_auxeval,
195
+ tups,
196
+ nproc=nproc,
197
+ chunksize=nproc,
198
+ keys=indices,
199
+ save=tmp_file,
200
+ )
201
+ ans = load(tmp_file)
202
+ for k, v in zip(indices, new_results):
203
+ assert k in ans
204
+ assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
205
+
206
+ data['res'] = [ans[idx]['res'] for idx in data['index']]
207
+ data['log'] = [ans[idx]['log'] for idx in data['index']]
208
+ dump(data, storage)
209
+
210
+ score = MathVista_acc(storage)
211
+ score_pth = storage.replace('.xlsx', '_score.csv')
212
+ dump(score, score_pth)
213
+ return score
214
+
215
+
216
+ class MathVision(ImageBaseDataset):
217
+ TYPE = 'VQA'
218
+ DATASET_URL = {
219
+ 'MathVision': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision.tsv',
220
+ 'MathVision_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision_MINI.tsv'
221
+ }
222
+ DATASET_MD5 = {
223
+ 'MathVision': '93f6de14f7916e598aa1b7165589831e',
224
+ 'MathVision_MINI': '060fe4fa5d868987ce179307bd5f8a33'
225
+ }
226
+
227
+ # It returns a DataFrame
228
+ @classmethod
229
+ def evaluate(self, eval_file, **judge_kwargs):
230
+ from .utils.mathv import MATH_V_auxeval, MATH_V_acc
231
+
232
+ if 'model' in judge_kwargs:
233
+ model = judge_kwargs['model']
234
+ else:
235
+ model = os.path.basename(os.environ.get('LOCAL_LLM'))
236
+ suffix = eval_file.split('.')[-1]
237
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
238
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
239
+ nproc = judge_kwargs.pop('nproc', 4)
240
+
241
+ if not osp.exists(storage):
242
+ data = load(eval_file)
243
+ model = build_judge(max_tokens=128, **judge_kwargs)
244
+ assert model.working(), ('MATH-Vision evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
245
+ lt = len(data)
246
+ lines = [data.iloc[i] for i in range(lt)]
247
+ tups = [(model, line) for line in lines]
248
+ indices = [line['index'] for line in lines]
249
+
250
+ ans = {}
251
+ if osp.exists(tmp_file):
252
+ ans = load(tmp_file)
253
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
254
+ indices = [i for i in indices if i not in ans]
255
+
256
+ if len(indices):
257
+ new_results = track_progress_rich(
258
+ MATH_V_auxeval,
259
+ tups,
260
+ nproc=nproc,
261
+ chunksize=nproc,
262
+ keys=indices,
263
+ save=tmp_file,
264
+ )
265
+ ans = load(tmp_file)
266
+ for k, v in zip(indices, new_results):
267
+ assert k in ans
268
+ assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
269
+
270
+ data['res'] = [ans[idx]['res'] for idx in data['index']]
271
+ data['log'] = [ans[idx]['log'] for idx in data['index']]
272
+ dump(data, storage)
273
+
274
+ score = MATH_V_acc(storage)
275
+ score_pth = storage.replace('.xlsx', '_score.csv')
276
+ dump(score, score_pth)
277
+ return score
278
+
279
+
280
+ class LLaVABench(ImageBaseDataset):
281
+ TYPE = 'VQA'
282
+ DATASET_URL = {'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv'}
283
+ DATASET_MD5 = {'LLaVABench': 'd382a093f749a697820d3dadd61c8428'}
284
+
285
+ # It returns a DataFrame
286
+ @classmethod
287
+ def evaluate(self, eval_file, **judge_kwargs):
288
+ from .utils.llavabench import (
289
+ build_prompt,
290
+ LLaVABench_atomeval,
291
+ LLaVABench_score,
292
+ )
293
+
294
+ suffix = '.' + eval_file.split('.')[-1]
295
+ record_file = eval_file.replace(suffix, '_openai_result' + suffix)
296
+ score_file = eval_file.replace(suffix, '_score.csv')
297
+ nproc = judge_kwargs.pop('nproc', 4)
298
+ system_prompt = 'You are a helpful and precise assistant for checking the quality of the answer.'
299
+
300
+ if not osp.exists(record_file):
301
+ data = load(eval_file)
302
+ lines = [data.iloc[i] for i in range(len(data))]
303
+ model = build_judge(temperature=0.2, system_prompt=system_prompt, **judge_kwargs)
304
+ assert model.working(), ('LLaVABench evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
305
+
306
+ prompts = [build_prompt(line) for line in lines]
307
+ tups = [(model, prompt) for prompt in prompts]
308
+ scores = track_progress_rich(LLaVABench_atomeval, tups, nproc=nproc, chunksize=nproc)
309
+ data['gpt4_score'] = [x[0] for x in scores]
310
+ data['score'] = [x[1] for x in scores]
311
+ dump(data, record_file)
312
+
313
+ data = load(record_file)
314
+ ret = LLaVABench_score(data).round(1)
315
+ dump(ret, score_file)
316
+ return ret
317
+
318
+
319
+ class MMVet(ImageBaseDataset):
320
+ TYPE = 'VQA'
321
+ DATASET_URL = {
322
+ 'MMVet': 'https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv'
323
+ }
324
+ DATASET_MD5 = {'MMVet': '748aa6d4aa9d4de798306a63718455e3'}
325
+
326
+ # It returns a DataFrame
327
+ @classmethod
328
+ def evaluate(self, eval_file, **judge_kwargs):
329
+ from .utils.mmvet import MMVet_auxeval, MMVet_acc
330
+
331
+ suffix = eval_file.split('.')[-1]
332
+ model = judge_kwargs['model']
333
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
334
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
335
+ nproc = judge_kwargs.pop('nproc', 4)
336
+ if not osp.exists(storage):
337
+ data = load(eval_file)
338
+ model = build_judge(max_tokens=3, **judge_kwargs)
339
+ assert model.working(), ('MMVet evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
340
+
341
+ lt = len(data)
342
+ lines = [data.iloc[i] for i in range(lt)]
343
+ tups = [(model, line) for line in lines]
344
+ indices = [line['index'] for line in lines]
345
+
346
+ ans = load(tmp_file) if osp.exists(tmp_file) else {}
347
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
348
+ indices = [i for i in indices if i not in ans]
349
+
350
+ if len(indices):
351
+ new_results = track_progress_rich(
352
+ MMVet_auxeval,
353
+ tups,
354
+ nproc=nproc,
355
+ chunksize=nproc,
356
+ keys=indices,
357
+ save=tmp_file,
358
+ )
359
+ ans = load(tmp_file)
360
+ for k, v in zip(indices, new_results):
361
+ assert k in ans
362
+ assert ans[k]['log'] == v['log'] and ans[k]['score'] == v['score']
363
+ data['score'] = [ans[idx]['score'] for idx in data['index']]
364
+ data['log'] = [ans[idx]['log'] for idx in data['index']]
365
+ dump(data, storage)
366
+
367
+ score, score_fine = MMVet_acc(storage)
368
+ score_pth = storage.replace('.xlsx', '_score.csv')
369
+ score_fine_pth = storage.replace('.xlsx', '_score_fine.csv')
370
+ dump(score, score_pth)
371
+ dump(score_fine, score_fine_pth)
372
+ return score
373
+
374
+
375
+ class MTVQADataset(ImageBaseDataset):
376
+ TYPE = 'VQA'
377
+ DATASET_URL = {'MTVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MTVQA_TEST.tsv'}
378
+ DATASET_MD5 = {'MTVQA_TEST': 'd87c17dbab934b7cd89c0a3c1c5657f4'}
379
+
380
+ @classmethod
381
+ def evaluate(self, eval_file, **judge_kwargs):
382
+ data = load(eval_file)
383
+ assert 'answer' in data and 'prediction' in data and 'category' in data
384
+ data['prediction'] = [str(x) for x in data['prediction']]
385
+ data['answer'] = [str(x) for x in data['answer']]
386
+ if 'split' in data:
387
+ assert np.all([x.lower() == 'test' for x in data['split']]), 'We only support MTVQA_TEST for now. '
388
+ lt = len(data)
389
+ category_scores = defaultdict(list)
390
+ for i in range(lt):
391
+ line = data.iloc[i]
392
+ ans = line['answer'].strip().lower().replace('.', '')
393
+ pred = line['prediction'].strip().lower().replace('.', '')
394
+ cate = line['category']
395
+ score = 1.0 if ans in pred else 0.0
396
+ category_scores[cate].append(score)
397
+ category_scores['Average'].append(score)
398
+ # Calculate the average score for each category, the score is normalized to [0, 100]
399
+ category_averages = {category: np.mean(scores) * 100 for category, scores in category_scores.items()}
400
+
401
+ suffix = eval_file.split('.')[-1]
402
+ result_file = eval_file.replace(f'.{suffix}', '_acc.json')
403
+ dump(category_averages, result_file)
404
+
405
+ return category_averages
406
+
407
+ # MT-VQA adopts a custom prompt
408
+ def build_prompt(self, line):
409
+ msgs = super().build_prompt(line)
410
+ assert sum([x['type'] == 'text' for x in msgs]) == 1
411
+ for item in msgs:
412
+ if item['type'] == 'text':
413
+ item['value'] += '\nAnswer the question using a word or phrase in the language of the question.'
414
+ return msgs
415
+
416
+
417
+ class CustomVQADataset(ImageBaseDataset):
418
+ TYPE = 'VQA'
419
+
420
+ def load_data(self, dataset):
421
+ data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
422
+
423
+ if file_size(data_path, 'GB') > 1:
424
+ local_path = data_path.replace('.tsv', '_local.tsv')
425
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
426
+ from ..tools import LOCALIZE
427
+
428
+ LOCALIZE(data_path, local_path)
429
+ data_path = local_path
430
+ return load(data_path)
431
+
432
+ def evaluate(self, eval_file, **judge_kwargs):
433
+ raise NotImplementedError
eval_mm/vlmevalkit/vlmeval/dataset/image_yorn.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..smp import *
2
+ from ..utils import *
3
+ from .image_base import ImageBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+
6
+
7
+ class ImageYORNDataset(ImageBaseDataset):
8
+
9
+ TYPE = 'Y/N'
10
+
11
+ DATASET_URL = {
12
+ 'MME': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
13
+ 'HallusionBench': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
14
+ 'POPE': 'https://opencompass.openxlab.space/utils/VLMEval/POPE.tsv',
15
+ }
16
+
17
+ DATASET_MD5 = {
18
+ 'MME': 'b36b43c3f09801f5d368627fb92187c3',
19
+ 'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
20
+ 'POPE': 'c12f5acb142f2ef1f85a26ba2fbe41d5',
21
+ }
22
+
23
+ # It returns a dataframe
24
+ def evaluate(self, eval_file, **judge_kwargs):
25
+ from .utils.yorn import YOrN_Extraction, YOrN_auxeval
26
+ from .utils.yorn import default_rating, MME_rating, Hallusion_rating, POPE_rating
27
+
28
+ dataset = self.dataset_name
29
+ data = load(eval_file)
30
+ data['prediction'] = [str(x) for x in data['prediction']]
31
+ storage = eval_file.replace('.xlsx', '_auxmatch.xlsx')
32
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
33
+ nproc = judge_kwargs.pop('nproc', 4)
34
+
35
+ if not osp.exists(storage):
36
+ ans_map = {k: YOrN_Extraction(v) for k, v in zip(data['index'], data['prediction'])}
37
+ if osp.exists(tmp_file):
38
+ tmp = load(tmp_file)
39
+ for k in tmp:
40
+ if ans_map[k] == 'Unknown' and tmp[k] != 'Unknown':
41
+ ans_map[k] = tmp[k]
42
+
43
+ data['extracted'] = [ans_map[x] for x in data['index']]
44
+ unknown = data[data['extracted'] == 'Unknown']
45
+
46
+ model = judge_kwargs.get('model', 'exact_matching')
47
+ if model == 'exact_matching':
48
+ model = None
49
+ elif gpt_key_set():
50
+ model = build_judge(**judge_kwargs)
51
+ if not model.working():
52
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
53
+ warnings.warn(DEBUG_MESSAGE)
54
+ model = None
55
+ else:
56
+ model = None
57
+ warnings.warn('OPENAI_API_KEY is not working properly, will use exact matching for evaluation')
58
+
59
+ if model is not None:
60
+ lt = len(unknown)
61
+ lines = [unknown.iloc[i] for i in range(lt)]
62
+ tups = [(model, line) for line in lines]
63
+ indices = list(unknown['index'])
64
+ if len(tups):
65
+ res = track_progress_rich(
66
+ YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file)
67
+ for k, v in zip(indices, res):
68
+ ans_map[k] = v
69
+
70
+ data['extracted'] = [ans_map[x] for x in data['index']]
71
+ dump(data, storage)
72
+
73
+ data = load(storage)
74
+ data['score'] = (data['answer'] == data['extracted'])
75
+ dump(data, storage)
76
+
77
+ if dataset is not None and listinstr(['MME'], dataset):
78
+ score = MME_rating(storage)
79
+ elif dataset is not None and listinstr(['Hallusion'], dataset):
80
+ score = Hallusion_rating(storage)
81
+ elif dataset is not None and listinstr(['POPE'], dataset):
82
+ score = POPE_rating(storage)
83
+ else:
84
+ score = default_rating(storage)
85
+
86
+ score_tgt = eval_file.replace('.xlsx', '_score.csv')
87
+ dump(score, score_tgt)
88
+ return score
eval_mm/vlmevalkit/vlmeval/dataset/mmbench_video.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ from ..smp import *
3
+ from .video_base import VideoBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+ from ..utils import track_progress_rich
6
+
7
+
8
+ FAIL_MSG = 'Failed to obtain answer via API.'
9
+
10
+
11
+ def unwrap_hf_pkl(pth, suffix='.mp4'):
12
+ base_dir = os.path.join(pth, 'video_pkl/')
13
+ target_dir = os.path.join(pth, 'video/')
14
+ pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
15
+ pickle_files.sort()
16
+
17
+ if not os.path.exists(target_dir):
18
+ os.makedirs(target_dir, exist_ok=True)
19
+ for pickle_file in pickle_files:
20
+ with open(pickle_file, 'rb') as file:
21
+ video_data = pickle.load(file)
22
+ # For each video file in the pickle file, write its contents to a new mp4 file
23
+ for video_name, video_content in video_data.items():
24
+ output_path = os.path.join(target_dir, f'{video_name}{suffix}')
25
+ with open(output_path, 'wb') as output_file:
26
+ output_file.write(video_content)
27
+ print('The video file has been restored and stored from the pickle file.')
28
+ else:
29
+ print('The video file already exists.')
30
+
31
+
32
+ class MMBenchVideo(VideoBaseDataset):
33
+
34
+ MD5 = '98f7df3eb1007fc375ea6fe88a98e2ff'
35
+ SYS = 'You are an AI assistant responsible for answering questions about videos.'
36
+ FRAMES_TMPL_PACK = """
37
+ You will be provided with {} separate frames uniformly sampled from a video, \
38
+ the frames are provided in chronological order of the video.
39
+ Please analyze these images and provide the answer / answers to the \
40
+ following question / questions about the video content.
41
+ If multiple questions are provided (with indices I1, I2, I3, ...), \
42
+ you should organize your answers in the following json format:
43
+ {{
44
+ 'I1': 'Answer to Question I1',
45
+ 'I2': 'Answer to Question I2',
46
+ ...
47
+ }}
48
+ Otherwise, please directly reply with your response to the only question.
49
+ Even if the information in these separate frames is not enough to give an answer,
50
+ PLEASE GIVE A RESPONSE TO EACH OF THE QUESTIONS IN THE FORMAT DESCRIBED ABOVE.
51
+ """
52
+
53
+ FRAMES_TMPL_NOPACK = """
54
+ You will be provided with {} separate frames uniformly sampled from a video, \
55
+ the frames are provided in chronological order of the video.
56
+ Please analyze these images and provide the answer to the question about the video content.
57
+ Please directly reply with your response to the only question.
58
+ """
59
+
60
+ TYPE = 'VQA'
61
+
62
+ def __init__(self, dataset='MMBench-Video', pack=False):
63
+ super().__init__(dataset=dataset, pack=pack)
64
+
65
+ @classmethod
66
+ def supported_datasets(cls):
67
+ return ['MMBench-Video']
68
+
69
+ def prepare_dataset(self, dataset_name='MMBench-Video', repo_id='nebulae09/MMBench-Video'):
70
+ def check_integrity(pth):
71
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
72
+ if md5(data_file) != self.MD5:
73
+ return False
74
+ data = load(data_file)
75
+ for video_pth in data['video_path']:
76
+ if not osp.exists(osp.join(pth, video_pth)):
77
+ return False
78
+ return True
79
+
80
+ cache_path = get_cache_path(repo_id)
81
+ if cache_path is not None and check_integrity(cache_path):
82
+ dataset_path = cache_path
83
+ else:
84
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
85
+ unwrap_hf_pkl(dataset_path)
86
+ self.video_path = osp.join(dataset_path, 'video/')
87
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
88
+
89
+ return dict(data_file=data_file, root=osp.join(dataset_path, 'video'))
90
+
91
+ def build_prompt_pack(self, line, num_frames):
92
+ if isinstance(line, int):
93
+ assert line < len(self)
94
+ video = self.videos[line]
95
+ elif isinstance(line, pd.Series):
96
+ video = line['video']
97
+ elif isinstance(line, str):
98
+ video = line
99
+
100
+ frames = self.save_video_frames(video, num_frames)
101
+ sub = self.data[self.data['video'] == video]
102
+ sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(num_frames)
103
+ message = [dict(type='text', value=sys_prompt)]
104
+ for im in frames:
105
+ message.append(dict(type='image', value=im))
106
+ nq = len(sub)
107
+ prompt = 'Questions: \n{}\nAnswers: \n'
108
+ qs = {int(sub.iloc[i]['index']): sub.iloc[i]['question'] for i in range(nq)}
109
+ prompt = prompt.format(json.dumps(qs))
110
+ message.append(dict(type='text', value=prompt))
111
+ return message
112
+
113
+ def build_prompt_nopack(self, line, num_frames, video_llm):
114
+ if isinstance(line, int):
115
+ assert line < len(self)
116
+ line = self.data.iloc[line]
117
+ if video_llm:
118
+ question = line['question']
119
+ prefix, video_idx_path = os.path.split(line['video_path'])
120
+ message = [dict(type='text', value=question)]
121
+ message.append(dict(type='video', value=os.path.join(self.video_path, video_idx_path)))
122
+ return message
123
+ else:
124
+ frames = self.save_video_frames(line['video'], num_frames)
125
+ sys_prompt = self.FRAMES_TMPL_NOPACK.format(num_frames)
126
+ message = [dict(type='text', value=sys_prompt)]
127
+ for im in frames:
128
+ message.append(dict(type='image', value=im))
129
+ prompt = 'Question: {}\nAnswer: '.format(line['question'])
130
+ message.append(dict(type='text', value=prompt))
131
+ return message
132
+
133
+ def build_prompt(self, line, num_frames, video_llm):
134
+ if self.pack and not video_llm:
135
+ return self.build_prompt_pack(line, num_frames)
136
+ else:
137
+ return self.build_prompt_nopack(line, num_frames, video_llm)
138
+
139
+ @staticmethod
140
+ def remove_side_quote(s, syms=[',', '"', "'"]):
141
+ if np.all([x in syms for x in s]):
142
+ return ''
143
+ while s[0] in syms:
144
+ s = s[1:]
145
+ while s[-1] in syms:
146
+ s = s[:-1]
147
+ return s
148
+
149
+ @staticmethod
150
+ def robust_json_load(s):
151
+ try:
152
+ jsons = list(extract_json_objects(s))
153
+ assert len(jsons) == 1
154
+ return jsons[0]
155
+ except:
156
+ if '{' in s and s.find('{') == s.rfind('{'):
157
+ sub_str = s[s.find('{') + 1:].strip()
158
+ lines = sub_str.split('\n')
159
+ res = {}
160
+ for l in lines:
161
+ l = l.strip()
162
+ if ': ' in l:
163
+ key = l.split(': ')[0].strip()
164
+ val = l.split(': ')[1].strip()
165
+ key = MMBenchVideo.remove_side_quote(key)
166
+ val = MMBenchVideo.remove_side_quote(val)
167
+ if len(key) and len(val):
168
+ res[key] = val
169
+ return res
170
+ return None
171
+
172
+ def load_pack_answers(self, data_raw):
173
+ vstats = defaultdict(lambda: 0)
174
+ data = defaultdict(lambda: {})
175
+
176
+ for k in data_raw:
177
+ ans = data_raw[k].strip()
178
+ if FAIL_MSG in ans:
179
+ vstats['GEN_FAIL'] += 1
180
+ continue
181
+ res = self.robust_json_load(ans)
182
+ if res is not None:
183
+ data[k] = res
184
+ vstats['PARSE_OK'] += 1
185
+ else:
186
+ vstats['PARSE_FAIL'] += 1
187
+
188
+ # return data
189
+ meta = cp.deepcopy(self.data)
190
+ lt = len(meta)
191
+ prediction = []
192
+ for i in range(lt):
193
+ line = meta.iloc[i]
194
+ vid = line['video']
195
+ idx = str(line['index'])
196
+ prediction.append(data[vid][idx] if idx in data[vid] else None)
197
+ meta['prediction'] = prediction
198
+ vstats['VALIDQ'] = len([x for x in prediction if x is not None])
199
+ vstats['INVALIDQ'] = len([x for x in prediction if x is None])
200
+ return meta, vstats
201
+
202
+ # It returns a dictionary
203
+ @classmethod
204
+ def evaluate(self, eval_file, **judge_kwargs):
205
+ from .utils.mmbench_video import get_dimension_rating, system_prompt, build_prompt
206
+
207
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
208
+ judge = judge_kwargs['model']
209
+ nproc = judge_kwargs.pop('nproc', 4)
210
+
211
+ tmp_file = eval_file.replace('.xlsx', f'_{judge}_tmp.pkl')
212
+ tgt_file = eval_file.replace('.xlsx', f'_{judge}_rating.json')
213
+ score_file = eval_file.replace('.xlsx', f'_{judge}_score.xlsx')
214
+
215
+ model = build_judge(system_prompt=system_prompt, **judge_kwargs)
216
+ assert model.working(), 'MMBench-Video evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE
217
+
218
+ if not osp.exists(score_file):
219
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
220
+ res = {k: v for k, v in res.items() if model.fail_msg not in v}
221
+
222
+ data = load(eval_file)
223
+ data_un = data[~data['index'].isin(res)]
224
+ data_un = data_un[~pd.isna(data_un['prediction'])]
225
+ lt = len(data_un)
226
+ prompts = [build_prompt(data_un.iloc[i]) for i in range(lt)]
227
+ indices = [data_un.iloc[i]['index'] for i in range(lt)]
228
+
229
+ if len(prompts):
230
+ _ = track_progress_rich(
231
+ model.generate,
232
+ prompts,
233
+ keys=indices,
234
+ save=tmp_file,
235
+ nproc=nproc,
236
+ chunksize=nproc
237
+ )
238
+ score_map = load(tmp_file)
239
+ data['score'] = [score_map[idx] if idx in score_map else -1 for idx in data['index']]
240
+ rejected = [x for x in score_map.values() if FAIL_MSG in x]
241
+ data['score'] = [int(x) if istype(x, int) else -1 for x in data['score']]
242
+ print(
243
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(score_map)} questions, '
244
+ f'failed to obtain the score for another {len(rejected)} questions. '
245
+ f'Those questions will be counted as 0 score in ALL rating, and will not be counted in VALID rating.'
246
+ )
247
+
248
+ dump(data, score_file)
249
+
250
+ rating = get_dimension_rating(score_file)
251
+ dump(rating, tgt_file)
252
+ return rating
eval_mm/vlmevalkit/vlmeval/dataset/mmlongbench.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ from urllib.request import urlopen
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import torchvision.transforms as transforms
6
+
7
+ from vlmeval.dataset.utils import build_judge, levenshtein_distance
8
+ from vlmeval.smp import *
9
+ from .image_base import ImageBaseDataset
10
+
11
+ FAIL_MSG = 'Failed to obtain answer via API.'
12
+
13
+
14
+ def get_gpt4_ICE():
15
+ example_1 = """
16
+ ---
17
+ Question: List the primary questions asked about the services in this report.
18
+ Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n
19
+ 1. Is the service safe?\n
20
+ 2. Is the service effective?\n
21
+ 3. Is the service caring?\n
22
+ 4. Is the service responsive?\n
23
+ 5. Is the service well-led?
24
+ Extracted answer: [
25
+ 'Is the servife safe?',
26
+ 'Is the service effective',
27
+ 'Is the serve caring?',
28
+ 'Is the service responsive?',
29
+ 'Is the service well-led?'
30
+ ]
31
+ Answer format: List\n
32
+ """
33
+
34
+ example_2 = """
35
+ ---
36
+ Question: How many regulations of the HSCA 2008 are breached in all according to this report?
37
+ Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities)
38
+ Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and
39
+ improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11:
40
+ Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17:
41
+ Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9.
42
+ Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement
43
+ the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing,
44
+ safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to
45
+ notify the CQC of incidents.
46
+ Extracted answer: 10
47
+ Answer format: Integer\n
48
+ """
49
+
50
+ example_3 = """
51
+ ---
52
+ Question: According to the survey that is the percentage of Chinese who are paying more or
53
+ about the same attention to politics after Trump's election?
54
+ Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying
55
+ more or about the same attention to politics after Trump's election. The report focuses primarily on American
56
+ demographics and does not include specific details about the Chinese population in relation to this question. If
57
+ you need information about a different demographic or a summary of the findings from the American demographic,
58
+ I can certainly help with that!
59
+ Extracted answer: Not answerable
60
+ Answer format: String\n
61
+ """
62
+
63
+ example_4 = """
64
+ ---
65
+ Question: How many quotations from male respondent over 50 years old are included in this report?
66
+ Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the
67
+ text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be
68
+ able to help you with your question.
69
+ Extracted answer: Fail to answer
70
+ Answer format: String\n
71
+ """
72
+
73
+ return [example_1, example_2, example_3, example_4]
74
+
75
+
76
+ def build_mmlongbench_gpt4_prompt(line):
77
+ task_description = """
78
+ Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis.
79
+ - Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List.
80
+ If you find the analysis the question can not be answered from the given documents, type "Not answerable".
81
+ Exception: If the analysis only tells you that it can not read/understand the images or documents,
82
+ type "Fail to answer".
83
+ - Please make your response as concise as possible. Also note that your response should be formatted as below:
84
+ ```
85
+ Extracted answer: [answer]
86
+ Answer format: [answer format]
87
+ ```
88
+ Please read the following example, then extract the answer from the model response
89
+ and type it at the end of the prompt.\n
90
+ """
91
+ question = line['question']
92
+ prediction = str(line['prediction'])
93
+ prompt = task_description
94
+ examples = get_gpt4_ICE()
95
+ for example in examples:
96
+ prompt += example
97
+ prompt += '---\nQuestion:' + question + '\n'
98
+ prompt += 'Analysis: ' + prediction
99
+ return prompt
100
+
101
+
102
+ def anls_compute(groundtruth, prediction, threshold=0.5):
103
+ dist = levenshtein_distance(groundtruth, prediction)
104
+ length = max(len(groundtruth.upper()), len(prediction.upper()))
105
+ value = 0.0 if length == 0 else float(dist) / float(length)
106
+ anls = 1.0 - value
107
+ if anls <= threshold:
108
+ anls = 0.0
109
+ return anls
110
+
111
+
112
+ def is_float_equal(reference, prediction, include_percentage: bool = False, is_close: float = False) -> bool:
113
+ def get_precision(gt_ans: float) -> int:
114
+ precision = 3
115
+ if '.' in str(gt_ans):
116
+ precision = len(str(gt_ans).split('.')[-1])
117
+ return precision
118
+
119
+ reference = float(str(reference).strip().rstrip('%').strip())
120
+ try:
121
+ prediction = float(str(prediction).strip().rstrip('%').strip())
122
+ except:
123
+ return False
124
+
125
+ if include_percentage:
126
+ gt_result = [reference / 100, reference, reference * 100]
127
+ else:
128
+ gt_result = [reference]
129
+ for item in gt_result:
130
+ try:
131
+ if is_close:
132
+ if math.isclose(item, prediction, rel_tol=0.01):
133
+ return True
134
+ precision = max(min(get_precision(prediction), get_precision(item)), 2)
135
+ if round(prediction, precision) == round(item, precision):
136
+ return True
137
+ except Exception:
138
+ continue
139
+ return False
140
+
141
+
142
+ def get_clean_string(s):
143
+ s = str(s).lower().strip()
144
+ if s.endswith('mile'):
145
+ s.rstrip('mile').strip()
146
+ if s.endswith('miles'):
147
+ s.rstrip('miles').strip()
148
+ if s.endswith('million'):
149
+ s.rstrip('million').strip()
150
+ # remove parenthesis
151
+ s = re.sub(r'\s*\([^)]*\)', '', s).strip()
152
+ # remove quotes
153
+ s = re.sub(r"^['\"]|['\"]$", '', s).strip()
154
+ s = s.strip().lstrip('$').strip()
155
+ s = s.strip().rstrip('%').strip()
156
+ return s
157
+
158
+
159
+ def is_exact_match(s):
160
+ flag = False
161
+ # Website
162
+ if 'https://' in s:
163
+ flag = True
164
+ # code file
165
+ if s.endswith('.py') or s.endswith('ipynb'):
166
+ flag = True
167
+ if s.startswith('page'):
168
+ flag = True
169
+ # telephone number
170
+ if re.fullmatch(r'\b\d+(-\d+|\s\d+)?\b', s):
171
+ flag = True
172
+ # time
173
+ if 'a.m.' in s or 'p.m.' in s:
174
+ flag = True
175
+ # YYYY-MM-DD
176
+ if re.fullmatch(r'\b\d{4}[-\s]\d{2}[-\s]\d{2}\b', s):
177
+ flag = True
178
+ # YYYY-MM
179
+ if re.fullmatch(r'\b\d{4}[-\s]\d{2}\b', s):
180
+ flag = True
181
+ # Email address
182
+ if re.fullmatch(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', s):
183
+ flag = True
184
+ return flag
185
+
186
+
187
+ def isfloat(num):
188
+ try:
189
+ float(num)
190
+ return True
191
+ except ValueError:
192
+ return False
193
+
194
+
195
+ def get_font():
196
+ try:
197
+ truetype_url = 'http://opencompass.openxlab.space/utils/Fonts/SimHei.ttf'
198
+ ff = urlopen(truetype_url)
199
+ font = ImageFont.truetype(ff, size=40)
200
+ except:
201
+ print('Fail to download the font. Use the default one.')
202
+ font = ImageFont.load_default(size=40)
203
+ return font
204
+
205
+
206
+ def frame2img(img_path_list, font, save_path=None, idx_start=0):
207
+ imgs = [Image.open(img_path) for img_path in img_path_list]
208
+
209
+ new_imgs = []
210
+ for img in imgs:
211
+ w, h = img.size
212
+ scale = w / h
213
+ if w > h:
214
+ new_w = 560 * 2
215
+ new_h = int(560 * 2 / scale)
216
+ else:
217
+ new_w = int(560 * 2 * scale)
218
+ new_h = 560 * 2
219
+ img = transforms.functional.resize(img, [new_h, new_w],)
220
+ new_imgs.append(img)
221
+ imgs = new_imgs
222
+ new_w = 0
223
+ new_h = 0
224
+ pad = 40
225
+ if w > h:
226
+ for im in imgs:
227
+ w, h = im.size
228
+ new_w = max(new_w, w)
229
+ new_h += h + 10 + pad
230
+ new_img = Image.new('RGB', (new_w, new_h), 'white')
231
+ draw = ImageDraw.Draw(new_img)
232
+ curr_h = 0
233
+ for idx, im in enumerate(imgs):
234
+ w, h = im.size
235
+ new_img.paste(im, (0, pad + curr_h))
236
+ draw.text((0, curr_h), f'<IMAGE {idx+idx_start}>', font=font, fill='black')
237
+ if idx + 1 < len(imgs):
238
+ draw.line([(0, pad + curr_h + h + 5), (new_w, pad + curr_h + h + 5)], fill='black', width=2)
239
+ curr_h += h + 10 + pad
240
+ else:
241
+ for im in imgs:
242
+ w, h = im.size
243
+ new_w += w + 10
244
+ new_h = max(new_h, h)
245
+ new_h += pad
246
+ new_img = Image.new('RGB', (new_w, new_h), 'white')
247
+ draw = ImageDraw.Draw(new_img)
248
+ curr_w = 0
249
+ for idx, im in enumerate(imgs):
250
+ w, h = im.size
251
+ new_img.paste(im, (curr_w, pad))
252
+ draw.text((curr_w, 0), f'<IMAGE {idx+idx_start}>', font=font, fill='black')
253
+ if idx + 1 < len(imgs):
254
+ draw.line([(curr_w + w + 5, 0), (curr_w + w + 5, new_h)], fill='black', width=2)
255
+ curr_w += w + 10
256
+
257
+ if save_path is not None:
258
+ new_img.save(save_path)
259
+
260
+ return new_img
261
+
262
+
263
+ def concat_images(image_list, max_concat=1, column_num=1):
264
+ concatenated_images = []
265
+ if column_num == -1:
266
+ MAX_COLUMN_NUM = 20
267
+ max_concat = 1
268
+ while len(image_list) / max_concat > MAX_COLUMN_NUM:
269
+ max_concat += 1
270
+ interval = max(math.ceil(len(image_list) / max_concat), 1)
271
+ for i in range(0, len(image_list), interval):
272
+ batch_images = image_list[i:i + interval]
273
+ concatenated_image = frame2img(batch_images, font=get_font(), idx_start=i)
274
+ concatenated_images.append(concatenated_image)
275
+ else:
276
+ interval = max(math.ceil(len(image_list) / max_concat), 1)
277
+ for i in range(0, len(image_list), interval):
278
+ batch_images = [Image.open(filename) for filename in image_list[i:i + interval]]
279
+ if column_num == 1:
280
+ total_height = batch_images[0].height * len(batch_images)
281
+ else:
282
+ total_height = batch_images[0].height * ((len(batch_images) - 1) // column_num + 1)
283
+ concatenated_image = Image.new('RGB', (batch_images[0].width * column_num, total_height), 'white')
284
+
285
+ x_offset, y_offset = 0, 0
286
+ for count, image in enumerate(batch_images):
287
+ concatenated_image.paste(image, (x_offset, y_offset))
288
+ x_offset += image.width
289
+ if (count + 1) % column_num == 0:
290
+ y_offset += image.height
291
+ x_offset = 0
292
+ concatenated_images.append(concatenated_image)
293
+ return concatenated_images
294
+
295
+
296
+ def eval_score(gt, pred, answer_type):
297
+ if answer_type == 'Int':
298
+ try:
299
+ gt, pred = int(gt), int(float(pred))
300
+ except:
301
+ pred = ''
302
+ score = (gt == pred)
303
+ elif answer_type == 'Float':
304
+ try:
305
+ gt = float(get_clean_string(str(gt)))
306
+ pred = float(get_clean_string(str(pred)))
307
+ except:
308
+ pred = ''
309
+ score = is_float_equal(gt, pred, include_percentage=True, is_close=True)
310
+ elif answer_type == 'Str':
311
+ gt = get_clean_string(gt)
312
+ pred = get_clean_string(pred)
313
+ if is_exact_match(gt):
314
+ score = (gt == pred)
315
+ else:
316
+ score = anls_compute(gt, pred)
317
+ else:
318
+ if isinstance(gt, str) and gt.startswith('['):
319
+ gt = eval(gt)
320
+ if not isinstance(gt, list):
321
+ gt = [gt]
322
+ if isinstance(pred, str) and pred.startswith('['):
323
+ pred = eval(pred)
324
+ if not isinstance(pred, list):
325
+ pred = [pred]
326
+ print(len(gt), len(pred))
327
+ if len(gt) != len(pred):
328
+ score = 0.0
329
+ else:
330
+ gt = sorted([get_clean_string(a) for a in gt])
331
+ pred = sorted([get_clean_string(a) for a in pred])
332
+ print(gt, pred)
333
+ if isfloat(gt[0]) or is_exact_match(gt[0]):
334
+ score = ('-'.join(gt) == '-'.join(pred))
335
+ else:
336
+ score = min([anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred)])
337
+
338
+ return float(score)
339
+
340
+
341
+ def MMLongBench_auxeval(model, line):
342
+ prompt = build_mmlongbench_gpt4_prompt(line)
343
+ log = ''
344
+ retry = 5
345
+
346
+ for i in range(retry):
347
+ prediction = line['prediction']
348
+ res = model.generate(prompt, temperature=i * 0.5)
349
+
350
+ if FAIL_MSG in res:
351
+ log += f'Try {i}: output is {prediction}, failed to parse.\n'
352
+ else:
353
+ log += 'Succeed'
354
+ try:
355
+ pred = res.split('Answer format:')[0].split('Extracted answer:')[1].strip()
356
+ except:
357
+ pred = ''
358
+ return dict(log=log, res=res, pred=pred)
359
+ log += 'All 5 retries failed.\n'
360
+ return dict(log=log, res='', pred='')
361
+
362
+
363
+ def get_f1(data):
364
+ gt_pos_data = data[data.apply(lambda k: k['answer'] != 'Not answerable', axis=1)]
365
+ pred_pos_data = data[data.apply(lambda k: k['pred'] != 'Not answerable', axis=1)]
366
+ recall = sum(gt_pos_data['score'].tolist()) / len(gt_pos_data)
367
+ precision = sum(pred_pos_data['score'].tolist()) / len(pred_pos_data)
368
+ return 2 * recall * precision / (recall + precision)
369
+
370
+
371
+ def MMLongBench_acc(result_file):
372
+ data = load(result_file)
373
+ overall_score = 0.0
374
+ score_list = list()
375
+ for i in range(len(data)):
376
+ item = data.iloc[i]
377
+ try:
378
+ score = eval_score(item['answer'], item['pred'], item['answer_format'])
379
+ except:
380
+ score = 0.0
381
+ score_list.append(score)
382
+ overall_score += score
383
+
384
+ data['score'] = score_list
385
+ dump(data, result_file)
386
+
387
+ data_chart = data[data.apply(lambda k: 'Chart' in eval(k['evidence_sources']), axis=1)]
388
+ data_table = data[data.apply(lambda k: 'Table' in eval(k['evidence_sources']), axis=1)]
389
+ data_image = data[data.apply(lambda k: 'Figure' in eval(k['evidence_sources']), axis=1)]
390
+ data_text = data[data.apply(lambda k: 'Pure-text (Plain-text)' in eval(k['evidence_sources']), axis=1)]
391
+ data_layout = data[data.apply(lambda k: 'Generalized-text (Layout)' in eval(k['evidence_sources']), axis=1)]
392
+
393
+ data_single = data[data.apply(lambda k: len(eval(k['evidence_pages'])) == 1, axis=1)]
394
+ data_multi = data[data.apply(lambda k: len(eval(k['evidence_pages'])) > 1, axis=1)]
395
+ data_unans = data[data.apply(lambda k: len(eval(k['evidence_pages'])) == 0, axis=1)]
396
+
397
+ res = dict()
398
+ res['category'] = [
399
+ 'overall_f1', 'overall_acc', 'text', 'layout', 'table', 'chart',
400
+ 'image', 'single-page', 'multi-page', 'unanswerable'
401
+ ]
402
+ res['num'] = [
403
+ len(data), len(data), len(data_text), len(data_layout), len(data_table),
404
+ len(data_chart), len(data_image), len(data_single), len(data_multi), len(data_unans)
405
+ ]
406
+ res['avg_score'] = [
407
+ get_f1(data),
408
+ overall_score / len(data),
409
+ sum(data_text['score'].tolist()) / len(data_text) if len(data_text) > 0 else 0.0,
410
+ sum(data_layout['score'].tolist()) / len(data_layout) if len(data_layout) > 0 else 0.0,
411
+ sum(data_table['score'].tolist()) / len(data_table) if len(data_table) > 0 else 0.0,
412
+ sum(data_chart['score'].tolist()) / len(data_chart) if len(data_chart) > 0 else 0.0,
413
+ sum(data_image['score'].tolist()) / len(data_image) if len(data_image) > 0 else 0.0,
414
+ sum(data_single['score'].tolist()) / len(data_single) if len(data_single) > 0 else 0.0,
415
+ sum(data_multi['score'].tolist()) / len(data_multi) if len(data_multi) > 0 else 0.0,
416
+ sum(data_unans['score'].tolist()) / len(data_unans) if len(data_unans) > 0 else 0.0,
417
+ ]
418
+ res = pd.DataFrame(res)
419
+ return res
420
+
421
+
422
+ class MMLongBench(ImageBaseDataset):
423
+
424
+ TYPE = 'VQA'
425
+
426
+ DATASET_URL = {
427
+ 'MMLongBench_DOC': 'https://opencompass.openxlab.space/utils/VLMEval/MMLongBench_DOC.tsv',
428
+ }
429
+ DATASET_MD5 = {
430
+ 'MMLongBench_DOC': '9b393e1f4c52718380d50586197eac9b',
431
+ }
432
+
433
+ SUPPORTED_MODELS = {
434
+ 'GPT4': (1, 1),
435
+ 'GPT4V': (1, 1),
436
+ 'GPT4V_HIGH': (1, 1),
437
+ 'GPT4o': (1, 1),
438
+ 'GPT4o_HIGH': (1, 1),
439
+ 'GPT4o_MINI': (1, 1),
440
+ 'MiniCPM-Llama3-V-2_5': (1, 5),
441
+ 'InternVL-Chat-V1-5': (5, 2),
442
+ 'XComposer2_4KHD': (1, 5),
443
+ 'XComposer2d5': (1, -1),
444
+ }
445
+
446
+ def __init__(self, dataset, **kwargs):
447
+ self.model_list = list(self.SUPPORTED_MODELS.keys())
448
+ model_name = kwargs['model']
449
+ if not listinstr(self.model_list, model_name):
450
+ raise AssertionError("{} doesn't support the evaluation on MMLongBench_DOC.".format(model_name))
451
+ super(MMLongBench, self).__init__(dataset)
452
+
453
+ self.is_api = True if listinstr(['GPT4'], model_name) else False
454
+ self.max_pages = 120
455
+ concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
456
+ self.concat_num = concat_num
457
+ self.column_num = column_num
458
+
459
+ def dump_image(self, origin_line):
460
+ os.makedirs(self.img_root, exist_ok=True)
461
+ try:
462
+ import fitz
463
+ except:
464
+ warnings.warn('Please use `pip install pymupdf` to parse PDF files.')
465
+
466
+ line = origin_line.copy()
467
+ line['image_path'] = line['image_path'][:self.max_pages]
468
+ skip_pdf_parse = True
469
+ for im_name in line['image_path']:
470
+ path = osp.join(self.img_root, im_name)
471
+ if not read_ok(path):
472
+ skip_pdf_parse = False
473
+ break
474
+
475
+ # Just for being compatible with the zooped loop: zip(line['image'], line['image_path'])
476
+ if skip_pdf_parse:
477
+ line['image'] = line['image_path']
478
+ else:
479
+ pdf_data = base64.b64decode(line['image'])
480
+ pdf_file = io.BytesIO(pdf_data)
481
+ encoded_images = []
482
+ with fitz.open(stream=pdf_file, filetype='pdf') as doc:
483
+ doc = doc[:self.max_pages]
484
+ for page in doc:
485
+ image = page.get_pixmap(dpi=144)
486
+ image_file = io.BytesIO(image.tobytes(output='png'))
487
+ image = Image.open(image_file)
488
+ encoded_image = encode_image_to_base64(image)
489
+ encoded_images.append(encoded_image)
490
+ line['image'] = encoded_images
491
+ print('process {}'.format(line['doc_id']))
492
+
493
+ if 'image' in line:
494
+ if isinstance(line['image'], list):
495
+ tgt_path = []
496
+ assert 'image_path' in line
497
+ for img, im_name in zip(line['image'], line['image_path']):
498
+ path = osp.join(self.img_root, im_name)
499
+ if not read_ok(path):
500
+ decode_base64_to_image_file(img, path)
501
+ tgt_path.append(path)
502
+ else:
503
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
504
+ if not read_ok(tgt_path):
505
+ decode_base64_to_image_file(line['image'], tgt_path)
506
+ tgt_path = [tgt_path]
507
+ else:
508
+ assert 'image_path' in line
509
+ tgt_path = toliststr(line['image_path'])
510
+
511
+ if self.concat_num > 0 and not self.is_api:
512
+ concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
513
+
514
+ old_tgt_path = tgt_path
515
+ assert isinstance(old_tgt_path, list)
516
+ if self.column_num != -1:
517
+ tgt_path = [
518
+ '_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
519
+ for i in range(len(concatenated_images))
520
+ ]
521
+ else:
522
+ tgt_path = [
523
+ '_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all_{}.jpg'.format(i)
524
+ for i in range(len(concatenated_images))
525
+ ]
526
+
527
+ for path, concatenated_image in zip(tgt_path, concatenated_images):
528
+ if not read_ok(path):
529
+ decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
530
+ num_images, image_size = len(old_tgt_path), concatenated_image.size
531
+ print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
532
+ return tgt_path
533
+
534
+ @classmethod
535
+ def evaluate(self, eval_file, **judge_kwargs):
536
+ logger = get_logger('Evaluation')
537
+ model = judge_kwargs['model']
538
+
539
+ suffix = eval_file.split('.')[-1]
540
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
541
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
542
+
543
+ if osp.exists(storage):
544
+ logger.warning(f'GPT scoring file {storage} already exists, will reuse it in MMLongBench_eval. ')
545
+ else:
546
+ data = load(eval_file)
547
+ model = build_judge(max_tokens=128, **judge_kwargs)
548
+ lt = len(data)
549
+ lines = [data.iloc[i] for i in range(lt)]
550
+ tups = [(model, line) for line in lines]
551
+ indices = [line['index'] for line in lines]
552
+
553
+ ans = {}
554
+ if osp.exists(tmp_file):
555
+ ans = load(tmp_file)
556
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
557
+ indices = [i for i in indices if i not in ans]
558
+
559
+ if len(indices):
560
+ new_results = list()
561
+ for model, line in tqdm(tups):
562
+ res = MMLongBench_auxeval(model, line)
563
+ new_results.append(res)
564
+
565
+ log_map, res_map, pred_map = {}, {}, {}
566
+ all_inds = [line['index'] for line in lines]
567
+ for k, v in zip(all_inds, new_results):
568
+ log_map[k] = v['log']
569
+ res_map[k] = v['res']
570
+ pred_map[k] = v['pred']
571
+ data['res'] = [res_map[idx] for idx in data['index']]
572
+ data['log'] = [log_map[idx] for idx in data['index']]
573
+ data['pred'] = [pred_map[idx] for idx in data['index']]
574
+ dump(data, storage)
575
+
576
+ score = MMLongBench_acc(storage)
577
+ score_pth = storage.replace('.xlsx', '_score.csv')
578
+
579
+ dump(score, score_pth)
580
+ logger.info(f'MMLongBench_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
581
+ logger.info('Score: ')
582
+ logger.info(score)
eval_mm/vlmevalkit/vlmeval/dataset/mvbench.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import huggingface_hub
2
+ from huggingface_hub import snapshot_download
3
+ from ..smp import *
4
+ from .video_base import VideoBaseDataset
5
+ from .utils import build_judge, DEBUG_MESSAGE
6
+ from ..utils import track_progress_rich
7
+ import torchvision.transforms as T
8
+ from torchvision import transforms
9
+ from torchvision.transforms.functional import InterpolationMode
10
+ from decord import VideoReader, cpu
11
+ import imageio
12
+ import cv2
13
+ import zipfile
14
+ import os
15
+ import glob
16
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
17
+ import moviepy.config_defaults
18
+ from .utils.mvbench import *
19
+
20
+ FAIL_MSG = 'Failed to obtain answer via API.'
21
+ moviepy.config_defaults.LOGGER_LEVEL = logging.CRITICAL + 1
22
+
23
+
24
+ class MVBench(VideoBaseDataset):
25
+
26
+ MD5 = 'ae2a2607e2f8618155709220c6e927a6'
27
+ SYS = """Carefully watch the video and pay attention to the cause and sequence of events, \
28
+ the detail and movement of objects, and the action and pose of persons. \
29
+ Based on your observations, select the best option that accurately addresses the question.
30
+ """
31
+
32
+ TYPE = 'MCQ'
33
+
34
+ def __init__(self, dataset='MVBench', pack=False):
35
+ self.type_data_list = {
36
+ 'Action Sequence': ('action_sequence.json',
37
+ 'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
38
+ 'Action Prediction': ('action_prediction.json',
39
+ 'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
40
+ 'Action Antonym': ('action_antonym.json',
41
+ 'your_data_path/ssv2_video/', 'video', False),
42
+ 'Fine-grained Action': ('fine_grained_action.json',
43
+ 'your_data_path/Moments_in_Time_Raw/videos/', 'video', False),
44
+ 'Unexpected Action': ('unexpected_action.json',
45
+ 'your_data_path/FunQA_test/test/', 'video', False),
46
+ 'Object Existence': ('object_existence.json',
47
+ 'your_data_path/clevrer/video_validation/', 'video', False),
48
+ 'Object Interaction': ('object_interaction.json',
49
+ 'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
50
+ 'Object Shuffle': ('object_shuffle.json',
51
+ 'your_data_path/perception/videos/', 'video', False),
52
+ 'Moving Direction': ('moving_direction.json',
53
+ 'your_data_path/clevrer/video_validation/', 'video', False),
54
+ 'Action Localization': ('action_localization.json',
55
+ 'your_data_path/sta/sta_video/', 'video', True), # has start & end
56
+ 'Scene Transition': ('scene_transition.json',
57
+ 'your_data_path/scene_qa/video/', 'video', False),
58
+ 'Action Count': ('action_count.json',
59
+ 'your_data_path/perception/videos/', 'video', False),
60
+ 'Moving Count': ('moving_count.json',
61
+ 'your_data_path/clevrer/video_validation/', 'video', False),
62
+ 'Moving Attribute': ('moving_attribute.json',
63
+ 'your_data_path/clevrer/video_validation/', 'video', False),
64
+ 'State Change': ('state_change.json',
65
+ 'your_data_path/perception/videos/', 'video', False),
66
+ 'Fine-grained Pose': ('fine_grained_pose.json',
67
+ 'your_data_path/nturgbd/', 'video', False),
68
+ 'Character Order': ('character_order.json',
69
+ 'your_data_path/perception/videos/', 'video', False),
70
+ 'Egocentric Navigation': ('egocentric_navigation.json',
71
+ 'your_data_path/vlnqa/', 'video', False),
72
+ 'Episodic Reasoning': ('episodic_reasoning.json',
73
+ 'your_data_path/tvqa/frames_fps3_hq/', 'frame', True), # has start & end, read frame
74
+ 'Counterfactual Inference': ('counterfactual_inference.json',
75
+ 'your_data_path/clevrer/video_validation/', 'video', False),
76
+ }
77
+ super().__init__(dataset=dataset, pack=pack)
78
+
79
+ @classmethod
80
+ def supported_datasets(cls):
81
+ return ['MVBench']
82
+
83
+ def prepare_dataset(self, dataset_name='MVBench', repo_id='OpenGVLab/MVBench'):
84
+ def check_integrity(pth):
85
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
86
+
87
+ if not os.path.exists(data_file):
88
+ return False
89
+
90
+ if md5(data_file) != self.MD5:
91
+ return False
92
+
93
+ data = load(data_file)
94
+ for idx, item in data.iterrows():
95
+ if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
96
+ return False
97
+ return True
98
+
99
+ cache_path = get_cache_path(repo_id, branch='main')
100
+ if cache_path is not None and check_integrity(cache_path):
101
+ dataset_path = cache_path
102
+ else:
103
+ def unzip_hf_zip(pth):
104
+ pth = os.path.join(pth, 'video/')
105
+ for filename in os.listdir(pth):
106
+ if filename.endswith('.zip'):
107
+ # 构建完整的文件路径
108
+ zip_path = os.path.join(pth, filename)
109
+
110
+ # 解压 ZIP 文件
111
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
112
+ zip_ref.extractall(pth)
113
+
114
+ def generate_tsv(pth):
115
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
116
+ if os.path.exists(data_file) and md5(data_file) == self.MD5:
117
+ return
118
+ json_data_dir = os.path.join(dataset_path, 'json')
119
+ self.data_list = []
120
+ for k, v in self.type_data_list.items():
121
+ with open(os.path.join(json_data_dir, v[0]), 'r') as f:
122
+ json_data = json.load(f)
123
+ for data in json_data:
124
+ self.data_list.append({
125
+ 'task_type': k,
126
+ 'prefix': v[1].replace('your_data_path', os.path.join(dataset_path, 'video')),
127
+ 'data_type': v[2],
128
+ 'bound': v[3],
129
+ 'start': data['start'] if 'start' in data.keys() else None,
130
+ 'end': data['end'] if 'end' in data.keys() else None,
131
+ 'video': data['video'],
132
+ 'question': data['question'],
133
+ 'answer': data['answer'],
134
+ 'candidates': data['candidates']
135
+ })
136
+
137
+ data_df = pd.DataFrame(self.data_list)
138
+ data_df = data_df.assign(index=range(len(data_df)))
139
+ data_df.to_csv(data_file, sep='\t', index=False)
140
+
141
+ def move_files(pth):
142
+ # special for mvbench
143
+ src_folder = os.path.join(pth, 'video/data0613')
144
+ for subdir in os.listdir(src_folder):
145
+ subdir_path = os.path.join(src_folder, subdir)
146
+ if os.path.isdir(subdir_path):
147
+ for subsubdir in os.listdir(subdir_path):
148
+ subsubdir_path = os.path.join(subdir_path, subsubdir)
149
+ if os.path.isdir(subsubdir_path):
150
+ for item in os.listdir(subsubdir_path):
151
+ item_path = os.path.join(subsubdir_path, item)
152
+ target_folder = os.path.join(pth, 'video', subdir, subsubdir, item)
153
+ if not os.path.exists(target_folder):
154
+ shutil.move(item_path, os.path.join(target_folder, item))
155
+
156
+ hf_token = os.environ.get('HUGGINGFACE_TOKEN')
157
+ huggingface_hub.login(hf_token)
158
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
159
+ move_files(dataset_path)
160
+ unzip_hf_zip(dataset_path)
161
+ generate_tsv(dataset_path)
162
+
163
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
164
+
165
+ self.decord_method = {
166
+ 'video': self.read_video,
167
+ 'gif': self.read_gif,
168
+ 'frame': self.read_frame,
169
+ }
170
+
171
+ self.nframe = 8
172
+ self.resolution = 224
173
+ self.frame_fps = 3
174
+
175
+ # transform
176
+ crop_size = self.resolution
177
+ scale_size = self.resolution
178
+ input_mean = [0.48145466, 0.4578275, 0.40821073]
179
+ input_std = [0.26862954, 0.26130258, 0.27577711]
180
+ self.transform = T.Compose([
181
+ GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
182
+ GroupCenterCrop(crop_size),
183
+ Stack(),
184
+ ToTorchFormatTensor(),
185
+ GroupNormalize(input_mean, input_std)
186
+ ])
187
+
188
+ return dict(root=dataset_path, data_file=data_file)
189
+
190
+ def get_index(self, bound, fps, max_frame, first_idx=0):
191
+ if bound:
192
+ start, end = bound[0], bound[1]
193
+ else:
194
+ start, end = -100000, 100000
195
+ start_idx = max(first_idx, round(start * fps))
196
+ end_idx = min(round(end * fps), max_frame)
197
+ seg_size = float(end_idx - start_idx) / self.num_segments
198
+ frame_indices = np.array([
199
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
200
+ for idx in range(self.num_segments)
201
+ ])
202
+ return frame_indices
203
+
204
+ def read_video(self, video_path, bound=None):
205
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
206
+ max_frame = len(vr) - 1
207
+ fps = float(vr.get_avg_fps())
208
+
209
+ images_group = list()
210
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
211
+ for frame_index in frame_indices:
212
+ img = Image.fromarray(vr[frame_index].asnumpy())
213
+ images_group.append(img)
214
+ torch_imgs = self.transform(images_group)
215
+ return torch_imgs
216
+
217
+ def read_gif(self, video_path, bound=None, fps=25):
218
+ gif = imageio.get_reader(video_path)
219
+ max_frame = len(gif) - 1
220
+
221
+ images_group = list()
222
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
223
+ for index, frame in enumerate(gif):
224
+ if index in frame_indices:
225
+ img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
226
+ img = Image.fromarray(img)
227
+ images_group.append(img)
228
+ torch_imgs = self.transform(images_group)
229
+ return torch_imgs
230
+
231
+ def read_frame(self, video_path, bound=None, fps=3):
232
+ max_frame = len(os.listdir(video_path))
233
+ images_group = list()
234
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1
235
+ for frame_index in frame_indices:
236
+ img = Image.open(os.path.join(video_path, f'{frame_index:05d}.jpg'))
237
+ images_group.append(img)
238
+ torch_imgs = self.transform(images_group)
239
+ return torch_imgs
240
+
241
+ def save_video_frames(self, imgs, video_name, frames):
242
+
243
+ frame_paths = self.frame_paths(video_name, frames)
244
+ flag = np.all([osp.exists(p) for p in frame_paths])
245
+
246
+ if not flag:
247
+ block_size = imgs.size(0) // frames
248
+ split_tensors = torch.split(imgs, block_size)
249
+ to_pil = transforms.ToPILImage()
250
+ images = [to_pil(arr) for arr in split_tensors]
251
+ for im, pth in zip(images, frame_paths):
252
+ if not osp.exists(pth):
253
+ im.save(pth)
254
+
255
+ return frame_paths
256
+
257
+ def qa_template(self, data):
258
+ question = f"Question: {data['question']}\n"
259
+ question += 'Options:\n'
260
+ answer = data['answer']
261
+ answer_idx = -1
262
+ for idx, c in enumerate(eval(data['candidates'])):
263
+ question += f"({chr(ord('A') + idx)}) {c}\n"
264
+ if c == answer:
265
+ answer_idx = idx
266
+ question = question.rstrip()
267
+ answer = f"({chr(ord('A') + answer_idx)}) {answer}"
268
+ return question, answer
269
+
270
+ def load_into_video_and_process(self, line):
271
+ video_path = os.path.join(line['prefix'], line['video'])
272
+
273
+ if line['data_type'] in ['gif'] or os.path.splitext(video_path)[1] in ['.webm']:
274
+ processed_video_path = video_path.replace(os.path.splitext(video_path)[1], '.mp4')
275
+ if not os.path.exists(processed_video_path):
276
+ # using MoviePy to transform GIF, webm into mp4 format
277
+ gif_clip = VideoFileClip(video_path)
278
+ gif_clip.write_videofile(processed_video_path, codec='libx264')
279
+ gif_clip.close()
280
+ elif line['data_type'] in ['frame']:
281
+ input_images = os.path.join(video_path, '*.jpg')
282
+ processed_video_path = f'{video_path}.mp4'
283
+ if not os.path.exists(processed_video_path):
284
+ # using MoviePy to transform images into mp4
285
+ image_files = sorted(glob.glob(input_images))
286
+ image_clip = ImageSequenceClip(image_files, fps=self.frame_fps)
287
+ image_clip.write_videofile(processed_video_path, codec='libx264')
288
+ image_clip.close()
289
+ else:
290
+ processed_video_path = video_path
291
+
292
+ if line['bound']:
293
+ base_name, suffix = os.path.splitext(processed_video_path)
294
+ output_video_path = f'{base_name}_processed{suffix}'
295
+ if not os.path.exists(output_video_path):
296
+ video_clip = VideoFileClip(processed_video_path)
297
+ clip = video_clip.subclip(line['start'], min(line['end'], video_clip.duration))
298
+ clip.write_videofile(output_video_path)
299
+ clip.close()
300
+ else:
301
+ output_video_path = processed_video_path
302
+
303
+ return output_video_path
304
+
305
+ def build_prompt(self, line, num_frames, video_llm):
306
+ if isinstance(line, int):
307
+ assert line < len(self)
308
+ line = self.data.iloc[line]
309
+
310
+ question, answer = self.qa_template(line)
311
+ message = [dict(type='text', value=self.SYS)]
312
+ message.append(dict(type='text', value=question))
313
+ if video_llm:
314
+ new_video_path = self.load_into_video_and_process(line)
315
+ message.append(dict(type='video', value=new_video_path))
316
+ else:
317
+ bound = None
318
+ if line['bound']:
319
+ bound = (
320
+ line['start'],
321
+ line['end'],
322
+ )
323
+ video_path = os.path.join(line['prefix'], line['video'])
324
+ decord_method = self.decord_method[line['data_type']]
325
+ self.num_segments = num_frames if num_frames > 0 else self.nframe
326
+ torch_imgs = decord_method(video_path, bound)
327
+ img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments)
328
+ for im in img_frame_paths:
329
+ message.append(dict(type='image', value=im))
330
+ message.append(dict(type='text', value='\nOnly give the best option.'))
331
+ message.append(dict(type='text', value='Best option:('))
332
+ return message
333
+
334
+ @classmethod
335
+ def evaluate(self, eval_file, **judge_kwargs):
336
+
337
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
338
+
339
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
340
+ tgt_file = eval_file.replace('.xlsx', '_rating.json')
341
+ score_file = eval_file.replace('.xlsx', '_score.xlsx')
342
+
343
+ if not osp.exists(score_file):
344
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
345
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
346
+
347
+ data = load(eval_file)
348
+ data_un = data[~pd.isna(data['prediction'])]
349
+
350
+ for idx in data['index']:
351
+ ans = data.loc[data['index'] == idx, 'answer'].values[0]
352
+ pred = data.loc[data['index'] == idx, 'prediction'].values[0]
353
+ options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
354
+ answer_idx = -1
355
+ for id, c in enumerate(options):
356
+ if c == ans:
357
+ answer_idx = id
358
+ ans = f"({chr(ord('A') + answer_idx)}) {ans}"
359
+
360
+ if FAIL_MSG in pred:
361
+ data.loc[idx, 'score'] = -1
362
+ else:
363
+ data.loc[idx, 'score'] = int(check_ans(pred, ans))
364
+
365
+ rejected = [x for x in data['score'] if x == -1]
366
+
367
+ print(
368
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
369
+ f'failed to obtain the score for another {len(rejected)} questions. '
370
+ f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
371
+ )
372
+
373
+ dump(data, score_file)
374
+
375
+ rating = get_dimension_rating(score_file)
376
+ dump(rating, tgt_file)
377
+ return rating
378
+
379
+
380
+ class MVBench_MP4(VideoBaseDataset):
381
+
382
+ MP4_MD5 = '7b4608045347904c28c153015a7a2b6b'
383
+ SYS = """Carefully watch the video and pay attention to the cause and sequence of events, \
384
+ the detail and movement of objects, and the action and pose of persons. \
385
+ Based on your observations, select the best option that accurately addresses the question.
386
+ """
387
+ TYPE = 'MCQ'
388
+
389
+ def __init__(self, dataset='MVBench_MP4', pack=False):
390
+ super().__init__(dataset=dataset, pack=pack)
391
+
392
+ @classmethod
393
+ def supported_datasets(cls):
394
+ return ['MVBench_MP4']
395
+
396
+ def prepare_dataset(self, dataset_name='MVBench_MP4', repo_id='OpenGVLab/MVBench'):
397
+ def check_integrity(pth):
398
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
399
+
400
+ if not os.path.exists(data_file):
401
+ return False
402
+
403
+ if md5(data_file) != self.MP4_MD5:
404
+ return False
405
+
406
+ data = load(data_file)
407
+ for idx, item in data.iterrows():
408
+ if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
409
+ return False
410
+ return True
411
+
412
+ cache_path = get_cache_path(repo_id, branch='video')
413
+ if cache_path is not None and check_integrity(cache_path):
414
+ dataset_path = cache_path
415
+ else:
416
+ def generate_tsv(pth):
417
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
418
+ if os.path.exists(data_file) and md5(data_file) == self.MD5:
419
+ return
420
+ json_data_path = os.path.join(dataset_path, 'test.json')
421
+ json_data = load(json_data_path)
422
+ root_data_dict = json_data['root']
423
+ self.data_list = []
424
+ for k, v in json_data['meta'].items():
425
+ for item in v:
426
+ self.data_list.append({
427
+ 'task_type': k,
428
+ 'prefix': root_data_dict[k],
429
+ 'video': item['video'],
430
+ 'question': item['question'],
431
+ 'answer': item['answer'],
432
+ 'candidates': item['candidates']
433
+ })
434
+ data_df = pd.DataFrame(self.data_list)
435
+ data_df = data_df.assign(index=range(len(data_df)))
436
+ data_df.to_csv(data_file, sep='\t', index=False)
437
+
438
+ hf_token = os.environ.get('HUGGINGFACE_TOKEN')
439
+ huggingface_hub.login(hf_token)
440
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset', revision='video')
441
+ generate_tsv(dataset_path)
442
+
443
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
444
+
445
+ self.nframe = 8
446
+ self.resolution = 224
447
+
448
+ # transform
449
+ crop_size = self.resolution
450
+ scale_size = self.resolution
451
+ input_mean = [0.48145466, 0.4578275, 0.40821073]
452
+ input_std = [0.26862954, 0.26130258, 0.27577711]
453
+ self.transform = T.Compose([
454
+ GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
455
+ GroupCenterCrop(crop_size),
456
+ Stack(),
457
+ ToTorchFormatTensor(),
458
+ GroupNormalize(input_mean, input_std)
459
+ ])
460
+
461
+ return dict(root=dataset_path, data_file=data_file)
462
+
463
+ def qa_template(self, data):
464
+ question = f"Question: {data['question']}\n"
465
+ question += 'Options:\n'
466
+ answer = data['answer']
467
+ answer_idx = -1
468
+ for idx, c in enumerate(eval(data['candidates'])):
469
+ question += f"({chr(ord('A') + idx)}) {c}\n"
470
+ if c == answer:
471
+ answer_idx = idx
472
+ question = question.rstrip()
473
+ answer = f"({chr(ord('A') + answer_idx)}) {answer}"
474
+ return question, answer
475
+
476
+ def get_index(self, max_frame):
477
+ seg_size = float(max_frame) / self.num_segments
478
+ frame_indices = np.array([
479
+ int((seg_size / 2) + np.round(seg_size * idx))
480
+ for idx in range(self.num_segments)
481
+ ])
482
+ return frame_indices
483
+
484
+ def read_video(self, video_path, bound=None):
485
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
486
+ max_frame = len(vr) - 1
487
+
488
+ images_group = list()
489
+ frame_indices = self.get_index(max_frame)
490
+ for frame_index in frame_indices:
491
+ img = Image.fromarray(vr[frame_index].asnumpy())
492
+ images_group.append(img)
493
+ torch_imgs = self.transform(images_group)
494
+ return torch_imgs
495
+
496
+ def save_video_frames(self, imgs, video_name, frames):
497
+
498
+ frame_paths = self.frame_paths(video_name, frames)
499
+ flag = np.all([osp.exists(p) for p in frame_paths])
500
+
501
+ if not flag:
502
+ block_size = imgs.size(0) // frames
503
+ split_tensors = torch.split(imgs, block_size)
504
+ to_pil = transforms.ToPILImage()
505
+ images = [to_pil(arr) for arr in split_tensors]
506
+ for im, pth in zip(images, frame_paths):
507
+ if not osp.exists(pth):
508
+ im.save(pth)
509
+
510
+ return frame_paths
511
+
512
+ def build_prompt(self, line, num_frames, video_llm):
513
+ if isinstance(line, int):
514
+ assert line < len(self)
515
+ line = self.data.iloc[line]
516
+
517
+ question, answer = self.qa_template(line)
518
+ message = [dict(type='text', value=self.SYS)]
519
+ message.append(dict(type='text', value=question))
520
+ video_path = os.path.join(self.data_root, line['prefix'], line['video'])
521
+ if video_llm:
522
+ message.append(dict(type='video', value=video_path))
523
+ else:
524
+ video_path = os.path.join(self.data_root, line['prefix'], line['video'])
525
+ self.num_segments = num_frames if num_frames > 0 else self.nframe
526
+ torch_imgs = self.read_video(video_path)
527
+ img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments)
528
+ for im in img_frame_paths:
529
+ message.append(dict(type='image', value=im))
530
+ message.append(dict(type='text', value='\nOnly give the best option.'))
531
+ message.append(dict(type='text', value='Best option:('))
532
+ return message
533
+
534
+ @classmethod
535
+ def evaluate(self, eval_file, **judge_kwargs):
536
+
537
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
538
+
539
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
540
+ tgt_file = eval_file.replace('.xlsx', '_rating.json')
541
+ score_file = eval_file.replace('.xlsx', '_score.xlsx')
542
+
543
+ if not osp.exists(score_file):
544
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
545
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
546
+
547
+ data = load(eval_file)
548
+ data_un = data[~pd.isna(data['prediction'])]
549
+
550
+ for idx in data['index']:
551
+ ans = data.loc[data['index'] == idx, 'answer'].values[0]
552
+ pred = data.loc[data['index'] == idx, 'prediction'].values[0]
553
+ options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
554
+ answer_idx = -1
555
+ for id, c in enumerate(options):
556
+ if c == ans:
557
+ answer_idx = id
558
+ ans = f"({chr(ord('A') + answer_idx)}) {ans}"
559
+
560
+ if FAIL_MSG in pred:
561
+ data.loc[idx, 'score'] = -1
562
+ else:
563
+ data.loc[idx, 'score'] = int(check_ans(pred, ans))
564
+
565
+ rejected = [x for x in data['score'] if x == -1]
566
+
567
+ print(
568
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
569
+ f'failed to obtain the score for another {len(rejected)} questions. '
570
+ f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
571
+ )
572
+
573
+ dump(data, score_file)
574
+
575
+ rating = get_dimension_rating(score_file)
576
+ dump(rating, tgt_file)
577
+ return rating
eval_mm/vlmevalkit/vlmeval/dataset/slidevqa.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ from typing import List
4
+
5
+ from vlmeval.dataset.utils.judge_util import build_judge
6
+ from vlmeval.smp import *
7
+ from .image_base import ImageBaseDataset
8
+ from .mmlongbench import concat_images, MMLongBench_auxeval, anls_compute
9
+
10
+
11
+ FAIL_MSG = 'Failed to obtain answer via API.'
12
+
13
+
14
+ def get_f1(gt, pred):
15
+ gt_bow, pred_bow = gt.strip().split(), pred.strip().split()
16
+ if not gt_bow or not pred_bow:
17
+ return 0.0
18
+
19
+ recall = len([pred_e for pred_e in pred_bow if pred_e in gt_bow]) / len(gt_bow)
20
+ precision = len([pred_e for pred_e in pred_bow if pred_e in gt_bow]) / len(pred_bow)
21
+ f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 1e-4 else 0.0
22
+ return f1
23
+
24
+
25
+ def SlideVQA_acc(result_file):
26
+ data = load(result_file)
27
+ anls_list, em_list, f1_list = list(), list(), list()
28
+ for i in range(len(data)):
29
+ item = data.iloc[i]
30
+ if isinstance(item['answer'], float) and math.isnan(item['answer']):
31
+ item['answer'] = 'Not answerable'
32
+
33
+ item['answer'] = re.sub('\n', '', item['answer']).lower()
34
+ item['pred'] = str(item['pred']).lower()
35
+ anls_score = anls_compute(item['answer'], item['pred'])
36
+ em_score = (item['answer'].strip() == item['pred'].strip())
37
+ f1_score = get_f1(item['answer'], item['pred'])
38
+ anls_list.append(anls_score)
39
+ em_list.append(em_score)
40
+ f1_list.append(f1_score)
41
+ print('---------------------')
42
+ print(item['answer'], item['pred'], anls_score, em_score, f1_score)
43
+
44
+ data['anls'] = anls_list
45
+ data['em'] = em_list
46
+ data['f1'] = f1_list
47
+ dump(data, result_file)
48
+
49
+ res = dict()
50
+ res['category'], res['num'] = ['anls', 'EM', 'F1'], [len(data), len(data), len(data)]
51
+ res['avg'] = [sum(anls_list) / len(data), sum(em_list) / len(data), sum(f1_list) / len(data)]
52
+ res = pd.DataFrame(res)
53
+ return res
54
+
55
+
56
+ class SlideVQA(ImageBaseDataset):
57
+
58
+ TYPE = 'VQA'
59
+
60
+ DATASET_URL = {
61
+ 'SLIDEVQA_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/SLIDEVQA_MINI.tsv',
62
+ 'SLIDEVQA': 'https://opencompass.openxlab.space/utils/VLMEval/SLIDEVQA.tsv',
63
+ }
64
+ DATASET_MD5 = {
65
+ 'SLIDEVQA_MINI': '6d9a8d8814fa5b7669deb2af3a3208eb',
66
+ 'SLIDEVQA': '5e822c2f800e94c1e23badfd478326b6',
67
+ }
68
+
69
+ SUPPORTED_MODELS = {
70
+ 'GPT4': (1, 1),
71
+ 'GPT4V': (1, 1),
72
+ 'GPT4V_HIGH': (1, 1),
73
+ 'GPT4o': (1, 1),
74
+ 'GPT4o_HIGH': (1, 1),
75
+ 'GPT4o_MINI': (1, 1),
76
+ 'XComposer2d5': (1, -1),
77
+ 'XComposer2_4KHD': (1, -1),
78
+ 'MiniCPM-Llama3-V-2_5': (1, 5),
79
+ 'InternVL-Chat-V1-5': (5, 2),
80
+ }
81
+
82
+ def __init__(self, dataset, **kwargs):
83
+ self.model_list = list(self.SUPPORTED_MODELS.keys())
84
+ model_name = kwargs['model']
85
+ if not listinstr(self.model_list, model_name):
86
+ raise AssertionError("{} doesn't support the evaluation on SlideVQA.".format(model_name))
87
+ super(SlideVQA, self).__init__(dataset)
88
+
89
+ self.is_api = True if listinstr(['GPT4'], model_name) else False
90
+ self.max_pages = 120
91
+ concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
92
+ self.concat_num = concat_num
93
+ self.column_num = column_num
94
+
95
+ def dump_image(self, origin_line):
96
+ os.makedirs(self.img_root, exist_ok=True)
97
+
98
+ line = origin_line.copy()
99
+ if not isinstance(line['image_path'], List):
100
+ line['image_path'] = [line['image_path']]
101
+ line['image_path'] = line['image_path'][:self.max_pages]
102
+
103
+ if 'image' in line:
104
+ if isinstance(line['image'], list):
105
+ tgt_path = []
106
+ assert 'image_path' in line
107
+ for img, im_name in zip(line['image'], line['image_path']):
108
+ path = osp.join(self.img_root, im_name)
109
+ if not read_ok(path):
110
+ decode_base64_to_image_file(img, path)
111
+ tgt_path.append(path)
112
+ else:
113
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
114
+ if not read_ok(tgt_path):
115
+ decode_base64_to_image_file(line['image'], tgt_path)
116
+ tgt_path = [tgt_path]
117
+ else:
118
+ assert 'image_path' in line
119
+ tgt_path = toliststr(line['image_path'])
120
+
121
+ if self.concat_num > 0 and not self.is_api:
122
+ concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
123
+
124
+ old_tgt_path = tgt_path
125
+ assert isinstance(old_tgt_path, list)
126
+ if self.column_num != -1:
127
+ tgt_path = [
128
+ '_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
129
+ for i in range(len(concatenated_images))
130
+ ]
131
+ else:
132
+ tgt_path = ['_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all.jpg']
133
+
134
+ for path, concatenated_image in zip(tgt_path, concatenated_images):
135
+ if not read_ok(path):
136
+ decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
137
+ num_images, image_size = len(old_tgt_path), concatenated_image.size
138
+ print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
139
+ return tgt_path
140
+
141
+ @classmethod
142
+ def evaluate(self, eval_file, **judge_kwargs):
143
+ logger = get_logger('Evaluation')
144
+ model = judge_kwargs['model']
145
+
146
+ suffix = eval_file.split('.')[-1]
147
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
148
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
149
+
150
+ if osp.exists(storage):
151
+ logger.warning(f'GPT scoring file {storage} already exists, will reuse it in SlideVQA_eval. ')
152
+ else:
153
+ data = load(eval_file)
154
+ model = build_judge(max_tokens=128, **judge_kwargs)
155
+ lt = len(data)
156
+ lines = [data.iloc[i] for i in range(lt)]
157
+ tups = [(model, line) for line in lines]
158
+ indices = [line['index'] for line in lines]
159
+
160
+ ans = {}
161
+ if osp.exists(tmp_file):
162
+ ans = load(tmp_file)
163
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
164
+ indices = [i for i in indices if i not in ans]
165
+
166
+ if len(indices):
167
+ new_results = list()
168
+ for model, line in tqdm(tups):
169
+ res = MMLongBench_auxeval(model, line)
170
+ new_results.append(res)
171
+
172
+ log_map, res_map, pred_map = {}, {}, {}
173
+ all_inds = [line['index'] for line in lines]
174
+ for k, v in zip(all_inds, new_results):
175
+ log_map[k] = v['log']
176
+ res_map[k] = v['res']
177
+ pred_map[k] = v['pred']
178
+ data['res'] = [res_map[idx] for idx in data['index']]
179
+ data['log'] = [log_map[idx] for idx in data['index']]
180
+ data['pred'] = [pred_map[idx] for idx in data['index']]
181
+ dump(data, storage)
182
+
183
+ score = SlideVQA_acc(storage)
184
+ score_pth = storage.replace('.xlsx', '_score.csv')
185
+
186
+ dump(score, score_pth)
187
+ logger.info(f'SlideVQA successfully finished evaluating {eval_file}, results saved in {score_pth}')
188
+ logger.info('Score: ')
189
+ logger.info(score)
eval_mm/vlmevalkit/vlmeval/dataset/text_base.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from ..smp import *
3
+
4
+
5
+ class TextBaseDataset:
6
+ MODALITY = 'TEXT'
7
+ DATASET_URL = {}
8
+ DATASET_MD5 = {}
9
+
10
+ def __init__(self, dataset='MMBench', **kwargs):
11
+ self.dataset_name = dataset
12
+
13
+ data = self.load_data(dataset)
14
+
15
+ data['index'] = [str(x) for x in data['index']]
16
+
17
+ if np.all([istype(x, int) for x in data['index']]):
18
+ data['index'] = [int(x) for x in data['index']]
19
+
20
+ self.data = data
21
+ self.post_build(dataset)
22
+
23
+ def __len__(self):
24
+ return len(self.data)
25
+
26
+ def __getitem__(self, idx):
27
+ return dict(self.data.iloc[idx])
28
+
29
+ def prepare_tsv(self, url, file_md5=None):
30
+ data_root = LMUDataRoot()
31
+ os.makedirs(data_root, exist_ok=True)
32
+ update_flag = False
33
+ file_name = url.split('/')[-1]
34
+ data_path = osp.join(data_root, file_name)
35
+ if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
36
+ pass
37
+ else:
38
+ warnings.warn('The dataset tsv is not downloaded')
39
+ download_file(url, data_path)
40
+ update_flag = True
41
+
42
+ if file_size(data_path, 'GB') > 1:
43
+ local_path = data_path.replace('.tsv', '_local.tsv')
44
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
45
+ from ..tools import LOCALIZE
46
+ LOCALIZE(data_path, local_path)
47
+ data_path = local_path
48
+ return load(data_path)
49
+
50
+ def dump_image(self, line):
51
+ return []
52
+
53
+ def display(self, line):
54
+ if isinstance(line, int):
55
+ line = self.data.iloc[line]
56
+ assert isinstance(line, pd.Series) or isinstance(line, dict)
57
+ mmqa_display(line)
58
+
59
+ # Return a list of dataset names that are supported by this class, can override
60
+ @classmethod
61
+ def supported_datasets(cls):
62
+ return list(cls.DATASET_URL)
63
+
64
+ # Given the dataset name, return the dataset as a pandas dataframe, can override
65
+ def load_data(self, dataset):
66
+ url = self.DATASET_URL[dataset]
67
+ file_md5 = self.DATASET_MD5[dataset]
68
+ return self.prepare_tsv(url, file_md5)
69
+
70
+ # Post built hook, will be called after the dataset is built, can override
71
+ def post_build(self, dataset):
72
+ pass
73
+
74
+ # Given one data record, return the built prompt (a multi-modal message), can override
75
+ def build_prompt(self, line):
76
+ if isinstance(line, int):
77
+ line = self.data.iloc[line]
78
+
79
+ question = line['question']
80
+
81
+ msgs = []
82
+ msgs.append(dict(type='text', value=question))
83
+ return msgs
84
+
85
+ # Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
86
+ @abstractmethod
87
+ def evaluate(self, eval_file, **judge_kwargs):
88
+ pass
eval_mm/vlmevalkit/vlmeval/dataset/text_mcq.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .text_base import TextBaseDataset
2
+ from .utils import build_judge, DEBUG_MESSAGE
3
+ from ..smp import *
4
+
5
+
6
+ class TextMCQDataset(TextBaseDataset):
7
+ TYPE = 'MCQ'
8
+
9
+ DATASET_URL = {}
10
+
11
+ DATASET_MD5 = {}
12
+
13
+ def build_prompt(self, line):
14
+
15
+ if isinstance(line, int):
16
+ line = self.data.iloc[line]
17
+
18
+ question = line['question']
19
+ options = {
20
+ cand: line[cand]
21
+ for cand in string.ascii_uppercase
22
+ if cand in line and not pd.isna(line[cand])
23
+ }
24
+ options_prompt = 'Options:\n'
25
+ for key, item in options.items():
26
+ options_prompt += f'{key}. {item}\n'
27
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
28
+ prompt = ''
29
+ if hint is not None:
30
+ prompt += f'Hint: {hint}\n'
31
+ prompt += f'Question: {question}\n'
32
+ if len(options):
33
+ prompt += options_prompt
34
+ prompt += 'Please select the correct answer from the options above. \n'
35
+
36
+ msgs = []
37
+
38
+ msgs.append(dict(type='text', value=prompt))
39
+
40
+ return msgs
41
+
42
+ def evaluate(self, eval_file, **judge_kwargs):
43
+ from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
44
+ # assert dataset is not None
45
+ dataset_map = {
46
+ 'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
47
+ 'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
48
+ }
49
+ dataset = self.dataset_name
50
+ if dataset in dataset_map:
51
+ dataset = dataset_map[dataset]
52
+ nproc = judge_kwargs.pop('nproc', 4)
53
+
54
+ circular = False
55
+
56
+ suffix = eval_file.split('.')[-1]
57
+ model = judge_kwargs.get('model', 'exact_matching')
58
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
59
+ name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
60
+ name_str = name_str_map[model] if model in name_str_map else model
61
+
62
+ if model == 'exact_matching':
63
+ model = None
64
+ elif gpt_key_set():
65
+ model = build_judge(**judge_kwargs)
66
+ if not model.working():
67
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
68
+ warnings.warn(DEBUG_MESSAGE)
69
+ model = None
70
+ else:
71
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
72
+ model = None
73
+
74
+ result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
75
+
76
+ data = load(eval_file)
77
+ data = data.sort_values(by='index')
78
+ data['prediction'] = [str(x) for x in data['prediction']]
79
+ # If not choice label, then use lower case
80
+ for k in data.keys():
81
+ data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
82
+
83
+ meta = self.data
84
+ meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
85
+ data_map = {x: y for x, y in zip(data['index'], data['question'])}
86
+ for k in data_map:
87
+ assert k in meta_q_map, (
88
+ f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
89
+ )
90
+
91
+ if circular:
92
+ data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
93
+ else:
94
+ data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
95
+
96
+ # load split
97
+ dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
98
+ data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
99
+
100
+ # May have different report acc functions for different datasets
101
+ if 'MMT' in dataset:
102
+ acc = report_acc_MMT(data)
103
+ else:
104
+ acc = report_acc(data)
105
+
106
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
107
+ dump(acc, score_file)
108
+
109
+ return acc
110
+
111
+
112
+ class CustomTextMCQDataset(TextMCQDataset):
113
+
114
+ def load_data(self, dataset):
115
+ data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
116
+
117
+ if file_size(data_path, 'GB') > 1:
118
+ local_path = data_path.replace('.tsv', '_local.tsv')
119
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
120
+ from ..tools import LOCALIZE
121
+ LOCALIZE(data_path, local_path)
122
+ data_path = local_path
123
+ return load(data_path)
eval_mm/vlmevalkit/vlmeval/dataset/utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .judge_util import build_judge, DEBUG_MESSAGE
2
+ from .multiple_choice import extract_answer_from_item, prefetch_answer
3
+ from .vqa_eval import levenshtein_distance
4
+
5
+
6
+ __all__ = [
7
+ 'build_judge', 'extract_answer_from_item', 'prefetch_answer',
8
+ 'levenshtein_distance', 'DEBUG_MESSAGE'
9
+ ]
eval_mm/vlmevalkit/vlmeval/dataset/utils/judge_util.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from ...api import OpenAIWrapper
3
+ from ...smp import load_env
4
+
5
+ INTERNAL = os.environ.get('INTERNAL', 0)
6
+
7
+
8
+ def build_judge(**kwargs):
9
+ model = kwargs.pop('model', None)
10
+ kwargs.pop('nproc', None)
11
+ load_env()
12
+ LOCAL_LLM = os.environ.get('LOCAL_LLM', None)
13
+ if LOCAL_LLM is None:
14
+ model_map = {
15
+ 'gpt-4-turbo': 'gpt-4-1106-preview',
16
+ 'gpt-4-0613': 'gpt-4-0613',
17
+ 'gpt-4-0125': 'gpt-4-0125-preview',
18
+ 'gpt-4-0409': 'gpt-4-turbo-2024-04-09',
19
+ 'chatgpt-1106': 'gpt-3.5-turbo-1106',
20
+ 'chatgpt-0125': 'gpt-3.5-turbo-0125',
21
+ 'gpt-4o': 'gpt-4o-2024-05-13',
22
+ 'gpt-4o-mini': 'gpt-4o-mini-2024-07-18',
23
+ }
24
+ model_version = model_map[model]
25
+ else:
26
+ model_version = LOCAL_LLM
27
+ model = OpenAIWrapper(model_version, **kwargs)
28
+ return model
29
+
30
+
31
+ DEBUG_MESSAGE = """
32
+ To debug the OpenAI API, you can try the following scripts in python:
33
+ ```python
34
+ from vlmeval.api import OpenAIWrapper
35
+ model = OpenAIWrapper('gpt-4-1106-preview', verbose=True)
36
+ msgs = [dict(type='text', value='Hello!')]
37
+ code, answer, resp = model.generate_inner(msgs)
38
+ print(code, answer, resp)
39
+ ```
40
+ You cam see the specific error if the API call fails.
41
+ """
eval_mm/vlmevalkit/vlmeval/dataset/utils/llavabench.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from ...smp import *
4
+
5
+ rule_dict = {
6
+ 'llava_bench_conv': {'role': 'Assistant', 'prompt': 'We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.'}, # noqa: E501
7
+ 'llava_bench_detail': {'role': 'Assistant', 'prompt': 'We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.'}, # noqa: E501
8
+ 'llava_bench_complex': {'role': 'Assistant', 'prompt': 'We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.'} # noqa: E501
9
+ }
10
+
11
+
12
+ def get_eval(judge, content):
13
+ return judge.generate(content)
14
+
15
+
16
+ def parse_score(review):
17
+ logger = get_logger('Evaluation')
18
+ try:
19
+ score_pair = review.split('\n')[0]
20
+ score_pair = score_pair.replace(',', ' ')
21
+ sp = score_pair.split(' ')
22
+ if len(sp) == 2:
23
+ return [float(sp[0]), float(sp[1])]
24
+ else:
25
+ logger.error('error', review)
26
+ return [-1, -1]
27
+ except Exception as e:
28
+ logger.error(e, 'error', review)
29
+ return [-1, -1]
30
+
31
+
32
+ def build_prompt(line):
33
+ cap_str = line['caption']
34
+ question = line['question']
35
+ ans1 = line['gpt4_ans']
36
+ ans2 = line['prediction']
37
+ category = 'llava_bench_' + line['category']
38
+ rule = rule_dict[category]
39
+ role, prompt = rule['role'], rule['prompt']
40
+
41
+ content = (f'[Context]\n{cap_str}\n\n'
42
+ f'[Question]\n{question}\n\n'
43
+ f'[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n'
44
+ f'[{role} 2]\n{ans2}\n\n[End of {role} 2]\n\n'
45
+ f'[System]\n{prompt}\n\n')
46
+ return content
47
+
48
+
49
+ def LLaVABench_atomeval(model, prompt):
50
+ review = get_eval(model, prompt)
51
+ scores = parse_score(review)
52
+ return scores
53
+
54
+
55
+ def LLaVABench_score(data):
56
+ cates = ['overall'] + list(set(data['category']))
57
+ ret = defaultdict(list)
58
+
59
+ for c in cates:
60
+ ret['split'].append(c)
61
+ sub = data[data['category'] == c] if c != 'overall' else data
62
+ ret['Relative Score (main)'].append(np.mean(sub['score']) / np.mean(sub['gpt4_score']) * 100)
63
+ ret['VLM Score'].append(np.mean(sub['score']) * 10)
64
+ ret['GPT4 Score'].append(np.mean(sub['gpt4_score']) * 10)
65
+ return pd.DataFrame(ret)
eval_mm/vlmevalkit/vlmeval/dataset/utils/mathv.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...smp import *
2
+ from ...utils import can_infer
3
+ try:
4
+ from latex2sympy2 import latex2sympy
5
+ except ImportError:
6
+ print('Please install latex2sympy2 by running "pip install latex2sympy2"')
7
+
8
+ FAIL_MSG = 'Failed to obtain answer via API.'
9
+
10
+
11
+ def is_equal(asw: str, gt_asw: str) -> bool:
12
+ if not isinstance(asw, str) != str or not isinstance(gt_asw, str):
13
+ print('Warning: input is not string')
14
+ print(asw, gt_asw)
15
+ asw = str(asw).lower().strip()
16
+ gt_asw = str(gt_asw).lower().strip()
17
+ if gt_asw == asw:
18
+ return True
19
+ try:
20
+ a = eval(gt_asw)
21
+ b = eval(asw)
22
+ if abs(a - b) < 1e-6:
23
+ return True
24
+ except:
25
+ pass
26
+ try:
27
+ a = latex2sympy(gt_asw)
28
+ b = latex2sympy(asw)
29
+ if abs(eval(str(a)) - eval(str(b))) < 1e-6:
30
+ return True
31
+ if abs(a - b) < 1e-6:
32
+ return True
33
+ except:
34
+ pass
35
+ return False
36
+
37
+
38
+ def get_gpt4_ICE():
39
+ example_1 = """
40
+ Hint: Please answer the question and provide the final answer at the end.\n
41
+ Question: Which number is missing?\n
42
+ Model response: The number missing in the sequence is 14.\n
43
+ Extracted answer: 14
44
+ """
45
+
46
+ example_2 = """
47
+ Hint: Please answer the question and provide the final answer at the end.\n
48
+ Question: What is the fraction of females facing the camera?\n
49
+ Model response: The fraction of females facing the camera is 0.6,
50
+ which means that six out of ten females in the group are facing the camera.\n
51
+ Extracted answer: 0.6
52
+ """
53
+
54
+ example_3 = """
55
+ Hint: Please answer the question and provide the final answer at the end.\n
56
+ Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
57
+ Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
58
+ Extracted answer: 1.45
59
+ """
60
+
61
+ example_4 = """
62
+ Hint: Please answer the question and provide the final answer at the end.\n
63
+ Question: Between which two years does the line graph saw its maximum peak?\n
64
+ Model response: The line graph saw its maximum peak between 2007 and 2008.\n
65
+ Extracted answer: [2007, 2008]
66
+ """
67
+
68
+ example_5 = """
69
+ Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
70
+ Question: What fraction of the shape is blue?\n
71
+ Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
72
+ Model response: The correct answer is (B) 8/11.\n
73
+ Extracted answer: B
74
+ """
75
+
76
+ return [example_1, example_2, example_3, example_4, example_5]
77
+
78
+
79
+ def build_mathv_gpt4_prompt(line):
80
+ task_description = """
81
+ Please read the following example.
82
+ Then extract the answer from the model response and type it at the end of the prompt.\n
83
+ """
84
+ question = line['question']
85
+ prediction = str(line['prediction'])
86
+ prompt = task_description
87
+ examples = get_gpt4_ICE()
88
+ for example in examples:
89
+ prompt += example + '\n'
90
+ prompt += question + '\n'
91
+ prompt += 'Model respone: ' + prediction
92
+ prompt += 'Extracted answer:'
93
+ return prompt
94
+
95
+
96
+ def list_to_dict(lst):
97
+ return {chr(65 + i): val for i, val in enumerate(lst)}
98
+
99
+
100
+ def post_check(line, prefetch=False):
101
+ res = None
102
+ ans = line['answer']
103
+ response = line['prediction'] if prefetch else line['res']
104
+ try:
105
+ if len(eval(line['choices'])) > 0:
106
+ ans = line['answer']
107
+ choices = list_to_dict(eval(line['choices']))
108
+ res = can_infer(response, choices)
109
+ if prefetch:
110
+ return res
111
+ else:
112
+ res = str(response)
113
+ ans = str(ans)
114
+ except ValueError:
115
+ pass
116
+
117
+ if is_equal(res, ans):
118
+ return res if prefetch else True
119
+ else:
120
+ return False
121
+
122
+
123
+ def MATH_V_auxeval(model, line):
124
+ prompt = build_mathv_gpt4_prompt(line)
125
+ log = ''
126
+ retry = 5
127
+ if post_check(line, prefetch=True):
128
+ res = post_check(line, prefetch=True)
129
+ return dict(log='Prefetch succeed', res=res)
130
+ for i in range(retry):
131
+ prediction = line['prediction']
132
+ res = model.generate(prompt, temperature=i * 0.5)
133
+
134
+ if FAIL_MSG in res:
135
+ log += f'Try {i}: output is {prediction}, failed to parse.\n'
136
+ else:
137
+ log += 'Succeed'
138
+ return dict(log=log, res=res)
139
+ log += 'All 5 retries failed.\n'
140
+ return dict(log=log, res='')
141
+
142
+
143
+ def MATH_V_acc(result_file):
144
+ data = load(result_file)
145
+ tot = defaultdict(lambda: 0)
146
+ fetch = defaultdict(lambda: 0)
147
+ hit = defaultdict(lambda: 0)
148
+ lt = len(data)
149
+ for i in range(lt):
150
+ item = data.iloc[i]
151
+ cate = item['category']
152
+ tot['Overall'] += 1
153
+ tot[cate] += 1
154
+ if item['log'] == 'Prefetch succeed':
155
+ fetch['Overall'] += 1
156
+ fetch[cate] += 1
157
+ if post_check(item, prefetch=False):
158
+ hit['Overall'] += 1
159
+ hit[cate] += 1
160
+
161
+ res = defaultdict(list)
162
+ for k in tot.keys():
163
+ res['Subject'].append(k)
164
+ res['tot'].append(tot[k])
165
+ res['prefetch'].append(fetch[k])
166
+ res['hit'].append(hit[k])
167
+ res['prefetch_rate'].append(fetch[k] / tot[k] * 100)
168
+ res['acc'].append(hit[k] / tot[k] * 100)
169
+ res = pd.DataFrame(res).sort_values('Subject', ignore_index=True)
170
+ return res
eval_mm/vlmevalkit/vlmeval/dataset/utils/mathvista.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...smp import *
2
+ from ...utils import can_infer
3
+
4
+
5
+ FAIL_MSG = 'Failed to obtain answer via API.'
6
+
7
+
8
+ def get_gpt4_ICE():
9
+ example_1 = """
10
+ Hint: Please answer the question requiring an integer answer and provide the final value,
11
+ e.g., 1, 2, 3, at the end.\n
12
+ Question: Which number is missing?\n
13
+ Model response: The number missing in the sequence is 14.\n
14
+ Extracted answer: 14
15
+ """
16
+
17
+ example_2 = """
18
+ Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value,
19
+ e.g., 1.2, 1.3, 1.4, at the end.\n
20
+ Question: What is the fraction of females facing the camera?\n
21
+ Model response: The fraction of females facing the camera is 0.6,
22
+ which means that six out of ten females in the group are facing the camera.\n
23
+ Extracted answer: 0.6
24
+ """
25
+
26
+ example_3 = """
27
+ Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value,
28
+ e.g., 1.23, 1.34, 1.45, at the end.\n
29
+ Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
30
+ Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
31
+ Extracted answer: 1.45
32
+ """
33
+
34
+ example_4 = """
35
+ Hint: Please answer the question requiring a Python list as an answer and provide the final list,
36
+ e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.\n
37
+ Question: Between which two years does the line graph saw its maximum peak?\n
38
+ Model response: The line graph saw its maximum peak between 2007 and 2008.\n
39
+ Extracted answer: [2007, 2008]
40
+ """
41
+
42
+ example_5 = """
43
+ Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
44
+ Question: What fraction of the shape is blue?\n
45
+ Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
46
+ Model response: The correct answer is (B) 8/11.\n
47
+ Extracted answer: B
48
+ """
49
+
50
+ return [example_1, example_2, example_3, example_4, example_5]
51
+
52
+
53
+ def build_mathvista_gpt4_prompt(line):
54
+ task_description = """
55
+ Please read the following example.
56
+ Then extract the answer from the model response and type it at the end of the prompt.\n
57
+ """
58
+ question = line['question']
59
+ prediction = str(line['prediction'])
60
+ prompt = task_description
61
+ examples = get_gpt4_ICE()
62
+ for example in examples:
63
+ prompt += example + '\n'
64
+ prompt += question + '\n'
65
+ prompt += 'Model respone: ' + prediction
66
+ prompt += 'Extracted answer:'
67
+ return prompt
68
+
69
+
70
+ def list_to_dict(lst):
71
+ return {chr(65 + i): val for i, val in enumerate(lst)}
72
+
73
+
74
+ def post_check(line, prefetch=False):
75
+ res = None
76
+ ans = line['answer']
77
+ response = line['prediction'] if prefetch else line['res']
78
+ try:
79
+ if line['question_type'] == 'multi_choice':
80
+ ans = line['answer_option']
81
+ choices = list_to_dict(eval(line['choices']))
82
+ res = can_infer(response, choices)
83
+ if prefetch:
84
+ return res
85
+ else:
86
+ if line['answer_type'] == 'integer':
87
+ res = int(response)
88
+ ans = int(line['answer'])
89
+ elif line['answer_type'] == 'float':
90
+ res = float(response)
91
+ ans = float(line['answer'])
92
+ else:
93
+ res = str(res)
94
+ ans = str(ans)
95
+ except ValueError:
96
+ pass
97
+
98
+ if res == ans:
99
+ return res if prefetch else True
100
+ else:
101
+ return False
102
+
103
+
104
+ def MathVista_auxeval(model, line):
105
+ prompt = build_mathvista_gpt4_prompt(line)
106
+ log = ''
107
+ retry = 5
108
+ if post_check(line, prefetch=True):
109
+ res = post_check(line, prefetch=True)
110
+ return dict(log='Prefetch succeed', res=res)
111
+ for i in range(retry):
112
+ prediction = line['prediction']
113
+ res = model.generate(prompt, temperature=i * 0.5)
114
+
115
+ if FAIL_MSG in res:
116
+ log += f'Try {i}: output is {prediction}, failed to parse.\n'
117
+ else:
118
+ log += 'Succeed'
119
+ return dict(log=log, res=res)
120
+ log += 'All 5 retries failed.\n'
121
+ return dict(log=log, res='')
122
+
123
+
124
+ def MathVista_acc(result_file):
125
+ data = load(result_file)
126
+ tot = defaultdict(lambda: 0)
127
+ fetch = defaultdict(lambda: 0)
128
+ hit = defaultdict(lambda: 0)
129
+ lt = len(data)
130
+ skill_list = []
131
+ for i in range(lt):
132
+ item = data.iloc[i]
133
+ cate = item['task']
134
+ tot['Overall'] += 1
135
+ try:
136
+ skills = eval(item['skills'])
137
+ except SyntaxError:
138
+ skills = [item['skills']]
139
+ for skill in skills:
140
+ if skill not in skill_list:
141
+ skill_list.append(skill)
142
+ tot[skill] += 1
143
+ tot[cate] += 1
144
+ if item['log'] == 'Prefetch succeed':
145
+ fetch['Overall'] += 1
146
+ fetch[cate] += 1
147
+ for skill in skills:
148
+ fetch[skill] += 1
149
+ if post_check(item, prefetch=False):
150
+ hit['Overall'] += 1
151
+ hit[cate] += 1
152
+ for skill in skills:
153
+ hit[skill] += 1
154
+
155
+ res = defaultdict(list)
156
+ for k in tot.keys():
157
+ res['Task&Skill'].append(k)
158
+ res['tot'].append(tot[k])
159
+ res['prefetch'].append(fetch[k])
160
+ res['hit'].append(hit[k])
161
+ res['prefetch_rate'].append(fetch[k] / tot[k] * 100)
162
+ res['acc'].append(hit[k] / tot[k] * 100)
163
+ res = pd.DataFrame(res)
164
+ return res
eval_mm/vlmevalkit/vlmeval/dataset/utils/mmbench_video.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...smp import *
2
+ import numpy as np
3
+
4
+ FAIL_MSG = 'Failed to obtain answer via API.'
5
+
6
+ system_prompt = """
7
+ As an AI assistant, your task is to evaluate a candidate answer in comparison to a given correct answer.
8
+ The question itself, the correct 'groundtruth' answer, and the candidate answer will be provided to you.
9
+ Your assessment should range from 0 to 3, \
10
+ based solely on the semantic similarity between the groundtruth and the candidate answer, \
11
+ disregarding any grammatical differences.
12
+ A rating of 0 suggests no similarity, implying the candidate answer is entirely incorrect.
13
+ A rating of 1 suggests low similarity, meaning the candidate answer is largely incorrect.
14
+ A rating of 2 suggests high similarity, meaning the candidate answer is largely correct.
15
+ Lastly, a rating of 3 indicates complete similarity, which means the candidate answer is entirely correct.
16
+ Your response should be a single integer from 0, 1, 2, or 3.
17
+ """
18
+
19
+ MMV_DIMENSIONS = {
20
+ 'CP': ['Video Topic', 'Video Emotion', 'Video Scene', 'Video Style'],
21
+ 'FP-S': ['OCR', 'Object Recognition', 'Attribute Recognition', 'Event Recognition', 'Human Motion', 'Counting'],
22
+ 'FP-C': ['Spatial Relationship', 'Human-object Interaction', 'Human Interaction'],
23
+ 'HL': ['Hallucination'],
24
+ 'LR': ['Structuralized Image-Text Understanding', 'Mathematical Calculation'],
25
+ 'AR': ['Physical Property', 'Function Reasoning', 'Identity Reasoning'],
26
+ 'RR': ['Natural Relation', 'Physical Relation', 'Social Relation'],
27
+ 'CSR': ['Common Sense Reasoning'],
28
+ 'TR': ['Counterfactual Reasoning', 'Causal Reasoning', 'Future Prediction'],
29
+ }
30
+ L3_DIMS = []
31
+ for k, v in MMV_DIMENSIONS.items():
32
+ L3_DIMS.extend(v)
33
+
34
+ MMV_DIMENSIONS['Perception'] = []
35
+ MMV_DIMENSIONS['Reasoning'] = []
36
+ MMV_DIMENSIONS['Overall'] = []
37
+ for k in ['CP', 'FP-C', 'FP-S', 'HL']:
38
+ MMV_DIMENSIONS['Perception'].extend(MMV_DIMENSIONS[k])
39
+ MMV_DIMENSIONS['Overall'].extend(MMV_DIMENSIONS[k])
40
+ for k in ['LR', 'AR', 'RR', 'CSR', 'TR']:
41
+ MMV_DIMENSIONS['Reasoning'].extend(MMV_DIMENSIONS[k])
42
+ MMV_DIMENSIONS['Overall'].extend(MMV_DIMENSIONS[k])
43
+
44
+
45
+ def get_dimension_rating(data_path):
46
+ data = load(data_path)
47
+ coarse_rating = {k: [] for k in MMV_DIMENSIONS}
48
+ fine_rating = {k: [] for k in L3_DIMS}
49
+
50
+ for i in range(len(data)):
51
+ cate = data.iloc[i]['dimensions']
52
+ cates = eval(cate)
53
+
54
+ for c in cates:
55
+ fine_rating[c].append(data.iloc[i]['score'])
56
+
57
+ for d in MMV_DIMENSIONS:
58
+ if np.any([x in MMV_DIMENSIONS[d] for x in cates]):
59
+ coarse_rating[d].append(data.iloc[i]['score'])
60
+
61
+ coarse_all = {k: f'{np.mean([max(x, 0) for x in v]):.2f}' for k, v in coarse_rating.items()}
62
+ coarse_valid = {k: f'{np.mean([x for x in v if x >= 0]):.2f}' for k, v in coarse_rating.items()}
63
+ fine_all = {k: f'{np.mean([max(x, 0) for x in v]):.2f}' for k, v in fine_rating.items()}
64
+ fine_valid = {k: f'{np.mean([x for x in v if x >= 0]):.2f}' for k, v in fine_rating.items()}
65
+ return dict(coarse_all=coarse_all, coarse_valid=coarse_valid, fine_all=fine_all, fine_valid=fine_valid)
66
+
67
+
68
+ def build_prompt(item):
69
+ tmpl = 'Question: {}\nGroundtruth answer: {}\nCandidate answer: {}\nYour response: '
70
+ return tmpl.format(item['question'], item['answer'], item['prediction'])
eval_mm/vlmevalkit/vlmeval/dataset/utils/mmdu.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...smp import *
2
+
3
+ meta_prompt = """
4
+ You are an assistant skilled at evaluating the quality of creative text.
5
+ Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to \
6
+ the user question displayed below. You'll need to assess the response on the following dimensions: \
7
+ Creativity, Richness, Visual Perception, Logical Coherence, Answer Accuracy and Image Relationship Understanding. \
8
+ We will provide you with a creative question and the AI model's response and a reference answer for your evaluation. \
9
+ As you begin your assessment, follow this process:
10
+ 1. Evaluate the AI model's answers on different dimensions, pointing out its strengths or weaknesses \
11
+ in each dimension and assigning a score of 1 to 10 for each.
12
+ 2. Finally, based on the assessments across dimensions, \
13
+ provide an overall score of 1 to 10 for the AI model's response.
14
+ 3. Your scoring should be as stringent as possible and follow the scoring rules below:
15
+ In general, the higher the quality of the model's response and its strict adherence to user needs, \
16
+ the higher the score. Responses that do not meet user needs will receive lower scores.
17
+ Scoring rules:
18
+ Creativity:
19
+ Scores 1-2 when there is no innovation or uniqueness in the content.
20
+ Scores 3-4 when providing partially original content but with low creative quality.
21
+ Scores 5-6 when mostly creative but lacks significant novelty, with moderate quality.
22
+ Scores 7-8 when having novelty and high-quality content.
23
+ Scores 9-10 when highly novel and of exceptional quality compared to the reference answer.
24
+ Richness:
25
+ Scores 1-2 when lacking depth and breadth, with very limited information.
26
+ Scores 3-4 when limited in depth and breadth, with fewer explanations and examples, showing low diversity.
27
+ Scores 5-6 when limited in depth and breadth but provides basic necessary information.
28
+ Scores 7-8 when providing depth and useful additional information.
29
+ Scores 9-10 when providing exceptional depth, breadth, and high diversity compared to the reference answer.
30
+ Visual Perception:
31
+ Scores 1-2 when the description of the visual information in the image contains errors or \
32
+ is significantly inconsistent with the content of the image.
33
+ Scores 3-4 When the description of the visual information in the image reflects only a small amount \
34
+ of the image's information and contains some errors.
35
+ Scores 5-6 when the description of the visual information in the image includes the basic information \
36
+ of the image but contains minimal information.
37
+ Scores 7-8 when the description of the visual information in the image matches the image well and is rich in content, \
38
+ providing a substantial amount of information about the image.
39
+ Scores 9-10 when the description of the visual information in the image not only matches the image \
40
+ but also is more detailed and informative compared to the reference answer, providing more information about the image.
41
+ Logical Coherence:
42
+ Scores 1-2 when entirely incoherent, lacking any logic, and not matching the question or known information.
43
+ Scores 3-4 when somewhat coherent but with many logical errors or inconsistencies.
44
+ Scores 5-6 when mostly coherent, with few errors, but may struggle to maintain complete coherence in complex situations.
45
+ Scores 7-8 when excellent logical handling, very few errors.
46
+ Scores 9-10 when flawless logic, impeccable in handling complexity, \
47
+ and significantly higher logical coherence compared to the reference answer.
48
+ Answer Accuracy:
49
+ Scores 1-2 when the answer is significantly inconsistent with the question or contains obvious errors.
50
+ Scores 3-4 when the answer is partially correct but contains some errors or is incomplete.
51
+ Scores 5-6 when the answer is basically correct but lacks details or is not sufficiently detailed.
52
+ Scores 7-8 when the answer is accurate and detailed, fully corresponding to the question.
53
+ Scores 9-10 when the answer is not only accurate and detailed but also provides additional useful information, \
54
+ exceeding expectations.
55
+ Image Relationship Understanding:
56
+ Scores 1-2 when there are significant errors or confusion in distinguishing and describing different images, \
57
+ unable to correctly identify and relate the content of the images.
58
+ Scores 3-4 when the description of different images reflects only minimal distinguishing information, \
59
+ contains some errors and confusion, and fails to clearly differentiate and relate the images.
60
+ Scores 5-6 when the description of different images includes basic distinguishing information, \
61
+ is able to correctly identify and relate the images in a basic manner, \
62
+ but the information provided is minimal and lacks detail.
63
+ Scores 7-8 when the description of different images is accurate and detailed, \
64
+ clearly distinguishing and relating the images, \
65
+ with rich content that points out the main commonalities and differences between the images.
66
+ Scores 9-10 when the description of different images is not only accurate and detailed but also \
67
+ provides richer information and analysis, clearly distinguishing and relating the images, \
68
+ more comprehensively pointing out the commonalities and differences \
69
+ between the images compared to the reference answer.
70
+ Overall Score:
71
+ Scores 1-2 when irrelevant to the question, factually incorrect, or generates harmful content.
72
+ Scores 3-4 when no serious errors, mostly harmless, but of low quality and does not meet requirements.
73
+ Scores 5-6 when basically meeting requirements but performing poorly in some dimensions, with moderate quality.
74
+ Scores 7-8 when performing well in all dimensions.
75
+ Scores 9-10 when fully addressing user questions and all requirements, significantly surpassing the reference answer.
76
+ Please remember, you must evaluate and explain before scoring. After your explanation for each dimension, \
77
+ add the score for that dimension. Finally, at the end of your response, \
78
+ in the format of the dictionary (including brackets), return all your scoring results, \
79
+ ensuring your scores are integers:
80
+ {'Dimension One': Score, 'Dimension Two': Score, ..., 'Overall Score': Score}, \
81
+ for example: {'Creativity': 9, 'Richness': 6, ..., 'Overall Score': 7}.\n
82
+ """
83
+ question_begin_prompt = '[Question]'
84
+ reference_begin_prompt = '[The Start of Reference Answer]'
85
+ reference_end_prompt = '[The End of Reference Answer]'
86
+ answers_begin_prompt = '[The Start of Assistant’s Answer]'
87
+ answers_end_prompt = '[The End of Assistant’s Answer]'
88
+
89
+
90
+ def mmdu_score(model, line):
91
+ question = eval(line['question'])
92
+ gt = eval(line['answer'])
93
+ prediction = eval(line['prediction'])
94
+
95
+ DIMS = [
96
+ 'Creativity', 'Richness', 'Visual Perception', 'Logical Coherence',
97
+ 'Answer Accuracy', 'Image Relationship Understanding', 'Overall Score'
98
+ ]
99
+
100
+ all_result_dict = []
101
+ logs = []
102
+ for j in range(len(question)):
103
+ try:
104
+ prompt = meta_prompt + question_begin_prompt + '\n' + question[j] + '\n\n' + \
105
+ reference_begin_prompt + '\n' + gt[j] + '\n' + reference_end_prompt + '\n\n' + \
106
+ answers_begin_prompt + '\n' + prediction[j] + '\n' + answers_end_prompt
107
+ response = model.generate(prompt)
108
+ start_index = response.find('{')
109
+ end_index = response.rfind('}') + 1
110
+ dictionary_str = response[start_index: end_index]
111
+ result_dict = eval(dictionary_str)
112
+ all_result_dict.append(result_dict)
113
+ if all([x in result_dict for x in DIMS]):
114
+ logs.append('Succeed')
115
+ else:
116
+ logs.append(
117
+ f'Following Dims are not in results of turn {j}: '
118
+ f'{",".join([x for x in DIMS if x not in result_dict])}'
119
+ )
120
+ except Exception as e:
121
+ print({e})
122
+ all_result_dict.append({d: None for d in DIMS})
123
+ logs.append(str(e))
124
+
125
+ df = pd.DataFrame(all_result_dict)
126
+ return dict(res=df, log='\n'.join(logs))
eval_mm/vlmevalkit/vlmeval/dataset/utils/mmvet.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...smp import *
2
+
3
+
4
+ def build_mmvet_gpt4_prompt(line):
5
+ question = line['question']
6
+ gt = str(line['answer'])
7
+ prediction = str(line['prediction'])
8
+ prompt = """
9
+ Compare the ground truth and prediction from AI models, to give a correctness score for the prediction.
10
+ <AND> in the ground truth means it is totally right
11
+ only when all elements in the ground truth are present in the prediction,
12
+ and <OR> means it is totally right when any one element in the ground truth is present in the prediction.
13
+ The correctness score is 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right).
14
+ Just complete the last space of the correctness score.
15
+
16
+ Question | Ground truth | Prediction | Correctness
17
+ --- | --- | --- | ---
18
+ What is x in the equation? | -1 <AND> -5 | x = 3 | 0.0
19
+ What is x in the equation? | -1 <AND> -5 | x = -1 | 0.5
20
+ What is x in the equation? | -1 <AND> -5 | x = -5 | 0.5
21
+ What is x in the equation? | -1 <AND> -5 | x = -5 or 5 | 0.5
22
+ What is x in the equation? | -1 <AND> -5 | x = -1 or x = -5 | 1.0
23
+ Can you explain this meme? | This meme is poking fun at the fact that the names of the countries
24
+ Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes,
25
+ while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues
26
+ because the names of these countries do not accurately represent their landscapes. |
27
+ The meme talks about Iceland and Greenland. It's pointing out that despite their names,
28
+ Iceland is not very icy and Greenland isn't very green. | 0.4
29
+ Can you explain this meme? | This meme is poking fun at the fact that the names of the countries
30
+ Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes,
31
+ while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues
32
+ because the names of these countries do not accurately represent their landscapes. |
33
+ The meme is using humor to point out the misleading nature of Iceland's and Greenland's names.
34
+ Iceland, despite its name, has lush green landscapes while Greenland is mostly covered in ice and snow.
35
+ The text 'This is why I have trust issues' is a playful way to suggest
36
+ that these contradictions can lead to distrust or confusion.
37
+ The humor in this meme is derived from the unexpected contrast between the names of the countries
38
+ and their actual physical characteristics. | 1.0
39
+ """
40
+ gpt4_prompt = prompt + '\n' + ' | '.join(
41
+ [question, gt.replace('<AND>', ' <AND> ').replace('<OR>', ' <OR> '), prediction, ''])
42
+ return gpt4_prompt
43
+
44
+
45
+ def MMVet_auxeval(model, line):
46
+ def float_cvt(s):
47
+ try:
48
+ return float(s)
49
+ except ValueError:
50
+ return None
51
+
52
+ prompt = build_mmvet_gpt4_prompt(line)
53
+ log = ''
54
+ retry = 5
55
+ for i in range(retry):
56
+ output = model.generate(prompt, temperature=i * 0.5)
57
+ score = float_cvt(output)
58
+ if score is None:
59
+ log += f'Try {i}: output is {output}, failed to parse.\n'
60
+ elif score < 0 or score > 1:
61
+ log += f'Try {i}: output is {output}, invalid score: {score}.\n'
62
+ else:
63
+ log += 'Succeed'
64
+ return dict(log=log, score=score)
65
+ log += 'All 5 retries failed.\n'
66
+ return dict(log=log, score=0.0)
67
+
68
+
69
+ def MMVet_acc(result_file):
70
+ data = load(result_file)
71
+ tot = defaultdict(lambda: 0)
72
+ score = defaultdict(lambda: 0)
73
+ lt = len(data)
74
+ cate2_list = []
75
+ for i in range(lt):
76
+ item = data.iloc[i]
77
+ cate = item['category']
78
+ cate2 = cate.replace(',', '_')
79
+ if cate2 not in cate2_list:
80
+ cate2_list.append(cate2)
81
+ grade = float(item['score'])
82
+ cate_list = ['rec', 'ocr', 'know', 'gen', 'spat', 'math']
83
+ for capa in cate_list:
84
+ if capa in cate:
85
+ tot[capa] += 1
86
+ score[capa] += grade
87
+ tot['Overall'] += 1
88
+ tot[cate2] += 1
89
+ score['Overall'] += grade
90
+ score[cate2] += grade
91
+
92
+ res = defaultdict(list)
93
+ res2 = defaultdict(list)
94
+ cate_list.append('Overall')
95
+ cate2_list.append('Overall')
96
+ for k in cate_list:
97
+ res['Category'].append(k)
98
+ res['tot'].append(tot[k])
99
+ res['acc'].append(score[k] / tot[k] * 100)
100
+ for v in cate2_list:
101
+ res2['Category'].append(v)
102
+ res2['tot'].append(tot[v])
103
+ res2['acc'].append(score[v] / tot[v] * 100)
104
+ res = pd.DataFrame(res)
105
+ res2 = pd.DataFrame(res2)
106
+ return res, res2
eval_mm/vlmevalkit/vlmeval/dataset/utils/multiple_choice.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from ...utils import can_infer, track_progress_rich
3
+ from ...smp import *
4
+ import numpy as np
5
+
6
+ MMB_abbrs = {
7
+ 'coarse_perception': 'CP',
8
+ 'finegrained_perception (instance-level)': 'FP-S',
9
+ 'finegrained_perception (cross-instance)': 'FP-C',
10
+ 'logic_reasoning': 'LR',
11
+ 'relation_reasoning': 'RR',
12
+ 'attribute_reasoning': 'AR'
13
+ }
14
+
15
+ MMT_abbrs = {
16
+ 'visual_recognition': 'VR',
17
+ 'localization': 'Loc',
18
+ 'ocr': 'OCR',
19
+ 'counting': 'Count',
20
+ 'hallucination': 'HLN',
21
+ 'image_retrieval': 'IR',
22
+ 'threed': '3D',
23
+ 'visual_captioning': 'VC',
24
+ 'visual_grounding': 'VG',
25
+ 'doc_understanding': 'DU',
26
+ 'action_recognition': 'AR',
27
+ 'pixel_level_perception': 'PLP',
28
+ 'image-to-image_translation': 'I2IT',
29
+ 'relation_reasoning': 'RR',
30
+ 'intelligence_quotient_test': 'IQT',
31
+ 'emotion': 'Emo',
32
+ 'visual_illusion': 'VI',
33
+ 'meme_understanding': 'MemU',
34
+ 'visual_prompt_understanding': 'VPU',
35
+ 'anomaly_detection': 'AND',
36
+ 'keypoint_detection': 'KD',
37
+ 'visual_commonsense_reasoning': 'VCR',
38
+ 'image_evaluation_judgement': 'IEJ',
39
+ 'multiple_image_analysis': 'MIA',
40
+ 'cross_image_matching': 'CIM',
41
+ 'temporal_understanding': 'TU',
42
+ 'visual_code': 'VP',
43
+ 'medical_understanding': 'MedU',
44
+ 'autonomous_driving': 'AUD',
45
+ 'discipline_knowledge_reasoning': 'DKR',
46
+ 'embodied_ai': 'EA',
47
+ 'gui_navigation': 'GN'
48
+ }
49
+
50
+
51
+ def MMMU_preproc(data):
52
+ logger = get_logger('Evaluation')
53
+ cnt = 0
54
+ As, Bs, Ans = list(data['A']), list(data['B']), list(data['answer'])
55
+ lt = len(data)
56
+ for i in range(lt):
57
+ if pd.isna(As[i]):
58
+ As[i] = Ans[i]
59
+ Bs[i] = 'Other Answers'
60
+ cnt += 1
61
+ logger.info(f'During MMMU_preproc in Evaluation, {cnt} open questions are re-formulated to multi-choice ones. ')
62
+ data['A'] = As
63
+ data['B'] = Bs
64
+ return data
65
+
66
+
67
+ def report_acc(df):
68
+ # assert group in [None, 'category', 'l2-category']
69
+ res = defaultdict(list)
70
+
71
+ if 'split' in df:
72
+ splits = list(set(df['split']))
73
+ res['split'] = splits
74
+ else:
75
+ df['split'] = ['none'] * len(df)
76
+ res['split'] = ['none']
77
+
78
+ for group in [None, 'l2-category', 'category']:
79
+ if group is None:
80
+ res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
81
+ elif group not in df:
82
+ continue
83
+ else:
84
+ abilities = list(set(df[group]))
85
+ abilities.sort()
86
+ for ab in abilities:
87
+ ab_name = MMB_abbrs[ab] if ab in MMB_abbrs else ab
88
+ sub_df = df[df[group] == ab]
89
+ res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
90
+ return pd.DataFrame(res)
91
+
92
+
93
+ def report_acc_MMT(df):
94
+ # assert group in [None, 'category', 'l2-category']
95
+ res = defaultdict(list)
96
+ res['split'] = list()
97
+ res['Overall'] = list()
98
+ for _, name in MMT_abbrs.items():
99
+ res[name] = list()
100
+
101
+ if 'split' in df:
102
+ splits = list(set(df['split']))
103
+ res['split'] = splits
104
+
105
+ else:
106
+ df['split'] = ['none'] * len(df)
107
+ res['split'] = ['none']
108
+
109
+ for group in [None, 'category', 'l2-category']:
110
+ if group is None:
111
+ res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
112
+ res['Overall'].extend([np.mean(df['hit'])])
113
+ elif group not in df:
114
+ continue
115
+ elif group == 'category':
116
+ abilities = list(set(df[group]))
117
+ abilities.sort()
118
+ for ab in abilities:
119
+ ab_name = ab
120
+ sub_df = df[df[group] == ab]
121
+ res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
122
+ res[ab_name].extend([np.mean(sub_df['hit'])])
123
+ else:
124
+ abilities = list(set(df[group]))
125
+ abilities.sort()
126
+ for ab in abilities:
127
+ sub_task_name_list = df[df['l2-category'] == ab]['category'].unique()
128
+ sub_task_acc = []
129
+ for sub_task_name in sub_task_name_list:
130
+ sub_df = df[df['category'] == sub_task_name]
131
+ sub_task_acc.append([np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']])
132
+
133
+ new_acc = []
134
+ for i in range(len(sub_task_acc[0])):
135
+ new_acc.append(sum([_[i] for _ in sub_task_acc]) / len([_ for _ in sub_task_acc]))
136
+ ab_name = MMT_abbrs[ab] if ab in MMT_abbrs else ab
137
+ res[ab_name] = new_acc
138
+
139
+ sub_task_acc = []
140
+ for sub_task_name in sub_task_name_list:
141
+ sub_df = df[df['category'] == sub_task_name]
142
+ sub_task_acc.append([np.mean(sub_df['hit'])])
143
+ new_acc = []
144
+ for i in range(len(sub_task_acc[0])):
145
+ new_acc.append(sum([_[i] for _ in sub_task_acc]) / len([_ for _ in sub_task_acc]))
146
+
147
+ res[ab_name].extend(new_acc)
148
+
149
+ res['split'].append('ALL')
150
+ return pd.DataFrame(res)
151
+
152
+
153
+ def build_prompt(question, options, prediction):
154
+ tmpl = (
155
+ 'You are an AI assistant who will help me to match '
156
+ 'an answer with several options of a single-choice question. '
157
+ 'You are provided with a question, several options, and an answer, '
158
+ 'and you need to find which option is most similar to the answer. '
159
+ 'If the meaning of all options are significantly different from the answer, output Z. '
160
+ 'Your should output a single uppercase character in A, B, C, D (if they are valid options), and Z. \n'
161
+ 'Example 1: \n'
162
+ 'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n'
163
+ 'Answer: a cute teddy bear\nYour output: A\n'
164
+ 'Example 2: \n'
165
+ 'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n'
166
+ 'Answer: Spider\nYour output: Z\n'
167
+ 'Example 3: \n'
168
+ 'Question: {}?\nOptions: {}\nAnswer: {}\nYour output: '
169
+ )
170
+ return tmpl.format(question, options, prediction)
171
+
172
+
173
+ def build_prompt_blink(question, options, prediction):
174
+ tmpl = (
175
+ 'You are an AI assistant who will help me to match an answer with several options of a single-choice question. '
176
+ 'You are provided with a question, several options, and an answer, '
177
+ 'and you need to find which option is most similar to the answer. '
178
+ "If the answer says things like refuse to answer, I'm sorry cannot help, etc., output Z."
179
+ 'If the meaning of all options are significantly different from the answer, '
180
+ 'or the answer does not select any option, output Z. '
181
+ 'Your should output one of the choices, A, B, C, D (if they are valid options), or Z.\n'
182
+ 'Example 1: \n'
183
+ 'Question: Which point is closer to the camera?\nSelect from the following choices.\n'
184
+ 'Options: A. Point A\nB. Point B\n(Z) Failed\n'
185
+ 'Answer: Point B, where the child is sitting, is closer to the camera.\nYour output: (B)\n'
186
+ 'Example 2: \n'
187
+ 'Question: Which point is closer to the camera?\nSelect from the following choices.\n'
188
+ 'Options: (A) Point A\n(B) Point B\n(Z) Failed\n'
189
+ "Answer: I'm sorry, but I can't assist with that request.\nYour output: (Z)\n"
190
+ 'Example 3: \n'
191
+ 'Question: Which point is corresponding to the reference point?\nSelect from the following choices.\n'
192
+ 'Options: (A) Point A\n(B) Point B\n(Z) Failed\n'
193
+ 'Answer:The reference point (REF) on the first image is at the tip of the pot, '
194
+ 'which is the part used to Poke if the pots were used for that action. Looking at the second image, '
195
+ 'we need to find the part of the object that would correspond to poking.\n'
196
+ "(A) Point A is at the tip of the spoon's handle, which is not used for poking.\n"
197
+ '(B) Point B is at the bottom of the spoon, which is not used for poking.\n'
198
+ '(C) Point C is on the side of the pspoonot, which is not used for poking.\n'
199
+ '(D) Point D is at the tip of the spoon, which is not used for poking.\n'
200
+ '\nTherefore, there is no correct answer in the choices\nYour output: (Z)\n'
201
+ 'Example 4: \n'
202
+ 'Question: {}?\nOptions: {}\n(Z) Failed\nAnswer: {}\nYour output: '
203
+ )
204
+ return tmpl.format(question, options, prediction)
205
+
206
+
207
+ def build_prompt_cn(question, options, prediction):
208
+ tmpl = (
209
+ '你是一个帮助我匹配答案与单选题中多个选项的 AI 助手。'
210
+ '你会被提供:一个问题,多个选项,一个答案。你的任务是找到与答案意义最相近的选项。'
211
+ '如果所有选项的意义都与答案显著不同,则输出 Z。'
212
+ '你应该输出一个单个的大写字母,例如 A, B, C, D(如果它们是有效选项),或 Z。'
213
+ '例 1:'
214
+ '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 一只可爱的泰迪熊\n输出: A\n'
215
+ '例 2: \n'
216
+ '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 蜘蛛\n输出: Z\n'
217
+ '例 3: \n'
218
+ '问题: {}?\n选项: {}\n答案: {}\n输出: '
219
+ )
220
+ return tmpl.format(question, options, prediction)
221
+
222
+
223
+ def build_choices(item):
224
+ ret = {}
225
+ for ch in string.ascii_uppercase:
226
+ if ch in item and (not pd.isna(item[ch])):
227
+ ret[ch] = item[ch]
228
+ return ret
229
+
230
+
231
+ def prefetch_answer(item):
232
+ choices = build_choices(item)
233
+ return can_infer(item['prediction'], choices)
234
+
235
+
236
+ def extract_answer_from_item(model, item, dataset_name=None):
237
+ logger = get_logger('Evaluation')
238
+ # It will return: (pred, raw, llm_time)
239
+ choices = build_choices(item)
240
+ option_str = build_option_str(choices)
241
+
242
+ if dataset_name == 'BLINK':
243
+ prompt = build_prompt_blink(item['question'], option_str, item['prediction'])
244
+ elif cn_string(item['question']):
245
+ prompt = build_prompt_cn(item['question'], option_str, item['prediction'])
246
+ else:
247
+ prompt = build_prompt(item['question'], option_str, item['prediction'])
248
+ retry = 3
249
+
250
+ ret = can_infer(item['prediction'], choices)
251
+ if ret:
252
+ return dict(opt=ret, log=item['prediction'])
253
+ if model is None:
254
+ return dict(opt='Z', log='Failed in Prefetch, no GPT-based answer matching under `exact_matching` policy.')
255
+
256
+ while retry:
257
+ ans = model.generate(prompt)
258
+ if 'Failed to obtain answer via API' in ans:
259
+ logger.warning('GPT API failed to answer. ')
260
+ else:
261
+ ret = can_infer(ans, choices)
262
+ if ret:
263
+ return dict(opt=ret, log=ans)
264
+ else:
265
+ logger.warning(f'Output includes 0 / > 1 letter among candidates {set(choices)} and Z: {ans}')
266
+ retry -= 1
267
+
268
+ if retry == 0:
269
+ options = list(choices) + ['Z'] if 'Z' not in choices else []
270
+ return dict(opt=rd.choice(options), log='Failed to predict, thus randomly generate one. ')
271
+
272
+
273
+ # For Circular Evaluation
274
+ def prefetch_circular_group(sub_data, verbose=False):
275
+ lt = len(sub_data)
276
+ GT, PRED = [], []
277
+ for i in range(lt):
278
+ item = sub_data.iloc[i]
279
+ GT.append(item['GT'])
280
+ PRED.append(prefetch_answer(item))
281
+ if PRED[-1] and (GT[-1] != PRED[-1]):
282
+ log = (
283
+ f'Failed in Prefetching Rolling {i}: Answer is {GT[-1]}, '
284
+ f"Prediction is {item['prediction']}, Pre-fetched is {PRED[-1]}. "
285
+ )
286
+ return dict(hit=0, log=log)
287
+ flag = True
288
+ for g, p in zip(GT, PRED):
289
+ if g != p:
290
+ flag = False
291
+ ret = (dict(hit=1, log='Succeed During Pre-fetching'), ) if flag else (None, )
292
+ ret = ret + (GT, PRED) if verbose else ret
293
+ return ret if len(ret) > 1 else ret[0]
294
+
295
+
296
+ def eval_vanilla(model, item, dataset_name=None):
297
+ res = extract_answer_from_item(model, item, dataset_name=dataset_name)
298
+ opt, match_log = res['opt'], res['log']
299
+ if opt == item['GT']:
300
+ return dict(hit=1, log=f'Match Log: {match_log}. ')
301
+ else:
302
+ return dict(hit=0, log=f'Match Log: {match_log}. ')
303
+
304
+
305
+ # For Circular Evaluation
306
+ def eval_circular_group(model, sub_data, dataset_name=None):
307
+ res, GT, PRED = prefetch_circular_group(sub_data, verbose=True)
308
+ if res is not None:
309
+ return res
310
+
311
+ lt = len(sub_data)
312
+ log = ''
313
+ for i in range(lt):
314
+ if PRED[i]:
315
+ log += f'Rolling {i} Matched.\n'
316
+ else:
317
+ res = extract_answer_from_item(model, sub_data.iloc[i], dataset_name=dataset_name)
318
+ opt, match_log = res['opt'], res['log']
319
+ PRED[i] = opt
320
+ if PRED[i] != GT[i]:
321
+ log += (
322
+ f"Failed in Rolling {i}: Answer is {GT[i]}; Prediction is {sub_data.iloc[i]['prediction']}; "
323
+ f'Pre-fetched is {PRED[i]}; Match Log is {match_log}.\n'
324
+ )
325
+ return dict(hit=0, log=log)
326
+ else:
327
+ log += (
328
+ f"Rolling {i}: Answer is {GT[i]}, Prediction is {sub_data.iloc[i]['prediction']}, "
329
+ f'Pre-fetched is {PRED[i]}.\n'
330
+ )
331
+
332
+ return dict(hit=1, log=log)
333
+
334
+
335
+ # data, meta are pd.DataFrame, result_file is a path
336
+ def mcq_vanilla_eval(model, data, meta, nproc, result_file, dataset_name=None):
337
+ result = {}
338
+ if osp.exists(result_file):
339
+ result = load(result_file)
340
+ answer_map = {i: c for i, c in zip(meta['index'], meta['answer'])}
341
+
342
+ if 'MMMU' in dataset_name:
343
+ data = MMMU_preproc(data)
344
+ answer_map = {k: (v if v in list(string.ascii_uppercase) else 'A') for k, v in answer_map.items()}
345
+
346
+ data = data[data['index'].isin(answer_map)]
347
+ data['GT'] = [answer_map[idx] for idx in data['index']]
348
+ items = []
349
+
350
+ for i in range(len(data)):
351
+ # Dealing with the normal part
352
+ item = data.iloc[i]
353
+ if item['index'] not in result:
354
+ items.append(item)
355
+
356
+ tups = [dict(model=model, item=x, dataset_name=dataset_name) for x in items]
357
+ keys = [x['index'] for x in items]
358
+ if len(tups):
359
+ res = track_progress_rich(eval_vanilla, tups, nproc=nproc, chunksize=nproc, save=result_file, keys=keys)
360
+ result = load(result_file)
361
+ for k, v in zip(keys, res):
362
+ if k in result:
363
+ assert result[k]['hit'] == v['hit'] and result[k]['log'] == v['log']
364
+ else:
365
+ result[k] = v
366
+ data['hit'] = [result[i]['hit'] for i in data['index']]
367
+ data['log'] = [result[i]['log'] for i in data['index']]
368
+ if 'GT' in data:
369
+ data.pop('GT')
370
+ return data
371
+
372
+
373
+ # data, meta are pd.DataFrame, result_file is a path
374
+ def mcq_circular_eval(model, data, meta, nproc, result_file, dataset_name=None):
375
+ result = {}
376
+ if osp.exists(result_file):
377
+ result = load(result_file)
378
+ # Build Answer Map
379
+ answer_map = {i: c for i, c in zip(meta['index'], meta['answer'])}
380
+
381
+ for idx in list(meta['index']) + list(data['index']):
382
+ assert istype(idx, int)
383
+
384
+ # Only keep those lines in the meta data
385
+ data = data[data['index'].isin(answer_map)]
386
+ data['GT'] = [answer_map[idx] for idx in data['index']]
387
+ data_main = data[data['index'] < int(1e6)]
388
+
389
+ data_groups = []
390
+ for i in range(len(data_main)):
391
+ # Dealing with the normal part
392
+ idx = data_main.iloc[i]['index']
393
+ if idx not in result:
394
+ sub_data = data[data['index'] % int(1e6) == idx]
395
+ data_groups.append(sub_data)
396
+
397
+ if len(data_groups):
398
+ prefetched = [prefetch_circular_group(g, verbose=False) for g in data_groups]
399
+ remain = []
400
+ for dg, pf in zip(data_groups, prefetched):
401
+ if pf is not None:
402
+ result[dg.iloc[0]['index'] % 1e6] = pf
403
+ else:
404
+ remain.append(dg)
405
+ dump(result, result_file)
406
+
407
+ tups = [dict(model=model, sub_data=x, dataset_name=dataset_name) for x in remain]
408
+ keys = [x.iloc[0]['index'] % 1e6 for x in remain]
409
+
410
+ if len(tups) == 0:
411
+ pass
412
+ elif model is None:
413
+ logger = get_logger('Evaluation')
414
+ logger.warning('Exact Matching mode, will not do GPT-based answer matching. ')
415
+ for k in keys:
416
+ result[k] = dict(
417
+ hit=0, log='Failed in Prefetch, no GPT-based answer matching under `exact_matching` policy.')
418
+ else:
419
+ res = track_progress_rich(
420
+ eval_circular_group,
421
+ tups,
422
+ nproc=nproc,
423
+ chunksize=nproc,
424
+ save=result_file,
425
+ keys=keys)
426
+ result = load(result_file)
427
+ for k, v in zip(keys, res):
428
+ if k in result:
429
+ assert result[k]['hit'] == v['hit'] and result[k]['log'] == v['log']
430
+ else:
431
+ result[k] = v
432
+
433
+ tmp_pth = f'/tmp/{timestr()}.xlsx'
434
+ dump(data_main, tmp_pth)
435
+ data_main = load(tmp_pth)
436
+ indices = data_main['index']
437
+ data_main['hit'] = [result[i]['hit'] for i in indices]
438
+ data_main['log'] = [result[i]['log'] for i in indices]
439
+ if 'GT' in data_main:
440
+ data_main.pop('GT')
441
+
442
+ return data_main
eval_mm/vlmevalkit/vlmeval/dataset/utils/mvbench.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...smp import *
2
+ from PIL import Image, ImageOps
3
+ import torchvision
4
+ import random
5
+ import numbers
6
+ import math
7
+ import torch
8
+
9
+
10
+ def get_dimension_rating(data_path):
11
+ data = load(data_path)
12
+ result_board = {}
13
+ for idx, item in data.iterrows():
14
+ if item['task_type'] not in result_board:
15
+ result_board[item['task_type']] = [0, 0]
16
+ result_board[item['task_type']][1] += 1
17
+ if item['score']:
18
+ result_board[item['task_type']][0] += 1
19
+
20
+ correct = 0
21
+ total = 0
22
+ for key, value in result_board.items():
23
+ correct += value[0]
24
+ total += value[1]
25
+ result_board[key].append(f'{value[0] / value[1] * 100 :.2f}%')
26
+
27
+ result_board['overall'] = [correct, total, f'{correct / total * 100 :.2f}%']
28
+
29
+ return result_board
30
+
31
+
32
+ def check_ans(pred, gt):
33
+ flag = False
34
+
35
+ pred_list = pred.lower().split(' ')
36
+ pred_option, _ = pred_list[0], ' '.join(pred_list[1:])
37
+ gt_list = gt.lower().split(' ')
38
+ gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
39
+ if gt_content[-1] == '.':
40
+ gt_content = gt_content[:-1]
41
+
42
+ if pred_option.replace('.', '') in gt_option:
43
+ flag = True
44
+ elif gt_option in pred_option:
45
+ flag = True
46
+
47
+ return flag
48
+
49
+
50
+ class GroupRandomCrop(object):
51
+ def __init__(self, size):
52
+ if isinstance(size, numbers.Number):
53
+ self.size = (int(size), int(size))
54
+ else:
55
+ self.size = size
56
+
57
+ def __call__(self, img_group):
58
+
59
+ w, h = img_group[0].size
60
+ th, tw = self.size
61
+
62
+ out_images = list()
63
+
64
+ x1 = random.randint(0, w - tw)
65
+ y1 = random.randint(0, h - th)
66
+
67
+ for img in img_group:
68
+ assert (img.size[0] == w and img.size[1] == h)
69
+ if w == tw and h == th:
70
+ out_images.append(img)
71
+ else:
72
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
73
+
74
+ return out_images
75
+
76
+
77
+ class MultiGroupRandomCrop(object):
78
+ def __init__(self, size, groups=1):
79
+ if isinstance(size, numbers.Number):
80
+ self.size = (int(size), int(size))
81
+ else:
82
+ self.size = size
83
+ self.groups = groups
84
+
85
+ def __call__(self, img_group):
86
+
87
+ w, h = img_group[0].size
88
+ th, tw = self.size
89
+
90
+ out_images = list()
91
+
92
+ for i in range(self.groups):
93
+ x1 = random.randint(0, w - tw)
94
+ y1 = random.randint(0, h - th)
95
+
96
+ for img in img_group:
97
+ assert (img.size[0] == w and img.size[1] == h)
98
+ if w == tw and h == th:
99
+ out_images.append(img)
100
+ else:
101
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
102
+
103
+ return out_images
104
+
105
+
106
+ class GroupCenterCrop(object):
107
+ def __init__(self, size):
108
+ self.worker = torchvision.transforms.CenterCrop(size)
109
+
110
+ def __call__(self, img_group):
111
+ return [self.worker(img) for img in img_group]
112
+
113
+
114
+ class GroupRandomHorizontalFlip(object):
115
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
116
+ """
117
+
118
+ def __init__(self, is_flow=False):
119
+ self.is_flow = is_flow
120
+
121
+ def __call__(self, img_group, is_flow=False):
122
+ v = random.random()
123
+ if v < 0.5:
124
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
125
+ if self.is_flow:
126
+ for i in range(0, len(ret), 2):
127
+ # invert flow pixel values when flipping
128
+ ret[i] = ImageOps.invert(ret[i])
129
+ return ret
130
+ else:
131
+ return img_group
132
+
133
+
134
+ class GroupNormalize(object):
135
+ def __init__(self, mean, std):
136
+ self.mean = mean
137
+ self.std = std
138
+
139
+ def __call__(self, tensor):
140
+ rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
141
+ rep_std = self.std * (tensor.size()[0] // len(self.std))
142
+
143
+ # TODO: make efficient
144
+ for t, m, s in zip(tensor, rep_mean, rep_std):
145
+ t.sub_(m).div_(s)
146
+
147
+ return tensor
148
+
149
+
150
+ class GroupScale(object):
151
+ """ Rescales the input PIL.Image to the given 'size'.
152
+ 'size' will be the size of the smaller edge.
153
+ For example, if height > width, then image will be
154
+ rescaled to (size * height / width, size)
155
+ size: size of the smaller edge
156
+ interpolation: Default: PIL.Image.BILINEAR
157
+ """
158
+
159
+ def __init__(self, size, interpolation=Image.BILINEAR):
160
+ self.worker = torchvision.transforms.Resize(size, interpolation)
161
+
162
+ def __call__(self, img_group):
163
+ return [self.worker(img) for img in img_group]
164
+
165
+
166
+ class GroupOverSample(object):
167
+ def __init__(self, crop_size, scale_size=None, flip=True):
168
+ self.crop_size = crop_size if not isinstance(
169
+ crop_size, int) else (crop_size, crop_size)
170
+
171
+ if scale_size is not None:
172
+ self.scale_worker = GroupScale(scale_size)
173
+ else:
174
+ self.scale_worker = None
175
+ self.flip = flip
176
+
177
+ def __call__(self, img_group):
178
+
179
+ if self.scale_worker is not None:
180
+ img_group = self.scale_worker(img_group)
181
+
182
+ image_w, image_h = img_group[0].size
183
+ crop_w, crop_h = self.crop_size
184
+
185
+ offsets = GroupMultiScaleCrop.fill_fix_offset(
186
+ False, image_w, image_h, crop_w, crop_h)
187
+ oversample_group = list()
188
+ for o_w, o_h in offsets:
189
+ normal_group = list()
190
+ flip_group = list()
191
+ for i, img in enumerate(img_group):
192
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
193
+ normal_group.append(crop)
194
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
195
+
196
+ if img.mode == 'L' and i % 2 == 0:
197
+ flip_group.append(ImageOps.invert(flip_crop))
198
+ else:
199
+ flip_group.append(flip_crop)
200
+
201
+ oversample_group.extend(normal_group)
202
+ if self.flip:
203
+ oversample_group.extend(flip_group)
204
+ return oversample_group
205
+
206
+
207
+ class GroupFullResSample(object):
208
+ def __init__(self, crop_size, scale_size=None, flip=True):
209
+ self.crop_size = crop_size if not isinstance(
210
+ crop_size, int) else (crop_size, crop_size)
211
+
212
+ if scale_size is not None:
213
+ self.scale_worker = GroupScale(scale_size)
214
+ else:
215
+ self.scale_worker = None
216
+ self.flip = flip
217
+
218
+ def __call__(self, img_group):
219
+
220
+ if self.scale_worker is not None:
221
+ img_group = self.scale_worker(img_group)
222
+
223
+ image_w, image_h = img_group[0].size
224
+ crop_w, crop_h = self.crop_size
225
+
226
+ w_step = (image_w - crop_w) // 4
227
+ h_step = (image_h - crop_h) // 4
228
+
229
+ offsets = list()
230
+ offsets.append((0 * w_step, 2 * h_step)) # left
231
+ offsets.append((4 * w_step, 2 * h_step)) # right
232
+ offsets.append((2 * w_step, 2 * h_step)) # center
233
+
234
+ oversample_group = list()
235
+ for o_w, o_h in offsets:
236
+ normal_group = list()
237
+ flip_group = list()
238
+ for i, img in enumerate(img_group):
239
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
240
+ normal_group.append(crop)
241
+ if self.flip:
242
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
243
+
244
+ if img.mode == 'L' and i % 2 == 0:
245
+ flip_group.append(ImageOps.invert(flip_crop))
246
+ else:
247
+ flip_group.append(flip_crop)
248
+
249
+ oversample_group.extend(normal_group)
250
+ oversample_group.extend(flip_group)
251
+ return oversample_group
252
+
253
+
254
+ class GroupMultiScaleCrop(object):
255
+
256
+ def __init__(self, input_size, scales=None, max_distort=1,
257
+ fix_crop=True, more_fix_crop=True):
258
+ self.scales = scales if scales is not None else [1, .875, .75, .66]
259
+ self.max_distort = max_distort
260
+ self.fix_crop = fix_crop
261
+ self.more_fix_crop = more_fix_crop
262
+ self.input_size = input_size if not isinstance(input_size, int) else [
263
+ input_size, input_size]
264
+ self.interpolation = Image.BILINEAR
265
+
266
+ def __call__(self, img_group):
267
+
268
+ im_size = img_group[0].size
269
+
270
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
271
+ crop_img_group = [
272
+ img.crop(
273
+ (offset_w,
274
+ offset_h,
275
+ offset_w + crop_w,
276
+ offset_h + crop_h)) for img in img_group]
277
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
278
+ for img in crop_img_group]
279
+ return ret_img_group
280
+
281
+ def _sample_crop_size(self, im_size):
282
+ image_w, image_h = im_size[0], im_size[1]
283
+
284
+ # find a crop size
285
+ base_size = min(image_w, image_h)
286
+ crop_sizes = [int(base_size * x) for x in self.scales]
287
+ crop_h = [
288
+ self.input_size[1] if abs(
289
+ x - self.input_size[1]) < 3 else x for x in crop_sizes]
290
+ crop_w = [
291
+ self.input_size[0] if abs(
292
+ x - self.input_size[0]) < 3 else x for x in crop_sizes]
293
+
294
+ pairs = []
295
+ for i, h in enumerate(crop_h):
296
+ for j, w in enumerate(crop_w):
297
+ if abs(i - j) <= self.max_distort:
298
+ pairs.append((w, h))
299
+
300
+ crop_pair = random.choice(pairs)
301
+ if not self.fix_crop:
302
+ w_offset = random.randint(0, image_w - crop_pair[0])
303
+ h_offset = random.randint(0, image_h - crop_pair[1])
304
+ else:
305
+ w_offset, h_offset = self._sample_fix_offset(
306
+ image_w, image_h, crop_pair[0], crop_pair[1])
307
+
308
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
309
+
310
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
311
+ offsets = self.fill_fix_offset(
312
+ self.more_fix_crop, image_w, image_h, crop_w, crop_h)
313
+ return random.choice(offsets)
314
+
315
+ @staticmethod
316
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
317
+ w_step = (image_w - crop_w) // 4
318
+ h_step = (image_h - crop_h) // 4
319
+
320
+ ret = list()
321
+ ret.append((0, 0)) # upper left
322
+ ret.append((4 * w_step, 0)) # upper right
323
+ ret.append((0, 4 * h_step)) # lower left
324
+ ret.append((4 * w_step, 4 * h_step)) # lower right
325
+ ret.append((2 * w_step, 2 * h_step)) # center
326
+
327
+ if more_fix_crop:
328
+ ret.append((0, 2 * h_step)) # center left
329
+ ret.append((4 * w_step, 2 * h_step)) # center right
330
+ ret.append((2 * w_step, 4 * h_step)) # lower center
331
+ ret.append((2 * w_step, 0 * h_step)) # upper center
332
+
333
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
334
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
335
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
336
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
337
+
338
+ return ret
339
+
340
+
341
+ class GroupRandomSizedCrop(object):
342
+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
343
+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
344
+ This is popularly used to train the Inception networks
345
+ size: size of the smaller edge
346
+ interpolation: Default: PIL.Image.BILINEAR
347
+ """
348
+
349
+ def __init__(self, size, interpolation=Image.BILINEAR):
350
+ self.size = size
351
+ self.interpolation = interpolation
352
+
353
+ def __call__(self, img_group):
354
+ for attempt in range(10):
355
+ area = img_group[0].size[0] * img_group[0].size[1]
356
+ target_area = random.uniform(0.08, 1.0) * area
357
+ aspect_ratio = random.uniform(3. / 4, 4. / 3)
358
+
359
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
360
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
361
+
362
+ if random.random() < 0.5:
363
+ w, h = h, w
364
+
365
+ if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
366
+ x1 = random.randint(0, img_group[0].size[0] - w)
367
+ y1 = random.randint(0, img_group[0].size[1] - h)
368
+ found = True
369
+ break
370
+ else:
371
+ found = False
372
+ x1 = 0
373
+ y1 = 0
374
+
375
+ if found:
376
+ out_group = list()
377
+ for img in img_group:
378
+ img = img.crop((x1, y1, x1 + w, y1 + h))
379
+ assert (img.size == (w, h))
380
+ out_group.append(
381
+ img.resize(
382
+ (self.size, self.size), self.interpolation))
383
+ return out_group
384
+ else:
385
+ # Fallback
386
+ scale = GroupScale(self.size, interpolation=self.interpolation)
387
+ crop = GroupRandomCrop(self.size)
388
+ return crop(scale(img_group))
389
+
390
+
391
+ class ConvertDataFormat(object):
392
+ def __init__(self, model_type):
393
+ self.model_type = model_type
394
+
395
+ def __call__(self, images):
396
+ if self.model_type == '2D':
397
+ return images
398
+ tc, h, w = images.size()
399
+ t = tc // 3
400
+ images = images.view(t, 3, h, w)
401
+ images = images.permute(1, 0, 2, 3)
402
+ return images
403
+
404
+
405
+ class Stack(object):
406
+
407
+ def __init__(self, roll=False):
408
+ self.roll = roll
409
+
410
+ def __call__(self, img_group):
411
+ if img_group[0].mode == 'L':
412
+ return np.concatenate([np.expand_dims(x, 2)
413
+ for x in img_group], axis=2)
414
+ elif img_group[0].mode == 'RGB':
415
+ if self.roll:
416
+ return np.concatenate([np.array(x)[:, :, ::-1]
417
+ for x in img_group], axis=2)
418
+ else:
419
+ # print(np.concatenate(img_group, axis=2).shape)
420
+ # print(img_group[0].shape)
421
+ return np.concatenate(img_group, axis=2)
422
+
423
+
424
+ class ToTorchFormatTensor(object):
425
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
426
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
427
+
428
+ def __init__(self, div=True):
429
+ self.div = div
430
+
431
+ def __call__(self, pic):
432
+ if isinstance(pic, np.ndarray):
433
+ # handle numpy array
434
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
435
+ else:
436
+ # handle PIL Image
437
+ img = torch.ByteTensor(
438
+ torch.ByteStorage.from_buffer(
439
+ pic.tobytes()))
440
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
441
+ # put it from HWC to CHW format
442
+ # yikes, this transpose takes 80% of the loading time/CPU
443
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
444
+ return img.float().div(255) if self.div else img.float()
445
+
446
+
447
+ class IdentityTransform(object):
448
+
449
+ def __call__(self, data):
450
+ return data
eval_mm/vlmevalkit/vlmeval/dataset/utils/ocrbench.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...smp import *
2
+
3
+
4
+ def OCRBench_eval(eval_file):
5
+ OCRBench_score = {
6
+ 'Regular Text Recognition': 0,
7
+ 'Irregular Text Recognition': 0,
8
+ 'Artistic Text Recognition': 0,
9
+ 'Handwriting Recognition': 0,
10
+ 'Digit String Recognition': 0,
11
+ 'Non-Semantic Text Recognition': 0,
12
+ 'Scene Text-centric VQA': 0,
13
+ 'Doc-oriented VQA': 0,
14
+ 'Key Information Extraction': 0,
15
+ 'Handwritten Mathematical Expression Recognition': 0
16
+ }
17
+
18
+ logger = get_logger('Evaluation')
19
+
20
+ data = load(eval_file)
21
+ lt = len(data)
22
+ lines = [data.iloc[i] for i in range(lt)]
23
+ for i in tqdm(range(len(lines))):
24
+ line = lines[i]
25
+ predict = str(line['prediction'])
26
+ answers = eval(line['answer'])
27
+ category = line['category']
28
+ if category == 'Handwritten Mathematical Expression Recognition':
29
+ for j in range(len(answers)):
30
+ answer = answers[j].strip().replace('\n', ' ').replace(' ', '')
31
+ predict = predict.strip().replace('\n', ' ').replace(' ', '')
32
+ if answer in predict:
33
+ OCRBench_score[category] += 1
34
+ break
35
+ else:
36
+ for j in range(len(answers)):
37
+ answer = answers[j].lower().strip().replace('\n', ' ')
38
+ predict = predict.lower().strip().replace('\n', ' ')
39
+ if answer in predict:
40
+ OCRBench_score[category] += 1
41
+ break
42
+
43
+ final_score_dict = {}
44
+ final_score_dict['Text Recognition'] = (
45
+ OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition']
46
+ + OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition']
47
+ + OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition']
48
+ )
49
+ final_score_dict['Scene Text-centric VQA'] = OCRBench_score['Scene Text-centric VQA']
50
+ final_score_dict['Doc-oriented VQA'] = OCRBench_score['Doc-oriented VQA']
51
+ final_score_dict['Key Information Extraction'] = OCRBench_score['Key Information Extraction']
52
+ final_score_dict['Handwritten Mathematical Expression Recognition'] = \
53
+ OCRBench_score['Handwritten Mathematical Expression Recognition']
54
+ final_score_dict['Final Score'] = (
55
+ final_score_dict['Text Recognition'] + final_score_dict['Scene Text-centric VQA']
56
+ + final_score_dict['Doc-oriented VQA'] + final_score_dict['Key Information Extraction']
57
+ + final_score_dict['Handwritten Mathematical Expression Recognition']
58
+ )
59
+ final_score_dict['Final Score Norm'] = float(final_score_dict['Final Score']) / 10
60
+ score_pth = eval_file.replace('.xlsx', '_score.json')
61
+ dump(final_score_dict, score_pth)
62
+ logger.info(f'OCRBench_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
63
+ logger.info('Score: ')
64
+ for key, value in final_score_dict.items():
65
+ logger.info('{}:{}'.format(key, value))
eval_mm/vlmevalkit/vlmeval/dataset/utils/videomme.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...smp import *
2
+ import numpy as np
3
+ import re
4
+
5
+ FAIL_MSG = 'Failed to obtain answer via API.'
6
+
7
+ DURATIONS = [
8
+ 'short',
9
+ 'medium',
10
+ 'long',
11
+ ]
12
+
13
+ DOMAINS = [
14
+ 'Knowledge',
15
+ 'Film & Television',
16
+ 'Sports Competition',
17
+ 'Artistic Performance',
18
+ 'Life Record',
19
+ 'Multilingual'
20
+ ]
21
+
22
+ SUB_CATEGORIES = [
23
+ 'Humanity & History',
24
+ 'Literature & Art',
25
+ 'Biology & Medicine',
26
+ 'Finance & Commerce',
27
+ 'Astronomy',
28
+ 'Geography',
29
+ 'Law',
30
+ 'Life Tip',
31
+ 'Technology',
32
+ 'Animation',
33
+ 'Movie & TV Show',
34
+ 'Documentary',
35
+ 'News Report',
36
+ 'Esports',
37
+ 'Basketball',
38
+ 'Football',
39
+ 'Athletics',
40
+ 'Other Sports',
41
+ 'Stage Play',
42
+ 'Magic Show',
43
+ 'Variety Show',
44
+ 'Acrobatics',
45
+ 'Handicraft',
46
+ 'Food',
47
+ 'Fashion',
48
+ 'Daily Life',
49
+ 'Travel',
50
+ 'Pet & Animal',
51
+ 'Exercise',
52
+ 'Multilingual'
53
+ ]
54
+
55
+ TASK_CATEGORIES = [
56
+ 'Temporal Perception',
57
+ 'Spatial Perception',
58
+ 'Attribute Perception',
59
+ 'Action Recognition',
60
+ 'Object Recognition',
61
+ 'OCR Problems',
62
+ 'Counting Problem',
63
+ 'Temporal Reasoning',
64
+ 'Spatial Reasoning',
65
+ 'Action Reasoning',
66
+ 'Object Reasoning',
67
+ 'Information Synopsis',
68
+ ]
69
+
70
+
71
+ def get_dimension_rating(data_path):
72
+ data = load(data_path)
73
+
74
+ duration_rating = {k: {} for k in DURATIONS}
75
+ for duration in DURATIONS + ['overall']:
76
+ duration_rating[duration] = {
77
+ 'overall': '',
78
+ 'domain': {k: [] for k in DOMAINS},
79
+ 'sub_category': {k: [] for k in SUB_CATEGORIES},
80
+ 'task_type': {k: [] for k in TASK_CATEGORIES}
81
+ }
82
+
83
+ for i in range(len(data)):
84
+
85
+ domain = data.iloc[i]['domain']
86
+ sub_ctg = data.iloc[i]['sub_category']
87
+ task_ctg = data.iloc[i]['task_type']
88
+
89
+ duration = data.iloc[i]['duration']
90
+ duration_rating[duration]['domain'][domain].append(data.iloc[i]['score'])
91
+ duration_rating[duration]['sub_category'][sub_ctg].append(data.iloc[i]['score'])
92
+ duration_rating[duration]['task_type'][task_ctg].append(data.iloc[i]['score'])
93
+
94
+ duration_rating['overall']['domain'][domain].append(data.iloc[i]['score'])
95
+ duration_rating['overall']['sub_category'][sub_ctg].append(data.iloc[i]['score'])
96
+ duration_rating['overall']['task_type'][task_ctg].append(data.iloc[i]['score'])
97
+
98
+ for duration in DURATIONS + ['overall']:
99
+
100
+ overall_res_dur = f'{np.mean([x for x in sum(duration_rating[duration]["domain"].values(), []) if x >= 0]):.2f}'
101
+ duration_rating[duration]['overall'] = overall_res_dur
102
+
103
+ for domain in DOMAINS:
104
+ domain_res_dur = f'{np.mean([x for x in duration_rating[duration]["domain"][domain] if x >= 0]):.2f}'
105
+ duration_rating[duration]['domain'][domain] = domain_res_dur
106
+
107
+ for sub_ctg in SUB_CATEGORIES:
108
+ sub_res_dur = f'{np.mean([x for x in duration_rating[duration]["sub_category"][sub_ctg] if x >= 0]):.2f}'
109
+ duration_rating[duration]['sub_category'][sub_ctg] = sub_res_dur
110
+
111
+ for task_ctg in TASK_CATEGORIES:
112
+ task_res_dur = f'{np.mean([x for x in duration_rating[duration]["task_type"][task_ctg] if x >= 0]):.2f}'
113
+ duration_rating[duration]['task_type'][task_ctg] = task_res_dur
114
+
115
+ return duration_rating
116
+
117
+
118
+ def extract_characters_regex(s):
119
+ s = s.strip()
120
+ answer_prefixes = [
121
+ 'The best answer is',
122
+ 'The correct answer is',
123
+ 'The answer is',
124
+ 'The answer',
125
+ 'The best option is'
126
+ 'The correct option is',
127
+ 'Best answer:'
128
+ 'Best option:',
129
+ 'Answer:',
130
+ 'Option:',
131
+ ]
132
+ for answer_prefix in answer_prefixes:
133
+ s = s.replace(answer_prefix, '')
134
+
135
+ if len(s.split()) > 10 and not re.search('[ABCD]', s):
136
+ return ''
137
+ matches = re.search(r'[ABCD]', s)
138
+ if matches is None:
139
+ return ''
140
+ return matches[0]
eval_mm/vlmevalkit/vlmeval/dataset/utils/vqa_eval.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ # Partly adopted from https://github.com/GT-Vision-Lab/VQA
3
+ # Copyright (c) 2014, Aishwarya Agrawal
4
+
5
+ from ...smp import *
6
+ from typing import Optional
7
+
8
+
9
+ def _process_digit_article(inText):
10
+ outText = []
11
+ tempText = inText.lower().split()
12
+ articles = ['a', 'an', 'the']
13
+ manualMap = {
14
+ 'none': '0',
15
+ 'zero': '0',
16
+ 'one': '1',
17
+ 'two': '2',
18
+ 'three': '3',
19
+ 'four': '4',
20
+ 'five': '5',
21
+ 'six': '6',
22
+ 'seven': '7',
23
+ 'eight': '8',
24
+ 'nine': '9',
25
+ 'ten': '10',
26
+ }
27
+ contractions = {
28
+ 'aint': "ain't",
29
+ 'arent': "aren't",
30
+ 'cant': "can't",
31
+ 'couldve': "could've",
32
+ 'couldnt': "couldn't",
33
+ "couldn'tve": "couldn't've",
34
+ "couldnt've": "couldn't've",
35
+ 'didnt': "didn't",
36
+ 'doesnt': "doesn't",
37
+ 'dont': "don't",
38
+ 'hadnt': "hadn't",
39
+ "hadnt've": "hadn't've",
40
+ "hadn'tve": "hadn't've",
41
+ 'hasnt': "hasn't",
42
+ 'havent': "haven't",
43
+ 'hed': "he'd",
44
+ "hed've": "he'd've",
45
+ "he'dve": "he'd've",
46
+ 'hes': "he's",
47
+ 'howd': "how'd",
48
+ 'howll': "how'll",
49
+ 'hows': "how's",
50
+ "Id've": "I'd've",
51
+ "I'dve": "I'd've",
52
+ 'Im': "I'm",
53
+ 'Ive': "I've",
54
+ 'isnt': "isn't",
55
+ 'itd': "it'd",
56
+ "itd've": "it'd've",
57
+ "it'dve": "it'd've",
58
+ 'itll': "it'll",
59
+ "let's": "let's",
60
+ 'maam': "ma'am",
61
+ 'mightnt': "mightn't",
62
+ "mightnt've": "mightn't've",
63
+ "mightn'tve": "mightn't've",
64
+ 'mightve': "might've",
65
+ 'mustnt': "mustn't",
66
+ 'mustve': "must've",
67
+ 'neednt': "needn't",
68
+ 'notve': "not've",
69
+ 'oclock': "o'clock",
70
+ 'oughtnt': "oughtn't",
71
+ "ow's'at": "'ow's'at",
72
+ "'ows'at": "'ow's'at",
73
+ "'ow'sat": "'ow's'at",
74
+ 'shant': "shan't",
75
+ "shed've": "she'd've",
76
+ "she'dve": "she'd've",
77
+ "she's": "she's",
78
+ 'shouldve': "should've",
79
+ 'shouldnt': "shouldn't",
80
+ "shouldnt've": "shouldn't've",
81
+ "shouldn'tve": "shouldn't've",
82
+ "somebody'd": 'somebodyd',
83
+ "somebodyd've": "somebody'd've",
84
+ "somebody'dve": "somebody'd've",
85
+ 'somebodyll': "somebody'll",
86
+ 'somebodys': "somebody's",
87
+ 'someoned': "someone'd",
88
+ "someoned've": "someone'd've",
89
+ "someone'dve": "someone'd've",
90
+ 'someonell': "someone'll",
91
+ 'someones': "someone's",
92
+ 'somethingd': "something'd",
93
+ "somethingd've": "something'd've",
94
+ "something'dve": "something'd've",
95
+ 'somethingll': "something'll",
96
+ 'thats': "that's",
97
+ 'thered': "there'd",
98
+ "thered've": "there'd've",
99
+ "there'dve": "there'd've",
100
+ 'therere': "there're",
101
+ 'theres': "there's",
102
+ 'theyd': "they'd",
103
+ "theyd've": "they'd've",
104
+ "they'dve": "they'd've",
105
+ 'theyll': "they'll",
106
+ 'theyre': "they're",
107
+ 'theyve': "they've",
108
+ 'twas': "'twas",
109
+ 'wasnt': "wasn't",
110
+ "wed've": "we'd've",
111
+ "we'dve": "we'd've",
112
+ 'weve': "we've",
113
+ 'werent': "weren't",
114
+ 'whatll': "what'll",
115
+ 'whatre': "what're",
116
+ 'whats': "what's",
117
+ 'whatve': "what've",
118
+ 'whens': "when's",
119
+ 'whered': "where'd",
120
+ 'wheres': "where's",
121
+ 'whereve': "where've",
122
+ 'whod': "who'd",
123
+ "whod've": "who'd've",
124
+ "who'dve": "who'd've",
125
+ 'wholl': "who'll",
126
+ 'whos': "who's",
127
+ 'whove': "who've",
128
+ 'whyll': "why'll",
129
+ 'whyre': "why're",
130
+ 'whys': "why's",
131
+ 'wont': "won't",
132
+ 'wouldve': "would've",
133
+ 'wouldnt': "wouldn't",
134
+ "wouldnt've": "wouldn't've",
135
+ "wouldn'tve": "wouldn't've",
136
+ 'yall': "y'all",
137
+ "yall'll": "y'all'll",
138
+ "y'allll": "y'all'll",
139
+ "yall'd've": "y'all'd've",
140
+ "y'alld've": "y'all'd've",
141
+ "y'all'dve": "y'all'd've",
142
+ 'youd': "you'd",
143
+ "youd've": "you'd've",
144
+ "you'dve": "you'd've",
145
+ 'youll': "you'll",
146
+ 'youre': "you're",
147
+ 'youve': "you've",
148
+ }
149
+ for word in tempText:
150
+ word = manualMap.setdefault(word, word)
151
+ if word not in articles:
152
+ outText.append(word)
153
+ for wordId, word in enumerate(outText):
154
+ if word in contractions:
155
+ outText[wordId] = contractions[word]
156
+ outText = ' '.join(outText)
157
+ return outText
158
+
159
+
160
+ def hit_calculate(result, dataset_name, anls_threshold=0.5):
161
+ if listinstr(['TextVQA'], dataset_name):
162
+ return [np.mean(x['match']) for x in result]
163
+ elif listinstr(['DocVQA', 'InfoVQA'], dataset_name):
164
+ return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result]
165
+ elif listinstr(['ChartQA', 'OCRVQA'], dataset_name):
166
+ return [np.max(x['match']) for x in result]
167
+ else: # default using vqa_score to calculate score
168
+ return [np.mean(x['match']) for x in result]
169
+
170
+
171
+ # https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
172
+ def relaxed_correctness(target: str,
173
+ prediction: str,
174
+ max_relative_change: float = 0.05) -> bool:
175
+ """Calculates relaxed correctness.
176
+
177
+ The correctness tolerates certain error ratio defined by max_relative_change.
178
+ See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
179
+ “Following Methani et al. (2020), we use a relaxed accuracy measure for the
180
+ numeric answers to allow a minor inaccuracy that may result from the automatic
181
+ data extraction process. We consider an answer to be correct if it is within
182
+ 5% of the gold answer. For non-numeric answers, we still need an exact match
183
+ to consider an answer to be correct.”
184
+
185
+ Args:
186
+ target: Target string.
187
+ prediction: Predicted string.
188
+ max_relative_change: Maximum relative change.
189
+
190
+ Returns:
191
+ Whether the prediction was correct given the specified tolerance.
192
+ """
193
+
194
+ def _to_float(text: str) -> Optional[float]:
195
+ try:
196
+ if text.endswith('%'):
197
+ # Convert percentages to floats.
198
+ return float(text.rstrip('%')) / 100.0
199
+ else:
200
+ return float(text)
201
+ except ValueError:
202
+ return None
203
+ prediction = str(prediction)
204
+ target = str(target)
205
+ prediction_float = _to_float(prediction)
206
+ target_float = _to_float(target)
207
+ if prediction_float is not None and target_float:
208
+ relative_change = abs(prediction_float - target_float) / abs(target_float)
209
+ return relative_change <= max_relative_change
210
+ else:
211
+ return prediction.lower() == target.lower()
212
+
213
+
214
+ def levenshtein_distance(s1, s2):
215
+ if len(s1) > len(s2):
216
+ s1, s2 = s2, s1
217
+
218
+ distances = range(len(s1) + 1)
219
+ for i2, c2 in enumerate(s2):
220
+ distances_ = [i2 + 1]
221
+ for i1, c1 in enumerate(s1):
222
+ if c1 == c2:
223
+ distances_.append(distances[i1])
224
+ else:
225
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
226
+ distances = distances_
227
+ return distances[-1]
228
+
229
+
230
+ def anls_compute(groundtruth, prediction):
231
+ gt_answer = ' '.join(groundtruth.strip().lower().split())
232
+ det_answer = ' '.join(prediction.strip().lower().split())
233
+ dist = levenshtein_distance(gt_answer, det_answer)
234
+ length = max(len(groundtruth.upper()), len(prediction.upper()))
235
+ values = 0.0 if length == 0 else float(dist) / float(length)
236
+ return values
237
+
238
+
239
+ def process_answer(answer):
240
+ answer = answer.replace('\n', ' ')
241
+ answer = answer.replace('\t', ' ')
242
+ answer = answer.strip()
243
+ answer = process_punctuation(answer)
244
+ answer = _process_digit_article(answer)
245
+ return answer
246
+
247
+
248
+ def process_line(line, method='vqa_score'):
249
+ ret = {}
250
+ if istype(line['answer'], list):
251
+ answers = eval(line['answer'])
252
+ else:
253
+ answers = [line['answer']]
254
+ if method == 'vqa_score':
255
+ ret['gt'] = [process_answer(x) for x in answers]
256
+ ret['pred'] = process_answer(line['prediction'])
257
+ ret['match'] = []
258
+ for current_idx, gtAnsDatum in enumerate(ret['gt']):
259
+ otherGTAns = [
260
+ item for ret_gt_idx, item in enumerate(ret['gt'])
261
+ if ret_gt_idx != current_idx
262
+ ]
263
+ matchingAns = [
264
+ item for item in otherGTAns if item == ret['pred']
265
+ ]
266
+ acc = min(1, float(len(matchingAns)) / 3)
267
+ ret['match'].append(acc)
268
+ elif method == 'anls':
269
+ ret['gt'] = answers
270
+ ret['pred'] = line['prediction']
271
+ ret['match'] = [anls_compute(x, ret['pred']) for x in ret['gt']]
272
+ elif method == 'relaxed_accuracy':
273
+ ret['gt'] = answers
274
+ ret['pred'] = line['prediction'].strip()
275
+ ret['match'] = [relaxed_correctness(ret['pred'], x) for x in ret['gt']]
276
+ elif method == 'accuracy':
277
+ ret['gt'] = answers
278
+ ret['pred'] = line['prediction'].strip()
279
+ ret['match'] = [(1.0 if (x.strip().lower() == ret['pred'].strip().lower()) else 0.0) for x in ret['gt']]
280
+ else: # default using vqa_score to calculate score
281
+ ret['gt'] = [process_answer(x) for x in answers]
282
+ ret['pred'] = process_answer(line['prediction'])
283
+ ret['match'] = [x == ret['pred'] for x in ret['gt']]
284
+
285
+ return ret
eval_mm/vlmevalkit/vlmeval/dataset/utils/yorn.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...smp import *
2
+
3
+
4
+ def MME_rating(data_file):
5
+ data = load(data_file)
6
+ stats = defaultdict(dict)
7
+ lt = len(data)
8
+ for i in range(lt):
9
+ item = data.iloc[i]
10
+ category = item['category']
11
+ image_path = item['image_path']
12
+ score = item['score']
13
+ if image_path not in stats[category]:
14
+ stats[category][image_path] = []
15
+ stats[category][image_path].append(score)
16
+
17
+ def acc(key, mode='normal'):
18
+ res = stats[key]
19
+ values = []
20
+ for val in res.values():
21
+ if mode == 'normal':
22
+ values.extend(val)
23
+ elif mode == 'plus':
24
+ values.append(val[0] * val[1])
25
+ return np.mean(values) * 100
26
+
27
+ scores = {}
28
+ for k in stats:
29
+ scores[k] = acc(k) + acc(k, 'plus')
30
+
31
+ super_cates = dict(
32
+ perception=[
33
+ 'OCR', 'artwork', 'celebrity', 'color', 'count', 'existence',
34
+ 'landmark', 'position', 'posters', 'scene'
35
+ ],
36
+ reasoning=['code_reasoning', 'commonsense_reasoning', 'numerical_calculation', 'text_translation']
37
+ )
38
+
39
+ ret = {}
40
+ for sc, cate_list in super_cates.items():
41
+ base = 0
42
+ for c in cate_list:
43
+ base += scores[c]
44
+ ret[sc] = base
45
+ ret.update(scores)
46
+ ret = d2df(ret)
47
+ return ret
48
+
49
+
50
+ def Hallusion_rating(data_file):
51
+ def calc_fAcc(data):
52
+ res = defaultdict(list)
53
+ lt = len(data)
54
+ for i in range(lt):
55
+ line = data.iloc[i]
56
+ res[f"{line['l2-category']}_{line['set_id']}_{line['figure_id']}"].append(line['score'])
57
+ return np.mean([np.all(x) for x in res.values()]) * 100
58
+
59
+ def calc_qAcc(data):
60
+ res = defaultdict(list)
61
+ lt = len(data)
62
+ for i in range(lt):
63
+ line = data.iloc[i]
64
+ res[f"{line['l2-category']}_{line['set_id']}_{line['question_id']}"].append(line['score'])
65
+ return np.mean([np.all(x) for x in res.values()]) * 100
66
+
67
+ def calc_aAcc(data):
68
+ return np.mean(data['score']) * 100
69
+
70
+ data = load(data_file)
71
+ data['set_id'] = [x.split('_')[3] for x in data['index']]
72
+ data['figure_id'] = [x.split('_')[4] for x in data['index']]
73
+ data['question_id'] = [x.split('_')[5] for x in data['index']]
74
+
75
+ res = dict(split=[], aAcc=[], fAcc=[], qAcc=[])
76
+ res['split'].append('Overall')
77
+ res['aAcc'].append(calc_aAcc(data))
78
+ res['fAcc'].append(calc_fAcc(data))
79
+ res['qAcc'].append(calc_qAcc(data))
80
+
81
+ if 'category' in data:
82
+ cates = list(set(data['category']))
83
+ for c in cates:
84
+ sub = data[data['category'] == c]
85
+ res['split'].append(c)
86
+ res['aAcc'].append(calc_aAcc(sub))
87
+ res['fAcc'].append(calc_fAcc(sub))
88
+ res['qAcc'].append(calc_qAcc(sub))
89
+
90
+ if 'l2-category' in data:
91
+ cates = list(set(data['l2-category']))
92
+ for c in cates:
93
+ sub = data[data['l2-category'] == c]
94
+ res['split'].append(c)
95
+ res['aAcc'].append(calc_aAcc(sub))
96
+ res['fAcc'].append(calc_fAcc(sub))
97
+ res['qAcc'].append(calc_qAcc(sub))
98
+ ret = pd.DataFrame(res)
99
+ return ret
100
+
101
+
102
+ def POPE_rating(data_file):
103
+ def cal_f1_score(y_true, y_pred):
104
+ tp = sum((y_true == 1) & (y_pred == 1))
105
+ fp = sum((y_true == 0) & (y_pred == 1))
106
+ fn = sum((y_true == 1) & (y_pred == 0))
107
+
108
+ precision = tp / (tp + fp) if (tp + fp) != 0 else 0
109
+ recall = tp / (tp + fn) if (tp + fn) != 0 else 0
110
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
111
+ return f1_score, precision, recall
112
+
113
+ data = load(data_file)
114
+ data = data.assign(category=data['category'].str.split(',')).explode('category')
115
+ data['index'] = range(len(data))
116
+ res = dict(split=[], Overall=[], acc=[], precision=[], recall=[])
117
+ y_true = np.array([1 if i == 'Yes' else 0 for i in data['answer']])
118
+ y_pred = np.array([1 if i == 'Yes' else 0 for i in data['extracted']])
119
+ f1_score, precision, recall = cal_f1_score(y_true, y_pred)
120
+ res['split'].append('Overall')
121
+ res['Overall'].append(f1_score * 100)
122
+ res['acc'].append(np.mean(data['score']) * 100)
123
+ res['precision'].append(precision * 100)
124
+ res['recall'].append(recall * 100)
125
+
126
+ if 'category' in data:
127
+ cates = list(set(data['category']))
128
+ cates = [c for c in cates if not pd.isna(c)]
129
+ for c in cates:
130
+ sub = data[data['category'] == c]
131
+ y_true = np.array([1 if i == 'Yes' else 0 for i in sub['answer']])
132
+ y_pred = np.array([1 if i == 'Yes' else 0 for i in sub['extracted']])
133
+ f1_score, precision, recall = cal_f1_score(y_true, y_pred)
134
+ res['split'].append(c)
135
+ res['Overall'].append(f1_score * 100)
136
+ res['acc'].append(np.mean(sub['score']) * 100)
137
+ res['precision'].append(precision * 100)
138
+ res['recall'].append(recall * 100)
139
+
140
+ ret = pd.DataFrame(res)
141
+ return ret
142
+
143
+
144
+ def default_rating(data_file):
145
+ data = load(data_file)
146
+ res = {}
147
+ res['Overall'] = np.mean(data['score']) * 100
148
+ if 'category' in data:
149
+ cates = list(set(data['category']))
150
+ cates = [c for c in cates if not pd.isna(c)]
151
+ cates.sort()
152
+ for c in cates:
153
+ sub = data[data['category'] == c]
154
+ res[c] = np.mean(sub['score']) * 100
155
+ if 'l2-category' in data:
156
+ cates = list(set(data['l2-category']))
157
+ cates = [c for c in cates if not pd.isna(c)]
158
+ cates.sort()
159
+ for c in cates:
160
+ sub = data[data['l2-category'] == c]
161
+ res[c] = np.mean(sub['score']) * 100
162
+ ret = d2df(res)
163
+ return ret
164
+
165
+
166
+ def YOrN_match_prompt(line):
167
+ tmpl = (
168
+ 'You are an AI assistant who will help me to match an answer with two options of a question. '
169
+ 'The options are only Yes / No. '
170
+ 'You are provided with a question and an answer, '
171
+ 'and you need to find which option (Yes / No) is most similar to the answer. '
172
+ 'If the meaning of all options are significantly different from the answer, output Unknown. '
173
+ 'Your should output a single word among the following 3 choices: Yes, No, Unknown.\n'
174
+ 'Example 1: \n'
175
+ "Question: Is the word in this image 'Hello'?\nAnswer: The word in this image is 'Hello'.\nYour output: Yes\n"
176
+ 'Example 2: \n'
177
+ "Question: Is the word in this image 'Hello'?\n"
178
+ "Answer: The word in this image is not 'Hello'.\nYour output: No\n"
179
+ 'Example 3: \n'
180
+ 'Question: {}?\nAnswer: {}\nYour output: '
181
+ )
182
+ return tmpl.format(line['question'], line['prediction'])
183
+
184
+
185
+ def YOrN_Extraction(output):
186
+ s = output.lower()
187
+ words = process_punctuation(s).split()
188
+ if 'yes' in words and 'no' not in words:
189
+ return 'Yes'
190
+ if 'yes' not in words and 'no' in words:
191
+ return 'No'
192
+ return 'Unknown'
193
+
194
+
195
+ def YOrN_auxeval(model, line):
196
+ prompt = YOrN_match_prompt(line)
197
+ retry = 5
198
+ for i in range(retry):
199
+ output = model.generate(prompt, temperature=0.5 * i)
200
+ ans = YOrN_Extraction(output)
201
+ if ans != 'Unknown':
202
+ return ans
203
+ return 'Unknown'
eval_mm/vlmevalkit/vlmeval/dataset/vcr.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from functools import partial
3
+ from .image_base import ImageBaseDataset
4
+ from ..smp import *
5
+
6
+ rouge = None
7
+ nlp_en = None
8
+ nlp_zh = None
9
+ nlp = None
10
+
11
+
12
+ def initialize():
13
+ import evaluate
14
+ import spacy
15
+
16
+ global rouge, nlp_en, nlp_zh, nlp
17
+
18
+ try:
19
+ rouge = evaluate.load('rouge', experiment_id=str(uuid.uuid4()))
20
+ except:
21
+ warnings.warn('Please first `pip install rouge_score`.')
22
+
23
+ try:
24
+ nlp_en = spacy.load('en_core_web_sm')
25
+ except:
26
+ warnings.warn('Will automatically download en_core_web_sm via spacy.')
27
+ spacy.cli.download('en_core_web_sm')
28
+ nlp_en = spacy.load('en_core_web_sm')
29
+
30
+ try:
31
+ nlp_zh = spacy.load('zh_core_web_sm')
32
+ except:
33
+ warnings.warn('Will automatically download zh_core_web_sm via spacy.')
34
+ spacy.cli.download('zh_core_web_sm')
35
+ nlp_zh = spacy.load('zh_core_web_sm')
36
+
37
+ nlp = {'en': nlp_en, 'zh': nlp_zh}
38
+
39
+
40
+ def rough_filter(answer_text):
41
+ if "I can't" in answer_text:
42
+ return False
43
+ elif 'I cannot' in answer_text:
44
+ return False
45
+ elif 'sorry' in answer_text.lower():
46
+ return False
47
+ if '无法' in answer_text:
48
+ return False
49
+ elif '抱歉' in answer_text:
50
+ return False
51
+ else:
52
+ return True
53
+
54
+
55
+ def zero_template(crossed_text):
56
+ return {
57
+ 'crossed_text': crossed_text,
58
+ 'max_sim_val': 0,
59
+ 'max_sim_string': '',
60
+ 'precision': 0,
61
+ 'recall': 0,
62
+ 'f1': 0,
63
+ 'jaccard': 0,
64
+ 'rouge1': 0,
65
+ 'exact_match': 0,
66
+ }
67
+
68
+
69
+ def tokenize(text, language):
70
+ """
71
+ Tokenize the text and return the tokens.
72
+
73
+ Parameters:
74
+ text (str): The text to tokenize.
75
+ language (str): The language of the text.
76
+
77
+ Returns:
78
+ list: The list of tokens.
79
+ """
80
+ assert language in ['en', 'zh']
81
+ nlp_language = nlp[language]
82
+ processed_text = nlp_language(text)
83
+ return [token.text for token in processed_text]
84
+
85
+
86
+ def find_best_match(needle, hay, language, rouge):
87
+ """
88
+ Finds the best matching n-gram in the haystack for the given needle.
89
+
90
+ Parameters:
91
+ needle (str): The string to find.
92
+ hay (str): The text to search within.
93
+
94
+ Returns:
95
+ tuple: The highest similarity value and the best matching string.
96
+ """
97
+ assert language in ['en', 'zh']
98
+ from nltk.util import ngrams
99
+ from difflib import SequenceMatcher as SM
100
+
101
+ tokens_hay = tokenize(hay, language)
102
+ tokens_needle = tokenize(needle, language)
103
+
104
+ splitter = '' if language == 'zh' else ' '
105
+ ngrams_ = ngrams(tokens_hay, len(tokens_needle))
106
+ max_sim_val = 0
107
+ max_sim_string = ''
108
+ max_sim_ngram = []
109
+ tokens_needle_set = set(tokens_needle)
110
+ ngrams_hasjoint = [
111
+ ngram
112
+ for ngram in ngrams_
113
+ if not set(ngram).isdisjoint(tokens_needle_set)
114
+ ]
115
+
116
+ for ngram in ngrams_hasjoint:
117
+ hay_ngram = splitter.join(ngram)
118
+ similarity = SM(None, hay_ngram, needle).ratio()
119
+ if similarity > max_sim_val:
120
+ max_sim_val = similarity
121
+ max_sim_string = hay_ngram
122
+ max_sim_ngram = ngram
123
+
124
+ # Evaluate
125
+ if len(max_sim_ngram) == 0:
126
+ return {
127
+ 'crossed_text': needle,
128
+ 'max_sim_val': 0,
129
+ 'max_sim_string': '',
130
+ 'precision': 0,
131
+ 'recall': 0,
132
+ 'f1': 0,
133
+ 'jaccard': 0,
134
+ 'rouge1': 0,
135
+ 'exact_match': 0,
136
+ }
137
+ pred_set = set(max_sim_ngram)
138
+ ref_set = set(tokens_needle)
139
+ correct_tokens = pred_set.intersection(ref_set)
140
+ len_correct_tokens = len(correct_tokens)
141
+
142
+ precision = len_correct_tokens / len(pred_set)
143
+ recall = len_correct_tokens / len(ref_set)
144
+ if (precision + recall) == 0:
145
+ f1 = 0
146
+ else:
147
+ f1 = 2 * precision * recall / (precision + recall)
148
+ union = pred_set.union(ref_set)
149
+ jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0
150
+ rouge_1 = rouge.compute(
151
+ predictions=[max_sim_string],
152
+ references=[needle],
153
+ tokenizer=partial(tokenize, language=language),
154
+ rouge_types=['rouge1'],
155
+ )['rouge1']
156
+ exact_match = float(list(max_sim_ngram) == list(tokens_needle))
157
+ out = {
158
+ 'crossed_text': needle,
159
+ 'max_sim_string': max_sim_string,
160
+ 'max_sim_val': max_sim_val,
161
+ 'precision': precision,
162
+ 'recall': recall,
163
+ 'f1': f1,
164
+ 'jaccard': jaccard,
165
+ 'rouge1': rouge_1,
166
+ 'exact_match': exact_match,
167
+ }
168
+ return out
169
+
170
+
171
+ def process_match_single_new(
172
+ image_id, prediction, answer, language, progress
173
+ ):
174
+ """
175
+ process the inference results for a single image and calculate the metrics
176
+
177
+ Parameters:
178
+ image_id (int): The image id (question id).
179
+ prediction (str): The prediction text.
180
+ answer (Union[str, List[str]]): The answer text, or a list of answer texts. The masked n-grams in the image.
181
+ language (str): The language of the text. Can be "en" or "zh".
182
+ rouge (rouge): The rouge metric object.
183
+ progress (multiprocessing.Queue): The progress queue.
184
+
185
+ Returns:
186
+ tuple: The image id (question_id, int) and the result per id (dict of dict of dict).
187
+ """
188
+ result_per_id = {image_id: {}}
189
+ if isinstance(answer, str):
190
+ answer = eval(answer)
191
+ assert isinstance(answer, list)
192
+ result = prediction.split('Assistant: ')[-1]
193
+ for i, crossed_text in enumerate(answer):
194
+ if rough_filter(result):
195
+ find_best_match_result = find_best_match(
196
+ crossed_text, result, language, rouge
197
+ )
198
+ if i == 0:
199
+ result_per_id[image_id] = {str(i): find_best_match_result}
200
+ else:
201
+ result_per_id[image_id][str(i)] = find_best_match_result
202
+ else:
203
+ if i == 0:
204
+ result_per_id[image_id] = {str(i): zero_template(crossed_text)}
205
+ else:
206
+ result_per_id[image_id][str(i)] = zero_template(crossed_text)
207
+ progress.put(1)
208
+ return image_id, result_per_id
209
+
210
+
211
+ class VCRDataset(ImageBaseDataset):
212
+ TYPE = 'VQA'
213
+
214
+ URL_PREFIX = 'https://huggingface.co/datasets/vcr-org'
215
+
216
+ DATASET_URL = {
217
+ 'VCR_EN_EASY_500': f'{URL_PREFIX}/VCR-wiki-en-easy-test-500/resolve/main/VCR-wiki-en-easy-test-500.tsv',
218
+ 'VCR_EN_EASY_100': f'{URL_PREFIX}/VCR-wiki-en-easy-test-100/resolve/main/VCR-wiki-en-easy-test-100.tsv',
219
+ 'VCR_EN_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-en-easy-test/resolve/main/VCR-wiki-en-easy-test.tsv',
220
+ 'VCR_EN_HARD_500': f'{URL_PREFIX}/VCR-wiki-en-hard-test-500/resolve/main/VCR-wiki-en-hard-test-500.tsv',
221
+ 'VCR_EN_HARD_100': f'{URL_PREFIX}/VCR-wiki-en-hard-test-100/resolve/main/VCR-wiki-en-hard-test-100.tsv',
222
+ 'VCR_EN_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-en-hard-test/resolve/main/VCR-wiki-en-hard-test.tsv',
223
+ 'VCR_ZH_EASY_500': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-500/resolve/main/VCR-wiki-zh-easy-test-500.tsv',
224
+ 'VCR_ZH_EASY_100': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-100/resolve/main/VCR-wiki-zh-easy-test-100.tsv',
225
+ 'VCR_ZH_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-zh-easy-test/resolve/main/VCR-wiki-zh-easy-test.tsv',
226
+ 'VCR_ZH_HARD_500': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-500/resolve/main/VCR-wiki-zh-hard-test-500.tsv',
227
+ 'VCR_ZH_HARD_100': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-100/resolve/main/VCR-wiki-zh-hard-test-100.tsv',
228
+ 'VCR_ZH_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-zh-hard-test/resolve/main/VCR-wiki-zh-hard-test.tsv',
229
+ }
230
+
231
+ DATASET_MD5 = {
232
+ 'VCR_EN_EASY_500': 'fd9258db52f8685dc710619a0ea0a261',
233
+ 'VCR_EN_EASY_100': '9df5d7266683458621ecbe122beb72f0',
234
+ 'VCR_EN_EASY_ALL': '8a9b96885f251d1c85f42f84073327f1',
235
+ 'VCR_EN_HARD_500': '0a22a85080b6a1f52b1f95e302d43df4',
236
+ 'VCR_EN_HARD_100': '1b20f5cbcbeae0b0bec77f7a36143958',
237
+ 'VCR_EN_HARD_ALL': '2d8b8b1ee0eba0e0b618fd3aa7d9710e',
238
+ 'VCR_ZH_EASY_500': 'beca5fd54176adf44cf94bd9b50cf048',
239
+ 'VCR_ZH_EASY_100': '4a86a5678a79844d6d22ab0629c51cd5',
240
+ 'VCR_ZH_EASY_ALL': '5050fe7f0027ad2068fd4c7f220edaea',
241
+ 'VCR_ZH_HARD_500': '617e3360f75c54455625cb0a8da5c1e7',
242
+ 'VCR_ZH_HARD_100': 'b0e38c85f5d5e63894a3b881c372a62b',
243
+ 'VCR_ZH_HARD_ALL': '54bbfef448206518b03127ef8b61404c',
244
+ }
245
+
246
+ def __init__(self, dataset='VCR_EN_EASY_500', skip_noimg=True):
247
+ super().__init__(dataset, skip_noimg)
248
+
249
+ initialize()
250
+ self.language = 'en' if 'EN' in dataset else 'zh'
251
+ self.difficulty = 'easy' if 'EASY' in dataset else 'hard'
252
+
253
+ # def build_prompt(self, line):
254
+ # msgs = super().build_prompt(line)
255
+ # assert msgs[-1]['type'] == 'text'
256
+ # if self.language == 'zh':
257
+ # msgs[-1]['value'] += '图像中被覆盖的文本是什么?请在不输出解释的情况下还原被覆盖的文本。'
258
+ # else:
259
+ # msgs[-1]['value'] += ('What is the covered texts in the image? '
260
+ # 'Please restore the covered texts without outputting the explanations.')
261
+ # return msgs
262
+
263
+ def evaluate(self, eval_file, **judge_kwargs):
264
+ import multiprocessing
265
+
266
+ vcr_score_list = {'Exact_Match': [], 'Jaccard': []}
267
+ vcr_score = {'Exact_Match': 0, 'Jaccard': 0}
268
+ logger = get_logger('Evaluation')
269
+ data = load(eval_file)
270
+
271
+ lt = len(data)
272
+ lines = [data.iloc[i] for i in range(lt)]
273
+
274
+ pool = multiprocessing.Pool()
275
+ manager = multiprocessing.Manager()
276
+ progress_queue = manager.Queue()
277
+ results = []
278
+
279
+ overall_results = {str(image_id): {} for image_id in range(len(lines))}
280
+
281
+ for instance_id, instance in enumerate(lines):
282
+ results.append(
283
+ pool.apply_async(
284
+ process_match_single_new,
285
+ args=(
286
+ str(instance_id),
287
+ instance['prediction'],
288
+ instance['answer'],
289
+ self.language,
290
+ progress_queue,
291
+ ),
292
+ )
293
+ )
294
+ pool.close()
295
+
296
+ # Display progress bar
297
+ for _ in tqdm(range(len(results))):
298
+ progress_queue.get()
299
+
300
+ pool.join()
301
+
302
+ # Merging results into overall_result
303
+ for result in results:
304
+ image_id, result_per_id = result.get()
305
+ overall_results[str(image_id)].update(result_per_id[image_id])
306
+ for blank_id_str in result_per_id[image_id].keys():
307
+ vcr_score_list['Exact_Match'].append(
308
+ result_per_id[image_id][blank_id_str]['exact_match']
309
+ )
310
+ vcr_score_list['Jaccard'].append(
311
+ result_per_id[image_id][blank_id_str]['jaccard']
312
+ )
313
+ vcr_score['Exact_Match'] = np.mean(vcr_score_list['Exact_Match'])
314
+ vcr_score['Jaccard'] = np.mean(vcr_score_list['Jaccard'])
315
+ results_out = {
316
+ k: v for i in range(len(results)) for k, v in results[i].get()[1].items()
317
+ }
318
+ results_with_metrics = {
319
+ 'Exact_Match': vcr_score['Exact_Match'],
320
+ 'Jaccard': vcr_score['Jaccard'],
321
+ 'Predictions': results_out,
322
+ }
323
+ score_pth = eval_file.replace(
324
+ '.xlsx', f'{self.language}_{self.difficulty}_score.json'
325
+ )
326
+ dump(results_with_metrics, score_pth)
327
+ logger.info(
328
+ f'VCR successfully finished evaluating {eval_file}, results saved in {score_pth}'
329
+ )
330
+ logger.info('Score: ')
331
+ for key, value in vcr_score.items():
332
+ logger.info('{}:{}'.format(key, value))
eval_mm/vlmevalkit/vlmeval/dataset/video_base.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from ..smp import *
3
+
4
+
5
+ class VideoBaseDataset:
6
+
7
+ MODALITY = 'VIDEO'
8
+
9
+ def __init__(self,
10
+ dataset='MMBench-Video',
11
+ pack=False):
12
+ try:
13
+ import decord
14
+ except:
15
+ warnings.warn('Please install decord via `pip install decord`.')
16
+
17
+ self.dataset_name = dataset
18
+ ret = self.prepare_dataset(dataset)
19
+ assert ret is not None
20
+ lmu_root = LMUDataRoot()
21
+ self.frame_root = osp.join(lmu_root, 'images', dataset)
22
+ os.makedirs(self.frame_root, exist_ok=True)
23
+ self.frame_tmpl = 'frame-{}-of-{}.jpg'
24
+
25
+ self.data_root = ret['root']
26
+ self.data_file = ret['data_file']
27
+ self.data = load(self.data_file)
28
+
29
+ assert 'question' in self.data and 'video' in self.data
30
+ videos = list(set(self.data['video']))
31
+ videos.sort()
32
+ self.videos = videos
33
+ self.pack = pack
34
+
35
+ def __len__(self):
36
+ return len(self.videos) if self.pack else len(self.data)
37
+
38
+ def __getitem__(self, idx):
39
+ if self.pack:
40
+ assert idx < len(self.videos)
41
+ sub_data = self.data[self.data['video'] == self.videos[idx]]
42
+ return sub_data
43
+ else:
44
+ assert idx < len(self.data)
45
+ return dict(self.data.iloc[idx])
46
+
47
+ def frame_paths(self, video, num_frames=8):
48
+ frame_root = osp.join(self.frame_root, video)
49
+ os.makedirs(frame_root, exist_ok=True)
50
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
51
+
52
+ def save_video_frames(self, video, num_frames=8):
53
+ frame_paths = self.frame_paths(video, num_frames)
54
+ flag = np.all([osp.exists(p) for p in frame_paths])
55
+ if flag:
56
+ return frame_paths
57
+ vid_path = osp.join(self.data_root, video + '.mp4')
58
+ vid = decord.VideoReader(vid_path)
59
+ step_size = len(vid) / (num_frames + 1)
60
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
61
+ images = [vid[i].numpy() for i in indices]
62
+ images = [Image.fromarray(arr) for arr in images]
63
+ for im, pth in zip(images, frame_paths):
64
+ if not osp.exists(pth):
65
+ im.save(pth)
66
+ return frame_paths
67
+
68
+ # Return a list of dataset names that are supported by this class, can override
69
+ @classmethod
70
+ def supported_datasets(cls):
71
+ return ['MMBench-Video', 'Video-MME', 'MVBench']
72
+
73
+ # Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
74
+ @abstractmethod
75
+ def evaluate(self, eval_file, **judge_kwargs):
76
+ pass
77
+
78
+ @abstractmethod
79
+ def build_prompt(self, idx, num_frames=8):
80
+ pass
81
+
82
+ @abstractmethod
83
+ def prepare_dataset(self, dataset):
84
+ # The prepare_dataset function should return a dictionary containing:
85
+ # `root` (directory that containing video files)
86
+ # `data_file` (the TSV dataset file)
87
+ pass
eval_mm/vlmevalkit/vlmeval/dataset/videomme.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ from ..smp import *
3
+ from .video_base import VideoBaseDataset
4
+
5
+ FAIL_MSG = 'Failed to obtain answer via API.'
6
+
7
+
8
+ def unwrap_hf_pkl(pth, suffix='.mp4'):
9
+ base_dir = os.path.join(pth, 'video_pkl/')
10
+ target_dir = os.path.join(pth, 'video/')
11
+ pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
12
+ pickle_files.sort()
13
+
14
+ if not os.path.exists(target_dir):
15
+ os.makedirs(target_dir, exist_ok=True)
16
+ for pickle_file in pickle_files:
17
+ with open(pickle_file, 'rb') as file:
18
+ video_data = pickle.load(file)
19
+ # For each video file in the pickle file, write its contents to a new mp4 file
20
+ for video_name, video_content in video_data.items():
21
+ output_path = os.path.join(target_dir, f'{video_name}{suffix}')
22
+ with open(output_path, 'wb') as output_file:
23
+ output_file.write(video_content)
24
+ print('The video file has been restored and stored from the pickle file.')
25
+ else:
26
+ print('The video file already exists.')
27
+
28
+
29
+ class VideoMME(VideoBaseDataset):
30
+
31
+ MD5 = '2f16cd40b1c125b67e661e59da2f6cd0'
32
+ SYS = ''
33
+
34
+ FRAMES_TMPL_NOSUB = """
35
+ These are the frames of a video. \
36
+ Select the best answer to the following multiple-choice question based on the video. \
37
+ Respond with only the letter (A, B, C, or D) of the correct option.
38
+ """
39
+
40
+ FRAMES_TMPL_SUB = """
41
+ These are the frames of a video. \
42
+ This video's subtitles are listed below:
43
+ {}
44
+ Select the best answer to the following multiple-choice question based on the video. \
45
+ Respond with only the letter (A, B, C, or D) of the correct option.
46
+ """
47
+
48
+ TYPE = 'MCQ'
49
+
50
+ def __init__(self, dataset='Video-MME', use_subtitle=False):
51
+ super().__init__(dataset=dataset)
52
+ self.use_subtitle = use_subtitle
53
+
54
+ @classmethod
55
+ def supported_datasets(cls):
56
+ return ['Video-MME']
57
+
58
+ def prepare_dataset(self, dataset_name='Video-MME', repo_id='lmms-lab/Video-MME'):
59
+
60
+ def check_integrity(pth):
61
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
62
+
63
+ if not os.path.exists(data_file):
64
+ return False
65
+
66
+ if md5(data_file) != self.MD5:
67
+ return False
68
+ data = load(data_file)
69
+ for video_pth in data['video_path']:
70
+ if not osp.exists(osp.join(pth, video_pth)):
71
+ return False
72
+ return True
73
+
74
+ cache_path = get_cache_path(repo_id)
75
+ if cache_path is not None and check_integrity(cache_path):
76
+ dataset_path = cache_path
77
+ else:
78
+
79
+ def unzip_hf_zip(pth):
80
+ import zipfile
81
+ base_dir = pth
82
+ target_dir = os.path.join(pth, 'video/')
83
+ zip_files = [
84
+ os.path.join(base_dir, file) for file in os.listdir(base_dir)
85
+ if file.endswith('.zip') and file.startswith('video')
86
+ ]
87
+ zip_files.sort()
88
+
89
+ if not os.path.exists(target_dir):
90
+ os.makedirs(target_dir, exist_ok=True)
91
+ for zip_file in zip_files:
92
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
93
+ for member in zip_ref.namelist():
94
+ # Check if the member is a file (not a directory)
95
+ if not member.endswith('/'):
96
+ # Extract the file to the specified directory
97
+ source = zip_ref.open(member)
98
+ target = open(os.path.join(target_dir, os.path.basename(member)), 'wb')
99
+ with source, target:
100
+ target.write(source.read())
101
+ print('The video file has been restored and stored from the zip file.')
102
+ else:
103
+ print('The video file already exists.')
104
+
105
+ subtitle_zip_file = os.path.join(base_dir, 'subtitle.zip')
106
+ subtitle_target_dir = os.path.join(base_dir, 'subtitle')
107
+
108
+ if not os.path.exists(subtitle_target_dir):
109
+ os.makedirs(subtitle_target_dir, exist_ok=True)
110
+ with zipfile.ZipFile(subtitle_zip_file, 'r') as zip_ref:
111
+ for member in zip_ref.namelist():
112
+ # Check if the member is a file (not a directory)
113
+ if not member.endswith('/'):
114
+ # Extract the file to the specified directory
115
+ source = zip_ref.open(member)
116
+ target = open(os.path.join(subtitle_target_dir, os.path.basename(member)), 'wb')
117
+ with source, target:
118
+ target.write(source.read())
119
+ print('The subtitle file has been restored and stored from the zip file.')
120
+ else:
121
+ print('The subtitle file already exists.')
122
+
123
+ def generate_tsv(pth):
124
+
125
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
126
+ if os.path.exists(data_file) and md5(data_file) == self.MD5:
127
+ return
128
+
129
+ data_file = pd.read_parquet(os.path.join(pth, 'videomme/test-00000-of-00001.parquet'))
130
+ data_file = data_file.assign(index=range(len(data_file)))
131
+ data_file['video'] = data_file['videoID']
132
+ data_file['video_path'] = data_file['videoID'].apply(lambda x: f'./video/{x}.mp4')
133
+ data_file['subtitle_path'] = data_file['videoID'].apply(lambda x: f'./subtitle/{x}.srt')
134
+ data_file['question'] += '\n' + data_file['options'].apply(lambda x: '\n'.join(x))
135
+
136
+ data_file = data_file[['index', 'video', 'video_path', 'duration', 'domain',
137
+ 'sub_category', 'task_type', 'subtitle_path', 'question', 'answer']]
138
+
139
+ data_file.to_csv(osp.join(pth, f'{dataset_name}.tsv'), sep='\t', index=False)
140
+
141
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
142
+ unzip_hf_zip(dataset_path)
143
+ generate_tsv(dataset_path)
144
+
145
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
146
+
147
+ return dict(data_file=data_file, root=dataset_path)
148
+
149
+ def save_video_frames(self, video, num_frames=8):
150
+
151
+ vid_path = osp.join(self.data_root, 'video', video + '.mp4')
152
+ vid = decord.VideoReader(vid_path)
153
+ step_size = len(vid) / (num_frames + 1)
154
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
155
+
156
+ video_info = {
157
+ 'fps': vid.get_avg_fps(),
158
+ 'n_frames': len(vid),
159
+ }
160
+
161
+ frame_paths = self.frame_paths(video, num_frames)
162
+ flag = np.all([osp.exists(p) for p in frame_paths])
163
+
164
+ if not flag:
165
+ images = [vid[i].numpy() for i in indices]
166
+ images = [Image.fromarray(arr) for arr in images]
167
+ for im, pth in zip(images, frame_paths):
168
+ if not osp.exists(pth):
169
+ im.save(pth)
170
+
171
+ return frame_paths, indices, video_info
172
+
173
+ def build_prompt(self, line, num_frames, video_llm):
174
+ if isinstance(line, int):
175
+ assert line < len(self)
176
+ line = self.data.iloc[line]
177
+
178
+ frames, indices, video_info = self.save_video_frames(line['video'], num_frames)
179
+
180
+ if self.use_subtitle and os.path.exists(osp.join(self.data_root, line['subtitle_path'])):
181
+ import pysubs2
182
+ subs = pysubs2.load(osp.join(self.data_root, line['subtitle_path']), encoding='utf-8')
183
+ subtitles = []
184
+
185
+ for seleced_frame_id in indices:
186
+ sub_text = ''
187
+ cur_time = pysubs2.make_time(fps=video_info['fps'], frames=seleced_frame_id)
188
+ for sub in subs:
189
+ if sub.start < cur_time and sub.end > cur_time:
190
+ sub_text = sub.text.replace('\\N', ' ')
191
+ break
192
+ if sub_text.strip():
193
+ subtitles.append(sub_text)
194
+ subtitles = '\n'.join(subtitles)
195
+ else:
196
+ subtitles = ''
197
+
198
+ message = [dict(type='text', value=self.SYS)]
199
+ if video_llm:
200
+ message.append(dict(type='video', value=osp.join(self.data_root, 'video', line['video'] + '.mp4')))
201
+ else:
202
+ for im in frames:
203
+ message.append(dict(type='image', value=im))
204
+
205
+ text_prompt = self.FRAMES_TMPL_NOSUB if not self.use_subtitle else self.FRAMES_TMPL_SUB.format(subtitles)
206
+ message.append(dict(type='text', value=text_prompt))
207
+ prompt = 'Question: {}\nAnswer: '.format(line['question'])
208
+ message.append(dict(type='text', value=prompt))
209
+ return message
210
+
211
+ # It returns a dictionary
212
+ @classmethod
213
+ def evaluate(self, eval_file, **judge_kwargs):
214
+ from .utils.videomme import get_dimension_rating, extract_characters_regex
215
+
216
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
217
+
218
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
219
+ tgt_file = eval_file.replace('.xlsx', '_rating.json')
220
+ score_file = eval_file.replace('.xlsx', '_score.xlsx')
221
+
222
+ if not osp.exists(score_file):
223
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
224
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
225
+
226
+ data = load(eval_file)
227
+ data_un = data[~pd.isna(data['prediction'])]
228
+
229
+ for idx in data['index']:
230
+ ans = data.loc[data['index'] == idx, 'answer'].values[0]
231
+ pred = data.loc[data['index'] == idx, 'prediction'].values[0]
232
+
233
+ if extract_characters_regex(pred) == '':
234
+ data.loc[idx, 'score'] = -1
235
+ else:
236
+ data.loc[idx, 'score'] = int(extract_characters_regex(pred) == ans)
237
+
238
+ rejected = [x for x in data['score'] if x == -1]
239
+
240
+ print(
241
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
242
+ f'failed to obtain the score for another {len(rejected)} questions. '
243
+ f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
244
+ )
245
+
246
+ dump(data, score_file)
247
+
248
+ rating = get_dimension_rating(score_file)
249
+ dump(rating, tgt_file)
250
+ return rating
eval_mm/vlmevalkit/vlmeval/inference.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from vlmeval.config import supported_VLM
4
+ from vlmeval.utils import track_progress_rich
5
+ from vlmeval.smp import *
6
+
7
+ FAIL_MSG = 'Failed to obtain answer via API.'
8
+
9
+
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--data', type=str, nargs='+', required=True)
13
+ parser.add_argument('--model', type=str, nargs='+', required=True)
14
+ parser.add_argument('--nproc', type=int, default=4, required=True)
15
+ parser.add_argument('--verbose', action='store_true')
16
+ args = parser.parse_args()
17
+ return args
18
+
19
+
20
+ # Only API model is accepted
21
+ def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
22
+ rank, world_size = get_rank_and_world_size()
23
+ assert rank == 0 and world_size == 1
24
+ dataset_name = dataset.dataset_name
25
+ data = dataset.data
26
+ if index_set is not None:
27
+ data = data[data['index'].isin(index_set)]
28
+
29
+ model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
30
+ assert getattr(model, 'is_api', False)
31
+
32
+ lt, indices = len(data), list(data['index'])
33
+ structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
34
+
35
+ out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
36
+ res = {}
37
+ if osp.exists(out_file):
38
+ res = load(out_file)
39
+ if ignore_failed:
40
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
41
+
42
+ structs = [s for i, s in zip(indices, structs) if i not in res]
43
+ indices = [i for i in indices if i not in res]
44
+
45
+ gen_func = model.generate
46
+ structs = [dict(message=struct, dataset=dataset_name) for struct in structs]
47
+
48
+ if len(structs):
49
+ track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
50
+
51
+ res = load(out_file)
52
+ if index_set is not None:
53
+ res = {k: v for k, v in res.items() if k in index_set}
54
+ os.remove(out_file)
55
+ return res
56
+
57
+
58
+ def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
59
+ dataset_name = dataset.dataset_name
60
+ prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
61
+ res = load(prev_file) if osp.exists(prev_file) else {}
62
+ if osp.exists(out_file):
63
+ res.update(load(out_file))
64
+
65
+ rank, world_size = get_rank_and_world_size()
66
+ sheet_indices = list(range(rank, len(dataset), world_size))
67
+ lt = len(sheet_indices)
68
+ data = dataset.data.iloc[sheet_indices]
69
+ data_indices = [i for i in data['index']]
70
+
71
+ # If finished, will exit without building the model
72
+ all_finished = True
73
+ for i in range(lt):
74
+ idx = data.iloc[i]['index']
75
+ if idx not in res:
76
+ all_finished = False
77
+ if all_finished:
78
+ res = {k: res[k] for k in data_indices}
79
+ dump(res, out_file)
80
+ return
81
+
82
+ # Data need to be inferred
83
+ data = data[~data['index'].isin(res)]
84
+ lt = len(data)
85
+
86
+ model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
87
+
88
+ is_api = getattr(model, 'is_api', False)
89
+ if is_api:
90
+ lt, indices = len(data), list(data['index'])
91
+ supp = infer_data_api(
92
+ work_dir=work_dir,
93
+ model_name=model_name,
94
+ dataset=dataset,
95
+ index_set=set(indices),
96
+ api_nproc=api_nproc)
97
+ for idx in indices:
98
+ assert idx in supp
99
+ res.update(supp)
100
+ res = {k: res[k] for k in data_indices}
101
+ dump(res, out_file)
102
+ return model_name
103
+ else:
104
+ model.set_dump_image(dataset.dump_image)
105
+
106
+ for i in tqdm(range(lt)):
107
+ idx = data.iloc[i]['index']
108
+ if idx in res:
109
+ continue
110
+
111
+ if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
112
+ struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
113
+ else:
114
+ struct = dataset.build_prompt(data.iloc[i])
115
+
116
+ response = model.generate(message=struct, dataset=dataset_name)
117
+ torch.cuda.empty_cache()
118
+
119
+ if verbose:
120
+ print(response, flush=True)
121
+
122
+ res[idx] = response
123
+ if (i + 1) % 20 == 0:
124
+ dump(res, out_file)
125
+
126
+ res = {k: res[k] for k in data_indices}
127
+ dump(res, out_file)
128
+ return model
129
+
130
+
131
+ # A wrapper for infer_data, do the pre & post processing
132
+ def infer_data_job(model, work_dir, model_name, dataset, verbose=False, api_nproc=4, ignore_failed=False):
133
+ rank, world_size = get_rank_and_world_size()
134
+ dataset_name = dataset.dataset_name
135
+ result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.xlsx')
136
+
137
+ prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
138
+ if osp.exists(result_file):
139
+ if rank == 0:
140
+ data = load(result_file)
141
+ results = {k: v for k, v in zip(data['index'], data['prediction'])}
142
+ if not ignore_failed:
143
+ results = {k: v for k, v in results.items() if FAIL_MSG not in str(v)}
144
+ dump(results, prev_file)
145
+ if world_size > 1:
146
+ dist.barrier()
147
+
148
+ tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
149
+ out_file = tmpl.format(rank)
150
+
151
+ model = infer_data(
152
+ model, work_dir=work_dir, dataset=dataset, out_file=out_file, verbose=verbose, api_nproc=api_nproc)
153
+ if world_size > 1:
154
+ dist.barrier()
155
+
156
+ if rank == 0:
157
+ data_all = {}
158
+ for i in range(world_size):
159
+ data_all.update(load(tmpl.format(i)))
160
+
161
+ data = dataset.data
162
+ for x in data['index']:
163
+ assert x in data_all
164
+ data['prediction'] = [str(data_all[x]) for x in data['index']]
165
+ if 'image' in data:
166
+ data.pop('image')
167
+
168
+ dump(data, result_file)
169
+ for i in range(world_size):
170
+ os.remove(tmpl.format(i))
171
+ return model
eval_mm/vlmevalkit/vlmeval/inference_mt.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from vlmeval.config import supported_VLM
4
+ from vlmeval.utils import track_progress_rich
5
+ from vlmeval.smp import *
6
+
7
+ FAIL_MSG = 'Failed to obtain answer via API.'
8
+
9
+
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--data', type=str, nargs='+', required=True)
13
+ parser.add_argument('--model', type=str, nargs='+', required=True)
14
+ parser.add_argument('--nproc', type=int, default=4, required=True)
15
+ parser.add_argument('--verbose', action='store_true')
16
+ args = parser.parse_args()
17
+ return args
18
+
19
+
20
+ def chat_mt(model, messages, dataset_name):
21
+ assert len(messages) % 2 == 0
22
+ nturn = len(messages) // 2
23
+ utter_stack = []
24
+ predictions = []
25
+
26
+ for i in range(nturn):
27
+ utter = messages[2 * i]
28
+ utter_stack.append(utter)
29
+ try:
30
+ resp = model.chat(utter_stack, dataset=dataset_name)
31
+ utter_stack.append(dict(role='assistant', content=resp))
32
+ except:
33
+ resp = FAIL_MSG
34
+ utter_stack.append(dict(role='assistant', content=resp))
35
+ predictions.append(resp)
36
+ return predictions
37
+
38
+
39
+ # Only API model is accepted
40
+ def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
41
+ rank, world_size = get_rank_and_world_size()
42
+ assert rank == 0 and world_size == 1
43
+ dataset_name = dataset.dataset_name
44
+ data = dataset.data
45
+ if index_set is not None:
46
+ data = data[data['index'].isin(index_set)]
47
+
48
+ model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
49
+ assert getattr(model, 'is_api', False)
50
+ assert hasattr(model, 'chat_inner')
51
+
52
+ lt, indices = len(data), list(data['index'])
53
+ structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
54
+
55
+ out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
56
+ res = {}
57
+ if osp.exists(out_file):
58
+ res = load(out_file)
59
+ if ignore_failed:
60
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
61
+
62
+ structs = [s for i, s in zip(indices, structs) if i not in res]
63
+ indices = [i for i in indices if i not in res]
64
+
65
+ structs = [dict(model=model, messages=struct, dataset_name=dataset_name) for struct in structs]
66
+
67
+ if len(structs):
68
+ track_progress_rich(chat_mt, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
69
+
70
+ res = load(out_file)
71
+ if index_set is not None:
72
+ res = {k: v for k, v in res.items() if k in index_set}
73
+ os.remove(out_file)
74
+ return res
75
+
76
+
77
+ def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
78
+ dataset_name = dataset.dataset_name
79
+ res = {}
80
+ if osp.exists(out_file):
81
+ res.update(load(out_file))
82
+
83
+ rank, world_size = get_rank_and_world_size()
84
+ sheet_indices = list(range(rank, len(dataset), world_size))
85
+ lt = len(sheet_indices)
86
+ data = dataset.data.iloc[sheet_indices]
87
+ data_indices = [i for i in data['index']]
88
+
89
+ # If finished, will exit without building the model
90
+ all_finished = True
91
+ for i in range(lt):
92
+ idx = data.iloc[i]['index']
93
+ if idx not in res:
94
+ all_finished = False
95
+ if all_finished:
96
+ res = {k: res[k] for k in data_indices}
97
+ dump(res, out_file)
98
+ return
99
+
100
+ # Data need to be inferred
101
+ data = data[~data['index'].isin(res)]
102
+ lt = len(data)
103
+
104
+ model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
105
+ assert hasattr(model, 'chat_inner')
106
+
107
+ is_api = getattr(model, 'is_api', False)
108
+ if is_api:
109
+ lt, indices = len(data), list(data['index'])
110
+ supp = infer_data_api(
111
+ work_dir=work_dir,
112
+ model_name=model_name,
113
+ dataset=dataset,
114
+ index_set=set(indices),
115
+ api_nproc=api_nproc)
116
+ for idx in indices:
117
+ assert idx in supp
118
+ res.update(supp)
119
+ res = {k: res[k] for k in data_indices}
120
+ dump(res, out_file)
121
+ return model_name
122
+ else:
123
+ model.set_dump_image(dataset.dump_image)
124
+
125
+ for i in tqdm(range(lt)):
126
+ idx = data.iloc[i]['index']
127
+ if idx in res:
128
+ continue
129
+
130
+ if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
131
+ struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
132
+ else:
133
+ struct = dataset.build_prompt(data.iloc[i])
134
+
135
+ response = chat_mt(model, struct, dataset_name)
136
+ torch.cuda.empty_cache()
137
+
138
+ if verbose:
139
+ print(response, flush=True)
140
+
141
+ res[idx] = response
142
+ if (i + 1) % 20 == 0:
143
+ dump(res, out_file)
144
+
145
+ res = {k: res[k] for k in data_indices}
146
+ dump(res, out_file)
147
+ return model
148
+
149
+
150
+ # A wrapper for infer_data, do the pre & post processing
151
+ def infer_data_job_mt(model, work_dir, model_name, dataset, verbose=False, api_nproc=4, ignore_failed=False):
152
+ rank, world_size = get_rank_and_world_size()
153
+ dataset_name = dataset.dataset_name
154
+ result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.tsv')
155
+
156
+ tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
157
+ out_file = tmpl.format(rank)
158
+
159
+ model = infer_data(
160
+ model, work_dir=work_dir, dataset=dataset, out_file=out_file, verbose=verbose, api_nproc=api_nproc)
161
+ if world_size > 1:
162
+ dist.barrier()
163
+
164
+ if rank == 0:
165
+ data_all = {}
166
+ for i in range(world_size):
167
+ data_all.update(load(tmpl.format(i)))
168
+
169
+ data = dataset.data
170
+ for x in data['index']:
171
+ assert x in data_all
172
+
173
+ data['prediction'] = [data_all[x] for x in data['index']]
174
+ if 'image' in data:
175
+ data.pop('image')
176
+
177
+ dump(data, result_file)
178
+ for i in range(world_size):
179
+ os.remove(tmpl.format(i))
180
+ return model