File size: 4,663 Bytes
9d8126f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f942c8
9d8126f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python

from typing import Tuple

import argparse
import onnxruntime
import os
import sys
import time
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument(
    "--onnx_model", default="model.onnx", help="Input onnx model")
parser.add_argument(
    "--data_dir",
    default="/workspace/dataset/imagenet",
    help="Directory of dataset")
parser.add_argument(
    "--batch_size", default=1, type=int, help="Evaluation batch size")
parser.add_argument(
    "--ipu",
    action="store_true",
    help="Use IPU for inference.",
)
parser.add_argument(
    "--provider_config",
    type=str,
    default="vaip_config.json",
    help="Path of the config file for seting provider_options.",
)
args = parser.parse_args()

class AverageMeter(object):
  """Computes and stores the average and current value"""

  def __init__(self, name, fmt=':f'):
    self.name = name
    self.fmt = fmt
    self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

  def __str__(self):
    fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
    return fmtstr.format(**self.__dict__)

def accuracy(output: torch.Tensor,
             target: torch.Tensor,
             topk: Tuple[int] = (1,)) -> Tuple[float]:
  """Computes the accuracy over the k top predictions for the specified values of k.
  Args:
    output: Prediction of the model.
    target: Ground truth labels.
    topk: Topk accuracy to compute.
  Returns:
    Accuracy results according to 'topk'.
  """

  with torch.no_grad():
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
      correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
      res.append(correct_k.mul_(100.0 / batch_size))
    return res

def prepare_data_loader(data_dir: str,
                        batch_size: int = 100,
                        workers: int = 8) -> torch.utils.data.DataLoader:
  """Returns a validation data loader of ImageNet by given `data_dir`.
  Args:
    data_dir: Directory where images stores. There must be a subdirectory named
      'validation' that stores the validation set of ImageNet.
    batch_size: Batch size of data loader.
    workers: How many subprocesses to use for data loading.
  Returns:
    An object of torch.utils.data.DataLoader.
  """

  valdir = os.path.join(data_dir, 'validation')

  normalize = transforms.Normalize(
      mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  val_dataset = datasets.ImageFolder(
      valdir,
      transforms.Compose([
          transforms.Resize(256),
          transforms.CenterCrop(224),
          transforms.ToTensor(),
          normalize,
      ]))

  return torch.utils.data.DataLoader(
      val_dataset,
      batch_size=batch_size,
      shuffle=False,
      num_workers=workers,
      pin_memory=True)

def val_imagenet():
  """Validate ONNX model on ImageNet dataset."""
  print(f'Current onnx model: {args.onnx_model}')

  if args.ipu:
    providers = ["VitisAIExecutionProvider"]
    provider_options = [{"config_file": args.provider_config}]
  else:
    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
    provider_options = None
  ort_session = onnxruntime.InferenceSession(
      args.onnx_model, providers=providers, provider_options=provider_options)

  val_loader = prepare_data_loader(args.data_dir, args.batch_size)

  top1 = AverageMeter('Acc@1', ':6.2f')
  top5 = AverageMeter('Acc@5', ':6.2f')

  start_time = time.time()
  val_loader = tqdm(val_loader, file=sys.stdout)
  with torch.no_grad():
    for batch_idx, (images, targets) in enumerate(val_loader):
      inputs, targets = images.numpy().transpose(0, 2, 3, 1), targets
      ort_inputs = {ort_session.get_inputs()[0].name: inputs}

      outputs = ort_session.run(None, ort_inputs)
      outputs = torch.from_numpy(outputs[0])

      acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
      top1.update(acc1, images.size(0))
      top5.update(acc5, images.size(0))

    current_time = time.time()
    print('Test Top1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'.format(
        float(top1.avg), float(top5.avg), (current_time - start_time)))

  return top1.avg, top5.avg

if __name__ == '__main__':
  val_imagenet()