uzumaki06 commited on
Commit
63e71c8
·
verified ·
1 Parent(s): 06f8db8

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +147 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import os
4
+ import base64
5
+ import io
6
+ import uuid
7
+ import shutil
8
+ from pathlib import Path
9
+ import time
10
+ import tempfile
11
+
12
+ model_name = "srimanth-d/GOT_CPU"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
15
+ model = model.eval()
16
+
17
+ UPLOAD_FOLDER = "./uploads"
18
+ RESULTS_FOLDER = "./results"
19
+
20
+ for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
21
+ if not os.path.exists(folder):
22
+ os.makedirs(folder)
23
+
24
+ def image_to_base64(image):
25
+ buffered = io.BytesIO()
26
+ image.save(buffered, format="PNG")
27
+ return base64.b64encode(buffered.getvalue()).decode()
28
+
29
+ # Cleanup function for removing old files
30
+ def cleanup_old_files():
31
+ current_time = time.time()
32
+ for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
33
+ for file_path in Path(folder).glob('*'):
34
+ if current_time - file_path.stat().st_mtime > 3600: # 1 hour
35
+ file_path.unlink()
36
+
37
+ # Function to search and highlight keywords in text
38
+ def search_in_text(text, keywords):
39
+ """Searches for keywords within the text and highlights matches."""
40
+ if not keywords:
41
+ return text
42
+ highlighted_text = text
43
+ for keyword in keywords.split():
44
+ highlighted_text = highlighted_text.replace(keyword, f"<mark>{keyword}</mark>")
45
+ return highlighted_text
46
+
47
+ # OCR processing function
48
+ def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
49
+ unique_id = str(uuid.uuid4())
50
+ image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
51
+ result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
52
+
53
+ shutil.copy(image, image_path)
54
+
55
+ try:
56
+ if got_mode == "plain texts OCR":
57
+ res = model.chat(tokenizer, image_path, ocr_type='ocr')
58
+ return res, None
59
+ elif got_mode == "format texts OCR":
60
+ res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
61
+ elif got_mode == "plain multi-crop OCR":
62
+ res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
63
+ return res, None
64
+ elif got_mode == "format multi-crop OCR":
65
+ res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
66
+ elif got_mode == "plain fine-grained OCR":
67
+ res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
68
+ return res, None
69
+ elif got_mode == "format fine-grained OCR":
70
+ res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
71
+ res_markdown = res
72
+
73
+ if "format" in got_mode and os.path.exists(result_path):
74
+ with open(result_path, 'r') as f:
75
+ html_content = f.read()
76
+ encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
77
+ iframe_src = f"data:text/html;base64,{encoded_html}"
78
+ iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
79
+ download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}.html">Download Full Result</a>'
80
+ return res_markdown, f"{download_link}<br>{iframe}"
81
+ else:
82
+ return res_markdown, None
83
+ except Exception as e:
84
+ return f"Error: {str(e)}", None
85
+ finally:
86
+ if os.path.exists(image_path):
87
+ os.remove(image_path)
88
+
89
+ # Streamlit interface
90
+ st.title("GOT OCR 2.0 Model")
91
+
92
+ st.markdown("""
93
+ Upload your image below and select your preferred mode. Note that more characters may increase wait times.
94
+ - **Plain Texts OCR & Format Texts OCR:** Use these modes for basic image-level OCR. Format Text OCR is preferred for better results.
95
+ - **Plain Multi-Crop OCR & Format Multi-Crop OCR:** Ideal for images with complex content, offering higher-quality results.
96
+ - **Plain Fine-Grained OCR & Format Fine-Grained OCR:** These modes allow you to specify fine-grained regions on the image for more flexible OCR. Regions can be defined by coordinates or colors (red, blue, green, black or white).
97
+ """)
98
+
99
+ uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
100
+ got_mode = st.selectbox("Choose OCR mode", [
101
+ "plain texts OCR",
102
+ "format texts OCR",
103
+ "plain multi-crop OCR",
104
+ "format multi-crop OCR",
105
+ "plain fine-grained OCR",
106
+ "format fine-grained OCR"
107
+ ])
108
+
109
+ if "fine-grained" in got_mode:
110
+ ocr_box = st.text_input("Input OCR box [x1,y1,x2,y2]")
111
+ ocr_color = st.selectbox("Choose OCR color", ["red", "green", "blue", "black", "white"])
112
+ else:
113
+ ocr_box = ""
114
+ ocr_color = ""
115
+
116
+ # Maintain state for OCR result
117
+ if 'ocr_result' not in st.session_state:
118
+ st.session_state.ocr_result = None
119
+ if 'html_result' not in st.session_state:
120
+ st.session_state.html_result = None
121
+
122
+ if st.button("Run OCR"):
123
+ if uploaded_image:
124
+ with tempfile.NamedTemporaryFile(delete=False) as temp:
125
+ temp.write(uploaded_image.read())
126
+ ocr_result, html_result = run_GOT(temp.name, got_mode, ocr_box=ocr_box, ocr_color=ocr_color)
127
+ st.session_state.ocr_result = ocr_result
128
+ st.session_state.html_result = html_result
129
+ st.text_area("OCR Result", ocr_result)
130
+ else:
131
+ st.warning("Please upload an image.")
132
+
133
+ # Display the OCR result if it has been set
134
+ if st.session_state.ocr_result:
135
+ st.text_area("OCR Result", st.session_state.ocr_result,key="display_area")
136
+
137
+ # Keyword search functionality
138
+ keywords = st.text_input("Enter keywords for highlighting",key="keyword_input")
139
+ if keywords:
140
+ highlighted_text = search_in_text(st.session_state.ocr_result, keywords)
141
+ st.markdown(highlighted_text, unsafe_allow_html=True)
142
+
143
+ if st.session_state.html_result:
144
+ st.markdown(st.session_state.html_result, unsafe_allow_html=True)
145
+
146
+ if __name__ == "__main__":
147
+ cleanup_old_files()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ tiktoken
3
+ verovio
4
+ transformers
5
+ Pillow
6
+ numpy
7
+ torch
8
+ torchvision
9
+ accelerate