Pclanglais commited on
Commit
e98a756
1 Parent(s): e2b4df4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -157
app.py CHANGED
@@ -1,28 +1,18 @@
1
- import spaces
2
  import transformers
3
  import re
4
- from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
5
  import torch
6
  import gradio as gr
7
- import json
8
- import os
9
- import shutil
10
- import requests
11
- import pandas as pd
12
  import difflib
13
  from concurrent.futures import ThreadPoolExecutor
 
14
 
15
  # OCR Correction Model
16
- ocr_model_name = "PleIAs/OCRonos-Vintage"
17
-
18
- import torch
19
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
20
-
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  # Load pre-trained model and tokenizer
24
- model_name = "PleIAs/OCRonos-Vintage"
25
- model = GPT2LMHeadModel.from_pretrained(model_name)
26
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
27
 
28
  # CSS for formatting
@@ -33,78 +23,12 @@ css = """
33
  margin-right: 2em;
34
  font-size: 1.2em;
35
  }
36
- :target {
37
- background-color: #CCF3DF;
38
- }
39
- .source {
40
- float: left;
41
- max-width: 17%;
42
- margin-left: 2%;
43
- }
44
- .tooltip {
45
- position: relative;
46
- cursor: pointer;
47
- font-variant-position: super;
48
- color: #97999b;
49
- }
50
- .tooltip:hover::after {
51
- content: attr(data-text);
52
- position: absolute;
53
- left: 0;
54
- top: 120%;
55
- white-space: pre-wrap;
56
- width: 500px;
57
- max-width: 500px;
58
- z-index: 1;
59
- background-color: #f9f9f9;
60
- color: #000;
61
- border: 1px solid #ddd;
62
- border-radius: 5px;
63
- padding: 5px;
64
- display: block;
65
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
66
- }
67
- .deleted {
68
- background-color: #ffcccb;
69
- text-decoration: line-through;
70
- }
71
  .inserted {
72
  background-color: #90EE90;
73
  }
74
- .manuscript {
75
- display: flex;
76
- margin-bottom: 10px;
77
- align-items: baseline;
78
- }
79
- .annotation {
80
- width: 15%;
81
- padding-right: 20px;
82
- color: grey !important;
83
- font-style: italic;
84
- text-align: right;
85
- }
86
- .content {
87
- width: 80%;
88
- }
89
- h2 {
90
- margin: 0;
91
- font-size: 1.5em;
92
- }
93
- .title-content h2 {
94
- font-weight: bold;
95
- }
96
- .bibliography-content {
97
- color: darkgreen !important;
98
- margin-top: -5px;
99
- }
100
- .paratext-content {
101
- color: #a4a4a4 !important;
102
- margin-top: -5px;
103
- }
104
  </style>
105
  """
106
 
107
- # Helper functions
108
  def generate_html_diff(old_text, new_text):
109
  d = difflib.Differ()
110
  diff = list(d.compare(old_text.split(), new_text.split()))
@@ -113,64 +37,31 @@ def generate_html_diff(old_text, new_text):
113
  if word.startswith(' '):
114
  html_diff.append(word[2:])
115
  elif word.startswith('+ '):
116
- html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>')
117
  return ' '.join(html_diff)
118
 
119
- def preprocess_text(text):
120
- text = re.sub(r'<[^>]+>', '', text)
121
- text = re.sub(r'\n', ' ', text)
122
- text = re.sub(r'\s+', ' ', text)
123
- return text.strip()
124
-
125
- def split_text(text, max_tokens=500):
126
- parts = text.split("\n")
127
  chunks = []
128
- current_chunk = ""
129
-
130
- for part in parts:
131
- if current_chunk:
132
- temp_chunk = current_chunk + "\n" + part
133
- else:
134
- temp_chunk = part
135
 
136
- num_tokens = len(tokenizer.tokenize(temp_chunk))
137
-
138
- if num_tokens <= max_tokens:
139
- current_chunk = temp_chunk
140
- else:
141
- if current_chunk:
142
- chunks.append(current_chunk)
143
- current_chunk = part
144
 
145
  if current_chunk:
146
- chunks.append(current_chunk)
147
-
148
- if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
149
- long_text = chunks[0]
150
- chunks = []
151
- while len(tokenizer.tokenize(long_text)) > max_tokens:
152
- split_point = len(long_text) // 2
153
- while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
154
- split_point += 1
155
- if split_point >= len(long_text):
156
- split_point = len(long_text) - 1
157
- chunks.append(long_text[:split_point].strip())
158
- long_text = long_text[split_point:].strip()
159
- if long_text:
160
- chunks.append(long_text)
161
 
162
  return chunks
163
 
164
-
165
- # Function to generate text
166
  def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
167
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
168
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
169
 
170
- # Set the number of threads for PyTorch
171
  torch.set_num_threads(num_threads)
172
 
173
- # Generate text
174
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
175
  future = executor.submit(
176
  model.generate,
@@ -183,41 +74,23 @@ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
183
  )
184
  output = future.result()
185
 
186
- # Decode and return the generated text
187
  result = tokenizer.decode(output[0], skip_special_tokens=True)
188
- print(result)
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- result = result.split("### Correction ###")[1]
191
- return result
192
-
193
- # OCR Correction Class
194
- class OCRCorrector:
195
- def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
196
- self.system_prompt = system_prompt
197
-
198
- def correct(self, user_message):
199
- generated_text = ocr_correction(user_message)
200
- html_diff = generate_html_diff(user_message, generated_text)
201
- return generated_text, html_diff
202
-
203
- # Combined Processing Class
204
- class TextProcessor:
205
- def __init__(self):
206
- self.ocr_corrector = OCRCorrector()
207
-
208
- @spaces.GPU(duration=120)
209
- def process(self, user_message):
210
- #OCR Correction
211
- corrected_text, html_diff = self.ocr_corrector.correct(user_message)
212
-
213
- # Combine results
214
- ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
215
-
216
- final_output = f"{css}{ocr_result}"
217
- return final_output
218
-
219
- # Create the TextProcessor instance
220
- text_processor = TextProcessor()
221
 
222
  # Define the Gradio interface
223
  with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
@@ -225,7 +98,7 @@ with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
225
  text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
226
  process_button = gr.Button("Process Text")
227
  text_output = gr.HTML(label="Processed text")
228
- process_button.click(text_processor.process, inputs=text_input, outputs=[text_output])
229
 
230
  if __name__ == "__main__":
231
  demo.queue().launch()
 
 
1
  import transformers
2
  import re
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  import torch
5
  import gradio as gr
 
 
 
 
 
6
  import difflib
7
  from concurrent.futures import ThreadPoolExecutor
8
+ import os
9
 
10
  # OCR Correction Model
11
+ model_name = "PleIAs/OCRonos-Vintage"
 
 
 
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # Load pre-trained model and tokenizer
15
+ model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
 
16
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
17
 
18
  # CSS for formatting
 
23
  margin-right: 2em;
24
  font-size: 1.2em;
25
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  .inserted {
27
  background-color: #90EE90;
28
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  </style>
30
  """
31
 
 
32
  def generate_html_diff(old_text, new_text):
33
  d = difflib.Differ()
34
  diff = list(d.compare(old_text.split(), new_text.split()))
 
37
  if word.startswith(' '):
38
  html_diff.append(word[2:])
39
  elif word.startswith('+ '):
40
+ html_diff.append(f'<span class="inserted">{word[2:]}</span>')
41
  return ' '.join(html_diff)
42
 
43
+ def split_text(text, max_tokens=400):
44
+ tokens = tokenizer.tokenize(text)
 
 
 
 
 
 
45
  chunks = []
46
+ current_chunk = []
 
 
 
 
 
 
47
 
48
+ for token in tokens:
49
+ current_chunk.append(token)
50
+ if len(current_chunk) >= max_tokens:
51
+ chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
52
+ current_chunk = []
 
 
 
53
 
54
  if current_chunk:
55
+ chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  return chunks
58
 
 
 
59
  def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
60
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
61
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
62
 
 
63
  torch.set_num_threads(num_threads)
64
 
 
65
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
66
  future = executor.submit(
67
  model.generate,
 
74
  )
75
  output = future.result()
76
 
 
77
  result = tokenizer.decode(output[0], skip_special_tokens=True)
78
+ return result.split("### Correction ###")[1].strip()
79
+
80
+ def process_text(user_message):
81
+ chunks = split_text(user_message)
82
+ corrected_chunks = []
83
+
84
+ for chunk in chunks:
85
+ corrected_chunk = ocr_correction(chunk)
86
+ corrected_chunks.append(corrected_chunk)
87
+
88
+ corrected_text = ' '.join(corrected_chunks)
89
+ html_diff = generate_html_diff(user_message, corrected_text)
90
 
91
+ ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
92
+ final_output = f"{css}{ocr_result}"
93
+ return final_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # Define the Gradio interface
96
  with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
 
98
  text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
99
  process_button = gr.Button("Process Text")
100
  text_output = gr.HTML(label="Processed text")
101
+ process_button.click(process_text, inputs=text_input, outputs=[text_output])
102
 
103
  if __name__ == "__main__":
104
  demo.queue().launch()