Tonic commited on
Commit
00ee90b
Β·
unverified Β·
1 Parent(s): 402e56a

initial commit

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -6,23 +6,28 @@ import base64
6
  import spaces
7
 
8
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
 
9
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
10
  model = model.eval().cuda()
 
11
 
12
  @spaces.GPU
13
  def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None, render=False):
 
 
 
14
  if task == "Plain Text OCR":
15
- res = model.chat(tokenizer, image, ocr_type='ocr')
16
  elif task == "Format Text OCR":
17
- res = model.chat(tokenizer, image, ocr_type='format')
18
  elif task == "Fine-grained OCR (Box)":
19
- res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box)
20
  elif task == "Fine-grained OCR (Color)":
21
- res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color)
22
  elif task == "Multi-crop OCR":
23
- res = model.chat_crop(tokenizer, image_file=image)
24
  elif task == "Render Formatted OCR":
25
- res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file='./demo.html')
26
  with open('./demo.html', 'r') as f:
27
  html_content = f.read()
28
  return res, html_content
 
6
  import spaces
7
 
8
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
9
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
10
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
11
  model = model.eval().cuda()
12
+ model.config.pad_token_id = tokenizer.eos_token_id
13
 
14
  @spaces.GPU
15
  def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None, render=False):
16
+ # Create attention mask
17
+ attention_mask = torch.ones((1, model.config.max_position_embeddings), dtype=torch.long, device=model.device)
18
+
19
  if task == "Plain Text OCR":
20
+ res = model.chat(tokenizer, image, ocr_type='ocr', attention_mask=attention_mask)
21
  elif task == "Format Text OCR":
22
+ res = model.chat(tokenizer, image, ocr_type='format', attention_mask=attention_mask)
23
  elif task == "Fine-grained OCR (Box)":
24
+ res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box, attention_mask=attention_mask)
25
  elif task == "Fine-grained OCR (Color)":
26
+ res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color, attention_mask=attention_mask)
27
  elif task == "Multi-crop OCR":
28
+ res = model.chat_crop(tokenizer, image_file=image, attention_mask=attention_mask)
29
  elif task == "Render Formatted OCR":
30
+ res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file='./demo.html', attention_mask=attention_mask)
31
  with open('./demo.html', 'r') as f:
32
  html_content = f.read()
33
  return res, html_content