File size: 3,890 Bytes
e9c4101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pytesseract
from PIL import Image
import numpy as np
from presidio_analyzer import AnalyzerEngine, RecognizerResult
from typing import List, Dict, Optional, Union, Tuple
from dataclasses import dataclass

@dataclass
class OCRResult:
    text: str
    left: int
    top: int
    width: int
    height: int

@dataclass
class CustomImageRecognizerResult:
    entity_type: str
    start: int
    end: int
    score: float
    left: int
    top: int
    width: int
    height: int
    text: str

class CustomImageAnalyzerEngine:
    def __init__(
        self,
        analyzer_engine: Optional[AnalyzerEngine] = None,
        tesseract_config: Optional[str] = None
    ):
        if not analyzer_engine:
            analyzer_engine = AnalyzerEngine()
        self.analyzer_engine = analyzer_engine
        self.tesseract_config = tesseract_config or '--oem 3 --psm 11'

    def perform_ocr(self, image: Union[str, Image.Image, np.ndarray]) -> List[OCRResult]:
        # Ensure image is a PIL Image
        if isinstance(image, str):
            image = Image.open(image)
        elif isinstance(image, np.ndarray):
            image = Image.fromarray(image)

        ocr_data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT, config=self.tesseract_config)
        
        # Filter out empty strings and low confidence results
        valid_indices = [i for i, text in enumerate(ocr_data['text']) if text.strip() and int(ocr_data['conf'][i]) > 0]
        
        return [
            OCRResult(
                text=ocr_data['text'][i],
                left=ocr_data['left'][i],
                top=ocr_data['top'][i],
                width=ocr_data['width'][i],
                height=ocr_data['height'][i]
            )
            for i in valid_indices
        ]

    def analyze_text(
        self, 
        ocr_results: List[OCRResult], 
        **text_analyzer_kwargs
    ) -> List[CustomImageRecognizerResult]:
        # Combine all OCR text
        full_text = ' '.join([result.text for result in ocr_results])
        
        # Define English as default language, if not specified
        if "language" not in text_analyzer_kwargs:
            text_analyzer_kwargs["language"] = "en"
        
        analyzer_result = self.analyzer_engine.analyze(
            text=full_text, **text_analyzer_kwargs
        )
        
        allow_list = text_analyzer_kwargs.get('allow_list', [])
        
        return self.map_analyzer_results_to_bounding_boxes(
            analyzer_result, ocr_results, full_text, allow_list
        )

    @staticmethod
    def map_analyzer_results_to_bounding_boxes(
        text_analyzer_results: List[RecognizerResult],
        ocr_results: List[OCRResult],
        full_text: str,
        allow_list: List[str],
    ) -> List[CustomImageRecognizerResult]:
        pii_bboxes = []
        text_position = 0

        for ocr_result in ocr_results:
            word_end = text_position + len(ocr_result.text)
            
            for result in text_analyzer_results:
                if (max(text_position, result.start) < min(word_end, result.end)) and (ocr_result.text not in allow_list):
                    pii_bboxes.append(
                        CustomImageRecognizerResult(
                            entity_type=result.entity_type,
                            start=result.start,
                            end=result.end,
                            score=result.score,
                            left=ocr_result.left,
                            top=ocr_result.top,
                            width=ocr_result.width,
                            height=ocr_result.height,
                            text=ocr_result.text
                        )
                    )
                    break
            
            text_position = word_end + 1  # +1 for the space between words

        return pii_bboxes