Spaces:
Running
on
Zero
Running
on
Zero
initial commit
Browse files
app.py
CHANGED
@@ -6,23 +6,28 @@ import base64
|
|
6 |
import spaces
|
7 |
|
8 |
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
|
|
|
9 |
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
|
10 |
model = model.eval().cuda()
|
|
|
11 |
|
12 |
@spaces.GPU
|
13 |
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None, render=False):
|
|
|
|
|
|
|
14 |
if task == "Plain Text OCR":
|
15 |
-
res = model.chat(tokenizer, image, ocr_type='ocr')
|
16 |
elif task == "Format Text OCR":
|
17 |
-
res = model.chat(tokenizer, image, ocr_type='format')
|
18 |
elif task == "Fine-grained OCR (Box)":
|
19 |
-
res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box)
|
20 |
elif task == "Fine-grained OCR (Color)":
|
21 |
-
res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color)
|
22 |
elif task == "Multi-crop OCR":
|
23 |
-
res = model.chat_crop(tokenizer, image_file=image)
|
24 |
elif task == "Render Formatted OCR":
|
25 |
-
res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file='./demo.html')
|
26 |
with open('./demo.html', 'r') as f:
|
27 |
html_content = f.read()
|
28 |
return res, html_content
|
|
|
6 |
import spaces
|
7 |
|
8 |
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
|
9 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
10 |
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
|
11 |
model = model.eval().cuda()
|
12 |
+
model.config.pad_token_id = tokenizer.eos_token_id
|
13 |
|
14 |
@spaces.GPU
|
15 |
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None, render=False):
|
16 |
+
# Create attention mask
|
17 |
+
attention_mask = torch.ones((1, model.config.max_position_embeddings), dtype=torch.long, device=model.device)
|
18 |
+
|
19 |
if task == "Plain Text OCR":
|
20 |
+
res = model.chat(tokenizer, image, ocr_type='ocr', attention_mask=attention_mask)
|
21 |
elif task == "Format Text OCR":
|
22 |
+
res = model.chat(tokenizer, image, ocr_type='format', attention_mask=attention_mask)
|
23 |
elif task == "Fine-grained OCR (Box)":
|
24 |
+
res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box, attention_mask=attention_mask)
|
25 |
elif task == "Fine-grained OCR (Color)":
|
26 |
+
res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color, attention_mask=attention_mask)
|
27 |
elif task == "Multi-crop OCR":
|
28 |
+
res = model.chat_crop(tokenizer, image_file=image, attention_mask=attention_mask)
|
29 |
elif task == "Render Formatted OCR":
|
30 |
+
res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file='./demo.html', attention_mask=attention_mask)
|
31 |
with open('./demo.html', 'r') as f:
|
32 |
html_content = f.read()
|
33 |
return res, html_content
|