Pclanglais's picture
Update app.py
eed441d verified
raw
history blame
4.88 kB
import spaces
import transformers
import re
import torch
import gradio as gr
import os
import ctranslate2
from concurrent.futures import ThreadPoolExecutor
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CTranslate2 model and tokenizer
model_path = "ocronos_ct2"
generator = ctranslate2.Generator(model_path, device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage")
# CSS for formatting (unchanged)
# CSS for formatting
css = """
<style>
.generation {
margin-left: 2em;
margin-right: 2em;
font-size: 1.2em;
}
:target {
background-color: #CCF3DF;
}
.source {
float: left;
max-width: 17%;
margin-left: 2%;
}
.tooltip {
position: relative;
cursor: pointer;
font-variant-position: super;
color: #97999b;
}
.tooltip:hover::after {
content: attr(data-text);
position: absolute;
left: 0;
top: 120%;
white-space: pre-wrap;
width: 500px;
max-width: 500px;
z-index: 1;
background-color: #f9f9f9;
color: #000;
border: 1px solid #ddd;
border-radius: 5px;
padding: 5px;
display: block;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.deleted {
background-color: #ffcccb;
text-decoration: line-through;
}
.inserted {
background-color: #90EE90;
}
.manuscript {
display: flex;
margin-bottom: 10px;
align-items: baseline;
}
.annotation {
width: 15%;
padding-right: 20px;
color: grey !important;
font-style: italic;
text-align: right;
}
.content {
width: 80%;
}
h2 {
margin: 0;
font-size: 1.5em;
}
.title-content h2 {
font-weight: bold;
}
.bibliography-content {
color: darkgreen !important;
margin-top: -5px;
}
.paratext-content {
color: #a4a4a4 !important;
margin-top: -5px;
}
</style>
"""
# Helper functions
def generate_html_diff(old_text, new_text):
d = difflib.Differ()
diff = list(d.compare(old_text.split(), new_text.split()))
html_diff = []
for word in diff:
if word.startswith(' '):
html_diff.append(word[2:])
elif word.startswith('+ '):
html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>')
return ' '.join(html_diff)
def preprocess_text(text):
text = re.sub(r'<[^>]+>', '', text)
text = re.sub(r'\n', ' ', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
def split_text(text, max_tokens=400):
encoded = tokenizer.encode(text)
splits = []
for i in range(0, len(encoded), max_tokens):
split = encoded[i:i+max_tokens]
splits.append(tokenizer.decode(split))
return splits
# Function to generate text using CTranslate2
def ocr_correction(prompt, max_new_tokens=600):
splits = split_text(prompt, max_tokens=400)
corrected_splits = []
for split in splits:
full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n"
encoded = tokenizer.encode(full_prompt)
prompt_tokens = tokenizer.convert_ids_to_tokens(encoded)
result = generator.generate_batch(
[prompt_tokens],
max_length=max_new_tokens,
sampling_temperature=0.7,
sampling_topk=20,
include_prompt_in_result=False
)[0]
corrected_text = tokenizer.decode(result.sequences_ids[0])
corrected_splits.append(corrected_text)
return " ".join(corrected_splits)
# OCR Correction Class
class OCRCorrector:
def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
self.system_prompt = system_prompt
def correct(self, user_message):
generated_text = ocr_correction(user_message)
html_diff = generate_html_diff(user_message, generated_text)
return generated_text, html_diff
# Combined Processing Class
class TextProcessor:
def __init__(self):
self.ocr_corrector = OCRCorrector()
@spaces.GPU(duration=120)
def process(self, user_message):
# OCR Correction
corrected_text, html_diff = self.ocr_corrector.correct(user_message)
# Combine results
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
final_output = f"{css}{ocr_result}"
return final_output
# Create the TextProcessor instance
text_processor = TextProcessor()
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(text_processor.process, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch()