d-matrix-user commited on
Commit
78d92d3
Β·
verified Β·
1 Parent(s): af2f1d5

update clip_eval to accept dmx model

Browse files
Files changed (1) hide show
  1. clip_eval.py +28 -29
clip_eval.py CHANGED
@@ -29,7 +29,6 @@ _CITATION = """
29
  }
30
  """
31
 
32
-
33
  @add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
34
  class DmxClipEval(evaluate.Metric):
35
  def _info(self):
@@ -38,19 +37,17 @@ class DmxClipEval(evaluate.Metric):
38
  description=_DESCRIPTION,
39
  citation=_CITATION,
40
  inputs_description=_KWARGS_DESCRIPTION,
41
- features=[
42
- datasets.Features(
43
- {
44
- "model_name": datasets.Value("string"),
45
- "dataset_names": datasets.Value("string"),
46
- "n_examples": datasets.Value("int32"),
47
- }
48
- ),
49
- ],
50
  )
51
 
52
  def clip_dataset_evaluator(
53
- self, model, device, desc, dataset_name="mscoco", n_examples=-1
54
  ):
55
  processor = CLIPProcessor.from_pretrained(model.config._name_or_path)
56
  if dataset_name == "mscoco":
@@ -116,34 +113,36 @@ class DmxClipEval(evaluate.Metric):
116
  }
117
  return metrics
118
 
119
- def clip_evaluator(self, model, device, desc, n_examples=-1):
120
  metrics = {}
121
- for name in ["mscoco", "flickr"]:
122
  metrics.update(
123
- self.clip_dataset_evaluator(model, device, desc, name, n_examples)
124
  )
125
  return metrics
126
 
127
- def _compute(self, model_name, dataset_names, n_examples):
128
-
129
- actual_model_name = model_name[0]
130
- actual_dataset_name_str = dataset_names[0]
131
- actual_n_examples = n_examples[0]
132
 
133
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
- model = CLIPModel.from_pretrained(actual_model_name).to(device)
135
-
136
- datasets_to_evaluate = [actual_dataset_name_str]
137
-
 
 
 
 
138
  metrics = {}
139
- for ds_name_loop_var in datasets_to_evaluate:
140
  dataset_metrics = self.clip_dataset_evaluator(
141
- model=model,
142
  device=device,
143
- desc=actual_model_name,
144
- dataset_name=ds_name_loop_var,
145
- n_examples=actual_n_examples,
146
  )
147
  metrics.update(dataset_metrics)
148
 
149
- return metrics
 
29
  }
30
  """
31
 
 
32
  @add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
33
  class DmxClipEval(evaluate.Metric):
34
  def _info(self):
 
37
  description=_DESCRIPTION,
38
  citation=_CITATION,
39
  inputs_description=_KWARGS_DESCRIPTION,
40
+ features=datasets.Features(
41
+ {
42
+ "model": datasets.Value("string"),
43
+ "dataset_names": datasets.Value("string"),
44
+ "n_examples": datasets.Value("int32"),
45
+ }
46
+ ),
 
 
47
  )
48
 
49
  def clip_dataset_evaluator(
50
+ self, model, device, dataset_name="mscoco", n_examples=-1
51
  ):
52
  processor = CLIPProcessor.from_pretrained(model.config._name_or_path)
53
  if dataset_name == "mscoco":
 
113
  }
114
  return metrics
115
 
116
+ def clip_evaluator(self, model, device, n_examples=-1):
117
  metrics = {}
118
+ for dataset_name in ["mscoco", "flickr"]:
119
  metrics.update(
120
+ self.clip_dataset_evaluator(model, device, dataset_name, n_examples)
121
  )
122
  return metrics
123
 
124
+ def _compute(self, model, dataset_names, n_examples, **kwargs):
125
+ dataset = dataset_names[0]
126
+ num_examples = n_examples[0]
127
+ model_input = model[0]
 
128
 
129
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
+
131
+ if isinstance(model_input, str):
132
+ actual_model = CLIPModel.from_pretrained(model_input).to(device)
133
+ else:
134
+ actual_model = model_input
135
+
136
+ datasets_to_evaluate = [dataset]
137
+
138
  metrics = {}
139
+ for ds_name in datasets_to_evaluate:
140
  dataset_metrics = self.clip_dataset_evaluator(
141
+ model=actual_model,
142
  device=device,
143
+ dataset_name=ds_name,
144
+ n_examples=num_examples,
 
145
  )
146
  metrics.update(dataset_metrics)
147
 
148
+ return metrics