import sys | |
from dataset import GridDataset | |
from Trainer import Trainer | |
trainer = Trainer(write_logs=False) | |
trainer.load_datasets() | |
trainer.create_model() | |
dataloader = trainer.dataset2dataloader( | |
trainer.train_dataset, num_workers=0 | |
) | |
for batch in dataloader: | |
break | |
vid = batch.get('vid').cuda() | |
txt = batch.get('txt').cuda() | |
vid_len = batch.get('vid_len').cuda() | |
txt_len = batch.get('txt_len').cuda() | |
y = trainer.net(vid) | |
print(y) | |
print('>>> ') | |