Pclanglais's picture
Update app.py
ba05a34 verified
import transformers
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
import difflib
from concurrent.futures import ThreadPoolExecutor
import os
# OCR Correction Model
model_name = "PleIAs/OCRonos-Vintage"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# CSS for formatting
css = """
<style>
.generation {
margin-left: 2em;
margin-right: 2em;
font-size: 1.2em;
}
.inserted {
background-color: #90EE90;
}
</style>
"""
def generate_html_diff(old_text, new_text):
d = difflib.Differ()
diff = list(d.compare(old_text.split(), new_text.split()))
html_diff = []
for word in diff:
if word.startswith(' '):
html_diff.append(word[2:])
elif word.startswith('+ '):
html_diff.append(f'<span class="inserted">{word[2:]}</span>')
return ' '.join(html_diff)
def split_text(text, max_tokens=400):
tokens = tokenizer.tokenize(text)
chunks = []
current_chunk = []
for token in tokens:
current_chunk.append(token)
if len(current_chunk) >= max_tokens:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
current_chunk = []
if current_chunk:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
return chunks
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
torch.set_num_threads(num_threads)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future = executor.submit(
model.generate,
input_ids,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=50,
num_return_sequences=1,
do_sample=False
)
output = future.result()
result = tokenizer.decode(output[0], skip_special_tokens=True)
return result.split("### Correction ###")[1].strip()
def process_text(user_message):
chunks = split_text(user_message)
corrected_chunks = []
for chunk in chunks:
corrected_chunk = ocr_correction(chunk)
corrected_chunks.append(corrected_chunk)
corrected_text = ' '.join(corrected_chunks)
html_diff = generate_html_diff(user_message, corrected_text)
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
final_output = f"{css}{ocr_result}"
return final_output
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector (CPU)</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(process_text, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch()