from transformers import ( NougatProcessor, VisionEncoderDecoderModel, TextIteratorStreamer, ) import gradio as gr import torch from pathlib import Path from pdf2image import convert_from_path import spaces from threading import Thread from gradio_pdf import PDF import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) models_supported = { "arabic-small-nougat": [ NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat"), VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat"), ], "arabic-base-nougat": [ NougatProcessor.from_pretrained("MohamedRashad/arabic-base-nougat"), VisionEncoderDecoderModel.from_pretrained( "MohamedRashad/arabic-base-nougat", torch_dtype=torch.bfloat16, attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"}, ), ], "arabic-large-nougat": [ NougatProcessor.from_pretrained("MohamedRashad/arabic-large-nougat"), VisionEncoderDecoderModel.from_pretrained( "MohamedRashad/arabic-large-nougat", torch_dtype=torch.bfloat16, attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"}, ), ], } @spaces.GPU def extract_text_from_image(image, model_name): print(f"Extracting text from image using model: {model_name}") processor, model = models_supported[model_name] context_length = model.decoder.config.max_position_embeddings torch_dtype = model.dtype device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) pixel_values = ( processor(image, return_tensors="pt").pixel_values.to(torch_dtype).to(device) ) streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True) # Start generation in a separate thread generation_kwargs = { "pixel_values": pixel_values, "min_length": 1, "max_new_tokens": context_length, "repetition_penalty": 1.5, "streamer": streamer, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Yield tokens as they become available output = "" for token in streamer: output += token yield output thread.join() @spaces.GPU def extract_text_from_pdf(pdf_path, model_name): processor, model = models_supported[model_name] context_length = model.decoder.config.max_position_embeddings torch_dtype = model.dtype device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True) print(f"Extracting text from PDF: {pdf_path}") images = convert_from_path(pdf_path) pdf_output = "" for image in images: pixel_values = ( processor(image, return_tensors="pt") .pixel_values.to(torch_dtype) .to(device) ) # Start generation in a separate thread generation_kwargs = { "pixel_values": pixel_values, "min_length": 1, "max_new_tokens": context_length, "repetition_penalty": 1.5, "streamer": streamer, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Yield tokens as they become available for token in streamer: pdf_output += token yield pdf_output thread.join() pdf_output += "\n\n" yield pdf_output model_description = """This is the official demo for the Arabic Nougat models. It is an end-to-end Markdown Extraction model that extracts text from images or PDFs and write them in Markdown. There are three models available: - [arabic-small-nougat](https://huggingface.co/MohamedRashad/arabic-small-nougat): A small model that is faster but less accurate (a finetune from [facebook/nougat-small](https://huggingface.co/facebook/nougat-small)). - [arabic-base-nougat](https://huggingface.co/MohamedRashad/arabic-base-nougat): A base model that is more accurate but slower (a finetune from [facebook/nougat-base](https://huggingface.co/facebook/nougat-base)). - [arabic-large-nougat](https://huggingface.co/MohamedRashad/arabic-large-nougat): The largest of the three (Made from scratch using [riotu-lab/Aranizer-PBE-86k](https://huggingface.co/riotu-lab/Aranizer-PBE-86k) tokenizer and a larger transformer decoder model). **Disclaimer**: These models hallucinate text and are not perfect. They are trained on a mix of synthetic and real data and may not work well on all types of images. """ example_images = list(Path(__file__).parent.glob("*.jpeg")) example_pdfs = [str(p) for p in Path(__file__).parent.glob("*.pdf")] with gr.Blocks(title="Arabic Nougat") as demo: gr.HTML( "

Arabic End-to-End Structured OCR for textbooks

" ) gr.Markdown(model_description) with gr.Tab("Extract Text from Image"): with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") model_dropdown = gr.Dropdown( label="Model", choices=list(models_supported.keys()), value=None ) image_submit_button = gr.Button(value="Submit", variant="primary") output = gr.Markdown(label="Output Markdown", rtl=True) image_submit_button.click( extract_text_from_image, inputs=[input_image, model_dropdown], outputs=output, ) gr.Examples( example_images, [input_image], output, extract_text_from_image, cache_examples=False, ) with gr.Tab("Extract Text from PDF"): with gr.Row(): with gr.Column(): input_pdf = PDF(label="Input PDF") model_dropdown = gr.Dropdown( label="Model", choices=list(models_supported.keys()), value=None ) pdf_submit_button = gr.Button(value="Submit", variant="primary") output = gr.Markdown(label="Output Markdown", rtl=True) pdf_submit_button.click( extract_text_from_pdf, inputs=[input_pdf, model_dropdown], outputs=output ) gr.Examples( example_pdfs, [input_pdf], output, extract_text_from_pdf, cache_examples=False, ) demo.queue().launch(share=False)