# Copyright (c) OpenMMLab. All rights reserved. import copy import os.path as osp from typing import Any, Dict, List, Optional, Sequence, Union import mmcv import mmengine import numpy as np from mmengine.dataset import Compose, pseudo_collate from mmengine.runner.checkpoint import _load_checkpoint from mmocr.registry import DATASETS from mmocr.structures import KIEDataSample from mmocr.utils import ConfigType from .base_mmocr_inferencer import BaseMMOCRInferencer, ModelType, PredType InputType = Dict InputsType = Sequence[Dict] class KIEInferencer(BaseMMOCRInferencer): """Key Information Extraction Inferencer. Args: model (str, optional): Path to the config file or the model name defined in metafile. For example, it could be "sdmgr_unet16_60e_wildreceipt" or "configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py". If model is not specified, user must provide the `weights` saved by MMEngine which contains the config string. Defaults to None. weights (str, optional): Path to the checkpoint. If it is not specified and model is a model name of metafile, the weights will be loaded from metafile. Defaults to None. device (str, optional): Device to run inference. If None, the available device will be automatically used. Defaults to None. scope (str, optional): The scope of the model. Defaults to "mmocr". """ def __init__(self, model: Union[ModelType, str, None] = None, weights: Optional[str] = None, device: Optional[str] = None, scope: Optional[str] = 'mmocr') -> None: super().__init__( model=model, weights=weights, device=device, scope=scope) self._load_metainfo_to_visualizer(weights, self.cfg) self.collate_fn = self.kie_collate def _load_metainfo_to_visualizer(self, weights: Optional[str], cfg: ConfigType) -> None: """Load meta information to visualizer.""" if hasattr(self, 'visualizer'): if weights is not None: w = _load_checkpoint(weights, map_location='cpu') if w and 'meta' in w and 'dataset_meta' in w['meta']: self.visualizer.dataset_meta = w['meta']['dataset_meta'] return if 'test_dataloader' in cfg: dataset_cfg = copy.deepcopy(cfg.test_dataloader.dataset) dataset_cfg['lazy_init'] = True dataset_cfg['metainfo'] = None dataset = DATASETS.build(dataset_cfg) self.visualizer.dataset_meta = dataset.metainfo else: raise ValueError( 'KIEVisualizer requires meta information from weights or ' 'test dataset, but none of them is provided.') def _init_pipeline(self, cfg: ConfigType) -> None: """Initialize the test pipeline.""" pipeline_cfg = cfg.test_dataloader.dataset.pipeline idx = self._get_transform_idx(pipeline_cfg, 'LoadKIEAnnotations') if idx == -1: raise ValueError( 'LoadKIEAnnotations is not found in the test pipeline') pipeline_cfg[idx]['with_label'] = False self.novisual = all( self._get_transform_idx(pipeline_cfg, t) == -1 for t in self.loading_transforms) # Remove Resize from test_pipeline, since SDMGR requires bbox # annotations to be resized together with pictures, but visualization # loads the original image from the disk. # TODO: find a more elegant way to fix this idx = self._get_transform_idx(pipeline_cfg, 'Resize') if idx != -1: pipeline_cfg.pop(idx) # If it's in non-visual mode, self.pipeline will be specified. # Otherwise, file_pipeline and ndarray_pipeline will be specified. if self.novisual: return Compose(pipeline_cfg) return super()._init_pipeline(cfg) @staticmethod def kie_collate(data_batch: Sequence) -> Any: """A collate function designed for KIE, where the first element (input) is a dict and we only want to keep it as-is instead of batching elements inside. Returns: Any: Transversed Data in the same format as the data_itement of ``data_batch``. """ # noqa: E501 transposed = list(zip(*data_batch)) for i in range(1, len(transposed)): transposed[i] = pseudo_collate(transposed[i]) return transposed def _inputs_to_list(self, inputs: InputsType) -> list: """Preprocess the inputs to a list. Preprocess inputs to a list according to its type. The inputs can be a dict or list[dict], where each dictionary contains following keys: - img (str or ndarray): Path to the image or the image itself. If KIE Inferencer is used in no-visual mode, this key is not required. Note: If it's an numpy array, it should be in BGR order. - img_shape (tuple(int, int)): Image shape in (H, W). In - instances (list[dict]): A list of instances. - bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box. - text (str): Annotation text. Each ``instance`` looks like the following: .. code-block:: python { # A nested list of 4 numbers representing the bounding box of # the instance, in (x1, y1, x2, y2) order. 'bbox': np.array([[x1, y1, x2, y2], [x1, y1, x2, y2], ...], dtype=np.int32), # List of texts. "texts": ['text1', 'text2', ...], } Args: inputs (InputsType): Inputs for the inferencer. Returns: list: List of input for the :meth:`preprocess`. """ processed_inputs = [] if not isinstance(inputs, (list, tuple)): inputs = [inputs] for single_input in inputs: if self.novisual: processed_input = copy.deepcopy(single_input) if 'img' not in single_input and \ 'img_shape' not in single_input: raise ValueError( 'KIEInferencer in no-visual mode ' 'requires input has "img" or "img_shape", but both are' ' not found.') if 'img' in single_input: img = single_input['img'] if isinstance(img, str): img_bytes = mmengine.fileio.get(img) img = mmcv.imfrombytes(img_bytes) processed_input['img'] = img processed_input['img_shape'] = img.shape[:2] processed_inputs.append(processed_input) else: if 'img' not in single_input: raise ValueError( 'This inferencer is constructed to ' 'accept image inputs, but the input does not contain ' '"img" key.') if isinstance(single_input['img'], str): processed_input = { k: v for k, v in single_input.items() if k != 'img' } processed_input['img_path'] = single_input['img'] processed_inputs.append(processed_input) elif isinstance(single_input['img'], np.ndarray): processed_inputs.append(copy.deepcopy(single_input)) else: atype = type(single_input['img']) raise ValueError(f'Unsupported input type: {atype}') return processed_inputs def visualize(self, inputs: InputsType, preds: PredType, return_vis: bool = False, show: bool = False, wait_time: int = 0, draw_pred: bool = True, pred_score_thr: float = 0.3, save_vis: bool = False, img_out_dir: str = '') -> Union[List[np.ndarray], None]: """Visualize predictions. Args: inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. preds (List[Dict]): Predictions of the model. return_vis (bool): Whether to return the visualization result. Defaults to False. show (bool): Whether to display the image in a popup window. Defaults to False. wait_time (float): The interval of show (s). Defaults to 0. draw_pred (bool): Whether to draw predicted bounding boxes. Defaults to True. pred_score_thr (float): Minimum score of bboxes to draw. Defaults to 0.3. save_vis (bool): Whether to save the visualization result. Defaults to False. img_out_dir (str): Output directory of visualization results. If left as empty, no file will be saved. Defaults to ''. Returns: List[np.ndarray] or None: Returns visualization results only if applicable. """ if self.visualizer is None or not (show or save_vis or return_vis): return None if getattr(self, 'visualizer') is None: raise ValueError('Visualization needs the "visualizer" term' 'defined in the config, but got None.') results = [] for single_input, pred in zip(inputs, preds): assert 'img' in single_input or 'img_shape' in single_input if 'img' in single_input: if isinstance(single_input['img'], str): img_bytes = mmengine.fileio.get(single_input['img']) img = mmcv.imfrombytes(img_bytes, channel_order='rgb') elif isinstance(single_input['img'], np.ndarray): img = single_input['img'].copy()[:, :, ::-1] # To RGB elif 'img_shape' in single_input: img = np.zeros(single_input['img_shape'], dtype=np.uint8) else: raise ValueError('Input does not contain either "img" or ' '"img_shape"') img_name = osp.splitext(osp.basename(pred.img_path))[0] if save_vis and img_out_dir: out_file = osp.splitext(img_name)[0] out_file = f'{out_file}.jpg' out_file = osp.join(img_out_dir, out_file) else: out_file = None visualization = self.visualizer.add_datasample( img_name, img, pred, show=show, wait_time=wait_time, draw_gt=False, draw_pred=draw_pred, pred_score_thr=pred_score_thr, out_file=out_file, ) results.append(visualization) return results def pred2dict(self, data_sample: KIEDataSample) -> Dict: """Extract elements necessary to represent a prediction into a dictionary. It's better to contain only basic data elements such as strings and numbers in order to guarantee it's json-serializable. Args: data_sample (TextRecogDataSample): The data sample to be converted. Returns: dict: The output dictionary. """ result = {} pred = data_sample.pred_instances result['scores'] = pred.scores.cpu().numpy().tolist() result['edge_scores'] = pred.edge_scores.cpu().numpy().tolist() result['edge_labels'] = pred.edge_labels.cpu().numpy().tolist() result['labels'] = pred.labels.cpu().numpy().tolist() return result