intuitive262 commited on
Commit
da194cb
·
1 Parent(s): 3243581

Updated code files

Browse files
Files changed (1) hide show
  1. app.py +49 -50
app.py CHANGED
@@ -3,13 +3,13 @@ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoPro
3
  import torch
4
  from qwen_vl_utils import process_vision_info
5
  from PIL import Image
6
- import re
7
  import gradio as gr
 
8
 
9
  rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
10
  vlm = Qwen2VLForConditionalGeneration.from_pretrained(
11
  "Qwen/Qwen2-VL-2B-Instruct",
12
- torch_dtype=torch.float32,
13
  trust_remote_code=True,
14
  device_map="auto",
15
  )
@@ -26,59 +26,58 @@ def extract_text(image, query):
26
  ],
27
  }
28
  ]
29
-
30
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
31
-
32
  image_inputs, video_inputs = process_vision_info(messages)
33
  inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
34
- inputs = inputs.to("cpu")
35
- with torch.no_grad():
36
- generated_ids = vlm.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9)
37
- generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
38
- return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
39
-
40
-
41
- def ocr(image):
42
- queries = [
43
- # "Extract and transcribe all the text visible in the image, including any small or partially visible text.",
44
- # "Look closely at the image and list any text you see, no matter how small or unclear.",
45
- # "What text can you identify in this image? Include everything, even if it's partially obscured or in the background."
46
- "Extract all the text in Sanskrit and English from the image."
47
- ]
48
 
49
- all_extracted_text = []
50
- for query in queries:
51
- extracted_text = extract_text(image, query)
52
- all_extracted_text.append(extracted_text)
53
 
54
- # Combine and deduplicate the results
55
- final_text = "\n".join(set(all_extracted_text))
56
-
57
- # final_text = post_process_text(final_text)
58
- return final_text
59
-
60
-
61
- def main_fun(image, keyword):
62
- ext_text = ocr(image)
63
-
64
- if keyword:
65
- highlight_text = re.sub(f'({re.escape(keyword)})', r'<span style="background-color: yellow;">\1</span>', ext_text, flags=re.IGNORECASE)
66
  else:
67
- highlight_text = ext_text
68
-
69
- return ext_text, highlight_text
70
 
71
- iface = gr.Interface(
72
- fn=main_fun,
73
- inputs=[
74
- gr.Image(type="pil", label="Upload an Image"),
75
- gr.Textbox(label="Enter search term", placeholder="Search")
76
- ],
77
- outputs=[
78
- gr.Textbox(label="Extracted Text"),
79
- gr.HTML(label="Search Results")
80
- ],
81
- title="Document Search using OCR (English/Hindi)"
82
- )
83
 
84
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
  from qwen_vl_utils import process_vision_info
5
  from PIL import Image
 
6
  import gradio as gr
7
+ import re
8
 
9
  rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
10
  vlm = Qwen2VLForConditionalGeneration.from_pretrained(
11
  "Qwen/Qwen2-VL-2B-Instruct",
12
+ torch_dtype=torch.float16,
13
  trust_remote_code=True,
14
  device_map="auto",
15
  )
 
26
  ],
27
  }
28
  ]
29
+
30
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
31
+
32
  image_inputs, video_inputs = process_vision_info(messages)
33
  inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
34
+ inputs = inputs.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ generated_ids = vlm.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9)
37
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
38
+ return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
39
 
40
+ def search_text(text, query):
41
+ if query:
42
+ searched_text = re.sub(f'({re.escape(query)})', r'<span style="background-color: yellow;">\1</span>', text, flags=re.IGNORECASE)
 
 
 
 
 
 
 
 
 
43
  else:
44
+ searched_text = text
45
+ return searched_text
 
46
 
47
+ def extraction(image, query):
48
+ extracted_text = extract_text(image, query)
49
+ return extracted_text, extracted_text # return twice - one to display output and the other for state management
 
 
 
 
 
 
 
 
 
50
 
51
+
52
+ """
53
+ Main App
54
+ """
55
+ with gr.Blocks() as main_app:
56
+ gr.Markdown("# Document Reader using OCR(English/Hindi)")
57
+
58
+ with gr.Row():
59
+ with gr.Column():
60
+ img_input = gr.Image(type="pil", label="Upload an Image")
61
+ query_input = gr.Textbox(label="Enter query for retrieval", placeholder="Query/Prompt")
62
+ search_input = gr.Textbox(label="Enter search term", placeholder="Search")
63
+ extract_button = gr.Button("Read Doc!")
64
+ search_button = gr.Button("Search!")
65
+
66
+ with gr.Column():
67
+ extracted_text_op = gr.Textbox(label="Output")
68
+ search_text_op = gr.HTML(label="Search Results")
69
+
70
+ extracted_text_state = gr.State()
71
+ extract_button.click(
72
+ extraction,
73
+ inputs=[img_input, query_input],
74
+ outputs=[extracted_text_op, extracted_text_state]
75
+ )
76
+
77
+ search_button.click(
78
+ search_text,
79
+ inputs=[extracted_text_state, search_input],
80
+ outputs=[search_text_op]
81
+ )
82
+
83
+ main_app.launch()