File size: 1,369 Bytes
cca9b7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from medomni.common.registry import registry
from medomni.tasks.base_task import BaseTask
from medomni.common.logger import MetricLogger, SmoothedValue
from medomni.datasets.data_utils import prepare_sample
import torch.distributed as dist
@registry.register_task("image_text_pretrain")
class ImageTextPretrainTask(BaseTask):
def __init__(self):
super().__init__()
def evaluation(self, model, data_loader, cuda_enabled=True):
if not hasattr(data_loader, "__next__"):
# convert to iterator if not already
data_loader = iter(data_loader)
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation"
# TODO make it configurable
print_freq = 10
results = []
ipdb.set_trace()
for samples in metric_logger.log_every(data_loader, print_freq, header):
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
eval_output = self.valid_step(model=model, samples=samples)
results.extend(eval_output)
if is_dist_avail_and_initialized():
dist.barrier()
return results
|