|
import os |
|
os.chdir('/root/autodl-tmp/t2m/T2M-GPT-main') |
|
import torch |
|
import numpy as np |
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
from os.path import join as pjoin |
|
from torch.distributions import Categorical |
|
import json |
|
import clip |
|
from attack import PGDAttacker |
|
import options.option_transformer as option_trans |
|
import models.vqvae as vqvae |
|
import utils.utils_model as utils_model |
|
import eval_trans_per as eval_trans |
|
from dataset import dataset_TM_train |
|
from dataset import dataset_TM_eval |
|
from dataset import dataset_tokenize |
|
from metrics import batch_tvd,topk_overlap_loss,topK_overlap_true_loss |
|
import models.t2m_trans as trans |
|
from options.get_eval_option import get_opt |
|
from models.evaluator_wrapper import EvaluatorModelWrapper |
|
import warnings |
|
from utils.word_vectorizer import WordVectorizer |
|
from utils.losses import loss_robust |
|
import wandb |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
args = option_trans.get_args_parser() |
|
torch.manual_seed(args.seed) |
|
|
|
args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') |
|
args.vq_dir= os.path.join("./dataset/KIT-ML" if args.dataname == 'kit' else "./dataset/HumanML3D", f'{args.vq_name}') |
|
os.makedirs(args.out_dir, exist_ok = True) |
|
os.makedirs(args.vq_dir, exist_ok = True) |
|
|
|
|
|
logger = utils_model.get_logger(args.out_dir) |
|
writer = SummaryWriter(args.out_dir) |
|
logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) |
|
|
|
|
|
w_vectorizer = WordVectorizer('./glove', 'our_vab') |
|
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataname == 'kit' else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' |
|
test_loader = dataset_TM_eval.DATALoader(args.dataname, True, 32, w_vectorizer) |
|
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataname == 'kit' else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt' |
|
|
|
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda')) |
|
eval_wrapper = EvaluatorModelWrapper(wrapper_opt) |
|
|
|
clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False) |
|
clip.model.convert_weights(clip_model) |
|
clip_model.eval() |
|
|
|
for p in clip_model.parameters(): |
|
p.requires_grad = False |
|
|
|
net = vqvae.HumanVQVAE(args, |
|
args.nb_code, |
|
args.code_dim, |
|
args.output_emb_width, |
|
args.down_t, |
|
args.stride_t, |
|
args.width, |
|
args.depth, |
|
args.dilation_growth_rate) |
|
trans_encoder =trans.Text2Motion_Transformer(num_vq=args.nb_code, |
|
embed_dim=args.embed_dim_gpt, |
|
clip_dim=args.clip_dim, |
|
block_size=args.block_size, |
|
num_layers=args.num_layers, |
|
n_head=args.n_head_gpt, |
|
drop_out_rate=args.drop_out_rate, |
|
fc_rate=args.ff_rate) |
|
|
|
print ('loading checkpoint from {}'.format(args.resume_pth)) |
|
ckpt = torch.load(args.resume_pth, map_location='cpu') |
|
net.load_state_dict(ckpt['net'], strict=True) |
|
net.eval() |
|
net.cuda() |
|
if args.resume_trans is not None: |
|
print ('loading transformer checkpoint from {}'.format(args.resume_trans)) |
|
ckpt = torch.load(args.resume_trans, map_location='cpu') |
|
trans_encoder.load_state_dict(ckpt['trans'], strict=True) |
|
trans_encoder.eval() |
|
trans_encoder.cuda() |
|
print('loading chechpoint successfully!') |
|
best_fid=1000.0, |
|
best_fid_per=1000.0 |
|
best_iter=0.0, |
|
best_div=100.0 |
|
best_top1=0.0 |
|
best_top2=0.0 |
|
best_top3=0.0 |
|
best_matching=100.0 |
|
best_fid, best_fid_per,best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_transformer_test(args.out_dir, val_loader, net, trans_encoder, logger, writer, nb_iter, best_fid,best_fid_per, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, clip_model=clip_model, eval_wrapper=eval_wrapper) |
|
|