Pclanglais commited on
Commit
dfbcb2e
·
verified ·
1 Parent(s): 28a19ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -17,15 +17,13 @@ ocr_model_name = "PleIAs/OCRonos-Vintage"
17
  import torch
18
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
19
 
20
- device = "cuda"
21
 
22
  # Load pre-trained model and tokenizer
23
  model_name = "PleIAs/OCRonos-Vintage"
24
  model = GPT2LMHeadModel.from_pretrained(model_name)
25
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
26
 
27
- model.to(device)
28
-
29
  # CSS for formatting
30
  css = """
31
  <style>
@@ -166,6 +164,7 @@ def split_text(text, max_tokens=500):
166
  # Function to generate text
167
  @spaces.GPU
168
  def ocr_correction(prompt, max_new_tokens=500):
 
169
 
170
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
171
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
 
17
  import torch
18
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
19
 
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  # Load pre-trained model and tokenizer
23
  model_name = "PleIAs/OCRonos-Vintage"
24
  model = GPT2LMHeadModel.from_pretrained(model_name)
25
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
26
 
 
 
27
  # CSS for formatting
28
  css = """
29
  <style>
 
164
  # Function to generate text
165
  @spaces.GPU
166
  def ocr_correction(prompt, max_new_tokens=500):
167
+ model.to(device)
168
 
169
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
170
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)