MedVersa_Internal / medomni /tasks /image_text_pretrain.py
hyzhou's picture
upload everything
cca9b7e
raw
history blame
1.37 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from medomni.common.registry import registry
from medomni.tasks.base_task import BaseTask
from medomni.common.logger import MetricLogger, SmoothedValue
from medomni.datasets.data_utils import prepare_sample
import torch.distributed as dist
@registry.register_task("image_text_pretrain")
class ImageTextPretrainTask(BaseTask):
def __init__(self):
super().__init__()
def evaluation(self, model, data_loader, cuda_enabled=True):
if not hasattr(data_loader, "__next__"):
# convert to iterator if not already
data_loader = iter(data_loader)
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation"
# TODO make it configurable
print_freq = 10
results = []
ipdb.set_trace()
for samples in metric_logger.log_every(data_loader, print_freq, header):
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
eval_output = self.valid_step(model=model, samples=samples)
results.extend(eval_output)
if is_dist_avail_and_initialized():
dist.barrier()
return results