File size: 4,393 Bytes
301d77a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
import streamlit as st
import os
from PIL import Image
import requests
import torch
import json
from torchvision import io
from typing import Dict
import re

@st.cache_resource
def init_model():
    tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
    model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
    model = model.eval()
    return model, tokenizer

def init_gpu_model():
    tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
    model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
    model = model.eval().cuda()
    return model, tokenizer

def init_qwen_model():
    model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
    return model, processor

def get_quen_op(image_file, model, processor):
    try: 
        image = Image.open(image_file).convert('RGB')
        conversation = [
            {
                "role":"user",
                "content":[
                    {
                        "type":"image",
                    },
                    {
                        "type":"text",
                        "text":"Extract text from this image."
                    }
                ]
            }
        ]
        text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
        inputs = {k: v.to(torch.float32) if torch.is_floating_point(v) else v for k, v in inputs.items()}

        generation_config = {
            "max_new_tokens": 32,
            "do_sample": False,
            "top_k": 20,
            "top_p": 0.90,
            "temperature": 0.4,
            "num_return_sequences": 1,
            "pad_token_id": processor.tokenizer.pad_token_id,
            "eos_token_id": processor.tokenizer.eos_token_id,
        }

        output_ids = model.generate(**inputs, **generation_config)
        if 'input_ids' in inputs:
                generated_ids = output_ids[:, inputs['input_ids'].shape[1]:]
        else:
            generated_ids = output_ids
            
        output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            
        return output_text[:] if output_text else "No text extracted from the image."
    
    except Exception as e:
        return f"An error occurred: {str(e)}"

@st.cache_data
def get_text(image_file, _model, _tokenizer):
    res = _model.chat(_tokenizer, image_file, ocr_type='ocr')
    return res

def highlight_text(text, search_term):
    if not search_term:
        return text
    pattern = re.compile(re.escape(search_term), re.IGNORECASE)
    return pattern.sub(lambda m: f'<span style="background-color: grey;">{m.group()}</span>', text)

def save_text_to_json(file_name, text_data):
    """Save the extracted text into a JSON file."""
    with open(file_name, 'w') as json_file:
        json.dump({"extracted_text": text_data}, json_file, indent=4)
    st.success(f"Text saved to {file_name}")

st.title("Extract text from the image using  - GOT-OCR2.0 and search keyword")
st.write("Upload an image")

MODEL, PROCESSOR = init_model()

image_file = st.file_uploader("Upload Image", type=['jpg', 'png', 'jpeg'])

if image_file:
    if not os.path.exists("images"):
        os.makedirs("images")
    with open(f"images/{image_file.name}", "wb") as f:
        f.write(image_file.getbuffer())

    image_file = f"images/{image_file.name}"

    text = get_text(image_file, MODEL, PROCESSOR)

    print(text)
    
    # Add search functionality
    search_term = st.text_input("Enter a word or phrase to search:")
    highlighted_text = highlight_text(text, search_term)
    
    st.markdown(highlighted_text, unsafe_allow_html=True)

    # Save the extracted text in JSON
    json_file_path = f"{image_file}_extracted.json"
    save_text_to_json(json_file_path, text)