update clip_eval to accept dmx model
Browse files- clip_eval.py +28 -29
clip_eval.py
CHANGED
@@ -29,7 +29,6 @@ _CITATION = """
|
|
29 |
}
|
30 |
"""
|
31 |
|
32 |
-
|
33 |
@add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
34 |
class DmxClipEval(evaluate.Metric):
|
35 |
def _info(self):
|
@@ -38,19 +37,17 @@ class DmxClipEval(evaluate.Metric):
|
|
38 |
description=_DESCRIPTION,
|
39 |
citation=_CITATION,
|
40 |
inputs_description=_KWARGS_DESCRIPTION,
|
41 |
-
features=
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
),
|
49 |
-
],
|
50 |
)
|
51 |
|
52 |
def clip_dataset_evaluator(
|
53 |
-
self, model, device,
|
54 |
):
|
55 |
processor = CLIPProcessor.from_pretrained(model.config._name_or_path)
|
56 |
if dataset_name == "mscoco":
|
@@ -116,34 +113,36 @@ class DmxClipEval(evaluate.Metric):
|
|
116 |
}
|
117 |
return metrics
|
118 |
|
119 |
-
def clip_evaluator(self, model, device,
|
120 |
metrics = {}
|
121 |
-
for
|
122 |
metrics.update(
|
123 |
-
self.clip_dataset_evaluator(model, device,
|
124 |
)
|
125 |
return metrics
|
126 |
|
127 |
-
def _compute(self,
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
actual_n_examples = n_examples[0]
|
132 |
|
133 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
138 |
metrics = {}
|
139 |
-
for
|
140 |
dataset_metrics = self.clip_dataset_evaluator(
|
141 |
-
model=
|
142 |
device=device,
|
143 |
-
|
144 |
-
|
145 |
-
n_examples=actual_n_examples,
|
146 |
)
|
147 |
metrics.update(dataset_metrics)
|
148 |
|
149 |
-
return metrics
|
|
|
29 |
}
|
30 |
"""
|
31 |
|
|
|
32 |
@add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
33 |
class DmxClipEval(evaluate.Metric):
|
34 |
def _info(self):
|
|
|
37 |
description=_DESCRIPTION,
|
38 |
citation=_CITATION,
|
39 |
inputs_description=_KWARGS_DESCRIPTION,
|
40 |
+
features=datasets.Features(
|
41 |
+
{
|
42 |
+
"model": datasets.Value("string"),
|
43 |
+
"dataset_names": datasets.Value("string"),
|
44 |
+
"n_examples": datasets.Value("int32"),
|
45 |
+
}
|
46 |
+
),
|
|
|
|
|
47 |
)
|
48 |
|
49 |
def clip_dataset_evaluator(
|
50 |
+
self, model, device, dataset_name="mscoco", n_examples=-1
|
51 |
):
|
52 |
processor = CLIPProcessor.from_pretrained(model.config._name_or_path)
|
53 |
if dataset_name == "mscoco":
|
|
|
113 |
}
|
114 |
return metrics
|
115 |
|
116 |
+
def clip_evaluator(self, model, device, n_examples=-1):
|
117 |
metrics = {}
|
118 |
+
for dataset_name in ["mscoco", "flickr"]:
|
119 |
metrics.update(
|
120 |
+
self.clip_dataset_evaluator(model, device, dataset_name, n_examples)
|
121 |
)
|
122 |
return metrics
|
123 |
|
124 |
+
def _compute(self, model, dataset_names, n_examples, **kwargs):
|
125 |
+
dataset = dataset_names[0]
|
126 |
+
num_examples = n_examples[0]
|
127 |
+
model_input = model[0]
|
|
|
128 |
|
129 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
130 |
+
|
131 |
+
if isinstance(model_input, str):
|
132 |
+
actual_model = CLIPModel.from_pretrained(model_input).to(device)
|
133 |
+
else:
|
134 |
+
actual_model = model_input
|
135 |
+
|
136 |
+
datasets_to_evaluate = [dataset]
|
137 |
+
|
138 |
metrics = {}
|
139 |
+
for ds_name in datasets_to_evaluate:
|
140 |
dataset_metrics = self.clip_dataset_evaluator(
|
141 |
+
model=actual_model,
|
142 |
device=device,
|
143 |
+
dataset_name=ds_name,
|
144 |
+
n_examples=num_examples,
|
|
|
145 |
)
|
146 |
metrics.update(dataset_metrics)
|
147 |
|
148 |
+
return metrics
|