Spaces:
Sleeping
Sleeping
File size: 3,817 Bytes
4902aa0 3bc9acc 4902aa0 0f8e862 da194cb 3bc9acc 4902aa0 d529377 4902aa0 3bc9acc 4902aa0 3bc9acc 4902aa0 da194cb 4902aa0 da194cb 4902aa0 96376c2 d529377 da194cb 3bc9acc da194cb 86115e8 da194cb 3bc9acc da194cb 3bc9acc da194cb 73d504c da194cb b120b83 1fb71cb 96376c2 b120b83 da194cb b120b83 da194cb b120b83 da194cb b120b83 da194cb b120b83 da194cb b120b83 da194cb b120b83 da194cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
import torch
from qwen_vl_utils import process_vision_info
from PIL import Image
import gradio as gr
import re
rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
vlm = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
torch_dtype=torch.float32,
trust_remote_code=True,
device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
def extract_text(image, query):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": query},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
inputs = inputs.to("cpu")
with torch.no_grad():
generated_ids = vlm.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
def search_text(text, query):
if query:
searched_text = re.sub(f'({re.escape(query)})', r'<span style="background-color: yellow;">\1</span>', text, flags=re.IGNORECASE)
else:
searched_text = text
return searched_text
def extraction(image, query):
extracted_text = extract_text(image, query)
return extracted_text, extracted_text # return twice - one to display output and the other for state management
"""
Main App
"""
with gr.Blocks() as main_app:
gr.Markdown("# Document Reader using OCR(English/Hindi)")
gr.Markdown("### Use Doc_Reader to extract text out of documents - images(OCR) or ask questions based on the input image")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Upload an Image")
gr.Markdown("""
### Please use this prompt for text extraction
**What text can you identify in this image? Include everything, even if it's partially obscured or in the background.**
""")
query_input = gr.Textbox(label="Enter query for retrieval", placeholder="Query/Prompt")
extract_button = gr.Button("Read Doc!")
search_input = gr.Textbox(label="Enter search term", placeholder="Search")
search_button = gr.Button("Search!")
with gr.Column():
extracted_text_op = gr.Textbox(label="Output")
search_text_op = gr.HTML(label="Search Results")
download_button = gr.Button("Download Plain Text")
# Retrieval
extracted_text_state = gr.State()
extract_button.click(
extraction,
inputs=[img_input, query_input],
outputs=[extracted_text_op, extracted_text_state]
)
# Search
search_button.click(
search_text,
inputs=[extracted_text_state, search_input],
outputs=[search_text_op]
)
# Download
download_button.click(
lambda text: gr.File.save_text_to_file(text, "extracted_text.txt"),
inputs=[extracted_text_state],
outputs=[gr.File(label="Download Extracted Text")]
)
main_app.launch() |