|
import os |
|
import shutil |
|
import tempfile |
|
from time import perf_counter |
|
from typing import Any, List, Union |
|
|
|
from doctr import models as models |
|
from doctr.io import DocumentFile |
|
from doctr.models import ocr_predictor |
|
from PIL import Image |
|
|
|
from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest |
|
from inference.core.entities.requests.inference import InferenceRequest |
|
from inference.core.entities.responses.doctr import DoctrOCRInferenceResponse |
|
from inference.core.entities.responses.inference import InferenceResponse |
|
from inference.core.env import MODEL_CACHE_DIR |
|
from inference.core.models.roboflow import RoboflowCoreModel |
|
from inference.core.utils.image_utils import load_image |
|
|
|
|
|
class DocTR(RoboflowCoreModel): |
|
def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): |
|
"""Initializes the DocTR model. |
|
|
|
Args: |
|
*args: Variable length argument list. |
|
**kwargs: Arbitrary keyword arguments. |
|
""" |
|
self.api_key = kwargs.get("api_key") |
|
self.dataset_id = "doctr" |
|
self.version_id = "default" |
|
self.endpoint = model_id |
|
model_id = model_id.lower() |
|
|
|
os.environ["DOCTR_CACHE_DIR"] = os.path.join(MODEL_CACHE_DIR, "doctr_rec") |
|
|
|
self.det_model = DocTRDet(api_key=kwargs.get("api_key")) |
|
self.rec_model = DocTRRec(api_key=kwargs.get("api_key")) |
|
|
|
os.makedirs(f"{MODEL_CACHE_DIR}/doctr_rec/models/", exist_ok=True) |
|
os.makedirs(f"{MODEL_CACHE_DIR}/doctr_det/models/", exist_ok=True) |
|
|
|
shutil.copyfile( |
|
f"{MODEL_CACHE_DIR}/doctr_det/db_resnet50/model.pt", |
|
f"{MODEL_CACHE_DIR}/doctr_det/models/db_resnet50-ac60cadc.pt", |
|
) |
|
shutil.copyfile( |
|
f"{MODEL_CACHE_DIR}/doctr_rec/crnn_vgg16_bn/model.pt", |
|
f"{MODEL_CACHE_DIR}/doctr_rec/models/crnn_vgg16_bn-9762b0b0.pt", |
|
) |
|
|
|
self.model = ocr_predictor( |
|
det_arch=self.det_model.version_id, |
|
reco_arch=self.rec_model.version_id, |
|
pretrained=True, |
|
) |
|
self.task_type = "ocr" |
|
|
|
def clear_cache(self) -> None: |
|
self.det_model.clear_cache() |
|
self.rec_model.clear_cache() |
|
|
|
def preprocess_image(self, image: Image.Image) -> Image.Image: |
|
""" |
|
DocTR pre-processes images as part of its inference pipeline. |
|
|
|
Thus, no preprocessing is required here. |
|
""" |
|
pass |
|
|
|
def infer_from_request( |
|
self, request: DoctrOCRInferenceRequest |
|
) -> DoctrOCRInferenceResponse: |
|
t1 = perf_counter() |
|
result = self.infer(**request.dict()) |
|
return DoctrOCRInferenceResponse( |
|
result=result, |
|
time=perf_counter() - t1, |
|
) |
|
|
|
def infer(self, image: Any, **kwargs): |
|
""" |
|
Run inference on a provided image. |
|
|
|
Args: |
|
request (DoctrOCRInferenceRequest): The inference request. |
|
|
|
Returns: |
|
DoctrOCRInferenceResponse: The inference response. |
|
""" |
|
|
|
img = load_image(image) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg") as f: |
|
image = Image.fromarray(img[0]) |
|
|
|
image.save(f.name) |
|
|
|
doc = DocumentFile.from_images([f.name]) |
|
|
|
result = self.model(doc).export() |
|
|
|
result = result["pages"][0]["blocks"] |
|
|
|
result = [ |
|
" ".join([word["value"] for word in line["words"]]) |
|
for block in result |
|
for line in block["lines"] |
|
] |
|
|
|
result = " ".join(result) |
|
|
|
return result |
|
|
|
def get_infer_bucket_file_list(self) -> list: |
|
"""Get the list of required files for inference. |
|
|
|
Returns: |
|
list: A list of required files for inference, e.g., ["model.pt"]. |
|
""" |
|
return ["model.pt"] |
|
|
|
|
|
class DocTRRec(RoboflowCoreModel): |
|
def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): |
|
"""Initializes the DocTR model. |
|
|
|
Args: |
|
*args: Variable length argument list. |
|
**kwargs: Arbitrary keyword arguments. |
|
""" |
|
pass |
|
|
|
self.get_infer_bucket_file_list() |
|
|
|
super().__init__(*args, model_id=model_id, **kwargs) |
|
|
|
def get_infer_bucket_file_list(self) -> list: |
|
"""Get the list of required files for inference. |
|
|
|
Returns: |
|
list: A list of required files for inference, e.g., ["model.pt"]. |
|
""" |
|
return ["model.pt"] |
|
|
|
|
|
class DocTRDet(RoboflowCoreModel): |
|
"""DocTR class for document Optical Character Recognition (OCR). |
|
|
|
Attributes: |
|
doctr: The DocTR model. |
|
ort_session: ONNX runtime inference session. |
|
""" |
|
|
|
def __init__(self, *args, model_id: str = "doctr_det/db_resnet50", **kwargs): |
|
"""Initializes the DocTR model. |
|
|
|
Args: |
|
*args: Variable length argument list. |
|
**kwargs: Arbitrary keyword arguments. |
|
""" |
|
|
|
self.get_infer_bucket_file_list() |
|
|
|
super().__init__(*args, model_id=model_id, **kwargs) |
|
|
|
def get_infer_bucket_file_list(self) -> list: |
|
"""Get the list of required files for inference. |
|
|
|
Returns: |
|
list: A list of required files for inference, e.g., ["model.pt"]. |
|
""" |
|
return ["model.pt"] |
|
|