acharyaaditya26 commited on
Commit
3d967f5
·
verified ·
1 Parent(s): 00fd0d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -127
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import spaces
3
  from transformers import AutoModel, AutoTokenizer
4
  from PIL import Image
5
  import numpy as np
@@ -11,7 +11,9 @@ import tempfile
11
  import time
12
  import shutil
13
  from pathlib import Path
 
14
 
 
15
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
16
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True)
17
  model = model.eval().cuda()
@@ -19,6 +21,7 @@ model = model.eval().cuda()
19
  UPLOAD_FOLDER = "./uploads"
20
  RESULTS_FOLDER = "./results"
21
 
 
22
  for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
23
  if not os.path.exists(folder):
24
  os.makedirs(folder)
@@ -28,75 +31,44 @@ def image_to_base64(image):
28
  image.save(buffered, format="PNG")
29
  return base64.b64encode(buffered.getvalue()).decode()
30
 
31
- @spaces.GPU
32
- def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
 
 
 
 
 
 
 
 
 
33
  unique_id = str(uuid.uuid4())
34
- image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
35
- result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
36
 
37
- shutil.copy(image, image_path)
 
38
 
39
  try:
40
- if got_mode == "plain texts OCR":
 
 
 
41
  res = model.chat(tokenizer, image_path, ocr_type='ocr')
42
- return res, None
43
- elif got_mode == "format texts OCR":
44
- res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
45
- elif got_mode == "plain multi-crop OCR":
46
- res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
47
- return res, None
48
- elif got_mode == "format multi-crop OCR":
49
- res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
50
- elif got_mode == "plain fine-grained OCR":
51
- res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
52
- return res, None
53
- elif got_mode == "format fine-grained OCR":
54
- res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
55
-
56
- # res_markdown = f"$$ {res} $$"
57
- res_markdown = res
58
-
59
- if "format" in got_mode and os.path.exists(result_path):
60
- with open(result_path, 'r') as f:
61
- html_content = f.read()
62
- encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
63
- iframe_src = f"data:text/html;base64,{encoded_html}"
64
- iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
65
- download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}.html">Download Full Result</a>'
66
- return res_markdown, f"{download_link}<br>{iframe}"
67
- else:
68
- return res_markdown, None
69
  except Exception as e:
70
  return f"Error: {str(e)}", None
71
  finally:
72
- if os.path.exists(image_path):
73
- os.remove(image_path)
74
-
75
- def task_update(task):
76
- if "fine-grained" in task:
77
- return [
78
- gr.update(visible=True),
79
- gr.update(visible=False),
80
- gr.update(visible=False),
81
- ]
82
- else:
83
- return [
84
- gr.update(visible=False),
85
- gr.update(visible=False),
86
- gr.update(visible=False),
87
- ]
88
-
89
- def fine_grained_update(task):
90
- if task == "box":
91
- return [
92
- gr.update(visible=False, value = ""),
93
- gr.update(visible=True),
94
- ]
95
- elif task == 'color':
96
- return [
97
- gr.update(visible=True),
98
- gr.update(visible=False, value = ""),
99
- ]
100
 
101
  def cleanup_old_files():
102
  current_time = time.time()
@@ -118,83 +90,49 @@ with gr.Blocks() as demo:
118
  "🔥🔥🔥This is the official online demo of GOT-OCR-2.0 model!!!"
119
 
120
  ### Demo Guidelines
121
- You need to upload your image below and choose one mode of GOT, then click "Submit" to run GOT model. More characters will result in longer wait times.
122
- - **plain texts OCR & format texts OCR**: The two modes are for the image-level OCR.
123
- - **plain multi-crop OCR & format multi-crop OCR**: For images with more complex content, you can achieve higher-quality results with these modes.
124
- - **plain fine-grained OCR & format fine-grained OCR**: In these modes, you can specify fine-grained regions on the input image for more flexible OCR. Fine-grained regions can be coordinates of the box, red color, blue color, or green color.
125
  """)
126
 
127
  with gr.Row():
128
  with gr.Column():
129
- image_input = gr.Image(type="filepath", label="upload your image")
130
- task_dropdown = gr.Dropdown(
131
- choices=[
132
- "plain texts OCR",
133
- "format texts OCR",
134
- "plain multi-crop OCR",
135
- "format multi-crop OCR",
136
- "plain fine-grained OCR",
137
- "format fine-grained OCR",
138
- ],
139
- label="Choose one mode of GOT",
140
- value="plain texts OCR"
141
- )
142
- fine_grained_dropdown = gr.Dropdown(
143
- choices=["box", "color"],
144
- label="fine-grained type",
145
- visible=False
146
- )
147
- color_dropdown = gr.Dropdown(
148
- choices=["red", "green", "blue"],
149
- label="color list",
150
- visible=False
151
- )
152
- box_input = gr.Textbox(
153
- label="input box: [x1,y1,x2,y2]",
154
- placeholder="e.g., [0,0,100,100]",
155
- visible=False
156
- )
157
  submit_button = gr.Button("Submit")
158
 
159
  with gr.Column():
160
- ocr_result = gr.Textbox(label="GOT output")
161
 
162
  with gr.Column():
163
- gr.Markdown("**If you choose the mode with format, the mathpix result will be automatically rendered as follows:**")
164
- html_result = gr.HTML(label="rendered html", show_label=True)
165
-
166
- gr.Examples(
167
- examples=[
168
- ["assets/coco.jpg", "plain texts OCR", "", "", ""],
169
- ["assets/en_30.png", "plain texts OCR", "", "", ""],
170
- ["assets/table.jpg", "format texts OCR", "", "", ""],
171
- ["assets/eq.jpg", "format texts OCR", "", "", ""],
172
- ["assets/exam.jpg", "format texts OCR", "", "", ""],
173
- ["assets/giga.jpg", "format multi-crop OCR", "", "", ""],
174
- ["assets/aff2.png", "plain fine-grained OCR", "box", "", "[409,763,756,891]"],
175
- ["assets/color.png", "plain fine-grained OCR", "color", "red", ""],
176
- ],
177
- inputs=[image_input, task_dropdown, fine_grained_dropdown, color_dropdown, box_input],
178
- outputs=[ocr_result, html_result],
179
- fn=run_GOT,
180
- label="examples",
181
- )
182
 
183
- task_dropdown.change(
184
- task_update,
185
- inputs=[task_dropdown],
186
- outputs=[fine_grained_dropdown, color_dropdown, box_input]
187
- )
188
- fine_grained_dropdown.change(
189
- fine_grained_update,
190
- inputs=[fine_grained_dropdown],
191
- outputs=[color_dropdown, box_input]
192
- )
193
-
194
  submit_button.click(
195
  run_GOT,
196
- inputs=[image_input, task_dropdown, fine_grained_dropdown, color_dropdown, box_input],
197
- outputs=[ocr_result, html_result]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
 
200
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import fitz # PyMuPDF
3
  from transformers import AutoModel, AutoTokenizer
4
  from PIL import Image
5
  import numpy as np
 
11
  import time
12
  import shutil
13
  from pathlib import Path
14
+ import json
15
 
16
+ # Load tokenizer and model
17
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
18
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True)
19
  model = model.eval().cuda()
 
21
  UPLOAD_FOLDER = "./uploads"
22
  RESULTS_FOLDER = "./results"
23
 
24
+ # Ensure directories exist
25
  for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
26
  if not os.path.exists(folder):
27
  os.makedirs(folder)
 
31
  image.save(buffered, format="PNG")
32
  return base64.b64encode(buffered.getvalue()).decode()
33
 
34
+ def pdf_to_images(pdf_path):
35
+ images = []
36
+ pdf_document = fitz.open(pdf_path)
37
+ for page_num in range(len(pdf_document)):
38
+ page = pdf_document.load_page(page_num)
39
+ pix = page.get_pixmap()
40
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
41
+ images.append(img)
42
+ return images
43
+
44
+ def run_GOT(pdf_file):
45
  unique_id = str(uuid.uuid4())
46
+ pdf_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.pdf")
47
+ shutil.copy(pdf_file, pdf_path)
48
 
49
+ images = pdf_to_images(pdf_path)
50
+ results = []
51
 
52
  try:
53
+ for i, image in enumerate(images):
54
+ image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}_page_{i+1}.png")
55
+ image.save(image_path)
56
+
57
  res = model.chat(tokenizer, image_path, ocr_type='ocr')
58
+ results.append({
59
+ "page_number": i + 1,
60
+ "text": res
61
+ })
62
+
63
+ if os.path.exists(image_path):
64
+ os.remove(image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
  return f"Error: {str(e)}", None
67
  finally:
68
+ if os.path.exists(pdf_path):
69
+ os.remove(pdf_path)
70
+
71
+ return json.dumps(results, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def cleanup_old_files():
74
  current_time = time.time()
 
90
  "🔥🔥🔥This is the official online demo of GOT-OCR-2.0 model!!!"
91
 
92
  ### Demo Guidelines
93
+ You need to upload your PDF below, and the model will automatically perform plain text OCR on each page.
 
 
 
94
  """)
95
 
96
  with gr.Row():
97
  with gr.Column():
98
+ pdf_input = gr.File(type="filepath", label="Upload your PDF")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  submit_button = gr.Button("Submit")
100
 
101
  with gr.Column():
102
+ ocr_result = gr.JSON(label="GOT output")
103
 
104
  with gr.Column():
105
+ gr.Markdown("**PDF Preview:**")
106
+ pdf_preview = gr.HTML(label="PDF Preview", show_label=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
 
 
 
 
 
 
 
 
 
 
108
  submit_button.click(
109
  run_GOT,
110
+ inputs=[pdf_input],
111
+ outputs=[ocr_result]
112
+ )
113
+
114
+ # Function to update PDF preview
115
+ def update_pdf_preview(pdf_file):
116
+ if not pdf_file:
117
+ return ""
118
+ pdf_path = pdf_file
119
+ pdf_document = fitz.open(pdf_path)
120
+ html_content = ""
121
+ for page_num in range(len(pdf_document)):
122
+ page = pdf_document.load_page(page_num)
123
+ pix = page.get_pixmap()
124
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
125
+ img_byte_arr = io.BytesIO()
126
+ img.save(img_byte_arr, format='PNG')
127
+ img_byte_arr = img_byte_arr.getvalue()
128
+ img_base64 = base64.b64encode(img_byte_arr).decode('utf-8')
129
+ html_content += f'<img src="data:image/png;base64,{img_base64}" width="100%"><br>'
130
+ return html_content
131
+
132
+ pdf_input.change(
133
+ update_pdf_preview,
134
+ inputs=[pdf_input],
135
+ outputs=[pdf_preview]
136
  )
137
 
138
  if __name__ == "__main__":