# Copyright (c) OpenMMLab. All rights reserved. import glob import os.path as osp import re from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import torch from mmengine.evaluator import BaseMetric from mmengine.logging import MMLogger from rapidfuzz.distance import Levenshtein from shapely.geometry import Point from mmocr.registry import METRICS # TODO: CTW1500 read pair @METRICS.register_module() class E2EPointMetric(BaseMetric): """Point metric for textspotting. Proposed in SPTS. Args: text_score_thrs (dict): Best text score threshold searching space. Defaults to dict(start=0.8, stop=1, step=0.01). word_spotting (bool): Whether to work in word spotting mode. Defaults to False. lexicon_path (str, optional): Lexicon path for word spotting, which points to a lexicon file or a directory. Defaults to None. lexicon_mapping (tuple, optional): The rule to map test image name to its corresponding lexicon file. Only effective when lexicon path is a directory. Defaults to ('(.*).jpg', r'\1.txt'). pair_path (str, optional): Pair path for word spotting, which points to a pair file or a directory. Defaults to None. pair_mapping (tuple, optional): The rule to map test image name to its corresponding pair file. Only effective when pair path is a directory. Defaults to ('(.*).jpg', r'\1.txt'). match_dist_thr (float, optional): Matching distance threshold for word spotting. Defaults to None. collect_device (str): Device name used for collecting results from different ranks during distributed training. Must be 'cpu' or 'gpu'. Defaults to 'cpu'. prefix (str, optional): The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None """ default_prefix: Optional[str] = 'e2e_icdar' def __init__(self, text_score_thrs: Dict = dict(start=0.8, stop=1, step=0.01), word_spotting: bool = False, lexicon_path: Optional[str] = None, lexicon_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'), pair_path: Optional[str] = None, pair_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'), match_dist_thr: Optional[float] = None, collect_device: str = 'cpu', prefix: Optional[str] = None) -> None: super().__init__(collect_device=collect_device, prefix=prefix) self.text_score_thrs = np.arange(**text_score_thrs) self.word_spotting = word_spotting self.match_dist_thr = match_dist_thr if lexicon_path: self.lexicon_mapping = lexicon_mapping self.pair_mapping = pair_mapping self.lexicons = self._read_lexicon(lexicon_path) self.pairs = self._read_pair(pair_path) def _read_lexicon(self, lexicon_path: str) -> List[str]: if lexicon_path.endswith('.txt'): lexicon = open(lexicon_path, 'r').read().splitlines() lexicon = [ele.strip() for ele in lexicon] else: lexicon = {} for file in glob.glob(osp.join(lexicon_path, '*.txt')): basename = osp.basename(file) lexicon[basename] = self._read_lexicon(file) return lexicon def _read_pair(self, pair_path: str) -> Dict[str, str]: pairs = {} if pair_path.endswith('.txt'): pair_lines = open(pair_path, 'r').read().splitlines() for line in pair_lines: line = line.strip() word = line.split(' ')[0].upper() word_gt = line[len(word) + 1:] pairs[word] = word_gt else: for file in glob.glob(osp.join(pair_path, '*.txt')): basename = osp.basename(file) pairs[basename] = self._read_pair(file) return pairs def poly_center(self, poly_pts): poly_pts = np.array(poly_pts).reshape(-1, 2) return poly_pts.mean(0) def process(self, data_batch: Sequence[Dict], data_samples: Sequence[Dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (Sequence[Dict]): A batch of data from dataloader. data_samples (Sequence[Dict]): A batch of outputs from the model. """ for data_sample in data_samples: pred_instances = data_sample.get('pred_instances') pred_points = pred_instances.get('points') text_scores = pred_instances.get('text_scores') if isinstance(text_scores, torch.Tensor): text_scores = text_scores.cpu().numpy() text_scores = np.array(text_scores, dtype=np.float32) pred_texts = pred_instances.get('texts') gt_instances = data_sample.get('gt_instances') gt_polys = gt_instances.get('polygons') gt_ignore_flags = gt_instances.get('ignored') gt_texts = gt_instances.get('texts') if isinstance(gt_ignore_flags, torch.Tensor): gt_ignore_flags = gt_ignore_flags.cpu().numpy() gt_points = [self.poly_center(poly) for poly in gt_polys] if self.word_spotting: gt_ignore_flags, gt_texts = self._word_spotting_filter( gt_ignore_flags, gt_texts) pred_ignore_flags = text_scores < self.text_score_thrs.min() text_scores = text_scores[~pred_ignore_flags] pred_texts = self._get_true_elements(pred_texts, ~pred_ignore_flags) pred_points = self._get_true_elements(pred_points, ~pred_ignore_flags) result = dict( # reserved for image-level lexcions gt_img_name=osp.basename(data_sample.get('img_path', '')), text_scores=text_scores, pred_points=pred_points, gt_points=gt_points, pred_texts=pred_texts, gt_texts=gt_texts, gt_ignore_flags=gt_ignore_flags) self.results.append(result) def _get_true_elements(self, array: List, flags: np.ndarray) -> List: return [array[i] for i in self._true_indexes(flags)] def compute_metrics(self, results: List[Dict]) -> Dict: """Compute the metrics from processed results. Args: results (list[dict]): The processed results of each batch. Returns: dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. """ logger: MMLogger = MMLogger.get_current_instance() best_eval_results = dict(hmean=-1) num_thres = len(self.text_score_thrs) num_preds = np.zeros( num_thres, dtype=int) # the number of points actually predicted num_tp = np.zeros(num_thres, dtype=int) # number of true positives num_gts = np.zeros(num_thres, dtype=int) # number of valid gts for result in results: text_scores = result['text_scores'] pred_points = result['pred_points'] gt_points = result['gt_points'] gt_texts = result['gt_texts'] pred_texts = result['pred_texts'] gt_ignore_flags = result['gt_ignore_flags'] gt_img_name = result['gt_img_name'] # Correct the words with lexicon pred_dist_flags = np.zeros(len(pred_texts), dtype=bool) if hasattr(self, 'lexicons'): for i, pred_text in enumerate(pred_texts): # If it's an image-level lexicon if isinstance(self.lexicons, dict): lexicon_name = self._map_img_name( gt_img_name, self.lexicon_mapping) pair_name = self._map_img_name(gt_img_name, self.pair_mapping) pred_texts[i], match_dist = self._match_word( pred_text, self.lexicons[lexicon_name], self.pairs[pair_name]) else: pred_texts[i], match_dist = self._match_word( pred_text, self.lexicons, self.pairs) if (self.match_dist_thr and match_dist >= self.match_dist_thr): # won't even count this as a prediction pred_dist_flags[i] = True # Filter out predictions by IoU threshold for i, text_score_thr in enumerate(self.text_score_thrs): pred_ignore_flags = pred_dist_flags | ( text_scores < text_score_thr) filtered_pred_texts = self._get_true_elements( pred_texts, ~pred_ignore_flags) filtered_pred_points = self._get_true_elements( pred_points, ~pred_ignore_flags) gt_matched = np.zeros(len(gt_texts), dtype=bool) num_gt = len(gt_texts) - np.sum(gt_ignore_flags) if num_gt == 0: continue num_gts[i] += num_gt for pred_text, pred_point in zip(filtered_pred_texts, filtered_pred_points): dists = [ Point(pred_point).distance(Point(gt_point)) for gt_point in gt_points ] min_idx = np.argmin(dists) if gt_texts[min_idx] == '###' or gt_ignore_flags[min_idx]: continue if not gt_matched[min_idx] and ( pred_text.upper() == gt_texts[min_idx].upper()): gt_matched[min_idx] = True num_tp[i] += 1 num_preds[i] += 1 for i, text_score_thr in enumerate(self.text_score_thrs): if num_preds[i] == 0 or num_tp[i] == 0: recall, precision, hmean = 0, 0, 0 else: recall = num_tp[i] / num_gts[i] precision = num_tp[i] / num_preds[i] hmean = 2 * recall * precision / (recall + precision) eval_results = dict( precision=precision, recall=recall, hmean=hmean) logger.info(f'text score threshold: {text_score_thr:.2f}, ' f'recall: {eval_results["recall"]:.4f}, ' f'precision: {eval_results["precision"]:.4f}, ' f'hmean: {eval_results["hmean"]:.4f}\n') if eval_results['hmean'] > best_eval_results['hmean']: best_eval_results = eval_results return best_eval_results def _map_img_name(self, img_name: str, mapping: Tuple[str, str]) -> str: """Map the image name to the another one based on mapping.""" return re.sub(mapping[0], mapping[1], img_name) def _true_indexes(self, array: np.ndarray) -> np.ndarray: """Get indexes of True elements from a 1D boolean array.""" return np.where(array)[0] def _word_spotting_filter(self, gt_ignore_flags: np.ndarray, gt_texts: List[str] ) -> Tuple[np.ndarray, List[str]]: """Filter out gt instances that cannot be in a valid dictionary, and do some simple preprocessing to texts.""" for i in range(len(gt_texts)): if gt_ignore_flags[i]: continue text = gt_texts[i] if text[-2:] in ["'s", "'S"]: text = text[:-2] text = text.strip('-') for char in "'!?.:,*\"()·[]/": text = text.replace(char, ' ') text = text.strip() gt_ignore_flags[i] = not self._include_in_dict(text) if not gt_ignore_flags[i]: gt_texts[i] = text return gt_ignore_flags, gt_texts def _include_in_dict(self, text: str) -> bool: """Check if the text could be in a valid dictionary.""" if len(text) != len(text.replace(' ', '')) or len(text) < 3: return False not_allowed = '×÷·' valid_ranges = [(ord(u'a'), ord(u'z')), (ord(u'A'), ord(u'Z')), (ord(u'À'), ord(u'ƿ')), (ord(u'DŽ'), ord(u'ɿ')), (ord(u'Ά'), ord(u'Ͽ')), (ord(u'-'), ord(u'-'))] for char in text: code = ord(char) if (not_allowed.find(char) != -1): return False valid = any(code >= r[0] and code <= r[1] for r in valid_ranges) if not valid: return False return True def _match_word(self, text: str, lexicons: List[str], pairs: Optional[Dict[str, str]] = None) -> Tuple[str, int]: """Match the text with the lexicons and pairs.""" text = text.upper() matched_word = '' matched_dist = 100 for lexicon in lexicons: lexicon = lexicon.upper() norm_dist = Levenshtein.distance(text, lexicon) norm_dist = Levenshtein.normalized_distance(text, lexicon) if norm_dist < matched_dist: matched_dist = norm_dist if pairs: matched_word = pairs[lexicon] else: matched_word = lexicon return matched_word, matched_dist