Pclanglais commited on
Commit
dd838d3
1 Parent(s): ffbf266

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -13,22 +13,19 @@ import pandas as pd
13
  import difflib
14
  from concurrent.futures import ThreadPoolExecutor
15
 
16
- # Define the device
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
-
19
  # OCR Correction Model
20
  ocr_model_name = "PleIAs/OCRonos-Vintage"
21
 
22
  import torch
23
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
24
 
 
 
25
  # Load pre-trained model and tokenizer
26
  model_name = "PleIAs/OCRonos-Vintage"
27
  model = GPT2LMHeadModel.from_pretrained(model_name)
28
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
29
 
30
- # Set the device to GPU if available, otherwise use CPU
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  model.to(device)
33
 
34
  # CSS for formatting
@@ -169,7 +166,9 @@ def split_text(text, max_tokens=500):
169
 
170
 
171
  # Function to generate text
172
- def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
 
 
173
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
174
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
175
 
@@ -177,9 +176,7 @@ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
177
  torch.set_num_threads(num_threads)
178
 
179
  # Generate text
180
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
181
- future = executor.submit(
182
- model.generate,
183
  input_ids,
184
  max_new_tokens=max_new_tokens,
185
  pad_token_id=tokenizer.eos_token_id,
@@ -188,8 +185,6 @@ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
188
  do_sample=True,
189
  temperature=0.7
190
  )
191
- output = future.result()
192
-
193
  # Decode and return the generated text
194
  result = tokenizer.decode(output[0], skip_special_tokens=True)
195
  print(result)
 
13
  import difflib
14
  from concurrent.futures import ThreadPoolExecutor
15
 
 
 
 
16
  # OCR Correction Model
17
  ocr_model_name = "PleIAs/OCRonos-Vintage"
18
 
19
  import torch
20
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
21
 
22
+ device = "cuda"
23
+
24
  # Load pre-trained model and tokenizer
25
  model_name = "PleIAs/OCRonos-Vintage"
26
  model = GPT2LMHeadModel.from_pretrained(model_name)
27
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
28
 
 
 
29
  model.to(device)
30
 
31
  # CSS for formatting
 
166
 
167
 
168
  # Function to generate text
169
+ @spaces.GPU
170
+ def ocr_correction(prompt, max_new_tokens=500):
171
+
172
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
173
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
174
 
 
176
  torch.set_num_threads(num_threads)
177
 
178
  # Generate text
179
+ output = model.generate,
 
 
180
  input_ids,
181
  max_new_tokens=max_new_tokens,
182
  pad_token_id=tokenizer.eos_token_id,
 
185
  do_sample=True,
186
  temperature=0.7
187
  )
 
 
188
  # Decode and return the generated text
189
  result = tokenizer.decode(output[0], skip_special_tokens=True)
190
  print(result)