Vinay15 commited on
Commit
4a0b98f
·
verified ·
1 Parent(s): ecc49db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
- import torch # Importing torch to check CUDA availability
5
 
6
  # Check CUDA availability
7
  def check_cuda():
@@ -14,35 +14,41 @@ def check_cuda():
14
  # Load the tokenizer and model
15
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
16
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, device_map="auto", use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
17
- model = model.eval() # No need for .cuda() with device_map="auto"
18
 
19
  # Define the OCR function
20
- def perform_ocr(image):
21
  # Check for CUDA availability and print the result
22
  cuda_info = check_cuda()
23
- print(cuda_info) # This will be logged in the output
24
 
25
- # Convert PIL image to RGB format (if necessary)
26
  if image.mode != "RGB":
27
  image = image.convert("RGB")
28
 
29
- # Save the image to a temporary path
30
- image_file_path = 'temp_image.jpg'
31
- image.save(image_file_path)
32
-
33
  # Perform OCR using the model
34
- res = model.chat(tokenizer, image_file_path, ocr_type='ocr')
35
 
36
- return res
 
 
 
 
37
 
38
  # Define the Gradio interface
39
  interface = gr.Interface(
40
  fn=perform_ocr,
41
- inputs=gr.Image(type="pil", label="Upload Image"),
42
- outputs=gr.Textbox(label="Extracted Text"),
 
 
 
 
 
 
43
  title="OCR and Document Search Web Application",
44
- description="Upload an image to extract text using the GOT-OCR2_0 model."
45
  )
46
 
47
  # Launch the Gradio app
48
- interface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  from PIL import Image
4
+ import torch
5
 
6
  # Check CUDA availability
7
  def check_cuda():
 
14
  # Load the tokenizer and model
15
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
16
  model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, device_map="auto", use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
17
+ model.eval()
18
 
19
  # Define the OCR function
20
+ def perform_ocr(image, keyword):
21
  # Check for CUDA availability and print the result
22
  cuda_info = check_cuda()
23
+ print(cuda_info)
24
 
25
+ # Convert PIL image to RGB format
26
  if image.mode != "RGB":
27
  image = image.convert("RGB")
28
 
 
 
 
 
29
  # Perform OCR using the model
30
+ res = model.chat(tokenizer, image, ocr_type='ocr')
31
 
32
+ # Check for keyword in the extracted text
33
+ if keyword.lower() in res.lower():
34
+ return res, f'Keyword "{keyword}" found in the text.'
35
+ else:
36
+ return res, f'Keyword "{keyword}" not found in the text.'
37
 
38
  # Define the Gradio interface
39
  interface = gr.Interface(
40
  fn=perform_ocr,
41
+ inputs=[
42
+ gr.Image(type="pil", label="Upload Image"),
43
+ gr.Textbox(label="Enter Keyword to Search")
44
+ ],
45
+ outputs=[
46
+ gr.Textbox(label="Extracted Text"),
47
+ gr.Textbox(label="Search Result")
48
+ ],
49
  title="OCR and Document Search Web Application",
50
+ description="Upload an image to extract text using the GOT-OCR2_0 model and search for a keyword."
51
  )
52
 
53
  # Launch the Gradio app
54
+ interface.launch()