import spaces import transformers import re import torch import gradio as gr import os import ctranslate2 import difflib import shutil import requests from concurrent.futures import ThreadPoolExecutor # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # Load CTranslate2 model and tokenizer model_path = "ocronos_ct2" generator = ctranslate2.Generator(model_path, device=device) tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage") # CSS for formatting (unchanged) # CSS for formatting css = """ """ # Helper functions def generate_html_diff(old_text, new_text): d = difflib.Differ() diff = list(d.compare(old_text.split(), new_text.split())) html_diff = [] for word in diff: if word.startswith(' '): html_diff.append(word[2:]) elif word.startswith('+ '): html_diff.append(f'{word[2:]}') return ' '.join(html_diff) def preprocess_text(text): text = re.sub(r'<[^>]+>', '', text) text = re.sub(r'\n', ' ', text) text = re.sub(r'\s+', ' ', text) return text.strip() def split_text(text, max_tokens=500): encoded = tokenizer.encode(text) splits = [] for i in range(0, len(encoded), max_tokens): split = encoded[i:i+max_tokens] splits.append(tokenizer.decode(split)) return splits # Function to generate text using CTranslate2 def ocr_correction(prompt, max_new_tokens=500): splits = split_text(prompt, max_tokens=500) corrected_splits = [] list_prompts = [] for split in splits: full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n" print(full_prompt) encoded = tokenizer.encode(full_prompt) prompt_tokens = tokenizer.convert_ids_to_tokens(encoded) list_prompts.append(prompt_tokens) results = generator.generate_batch( list_prompts, max_length=max_new_tokens, sampling_temperature=0, sampling_topk=20, repetition_penalty=1.1, include_prompt_in_result=False ) for result in results: corrected_text = tokenizer.decode(result.sequences_ids[0]) corrected_splits.append(corrected_text) return " ".join(corrected_splits) # OCR Correction Class class OCRCorrector: def __init__(self, system_prompt="Le dialogue suivant est une conversation"): self.system_prompt = system_prompt def correct(self, user_message): generated_text = ocr_correction(user_message) html_diff = generate_html_diff(user_message, generated_text) return generated_text, html_diff # Combined Processing Class class TextProcessor: def __init__(self): self.ocr_corrector = OCRCorrector() @spaces.GPU(duration=120) def process(self, user_message): # OCR Correction corrected_text, html_diff = self.ocr_corrector.correct(user_message) # Combine results ocr_result = f'