|
import sys |
|
import time |
|
import torch |
|
import argparse |
|
from omegaconf import OmegaConf |
|
|
|
sys.path.append('.') |
|
from utils.build_utils import build_from_cfg |
|
|
|
parser = argparse.ArgumentParser( |
|
prog = 'AMT', |
|
description = 'Speed¶meter benchmark', |
|
) |
|
parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') |
|
args = parser.parse_args() |
|
|
|
cfg_path = args.config |
|
network_cfg = OmegaConf.load(cfg_path).network |
|
model = build_from_cfg(network_cfg) |
|
model = model.cuda() |
|
model.eval() |
|
|
|
img0 = torch.randn(1, 3, 256, 448).cuda() |
|
img1 = torch.randn(1, 3, 256, 448).cuda() |
|
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).cuda() |
|
|
|
with torch.no_grad(): |
|
for i in range(100): |
|
out = model(img0, img1, embt, eval=True) |
|
torch.cuda.synchronize() |
|
time_stamp = time.time() |
|
for i in range(1000): |
|
out = model(img0, img1, embt, eval=True) |
|
torch.cuda.synchronize() |
|
print('Time: {:.5f}s'.format((time.time() - time_stamp) / 1)) |
|
|
|
total = sum([param.nelement() for param in model.parameters()]) |
|
print('Parameters: {:.2f}M'.format(total / 1e6)) |
|
|