Spaces:
Runtime error
Runtime error
import csv | |
# import fire | |
import json | |
import numpy as np | |
import os | |
# import pandas as pd | |
import sys | |
import torch | |
import requests | |
from dataclasses import dataclass | |
from PIL import Image | |
from nltk import edit_distance | |
from torchvision import transforms as T | |
from typing import Optional, Callable, Sequence, Tuple | |
from tqdm import tqdm | |
from IndicPhotoOCR.utils.strhub.data.module import SceneTextDataModule | |
from IndicPhotoOCR.utils.strhub.models.utils import load_from_checkpoint | |
model_info = { | |
"assamese": { | |
"path": "models/assamese.ckpt", | |
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/assamese.ckpt", | |
}, | |
"bengali": { | |
"path": "models/bengali.ckpt", | |
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/bengali.ckpt", | |
}, | |
"hindi": { | |
"path": "models/hindi.ckpt", | |
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/hindi.ckpt", | |
}, | |
"gujarati": { | |
"path": "models/gujarati.ckpt", | |
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/gujarati.ckpt", | |
}, | |
"marathi": { | |
"path": "models/marathi.ckpt", | |
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/marathi.ckpt", | |
}, | |
"odia": { | |
"path": "models/odia.ckpt", | |
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/odia.ckpt", | |
}, | |
"punjabi": { | |
"path": "models/punjabi.ckpt", | |
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/punjabi.ckpt", | |
}, | |
"tamil": { | |
"path": "models/tamil.ckpt", | |
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/tamil.ckpt", | |
}, | |
"telugu": { | |
"path": "models/telugu.ckpt", | |
"url" : "https://github.com/anikde/STocr/releases/download/V2.0.0/telugu.ckpt", | |
} | |
} | |
class PARseqrecogniser: | |
def __init__(self): | |
pass | |
def get_transform(self, img_size: Tuple[int], augment: bool = False, rotation: int = 0): | |
transforms = [] | |
if augment: | |
from .augment import rand_augment_transform | |
transforms.append(rand_augment_transform()) | |
if rotation: | |
transforms.append(lambda img: img.rotate(rotation, expand=True)) | |
transforms.extend([ | |
T.Resize(img_size, T.InterpolationMode.BICUBIC), | |
T.ToTensor(), | |
T.Normalize(0.5, 0.5) | |
]) | |
return T.Compose(transforms) | |
def load_model(self, device, checkpoint): | |
model = load_from_checkpoint(checkpoint).eval().to(device) | |
return model | |
def get_model_output(self, device, model, image_path): | |
hp = model.hparams | |
transform = self.get_transform(hp.img_size, rotation=0) | |
image_name = image_path.split("/")[-1] | |
img = Image.open(image_path).convert('RGB') | |
img = transform(img) | |
logits = model(img.unsqueeze(0).to(device)) | |
probs = logits.softmax(-1) | |
preds, probs = model.tokenizer.decode(probs) | |
text = model.charset_adapter(preds[0]) | |
scores = probs[0].detach().cpu().numpy() | |
return text | |
# Ensure model file exists; download directly if not | |
def ensure_model(self, model_name): | |
model_path = model_info[model_name]["path"] | |
url = model_info[model_name]["url"] | |
root_model_dir = "IndicPhotoOCR/recognition/" | |
model_path = os.path.join(root_model_dir, model_path) | |
if not os.path.exists(model_path): | |
print(f"Model not found locally. Downloading {model_name} from {url}...") | |
# Start the download with a progress bar | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
os.makedirs(f"{root_model_dir}/models", exist_ok=True) | |
with open(model_path, "wb") as f, tqdm( | |
desc=model_name, | |
total=total_size, | |
unit='B', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as bar: | |
for data in response.iter_content(chunk_size=1024): | |
f.write(data) | |
bar.update(len(data)) | |
print(f"Downloaded model for {model_name}.") | |
return model_path | |
def bstr(checkpoint, language, image_dir, save_dir): | |
""" | |
Runs the OCR model to process images and save the output as a JSON file. | |
Args: | |
checkpoint (str): Path to the model checkpoint file. | |
language (str): Language code (e.g., 'hindi', 'english'). | |
image_dir (str): Directory containing the images to process. | |
save_dir (str): Directory where the output JSON file will be saved. | |
Example usage: | |
python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save | |
""" | |
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |
if language != "english": | |
model = load_model(device, checkpoint) | |
else: | |
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device) | |
parseq_dict = {} | |
for image_path in tqdm(os.listdir(image_dir)): | |
assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}" | |
text = get_model_output(device, model, os.path.join(image_dir, image_path), language=f"{language}") | |
filename = image_path.split('/')[-1] | |
parseq_dict[filename] = text | |
os.makedirs(save_dir, exist_ok=True) | |
with open(f"{save_dir}/{language}_test.json", 'w') as json_file: | |
json.dump(parseq_dict, json_file, indent=4, ensure_ascii=False) | |
def bstr_onImage(checkpoint, language, image_path): | |
""" | |
Runs the OCR model to process images and save the output as a JSON file. | |
Args: | |
checkpoint (str): Path to the model checkpoint file. | |
language (str): Language code (e.g., 'hindi', 'english'). | |
image_dir (str): Directory containing the images to process. | |
save_dir (str): Directory where the output JSON file will be saved. | |
Example usage: | |
python your_script.py --checkpoint /path/to/checkpoint.ckpt --language hindi --image_dir /path/to/images --save_dir /path/to/save | |
""" | |
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |
if language != "english": | |
model = load_model(device, checkpoint) | |
else: | |
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval().to(device) | |
# parseq_dict = {} | |
# for image_path in tqdm(os.listdir(image_dir)): | |
# assert os.path.exists(os.path.join(image_dir, image_path)) == True, f"{image_path}" | |
text = get_model_output(device, model, image_path, language=f"{language}") | |
return text | |
def recognise(self, checkpoint: str, image_path: str, language: str, verbose: bool, device: str) -> str: | |
""" | |
Loads the desired model and returns the recognized word from the specified image. | |
Args: | |
checkpoint (str): Path to the model checkpoint file. | |
language (str): Language code (e.g., 'hindi', 'english'). | |
image_path (str): Path to the image for which text recognition is needed. | |
Returns: | |
str: The recognized text from the image. | |
""" | |
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
if language != "english": | |
model_path = self.ensure_model(checkpoint) | |
model = self.load_model(device, model_path) | |
else: | |
model = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, verbose=verbose).eval().to(device) | |
recognized_text = self.get_model_output(device, model, image_path) | |
return recognized_text | |
# if __name__ == '__main__': | |
# fire.Fire(main) |