intuitive262 commited on
Commit
4902aa0
·
1 Parent(s): 10c178b

Uploaded code files

Browse files
Files changed (2) hide show
  1. app.py +88 -73
  2. requirements.txt +7 -7
app.py CHANGED
@@ -1,88 +1,103 @@
1
- import gradio as gr
2
- import numpy as np
3
- from PIL import Image
4
  import torch
5
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
 
6
  import re
7
 
8
- # Load the first OCR model (Microsoft's TrOCR)
9
- ms_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
- ms_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
 
 
 
 
11
 
12
- # Load the second OCR model (Surya-OCR)
13
- surya_processor = TrOCRProcessor.from_pretrained("suryavarmaaddala/suryaocr")
14
- surya_model = VisionEncoderDecoderModel.from_pretrained("suryavarmaaddala/suryaocr")
 
 
 
15
 
16
- def preprocess_image(image):
17
- if isinstance(image, str):
18
- image = Image.open(image).convert("RGB")
19
- elif isinstance(image, np.ndarray):
20
- image = Image.fromarray(image).convert("RGB")
21
- return image
22
 
23
- def microsoft_ocr(image):
24
- image = preprocess_image(image)
25
- pixel_values = ms_processor(image, return_tensors="pt").pixel_values
26
-
27
- generated_ids = ms_model.generate(pixel_values)
28
- generated_text = ms_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
29
-
30
- return generated_text
31
 
32
- def surya_ocr(image):
33
- image = preprocess_image(image)
34
- pixel_values = surya_processor(image, return_tensors="pt").pixel_values
 
 
 
 
 
 
 
35
 
36
- generated_ids = surya_model.generate(pixel_values)
37
- generated_text = surya_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
38
 
39
- return generated_text
40
-
41
  def post_process_text(text):
42
- # Simple post-processing to split into lines
43
- return '\n'.join(text.split('. '))
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- def search_text(text, query):
46
- try:
47
- pattern = re.compile(query, re.IGNORECASE)
48
- lines = text.split('\n')
49
- matching_lines = [line for line in lines if pattern.search(line)]
50
- return '\n'.join(matching_lines) if matching_lines else "No matches found."
51
- except re.error:
52
- return "Invalid regex pattern. Please try again."
53
 
54
- def process_and_search(image, search_query):
55
- try:
56
- ms_text = microsoft_ocr(image)
57
- surya_text = surya_ocr(image)
58
-
59
- result = f"Microsoft OCR Result:\n{ms_text}\n\nSurya OCR Result:\n{surya_text}"
60
- processed_text = post_process_text(result)
61
-
62
- search = None
63
- if search_query:
64
- search = search_text(processed_text, search_query)
65
- return image, processed_text, search
66
- except Exception as e:
67
- return None, f"An error occurred: {str(e)}", None
68
 
69
- with gr.Blocks() as demo:
70
- with gr.Row():
71
- with gr.Column(scale=1):
72
- image_input = gr.Image(type="filepath", label="Upload your image")
73
- search_query_input = gr.Textbox(label="Enter search query")
74
- submit_button = gr.Button("Submit")
75
-
76
- with gr.Column(scale=2):
77
- displayed_image = gr.Image(label="Uploaded Image")
78
- ocr_result = gr.Textbox(label="OCR Result", lines=10)
79
- search_result = gr.Textbox(label="Search Result", lines=5)
80
 
81
- submit_button.click(
82
- fn=process_and_search,
83
- inputs=[image_input, search_query_input],
84
- outputs=[displayed_image, ocr_result, search_result]
85
- )
 
 
 
 
 
 
 
86
 
87
- if __name__ == "__main__":
88
- demo.launch()
 
1
+ from byaldi import RAGMultiModalModel
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
 
3
  import torch
4
+ from qwen_vl_utils import process_vision_info
5
+ from PIL import Image
6
+ import os
7
  import re
8
 
9
+ rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
10
+ vlm = Qwen2VLForConditionalGeneration.from_pretrained(
11
+ "Qwen/Qwen2-VL-2B-Instruct",
12
+ torch_dtype=torch.float16,
13
+ trust_remote_code=True,
14
+ device_map="auto",
15
+ )
16
 
17
+ rag.index(
18
+ input_path="./test1.png",
19
+ index_name="index",
20
+ store_collection_with_index=False,
21
+ overwrite=True,
22
+ )
23
 
24
+ text_query = "What is the text content displayed in the image?"
25
+ res = rag.search(text_query, k=1)
26
+ res
 
 
 
27
 
28
+ image = Image.open("./test2.jpg")
29
+ image_index = res[0]["page_num"] - 1
30
+
31
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
 
 
 
 
32
 
33
+ def extract_text(image, query):
34
+ messages = [
35
+ {
36
+ "role": "user",
37
+ "content": [
38
+ {"type": "image", "image": image},
39
+ {"type": "text", "text": query},
40
+ ],
41
+ }
42
+ ]
43
 
44
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
+
46
+ image_inputs, video_inputs = process_vision_info(messages)
47
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
48
+ inputs = inputs.to("cpu")
49
+ with torch.no_grad():
50
+ generated_ids = vlm.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9)
51
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
52
+ return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
53
 
 
 
54
  def post_process_text(text):
55
+ # Remove extra whitespace
56
+ text = re.sub(r'\s+', ' ', text).strip()
57
+ # Remove repeated phrases (which sometimes occur in multi-pass extraction)
58
+ phrases = text.split('. ')
59
+ unique_phrases = list(dict.fromkeys(phrases))
60
+ text = '. '.join(unique_phrases)
61
+ return text
62
+
63
+ def ocr(image):
64
+ queries = [
65
+ "Extract and transcribe all the text visible in the image, including any small or partially visible text.",
66
+ "Look closely at the image and list any text you see, no matter how small or unclear.",
67
+ "What text can you identify in this image? Include everything, even if it's partially obscured or in the background."
68
+ ]
69
 
70
+ all_extracted_text = []
71
+ for query in queries:
72
+ extracted_text = extract_text(image, query)
73
+ all_extracted_text.append(extracted_text)
 
 
 
 
74
 
75
+ # Combine and deduplicate the results
76
+ final_text = "\n".join(set(all_extracted_text))
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ final_text = post_process_text(final_text)
79
+ return final_text
80
+
81
+
82
+ def main_fun(image, keyword):
83
+ ext_text = ocr(image)
84
+
85
+ if keyword:
86
+ highlight_text = re.sub(f'({re.escape(keyword)})', r'<span style="background-color: yellow;">\1</span>', ext_text, flags=re.IGNORECASE)
87
+
88
+ return ext_text, highlight_text
89
 
90
+ iface = gr.Interface(
91
+ fn=app,
92
+ inputs=[
93
+ gr.Image(type="pil", label="Upload an Image").
94
+ gr.Textbox(label="Enter search term")
95
+ ],
96
+ outputs=[
97
+ gr.Textbox(label="Extracted Text"),
98
+ gr.HTML(label="Search Results")
99
+ ],
100
+ title="Document Search using OCR (English/Hindi)"
101
+ )
102
 
103
+ iface.launch()
 
requirements.txt CHANGED
@@ -1,10 +1,10 @@
1
  gradio
2
- Pillow
3
- surya-ocr
 
 
 
4
  torch
5
- transformers
6
- tiktoken
7
  torchvision
8
- verovio
9
- accelerate
10
- rapidfuzz
 
1
  gradio
2
+ byaldi
3
+ qwen-vl-utils
4
+ numpy==1.24.4
5
+ Pillow==10.3.0
6
+ Requests==2.31.0
7
  torch
 
 
8
  torchvision
9
+ git+https://github.com/huggingface/transformers.git
10
+ accelerate