import gradio as gr from byaldi import RAGMultiModalModel from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import torch from PIL import Image import os import traceback import spaces # Ensure import for GPU management # Load the Byaldi and Qwen2-VL models without using .cuda() rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali") qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16 ) # Processor for Qwen2-VL processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True) @spaces.GPU # Decorate the function for GPU management def ocr_and_extract(image, text_query): try: # Save the uploaded image temporarily temp_image_path = "temp_image.jpg" image.save(temp_image_path) # Index the image with Byaldi rag_model.index( input_path=temp_image_path, index_name="image_index", store_collection_with_index=False, overwrite=True ) # Perform the search query on the indexed image results = rag_model.search(text_query, k=1) # Prepare the input for Qwen2-VL image_data = Image.open(temp_image_path) messages = [ { "role": "user", "content": [ {"type": "image", "image": image_data}, {"type": "text", "text": text_query}, ], } ] # Process the message and prepare for Qwen2-VL text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, _ = process_vision_info(messages) inputs = processor( text=[text_input], images=image_inputs, padding=True, return_tensors="pt", ) # Move the Qwen2-VL model and inputs to GPU qwen_model.to("cuda") inputs = {k: v.to("cuda") for k, v in inputs.items()} # Generate the output with Qwen2-VL generated_ids = qwen_model.generate(**inputs, max_new_tokens=50) output_text = processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) # Clean up the temporary file os.remove(temp_image_path) return output_text[0] except Exception as e: error_message = str(e) traceback.print_exc() return f"Error: {error_message}" # Gradio interface for image input iface = gr.Interface( fn=ocr_and_extract, inputs=[ gr.Image(type="pil"), gr.Textbox(label="Enter your query (optional)"), ], outputs="text", title="Image OCR with Byaldi + Qwen2-VL", description="Upload an image (JPEG/PNG) containing Hindi and English text for OCR.", ) # Launch the Gradio app iface.launch()