gauri-sharan commited on
Commit
40f7360
·
verified ·
1 Parent(s): d85fa29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -71
app.py CHANGED
@@ -6,34 +6,43 @@ import torch
6
  from PIL import Image
7
  import os
8
  import traceback
 
9
  import re
10
 
11
- # Load models
12
- rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
 
 
 
 
13
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
14
  "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
15
- )
 
 
16
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
17
 
18
- extracted_text = "" # Store the extracted text globally for keyword search
 
19
 
20
- def ocr_and_extract(image, text_query=None):
 
21
  global extracted_text
22
  try:
23
  # Save the uploaded image temporarily
24
  temp_image_path = "temp_image.jpg"
25
  image.save(temp_image_path)
26
 
27
- # Index the image with Byaldi
28
  rag_model.index(
29
  input_path=temp_image_path,
30
- index_name="image_index",
31
  store_collection_with_index=False,
32
- overwrite=True
33
  )
34
 
35
  # Perform the search query on the indexed image
36
- results = rag_model.search(text_query, k=1)
37
 
38
  # Prepare the input for Qwen2-VL
39
  image_data = Image.open(temp_image_path)
@@ -43,31 +52,33 @@ def ocr_and_extract(image, text_query=None):
43
  "role": "user",
44
  "content": [
45
  {"type": "image", "image": image_data},
46
- {"type": "text", "text": text_query},
47
  ],
48
  }
49
  ]
50
 
51
- # Process input for Qwen2-VL
52
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
  image_inputs, _ = process_vision_info(messages)
54
 
 
55
  inputs = processor(
56
  text=[text_input],
57
  images=image_inputs,
58
  padding=True,
59
  return_tensors="pt",
60
- )
61
-
62
- qwen_model.to("cuda")
63
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
64
 
65
  # Generate the output with Qwen2-VL
66
  generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
67
- output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
68
-
69
- # Store the extracted text for keyword search
70
- extracted_text = output_text[0]
 
 
 
 
 
71
  os.remove(temp_image_path)
72
 
73
  return extracted_text
@@ -77,57 +88,36 @@ def ocr_and_extract(image, text_query=None):
77
  traceback.print_exc()
78
  return f"Error: {error_message}"
79
 
80
- def search_keywords(keyword):
81
- global extracted_text
82
  if not extracted_text:
83
  return "No text extracted yet. Please upload an image."
84
-
85
- # Perform basic keyword search within the extracted text
86
- if re.search(rf"\b{re.escape(keyword)}\b", extracted_text, re.IGNORECASE):
87
- highlighted_text = re.sub(rf"({re.escape(keyword)})", r"<mark>\1</mark>", extracted_text, flags=re.IGNORECASE)
88
- return f"Keyword found! {highlighted_text}"
89
- else:
90
- return "Keyword not found in the extracted text."
91
-
92
- # Gradio interface
93
- image_input = gr.Image(type="pil")
94
- text_output = gr.Textbox(label="Extracted Text", interactive=True)
95
- keyword_search = gr.Textbox(label="Enter keywords to search")
96
- search_button = gr.Button("Search Keywords")
97
- search_output = gr.HTML()
98
-
99
- extract_button = gr.Button("Extract Text")
100
-
101
- # Layout update
102
- iface = gr.Interface(
103
- fn=ocr_and_extract,
104
- inputs=[image_input],
105
- outputs=[text_output],
106
- title="Image OCR with Byaldi + Qwen2-VL",
107
- description="Upload an image containing Hindi and English text for OCR. Then, search for specific keywords.",
108
- )
109
-
110
- # Keyword search layout
111
- iface_search = gr.Interface(
112
- fn=search_keywords,
113
- inputs=[keyword_search],
114
- outputs=[search_output],
115
- )
116
-
117
- # Move extract button above the text output
118
- def combined_interface(image, keyword):
119
- ocr_text = ocr_and_extract(image)
120
- search_result = search_keywords(keyword)
121
- return ocr_text, search_result
122
-
123
- combined_iface = gr.Interface(
124
- fn=combined_interface,
125
- inputs=[image_input, keyword_search],
126
- outputs=[text_output, search_output],
127
- live=True,
128
- title="Image OCR & Keyword Search",
129
- description="Extract text from the image and search for specific keywords."
130
- )
131
-
132
- # Launch the app
133
- combined_iface.launch()
 
6
  from PIL import Image
7
  import os
8
  import traceback
9
+ import spaces
10
  import re
11
 
12
+ # Check if CUDA is available
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {device}")
15
+
16
+ # Load the Byaldi and Qwen2-VL models
17
+ rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali") # Byaldi model
18
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
19
  "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
20
+ ).to(device) # Move Qwen2-VL to GPU
21
+
22
+ # Processor for Qwen2-VL
23
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
24
 
25
+ # Global variable to store extracted text
26
+ extracted_text = ""
27
 
28
+ @spaces.GPU(duration=120) # Increased GPU duration to 120 seconds
29
+ def ocr_and_extract(image):
30
  global extracted_text
31
  try:
32
  # Save the uploaded image temporarily
33
  temp_image_path = "temp_image.jpg"
34
  image.save(temp_image_path)
35
 
36
+ # Index the image with Byaldi, and force overwrite of the existing index
37
  rag_model.index(
38
  input_path=temp_image_path,
39
+ index_name="image_index", # Reuse the same index
40
  store_collection_with_index=False,
41
+ overwrite=True # Overwrite the index for every new image
42
  )
43
 
44
  # Perform the search query on the indexed image
45
+ results = rag_model.search("", k=1)
46
 
47
  # Prepare the input for Qwen2-VL
48
  image_data = Image.open(temp_image_path)
 
52
  "role": "user",
53
  "content": [
54
  {"type": "image", "image": image_data},
 
55
  ],
56
  }
57
  ]
58
 
59
+ # Process the message and prepare for Qwen2-VL
60
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
  image_inputs, _ = process_vision_info(messages)
62
 
63
+ # Move the image inputs and processor outputs to CUDA
64
  inputs = processor(
65
  text=[text_input],
66
  images=image_inputs,
67
  padding=True,
68
  return_tensors="pt",
69
+ ).to(device)
 
 
 
70
 
71
  # Generate the output with Qwen2-VL
72
  generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
73
+ output_text = processor.batch_decode(
74
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
75
+ )
76
+
77
+ # Filter out "You are a helpful assistant" and "assistant" labels
78
+ filtered_output = [line for line in output_text[0].split("\n") if not any(kw in line.lower() for kw in ["you are a helpful assistant", "assistant", "user", "system"])]
79
+ extracted_text = "\n".join(filtered_output).strip()
80
+
81
+ # Clean up the temporary file
82
  os.remove(temp_image_path)
83
 
84
  return extracted_text
 
88
  traceback.print_exc()
89
  return f"Error: {error_message}"
90
 
91
+ def search_keywords(keywords):
 
92
  if not extracted_text:
93
  return "No text extracted yet. Please upload an image."
94
+
95
+ # Highlight matching keywords in the extracted text
96
+ highlighted_text = extracted_text
97
+ for keyword in keywords.split():
98
+ highlighted_text = re.sub(f"({re.escape(keyword)})", r"<mark>\1</mark>", highlighted_text, flags=re.IGNORECASE)
99
+
100
+ # Return the highlighted text
101
+ return highlighted_text
102
+
103
+ # Gradio interface for image input and keyword search
104
+ with gr.Blocks() as iface:
105
+ # Image upload and text extraction section
106
+ with gr.Column():
107
+ img_input = gr.Image(type="pil", label="Upload an Image")
108
+ extracted_output = gr.Textbox(label="Extracted Text", interactive=False)
109
+
110
+ # Functionality to trigger the OCR and extraction
111
+ img_button = gr.Button("Extract Text")
112
+ img_button.click(fn=ocr_and_extract, inputs=img_input, outputs=extracted_output)
113
+
114
+ # Keyword search section
115
+ with gr.Column():
116
+ search_input = gr.Textbox(label="Enter keywords to search")
117
+ search_output = gr.HTML(label="Search Results")
118
+
119
+ # Functionality to search within the extracted text
120
+ search_button = gr.Button("Search")
121
+ search_button.click(fn=search_keywords, inputs=search_input, outputs=search_output)
122
+
123
+ iface.launch()