clip_eval / clip_eval.py
d-matrix-user's picture
Update clip_eval.py
65d407a verified
import evaluate
from evaluate.utils.file_utils import add_start_docstrings
import datasets
import torch
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm
_DESCRIPTION = """
This metric evaluates CLIP models on image-text retrieval tasks using standard datasets.
It calculates Recall@K metrics for both text-to-image and image-to-text retrieval.
"""
_KWARGS_DESCRIPTION = """
Args:
model_name: Name or path of the CLIP model to evaluate (e.g., "openai/clip-vit-base-patch32")
dataset_names: List of dataset names to evaluate on (choices: "mscoco", "flickr")
n_examples: Number of examples to use for evaluation (-1 for all)
Returns:
Dictionary containing Recall@K metrics for each dataset and retrieval direction
"""
_CITATION = """
@inproceedings{radford2021learning,
title={Learning transferable visual models from natural language supervision},
author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and others},
booktitle={International Conference on Machine Learning},
year={2021},
}
"""
@add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class DmxClipEval(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"dataset_names": datasets.Value("string"),
}
),
)
def clip_dataset_evaluator(
self, model, device, dataset_name="mscoco", n_examples=-1
):
processor = CLIPProcessor.from_pretrained(model.config._name_or_path)
if dataset_name == "mscoco":
ds = datasets.load_dataset(
"clip-benchmark/wds_mscoco_captions", split="test"
)
elif dataset_name == "flickr":
ds = datasets.load_dataset("clip-benchmark/wds_flickr8k", split="test")
else:
raise ValueError(f"invalid dataset name : {dataset_name}")
if n_examples != -1:
ds = ds.select(range(min(n_examples, len(ds))))
dl = torch.utils.data.DataLoader(torch.arange(len(ds)), batch_size=8)
all_image_embeds = []
all_text_embeds = []
for indices in tqdm(dl, desc=f"Processing {dataset_name}"):
batch = ds[indices.tolist()]
inputs = processor(
text=batch["txt"],
images=batch["jpg"],
return_tensors="pt",
padding=True,
)
inputs["input_ids"] = inputs["input_ids"][:, :77]
inputs["attention_mask"] = inputs["attention_mask"][:, :77]
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model(**inputs)
all_image_embeds.append(output.image_embeds.cpu())
all_text_embeds.append(output.text_embeds.cpu())
all_image_embeds = torch.cat(all_image_embeds, dim=0)
all_text_embeds = torch.cat(all_text_embeds, dim=0)
text_img_sim = all_text_embeds @ all_image_embeds.t()
def get_top_k(sim_mat, k_arr):
ordered_winners = torch.argsort(sim_mat, dim=-1, descending=True)
correct_winner_mask = (
ordered_winners
== torch.arange(ordered_winners.shape[0])
.unsqueeze(1)
.to(ordered_winners.device)
).long()
return [
correct_winner_mask[:, :k].sum(-1).float().mean().item() for k in k_arr
]
k_arr = [1, 5, 10]
metrics = {
**{
f"{dataset_name}:image_recall@{k}": val
for k, val in zip(k_arr, get_top_k(text_img_sim, k_arr))
},
**{
f"{dataset_name}:text_recall@{k}": val
for k, val in zip(k_arr, get_top_k(text_img_sim.t(), k_arr))
},
}
return metrics
def clip_evaluator(self, model, device, n_examples=-1):
metrics = {}
for dataset_name in ["mscoco", "flickr"]:
metrics.update(
self.clip_dataset_evaluator(model, device, dataset_name, n_examples)
)
return metrics
def _compute(self, model, dataset_names, n_examples, **kwargs):
dataset = dataset_names[0]
num_examples = n_examples[0]
model_input = model[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(model_input, str):
actual_model = CLIPModel.from_pretrained(model_input).to(device)
else:
actual_model = model_input
datasets_to_evaluate = [dataset]
metrics = {}
for ds_name in datasets_to_evaluate:
dataset_metrics = self.clip_dataset_evaluator(
model=actual_model,
device=device,
dataset_name=ds_name,
n_examples=num_examples,
)
metrics.update(dataset_metrics)
return metrics