sabaridsnfuji's picture
updated the code and model
a42c0ed verified
raw
history blame
5.09 kB
import gradio as gr
from PIL import Image
import torch
from transformers import AutoFeatureExtractor, AutoTokenizer, TrOCRProcessor, VisionEncoderDecoderModel
import os
# def sauvola_thresholding(grayImage_, window_size=15):
# """
# Sauvola thresholds are local thresholding techniques that are
# useful for images where the background is not uniform, especially for text recognition.
# grayImage_ --- Input image should be in 2-Dimension Gray Scale format.
# window_size --- It represents the filter window size.
# """
# # Assert the input conditions
# assert len(grayImage_.shape) == 2, "Input image must be a 2-dimensional gray scale image."
# assert isinstance(window_size, int) and window_size > 0, "Window size must be a positive integer."
# thresh_sauvolavalue = threshold_sauvola(grayImage_, window_size=window_size)
# thresholdImage_ = (grayImage_ > thresh_sauvolavalue)
# return thresholdImage_
class OCRModel:
def __init__(self, encoder_model, decoder_model, trained_model_path):
# Load processor and model
self.feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_model)
self.decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_model)
self.processor = TrOCRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.decoder_tokenizer)
self.model = VisionEncoderDecoderModel.from_pretrained(trained_model_path)
# Configure model settings
self.model.config.decoder_start_token_id = self.processor.tokenizer.cls_token_id
self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id
self.model.config.vocab_size = self.model.config.decoder.vocab_size
self.model.config.eos_token_id = self.processor.tokenizer.sep_token_id
self.model.config.max_length = 64
self.model.config.early_stopping = True
self.model.config.no_repeat_ngram_size = 3
self.model.config.length_penalty = 2.0
self.model.config.num_beams = 4
def read_and_show(self, image_path):
"""
Reads an image from the provided path and converts it to RGB.
:param image_path: String, path to the input image.
:return: PIL Image object
"""
image = Image.open(image_path).convert('RGB')
return image
def ocr(self, image):
"""
Performs OCR on the given image.
:param image: PIL Image object.
:return: Extracted text from the image.
"""
# Preprocess the image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pixel_values = self.processor(image, return_tensors='pt').pixel_values.to(device)
# Generate text
generated_ids = self.model.generate(pixel_values)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Initialize the OCR model
ocr_model = OCRModel(encoder_model="google/vit-base-patch16-224-in21k",
decoder_model="surajp/RoBERTa-hindi-guj-san",
trained_model_path="./model/") #'sabaridsnfuji/Tamil_Offline_Handwritten_OCR')#"./model/")
def main(image_path):
# Process the image and extract text
image = ocr_model.read_and_show(image_path)
text = ocr_model.ocr(image)
return image, text
# Gradio Interface function
def gradio_interface(image):
# Save the uploaded image locally
image_path = "uploaded_image.png"
image.save(image_path)
# Call the main function to process the image and get the result
processed_image, result_text = main(image_path)
return processed_image, result_text
# Sample images for demonstration (make sure these image paths exist)
sample_images = [
"./sample/16.jpg", # replace with actual image paths
"./sample/20.jpg", # replace with actual image paths
"./sample/21.jpg", # replace with actual image paths
"./sample/31.jpg", # replace with actual image paths
"./sample/35.jpg", # replace with actual image paths
]
# Ensure sample images directory exists
os.makedirs("samples", exist_ok=True)
# Save some dummy sample images if they don't exist (you should replace these with actual images)
for i, sample in enumerate(sample_images):
if not os.path.exists(sample):
img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50))
img.save(sample)
# Gradio UI setup with examples
gr_interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil"), # Updated to gr.Image
outputs=[gr.Image(type="pil"), gr.Textbox()], # Updated to gr.Image and gr.Textbox
title="Hindi Handwritten OCR Recognition",
description="Upload a cropped image containing a word, or use the sample images below to recognize the text. This is a word recognition model. Currently, text detection is not supported.",
examples=sample_images # Add the examples here
)
# Launch the Gradio interface
if __name__ == "__main__":
gr_interface.launch()