StarCycle commited on
Commit
35a83a5
·
verified ·
1 Parent(s): d2d310a

Upload mmbench.py

Browse files
modified_xtuner/xtuner/tools/mmbench.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import json
4
+ import math
5
+ import os
6
+ import os.path as osp
7
+ import re
8
+ import string
9
+ import time
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import tqdm
15
+ from huggingface_hub import snapshot_download
16
+ from mmengine import mkdir_or_exist
17
+ from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
18
+ master_only)
19
+ from mmengine.utils.dl_utils import set_multi_processing
20
+ from peft import PeftModel
21
+ from rich.console import Console
22
+ from rich.table import Table
23
+ from torch.utils.data import Dataset
24
+ from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
25
+ BitsAndBytesConfig, SiglipImageProcessor,
26
+ SiglipVisionModel, Dinov2Model,
27
+ GenerationConfig)
28
+
29
+ from xtuner.dataset.utils import decode_base64_to_image, expand2square
30
+ from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
31
+ from xtuner.tools.utils import get_stop_criteria, is_cn_string
32
+ from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
33
+ PROMPT_TEMPLATE)
34
+
35
+ TORCH_DTYPE_MAP = dict(
36
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
37
+
38
+
39
+ def parse_args():
40
+ parser = argparse.ArgumentParser(description='MMBench')
41
+ parser.add_argument(
42
+ 'model_name_or_path', help='Hugging Face model name or path')
43
+ parser.add_argument('--data-path', default=None, help='data path')
44
+ parser.add_argument('--work-dir', help='the dir to save results')
45
+ parser.add_argument('--llava', default=None, help='llava name or path')
46
+ parser.add_argument(
47
+ '--siglip', default=None, help='siglip visual encoder name or path')
48
+ parser.add_argument(
49
+ '--visual-select-layer', default=-2, help='visual select layer')
50
+ parser.add_argument(
51
+ '--dino', default=None, help='dino visual encoder name or path')
52
+ parser.add_argument(
53
+ '--prompt-template',
54
+ choices=PROMPT_TEMPLATE.keys(),
55
+ default=None,
56
+ help='Specify a prompt template')
57
+ parser.add_argument(
58
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
59
+ parser.add_argument(
60
+ '--torch-dtype',
61
+ default='fp16',
62
+ choices=TORCH_DTYPE_MAP.keys(),
63
+ help='Override the default `torch.dtype` and load the model under '
64
+ 'a specific `dtype`.')
65
+ parser.add_argument(
66
+ '--bits',
67
+ type=int,
68
+ choices=[4, 8, None],
69
+ default=None,
70
+ help='LLM bits')
71
+ parser.add_argument(
72
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
73
+ parser.add_argument(
74
+ '--offload-folder',
75
+ default=None,
76
+ help='The folder in which to offload the model weights (or where the '
77
+ 'model weights are already offloaded).')
78
+ parser.add_argument(
79
+ '--max-new-tokens',
80
+ type=int,
81
+ default=100,
82
+ help='Maximum number of new tokens allowed in generated text')
83
+ parser.add_argument(
84
+ '--seed',
85
+ type=int,
86
+ default=0,
87
+ help='Random seed for reproducible text generation')
88
+ parser.add_argument(
89
+ '--launcher',
90
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
91
+ default='none',
92
+ help='job launcher')
93
+ args = parser.parse_args()
94
+ return args
95
+
96
+
97
+ @master_only
98
+ def master_print(msg):
99
+ print(msg)
100
+
101
+
102
+ class MMBenchDataset(Dataset):
103
+ ABBRS = {
104
+ 'coarse_perception': 'CP',
105
+ 'finegrained_perception (instance-level)': 'FP-S',
106
+ 'finegrained_perception (cross-instance)': 'FP-C',
107
+ 'logic_reasoning': 'LR',
108
+ 'relation_reasoning': 'RR',
109
+ 'attribute_reasoning': 'AR',
110
+ 'sketch_reasoning': 'Sketch Reasoning',
111
+ 'scenery_building': 'Scenery & Building',
112
+ 'food_clothes': 'Food & Clothes',
113
+ 'historical_figure': 'Historical Figure',
114
+ 'traditional_show': 'Traditional Show',
115
+ 'calligraphy_painting': 'Calligraphy Painting',
116
+ 'cultural_relic': 'Cultural Relic'
117
+ }
118
+
119
+ def __init__(self, data_file):
120
+ self.data_file = data_file
121
+ self.df = pd.read_csv(data_file, sep='\t')
122
+ self.split = 'dev' if 'answer' in self.df.iloc[0].keys() else 'test'
123
+ self.has_l2_category = 'l2-category' in self.df.columns.to_list()
124
+
125
+ def get_image(self, image):
126
+ while len(image) < 16:
127
+ image = self.df[self.df['index'] == int(image)]['image'].values
128
+ assert len(image) == 1
129
+ image = image[0]
130
+ image = decode_base64_to_image(image)
131
+ return image
132
+
133
+ def __len__(self):
134
+ return len(self.df)
135
+
136
+ def __getitem__(self, idx):
137
+ index = self.df.iloc[idx]['index']
138
+ image = self.df.iloc[idx]['image']
139
+ image = self.get_image(image)
140
+ question = self.df.iloc[idx]['question']
141
+ answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[
142
+ 0].keys() else None
143
+ category = self.df.iloc[idx]['category']
144
+
145
+ options = {
146
+ cand: self.load_from_df(idx, cand)
147
+ for cand in string.ascii_uppercase
148
+ if self.load_from_df(idx, cand) is not None
149
+ }
150
+ options_prompt = ''
151
+ for key, item in options.items():
152
+ options_prompt += f'{key}. {item}\n'
153
+
154
+ hint = self.load_from_df(idx, 'hint')
155
+ data = {
156
+ 'img': image,
157
+ 'question': question,
158
+ 'answer': answer,
159
+ 'options': options_prompt,
160
+ 'category': category,
161
+ 'options_dict': options,
162
+ 'index': index,
163
+ 'context': hint,
164
+ }
165
+ if self.has_l2_category:
166
+ data.update({'l2-category': self.df.iloc[idx]['l2-category']})
167
+ return data
168
+
169
+ def load_from_df(self, idx, key):
170
+ if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
171
+ return self.df.iloc[idx][key]
172
+ else:
173
+ return None
174
+
175
+ @master_only
176
+ def eval_result(self, result_df, show=True):
177
+
178
+ def calc_acc(df, group='category'):
179
+ assert group in ['overall', 'category', 'l2-category']
180
+ if group == 'overall':
181
+ res = {'Average': np.mean(df['hit'])}
182
+ else:
183
+ res = {}
184
+ abilities = list(set(df[group]))
185
+ abilities.sort()
186
+ for ab in abilities:
187
+ sub_df = df[df[group] == ab]
188
+ ab = self.ABBRS[ab] if ab in self.ABBRS else ab
189
+ res[ab] = np.mean(sub_df['hit'])
190
+ return res
191
+
192
+ def eval_sub_data(sub_data, answer_map):
193
+ lt = len(sub_data)
194
+ for i in range(lt):
195
+ item = sub_data.iloc[i]
196
+ match = re.search(r'([A-D]+)', item['prediction'])
197
+ pred = match.group(1) if match else ''
198
+ gt = answer_map[item['index']]
199
+ if gt != pred:
200
+ return 0
201
+ return 1
202
+
203
+ def show_result(ret_json):
204
+ show_dict = ret_json.copy()
205
+ table = Table(title=f' MMBench ({self.data_file}) ')
206
+ console = Console()
207
+ table.add_column('Category', justify='left')
208
+ table.add_column('Accuracy (%)', justify='right')
209
+ average = show_dict.pop('Average') * 100
210
+ table.add_row('Average', f'{average:.1f}')
211
+ table.add_section()
212
+ for cat_name, cat_acc in show_dict.items():
213
+ table.add_row(cat_name, f'{cat_acc * 100:.1f}')
214
+ with console.capture() as capture:
215
+ console.print(table, end='')
216
+ print('\n' + capture.get())
217
+ print('Note: Please be cautious if you use the results in papers, '
218
+ "since we don't use ChatGPT as a helper for choice "
219
+ 'extraction')
220
+
221
+ data = result_df.sort_values(by='index')
222
+ data['prediction'] = [str(x) for x in data['prediction']]
223
+ for k in data.keys():
224
+ data[k.lower() if k not in 'ABCD' else k] = data.pop(k)
225
+
226
+ data_main = data[data['index'] < int(1e6)]
227
+ cate_map = {
228
+ i: c
229
+ for i, c in zip(self.df['index'], self.df['category'])
230
+ }
231
+ if self.has_l2_category:
232
+ l2_cate_map = {
233
+ i: c
234
+ for i, c in zip(self.df['index'], self.df['l2-category'])
235
+ }
236
+ answer_map = {
237
+ i: c
238
+ for i, c in zip(self.df['index'], self.df['answer'])
239
+ }
240
+
241
+ lt = len(data_main)
242
+ hit, tot = 0, 0
243
+ result = {}
244
+ for i in range(lt):
245
+ item_main = data_main.iloc[i]
246
+ idx = item_main['index']
247
+ assert idx not in result
248
+ sub_data = data[data['index'] % int(1e6) == idx]
249
+ ret = eval_sub_data(sub_data, answer_map)
250
+ result[idx] = ret
251
+ hit += ret
252
+ tot += 1
253
+
254
+ indices = data_main['index']
255
+ data_main = data_main.copy()
256
+ data_main['hit'] = [result[i] for i in indices]
257
+ main_idx = data_main['index']
258
+ data_main['category'] = [cate_map[i] for i in main_idx]
259
+
260
+ ret_json = calc_acc(data_main, 'overall')
261
+
262
+ if self.has_l2_category:
263
+ data_main['l2-category'] = [l2_cate_map[i] for i in main_idx]
264
+ l2 = calc_acc(data_main, 'l2-category')
265
+ ret_json.update(l2)
266
+ else:
267
+ leaf = calc_acc(data_main, 'category')
268
+ ret_json.update(leaf)
269
+ if show:
270
+ show_result(ret_json)
271
+ return ret_json
272
+
273
+
274
+ def main():
275
+ args = parse_args()
276
+
277
+ torch.manual_seed(args.seed)
278
+
279
+ if args.launcher != 'none':
280
+ set_multi_processing(distributed=True)
281
+ init_dist(args.launcher)
282
+
283
+ rank, world_size = get_dist_info()
284
+ torch.cuda.set_device(rank)
285
+ else:
286
+ rank = 0
287
+ world_size = 1
288
+
289
+ # build llm
290
+ quantization_config = None
291
+ load_in_8bit = False
292
+ if args.bits == 4:
293
+ quantization_config = BitsAndBytesConfig(
294
+ load_in_4bit=True,
295
+ load_in_8bit=False,
296
+ llm_int8_threshold=6.0,
297
+ llm_int8_has_fp16_weight=False,
298
+ bnb_4bit_compute_dtype=torch.float16,
299
+ bnb_4bit_use_double_quant=True,
300
+ bnb_4bit_quant_type='nf4')
301
+ elif args.bits == 8:
302
+ load_in_8bit = True
303
+ model_kwargs = {
304
+ 'quantization_config': quantization_config,
305
+ 'load_in_8bit': load_in_8bit,
306
+ 'device_map': rank if world_size > 1 else 'auto',
307
+ 'offload_folder': args.offload_folder,
308
+ 'trust_remote_code': True,
309
+ 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
310
+ }
311
+
312
+ # build llm
313
+ with LoadWoInit():
314
+ llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
315
+ **model_kwargs)
316
+ tokenizer = AutoTokenizer.from_pretrained(
317
+ args.model_name_or_path,
318
+ trust_remote_code=True,
319
+ encode_special_tokens=True)
320
+ master_print(f'Load LLM from {args.model_name_or_path}')
321
+
322
+ llava_path = snapshot_download(
323
+ repo_id=args.llava) if not osp.isdir(args.llava) else args.llava
324
+
325
+ # build visual_encoder
326
+ if 'visual_encoder' in os.listdir(llava_path):
327
+ assert args.visual_encoder is None, (
328
+ "Please don't specify the `--visual-encoder` since passed "
329
+ '`--llava` contains a visual encoder!')
330
+ visual_encoder_path = osp.join(llava_path, 'visual_encoder')
331
+ else:
332
+ assert args.siglip is not None, (
333
+ 'Please specify the `--siglip`!')
334
+ assert args.dino is not None, (
335
+ 'Please specify the `--dino`!')
336
+ with LoadWoInit():
337
+ siglip = SiglipVisionModel.from_pretrained(
338
+ args.siglip,
339
+ torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
340
+ image_processor = SiglipImageProcessor.from_pretrained(
341
+ args.siglip)
342
+ master_print(f'Load siglip from {args.siglip}')
343
+ dino = Dinov2Model.from_pretrained(
344
+ args.dino,
345
+ torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
346
+ master_print(f'Load dino from {args.dino}')
347
+
348
+ # load adapter
349
+ if 'llm_adapter' in os.listdir(llava_path):
350
+ adapter_path = osp.join(llava_path, 'llm_adapter')
351
+
352
+ with LoadWoInit():
353
+ llm = PeftModel.from_pretrained(
354
+ llm, adapter_path, offload_folder=args.offload_folder)
355
+
356
+ master_print(f'Load LLM adapter from {args.llava}')
357
+
358
+ if 'visual_encoder_adapter' in os.listdir(llava_path):
359
+ adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
360
+ visual_encoder = PeftModel.from_pretrained(
361
+ visual_encoder, adapter_path, offload_folder=args.offload_folder)
362
+ master_print(f'Load visual_encoder adapter from {args.llava}')
363
+
364
+ # build projector
365
+ projector_path = osp.join(llava_path, 'projector')
366
+ with LoadWoInit():
367
+ projector = AutoModel.from_pretrained(
368
+ projector_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
369
+ master_print(f'Load projector from {args.llava}')
370
+
371
+ projector.cuda()
372
+ projector.eval()
373
+
374
+ siglip.cuda()
375
+ siglip.eval()
376
+ dino.cuda()
377
+ dino.eval()
378
+
379
+ llm.eval()
380
+
381
+ stop_words = args.stop_words
382
+ if args.prompt_template:
383
+ template = PROMPT_TEMPLATE[args.prompt_template]
384
+ stop_words += template.get('STOP_WORDS', [])
385
+ stop_criteria = get_stop_criteria(
386
+ tokenizer=tokenizer, stop_words=stop_words)
387
+
388
+ gen_config = GenerationConfig(
389
+ max_new_tokens=args.max_new_tokens,
390
+ do_sample=False,
391
+ eos_token_id=tokenizer.eos_token_id,
392
+ pad_token_id=tokenizer.pad_token_id
393
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
394
+ )
395
+
396
+ # work_dir
397
+ if args.work_dir is not None:
398
+ # update configs according to CLI args if args.work_dir is not None
399
+ save_dir = args.work_dir
400
+ else:
401
+ # use config filename as default work_dir
402
+ save_dir = osp.join('./work_dirs',
403
+ osp.splitext(osp.basename(args.data_path))[0])
404
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
405
+ save_dir = osp.join(save_dir, timestamp)
406
+
407
+ if rank == 0:
408
+ mkdir_or_exist(osp.abspath(save_dir))
409
+ print('=======================================================')
410
+ print(f'Dataset path: {osp.abspath(args.data_path)}\n'
411
+ f'Results will be saved to {osp.abspath(save_dir)}')
412
+ print('=======================================================')
413
+
414
+ args_path = osp.join(save_dir, 'args.json')
415
+ with open(args_path, 'w') as f:
416
+ json.dump(args.__dict__, f, indent=2)
417
+
418
+ results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
419
+ results_json_path = osp.join(save_dir, 'mmbench_result.json')
420
+
421
+ dataset = MMBenchDataset(args.data_path)
422
+
423
+ results = []
424
+ n_samples = len(dataset)
425
+ per_rank_samples = math.ceil(n_samples / world_size)
426
+
427
+ per_rank_ids = range(per_rank_samples * rank,
428
+ min(n_samples, per_rank_samples * (rank + 1)))
429
+ for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
430
+ data_sample = dataset[i]
431
+ if data_sample['context'] is not None:
432
+ text = data_sample['context'] + '\n' + data_sample[
433
+ 'question'] + '\n' + data_sample['options']
434
+ else:
435
+ text = data_sample['question'] + '\n' + data_sample['options']
436
+
437
+ text = DEFAULT_IMAGE_TOKEN + '\n' + text
438
+
439
+ if is_cn_string(text):
440
+ text = text + '请直接回答选项字母。'
441
+ else:
442
+ text = text + ("Answer with the option's letter from the "
443
+ 'given choices directly.')
444
+
445
+ if args.prompt_template:
446
+ prompt_text = ''
447
+ template = PROMPT_TEMPLATE[args.prompt_template]
448
+ prompt_text += template['INSTRUCTION'].format(
449
+ input=text, round=1, bot_name=args.bot_name)
450
+ else:
451
+ prompt_text = text
452
+ inputs = prompt_text
453
+
454
+ image = data_sample['img'].convert('RGB')
455
+ image = expand2square(
456
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
457
+ image = image_processor.preprocess(
458
+ image, return_tensors='pt')['pixel_values'][0]
459
+ image = image.cuda().unsqueeze(0)
460
+
461
+ siglip_out = siglip(
462
+ image, output_hidden_states=True).hidden_states[args.visual_select_layer]
463
+ dino_out = dino(
464
+ image, output_hidden_states=True).hidden_states[-1][:, 1:]
465
+ visual_out = torch.cat((siglip_out, dino_out), dim=-1)
466
+ pixel_values = projector(visual_out)
467
+
468
+ chunk_encode = []
469
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
470
+ if idx == 0:
471
+ cur_encode = tokenizer.encode(chunk)
472
+ else:
473
+ cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
474
+ chunk_encode.append(cur_encode)
475
+ assert len(chunk_encode) == 2
476
+ ids = []
477
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
478
+ ids.extend(cur_chunk_encode)
479
+ if idx != len(chunk_encode) - 1:
480
+ ids.append(IMAGE_TOKEN_INDEX)
481
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
482
+ mm_inputs = prepare_inputs_labels_for_multimodal(
483
+ llm=llm, input_ids=ids, pixel_values=pixel_values)
484
+
485
+ generate_output = llm.generate(
486
+ **mm_inputs,
487
+ generation_config=gen_config,
488
+ streamer=None,
489
+ bos_token_id=tokenizer.bos_token_id,
490
+ stopping_criteria=stop_criteria)
491
+
492
+ predict = tokenizer.decode(
493
+ generate_output[0], skip_special_tokens=True).strip()
494
+ cur_result = {}
495
+ cur_result['question'] = data_sample.get('question')
496
+ cur_result.update(data_sample.get('options_dict'))
497
+ cur_result['prediction'] = predict
498
+ if data_sample.get('category') is not None:
499
+ cur_result['category'] = data_sample.get('category')
500
+ if data_sample.get('l2-category') is not None:
501
+ cur_result['l2-category'] = data_sample.get('l2-category')
502
+ cur_result['index'] = data_sample.get('index')
503
+ cur_result['split'] = data_sample.get('split')
504
+ cur_result['answer'] = data_sample.get('answer')
505
+ results.append(cur_result)
506
+
507
+ results = collect_results(results, n_samples)
508
+
509
+ if get_rank() == 0:
510
+
511
+ results_df = pd.DataFrame(results)
512
+ with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer:
513
+ results_df.to_excel(writer, index=False)
514
+
515
+ if dataset.split == 'dev':
516
+ results_dict = dataset.eval_result(results_df, show=True)
517
+ with open(results_json_path, 'w') as f:
518
+ json.dump(results_dict, f, indent=2)
519
+ else:
520
+ print('All done!')
521
+
522
+
523
+ if __name__ == '__main__':
524
+
525
+ main()