Spaces:
Sleeping
Sleeping
File size: 14,174 Bytes
14c9181 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import re
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from rapidfuzz.distance import Levenshtein
from shapely.geometry import Point
from mmocr.registry import METRICS
# TODO: CTW1500 read pair
@METRICS.register_module()
class E2EPointMetric(BaseMetric):
"""Point metric for textspotting. Proposed in SPTS.
Args:
text_score_thrs (dict): Best text score threshold searching
space. Defaults to dict(start=0.8, stop=1, step=0.01).
word_spotting (bool): Whether to work in word spotting mode. Defaults
to False.
lexicon_path (str, optional): Lexicon path for word spotting, which
points to a lexicon file or a directory. Defaults to None.
lexicon_mapping (tuple, optional): The rule to map test image name to
its corresponding lexicon file. Only effective when lexicon path
is a directory. Defaults to ('(.*).jpg', r'\1.txt').
pair_path (str, optional): Pair path for word spotting, which points
to a pair file or a directory. Defaults to None.
pair_mapping (tuple, optional): The rule to map test image name to
its corresponding pair file. Only effective when pair path is a
directory. Defaults to ('(.*).jpg', r'\1.txt').
match_dist_thr (float, optional): Matching distance threshold for
word spotting. Defaults to None.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None
"""
default_prefix: Optional[str] = 'e2e_icdar'
def __init__(self,
text_score_thrs: Dict = dict(start=0.8, stop=1, step=0.01),
word_spotting: bool = False,
lexicon_path: Optional[str] = None,
lexicon_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'),
pair_path: Optional[str] = None,
pair_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'),
match_dist_thr: Optional[float] = None,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.text_score_thrs = np.arange(**text_score_thrs)
self.word_spotting = word_spotting
self.match_dist_thr = match_dist_thr
if lexicon_path:
self.lexicon_mapping = lexicon_mapping
self.pair_mapping = pair_mapping
self.lexicons = self._read_lexicon(lexicon_path)
self.pairs = self._read_pair(pair_path)
def _read_lexicon(self, lexicon_path: str) -> List[str]:
if lexicon_path.endswith('.txt'):
lexicon = open(lexicon_path, 'r').read().splitlines()
lexicon = [ele.strip() for ele in lexicon]
else:
lexicon = {}
for file in glob.glob(osp.join(lexicon_path, '*.txt')):
basename = osp.basename(file)
lexicon[basename] = self._read_lexicon(file)
return lexicon
def _read_pair(self, pair_path: str) -> Dict[str, str]:
pairs = {}
if pair_path.endswith('.txt'):
pair_lines = open(pair_path, 'r').read().splitlines()
for line in pair_lines:
line = line.strip()
word = line.split(' ')[0].upper()
word_gt = line[len(word) + 1:]
pairs[word] = word_gt
else:
for file in glob.glob(osp.join(pair_path, '*.txt')):
basename = osp.basename(file)
pairs[basename] = self._read_pair(file)
return pairs
def poly_center(self, poly_pts):
poly_pts = np.array(poly_pts).reshape(-1, 2)
return poly_pts.mean(0)
def process(self, data_batch: Sequence[Dict],
data_samples: Sequence[Dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Sequence[Dict]): A batch of data from dataloader.
data_samples (Sequence[Dict]): A batch of outputs from
the model.
"""
for data_sample in data_samples:
pred_instances = data_sample.get('pred_instances')
pred_points = pred_instances.get('points')
text_scores = pred_instances.get('text_scores')
if isinstance(text_scores, torch.Tensor):
text_scores = text_scores.cpu().numpy()
text_scores = np.array(text_scores, dtype=np.float32)
pred_texts = pred_instances.get('texts')
gt_instances = data_sample.get('gt_instances')
gt_polys = gt_instances.get('polygons')
gt_ignore_flags = gt_instances.get('ignored')
gt_texts = gt_instances.get('texts')
if isinstance(gt_ignore_flags, torch.Tensor):
gt_ignore_flags = gt_ignore_flags.cpu().numpy()
gt_points = [self.poly_center(poly) for poly in gt_polys]
if self.word_spotting:
gt_ignore_flags, gt_texts = self._word_spotting_filter(
gt_ignore_flags, gt_texts)
pred_ignore_flags = text_scores < self.text_score_thrs.min()
text_scores = text_scores[~pred_ignore_flags]
pred_texts = self._get_true_elements(pred_texts,
~pred_ignore_flags)
pred_points = self._get_true_elements(pred_points,
~pred_ignore_flags)
result = dict(
# reserved for image-level lexcions
gt_img_name=osp.basename(data_sample.get('img_path', '')),
text_scores=text_scores,
pred_points=pred_points,
gt_points=gt_points,
pred_texts=pred_texts,
gt_texts=gt_texts,
gt_ignore_flags=gt_ignore_flags)
self.results.append(result)
def _get_true_elements(self, array: List, flags: np.ndarray) -> List:
return [array[i] for i in self._true_indexes(flags)]
def compute_metrics(self, results: List[Dict]) -> Dict:
"""Compute the metrics from processed results.
Args:
results (list[dict]): The processed results of each batch.
Returns:
dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
logger: MMLogger = MMLogger.get_current_instance()
best_eval_results = dict(hmean=-1)
num_thres = len(self.text_score_thrs)
num_preds = np.zeros(
num_thres, dtype=int) # the number of points actually predicted
num_tp = np.zeros(num_thres, dtype=int) # number of true positives
num_gts = np.zeros(num_thres, dtype=int) # number of valid gts
for result in results:
text_scores = result['text_scores']
pred_points = result['pred_points']
gt_points = result['gt_points']
gt_texts = result['gt_texts']
pred_texts = result['pred_texts']
gt_ignore_flags = result['gt_ignore_flags']
gt_img_name = result['gt_img_name']
# Correct the words with lexicon
pred_dist_flags = np.zeros(len(pred_texts), dtype=bool)
if hasattr(self, 'lexicons'):
for i, pred_text in enumerate(pred_texts):
# If it's an image-level lexicon
if isinstance(self.lexicons, dict):
lexicon_name = self._map_img_name(
gt_img_name, self.lexicon_mapping)
pair_name = self._map_img_name(gt_img_name,
self.pair_mapping)
pred_texts[i], match_dist = self._match_word(
pred_text, self.lexicons[lexicon_name],
self.pairs[pair_name])
else:
pred_texts[i], match_dist = self._match_word(
pred_text, self.lexicons, self.pairs)
if (self.match_dist_thr
and match_dist >= self.match_dist_thr):
# won't even count this as a prediction
pred_dist_flags[i] = True
# Filter out predictions by IoU threshold
for i, text_score_thr in enumerate(self.text_score_thrs):
pred_ignore_flags = pred_dist_flags | (
text_scores < text_score_thr)
filtered_pred_texts = self._get_true_elements(
pred_texts, ~pred_ignore_flags)
filtered_pred_points = self._get_true_elements(
pred_points, ~pred_ignore_flags)
gt_matched = np.zeros(len(gt_texts), dtype=bool)
num_gt = len(gt_texts) - np.sum(gt_ignore_flags)
if num_gt == 0:
continue
num_gts[i] += num_gt
for pred_text, pred_point in zip(filtered_pred_texts,
filtered_pred_points):
dists = [
Point(pred_point).distance(Point(gt_point))
for gt_point in gt_points
]
min_idx = np.argmin(dists)
if gt_texts[min_idx] == '###' or gt_ignore_flags[min_idx]:
continue
if not gt_matched[min_idx] and (
pred_text.upper() == gt_texts[min_idx].upper()):
gt_matched[min_idx] = True
num_tp[i] += 1
num_preds[i] += 1
for i, text_score_thr in enumerate(self.text_score_thrs):
if num_preds[i] == 0 or num_tp[i] == 0:
recall, precision, hmean = 0, 0, 0
else:
recall = num_tp[i] / num_gts[i]
precision = num_tp[i] / num_preds[i]
hmean = 2 * recall * precision / (recall + precision)
eval_results = dict(
precision=precision, recall=recall, hmean=hmean)
logger.info(f'text score threshold: {text_score_thr:.2f}, '
f'recall: {eval_results["recall"]:.4f}, '
f'precision: {eval_results["precision"]:.4f}, '
f'hmean: {eval_results["hmean"]:.4f}\n')
if eval_results['hmean'] > best_eval_results['hmean']:
best_eval_results = eval_results
return best_eval_results
def _map_img_name(self, img_name: str, mapping: Tuple[str, str]) -> str:
"""Map the image name to the another one based on mapping."""
return re.sub(mapping[0], mapping[1], img_name)
def _true_indexes(self, array: np.ndarray) -> np.ndarray:
"""Get indexes of True elements from a 1D boolean array."""
return np.where(array)[0]
def _word_spotting_filter(self, gt_ignore_flags: np.ndarray,
gt_texts: List[str]
) -> Tuple[np.ndarray, List[str]]:
"""Filter out gt instances that cannot be in a valid dictionary, and do
some simple preprocessing to texts."""
for i in range(len(gt_texts)):
if gt_ignore_flags[i]:
continue
text = gt_texts[i]
if text[-2:] in ["'s", "'S"]:
text = text[:-2]
text = text.strip('-')
for char in "'!?.:,*\"()·[]/":
text = text.replace(char, ' ')
text = text.strip()
gt_ignore_flags[i] = not self._include_in_dict(text)
if not gt_ignore_flags[i]:
gt_texts[i] = text
return gt_ignore_flags, gt_texts
def _include_in_dict(self, text: str) -> bool:
"""Check if the text could be in a valid dictionary."""
if len(text) != len(text.replace(' ', '')) or len(text) < 3:
return False
not_allowed = '×÷·'
valid_ranges = [(ord(u'a'), ord(u'z')), (ord(u'A'), ord(u'Z')),
(ord(u'À'), ord(u'ƿ')), (ord(u'DŽ'), ord(u'ɿ')),
(ord(u'Ά'), ord(u'Ͽ')), (ord(u'-'), ord(u'-'))]
for char in text:
code = ord(char)
if (not_allowed.find(char) != -1):
return False
valid = any(code >= r[0] and code <= r[1] for r in valid_ranges)
if not valid:
return False
return True
def _match_word(self,
text: str,
lexicons: List[str],
pairs: Optional[Dict[str, str]] = None) -> Tuple[str, int]:
"""Match the text with the lexicons and pairs."""
text = text.upper()
matched_word = ''
matched_dist = 100
for lexicon in lexicons:
lexicon = lexicon.upper()
norm_dist = Levenshtein.distance(text, lexicon)
norm_dist = Levenshtein.normalized_distance(text, lexicon)
if norm_dist < matched_dist:
matched_dist = norm_dist
if pairs:
matched_word = pairs[lexicon]
else:
matched_word = lexicon
return matched_word, matched_dist
|