MohamedRashad commited on
Commit
e825481
1 Parent(s): 9d6dab6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -3
README.md CHANGED
@@ -36,7 +36,11 @@ from transformers import NougatProcessor, VisionEncoderDecoderModel
36
 
37
  # Load the model and processor
38
  processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-large-nougat")
39
- model = VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-large-nougat", torch_dtype=torch.bfloat16, attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"})
 
 
 
 
40
 
41
  # Get the max context length of the model & dtype of the weights
42
  context_length = model.decoder.config.max_position_embeddings
@@ -46,10 +50,13 @@ torch_dtype = model.dtype
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  model.to(device)
48
 
 
49
  def predict(img_path):
50
  # prepare PDF image for the model
51
  image = Image.open(img_path)
52
- pixel_values = processor(image, return_tensors="pt").pixel_values.to(torch_dtype).to(device)
 
 
53
 
54
  # generate transcription
55
  outputs = model.generate(
@@ -61,10 +68,11 @@ def predict(img_path):
61
  )
62
 
63
  page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
64
- page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
65
  return page_sequence
66
 
 
67
  print(predict("path/to/page_image.jpg"))
 
68
  ```
69
 
70
 
 
36
 
37
  # Load the model and processor
38
  processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-large-nougat")
39
+ model = VisionEncoderDecoderModel.from_pretrained(
40
+ "MohamedRashad/arabic-large-nougat",
41
+ torch_dtype=torch.bfloat16,
42
+ attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"},
43
+ )
44
 
45
  # Get the max context length of the model & dtype of the weights
46
  context_length = model.decoder.config.max_position_embeddings
 
50
  device = "cuda" if torch.cuda.is_available() else "cpu"
51
  model.to(device)
52
 
53
+
54
  def predict(img_path):
55
  # prepare PDF image for the model
56
  image = Image.open(img_path)
57
+ pixel_values = (
58
+ processor(image, return_tensors="pt").pixel_values.to(torch_dtype).to(device)
59
+ )
60
 
61
  # generate transcription
62
  outputs = model.generate(
 
68
  )
69
 
70
  page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
 
71
  return page_sequence
72
 
73
+
74
  print(predict("path/to/page_image.jpg"))
75
+
76
  ```
77
 
78