File size: 3,234 Bytes
16d9c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import re
import streamlit as st
from transformers import AutoModel, AutoTokenizer

# Set up page configuration for wide layout
st.set_page_config(
    page_title="OCR and Document Search Web Application Prototype Using GOT-OCR 2.0",
    layout="wide",
)


# Function to initialize model and tokenizer
@st.cache_resource
def init_model():
    tokenizer = AutoTokenizer.from_pretrained(
        "srimanth-d/GOT_CPU", trust_remote_code=True
    )
    model = AutoModel.from_pretrained(
        "srimanth-d/GOT_CPU",
        trust_remote_code=True,
        use_safetensors=True,
        pad_token_id=tokenizer.eos_token_id,
    )
    return model.eval(), tokenizer


# Function to extract text from the image using OCR model
@st.cache_data
def extract_text(image_file, _model, _tokenizer):
    res = _model.chat(_tokenizer, image_file, ocr_type="ocr")
    return res


# Function to highlight search term in extracted text
def highlight_text(text, search_term):
    if not search_term:
        return text
    pattern = re.compile(re.escape(search_term), re.IGNORECASE)
    return pattern.sub(
        lambda m: f'<span style="background-color: #FFFF00; font-weight: bold;">{m.group()}</span>',
        text,
    )


# Streamlit UI components
st.title("OCR and Document Search Web Application Prototype Using GOT-OCR 2.0")
st.write("Upload an image for OCR")

# Initialize model and tokenizer
model, tokenizer = init_model()

# Create columns for layout
col1, col2 = st.columns([1, 2])  # Adjust proportions as needed

with col1:
    # File uploader for images
    uploaded_image = st.file_uploader(
        "Upload Image", type=["jpg", "png", "jpeg"], key="image_upload"
    )

    # If an image is uploaded
    if uploaded_image:
        # Save the uploaded image to a local directory
        if not os.path.exists("images"):
            os.makedirs("images")
        image_path = os.path.join("images", uploaded_image.name)
        with open(image_path, "wb") as f:
            f.write(uploaded_image.getbuffer())

        # Create buttons for viewing the full image and clearing the image
        col1a, col1b = st.columns([0.5, 0.5])
        with col1a:
            if st.button("View Full Image"):
                # Show full image on demand
                st.image(image_path, caption="Full Size Image", use_column_width=True)

    else:
        st.info("Please upload an image to perform OCR.")

    # Fallback text in case no image is uploaded
    extracted_text = ""

    # Once the image is uploaded, extract the text
    if uploaded_image:
        extracted_text = extract_text(image_path, model, tokenizer)

with col2:
    # Input field for keyword search
    search_term = st.text_input("Enter a word or phrase to search:")

    # Highlight search term in extracted text
    highlighted_text = highlight_text(extracted_text, search_term)

    # Display search results
    if search_term:
        if search_term.lower() in highlighted_text.lower():
            st.success(f"Word **'{search_term}'** found in the text!")
        else:
            st.error(f"Word **'{search_term}'** not found.")

    # Show the extracted text with highlighted search terms
    st.markdown(highlighted_text, unsafe_allow_html=True)