import spaces import transformers import re import torch import gradio as gr import os import ctranslate2 import difflib import shutil import requests from concurrent.futures import ThreadPoolExecutor # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # Load CTranslate2 model and tokenizer model_path = "ocronos_ct2" generator = ctranslate2.Generator(model_path, device=device) tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage") # CSS for formatting (unchanged) # CSS for formatting css = """ """ # Helper functions def generate_html_diff(old_text, new_text): d = difflib.Differ() diff = list(d.compare(old_text.split(), new_text.split())) html_diff = [] for word in diff: if word.startswith(' '): html_diff.append(word[2:]) elif word.startswith('+ '): html_diff.append(f'{word[2:]}') return ' '.join(html_diff) def preprocess_text(text): text = re.sub(r'<[^>]+>', '', text) text = re.sub(r'\n', ' ', text) text = re.sub(r'\s+', ' ', text) return text.strip() def split_text(text, max_tokens=500): encoded = tokenizer.encode(text) splits = [] for i in range(0, len(encoded), max_tokens): split = encoded[i:i+max_tokens] splits.append(tokenizer.decode(split)) return splits # Function to generate text using CTranslate2 def ocr_correction(prompt, max_new_tokens=500): splits = split_text(prompt, max_tokens=500) corrected_splits = [] list_prompts = [] for split in splits: full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n" print(full_prompt) encoded = tokenizer.encode(full_prompt) prompt_tokens = tokenizer.convert_ids_to_tokens(encoded) list_prompts.append(prompt_tokens) results = generator.generate_batch( list_prompts, max_length=max_new_tokens, sampling_temperature=0, sampling_topk=20, repetition_penalty=1.1, include_prompt_in_result=False ) for result in results: corrected_text = tokenizer.decode(result.sequences_ids[0]) corrected_splits.append(corrected_text) return " ".join(corrected_splits) # OCR Correction Class class OCRCorrector: def __init__(self, system_prompt="Le dialogue suivant est une conversation"): self.system_prompt = system_prompt def correct(self, user_message): generated_text = ocr_correction(user_message) html_diff = generate_html_diff(user_message, generated_text) return generated_text, html_diff # Combined Processing Class class TextProcessor: def __init__(self): self.ocr_corrector = OCRCorrector() @spaces.GPU(duration=120) def process(self, user_message): # OCR Correction corrected_text, html_diff = self.ocr_corrector.correct(user_message) # Combine results ocr_result = f'

OCR Correction

\n
{html_diff}
' final_output = f"{css}{ocr_result}" return final_output # Create the TextProcessor instance text_processor = TextProcessor() # Define the Gradio interface with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: gr.HTML("""

Vintage OCR corrector

""") text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5) process_button = gr.Button("Process Text") text_output = gr.HTML(label="Processed text") process_button.click(text_processor.process, inputs=text_input, outputs=[text_output]) if __name__ == "__main__": demo.queue().launch()