Heramb26 commited on
Commit
0218a3d
1 Parent(s): 9206f14

Custom Model

Browse files
Files changed (2) hide show
  1. app.py +18 -24
  2. requirements.txt +4 -0
app.py CHANGED
@@ -1,39 +1,33 @@
1
- import gradio as gr
2
- import os
3
  import torch
4
  from PIL import Image
5
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
 
6
 
7
- # Set up device
8
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
- # Load the fine-tuned model
11
- checkpoint_path = './checkpoint-2070' # Path to your fine-tuned model checkpoint
12
- model = VisionEncoderDecoderModel.from_pretrained(checkpoint_path).to(device)
13
 
14
- # Use the original model's processor (tokenizer and feature extractor)
 
15
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
16
 
17
  def ocr_image(image):
18
  """
19
- Perform OCR on a single image.
20
- :param image: PIL Image object.
21
- :return: Extracted text from the image.
22
  """
23
- pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
 
24
  generated_ids = model.generate(pixel_values)
25
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
26
  return generated_text
27
 
28
- # Define the Gradio interface
29
- interface = gr.Interface(
30
- fn=ocr_image, # Function to call for prediction
31
- inputs=gr.inputs.Image(type="pil"), # Accept an image as input
32
- outputs="text", # Return extracted text
33
- title="OCR with TrOCR",
34
- description="Upload an image, and the fine-tuned TrOCR model will extract the text for you."
35
- )
36
-
37
- # Launch the Gradio app
38
- if __name__ == "__main__":
39
- interface.launch()
 
 
 
1
  import torch
2
  from PIL import Image
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
+ from huggingface_hub import hf_hub_download
5
+ import os
6
 
7
+ # Load the model checkpoint and tokenizer files from Hugging Face Model Hub
8
+ checkpoint_folder = hf_hub_download(repo_id="Heramb26/tr-ocr-custom-checkpoints", filename="checkpoint-2070")
9
 
10
+ # Set up the device (GPU or CPU)
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
12
 
13
+ # Load the fine-tuned model and processor from the downloaded folder
14
+ model = VisionEncoderDecoderModel.from_pretrained(checkpoint_folder).to(device)
15
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
16
 
17
  def ocr_image(image):
18
  """
19
+ Perform OCR on an image using the loaded model.
20
+ :param image: Input PIL image.
21
+ :return: Extracted text.
22
  """
23
+ # Preprocess image and generate OCR text
24
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
25
  generated_ids = model.generate(pixel_values)
26
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
27
  return generated_text
28
 
29
+ # Example usage
30
+ image_path = "path/to/your/image.jpg" # Update with the path to your image
31
+ image = Image.open(image_path) # Open the image file using PIL
32
+ extracted_text = ocr_image(image) # Perform OCR on the image
33
+ print("Extracted Text:", extracted_text)
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ pillow