MohamedRashad
commited on
Commit
•
e825481
1
Parent(s):
9d6dab6
Update README.md
Browse files
README.md
CHANGED
@@ -36,7 +36,11 @@ from transformers import NougatProcessor, VisionEncoderDecoderModel
|
|
36 |
|
37 |
# Load the model and processor
|
38 |
processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-large-nougat")
|
39 |
-
model = VisionEncoderDecoderModel.from_pretrained(
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Get the max context length of the model & dtype of the weights
|
42 |
context_length = model.decoder.config.max_position_embeddings
|
@@ -46,10 +50,13 @@ torch_dtype = model.dtype
|
|
46 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
model.to(device)
|
48 |
|
|
|
49 |
def predict(img_path):
|
50 |
# prepare PDF image for the model
|
51 |
image = Image.open(img_path)
|
52 |
-
pixel_values =
|
|
|
|
|
53 |
|
54 |
# generate transcription
|
55 |
outputs = model.generate(
|
@@ -61,10 +68,11 @@ def predict(img_path):
|
|
61 |
)
|
62 |
|
63 |
page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
64 |
-
page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
|
65 |
return page_sequence
|
66 |
|
|
|
67 |
print(predict("path/to/page_image.jpg"))
|
|
|
68 |
```
|
69 |
|
70 |
|
|
|
36 |
|
37 |
# Load the model and processor
|
38 |
processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-large-nougat")
|
39 |
+
model = VisionEncoderDecoderModel.from_pretrained(
|
40 |
+
"MohamedRashad/arabic-large-nougat",
|
41 |
+
torch_dtype=torch.bfloat16,
|
42 |
+
attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"},
|
43 |
+
)
|
44 |
|
45 |
# Get the max context length of the model & dtype of the weights
|
46 |
context_length = model.decoder.config.max_position_embeddings
|
|
|
50 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
51 |
model.to(device)
|
52 |
|
53 |
+
|
54 |
def predict(img_path):
|
55 |
# prepare PDF image for the model
|
56 |
image = Image.open(img_path)
|
57 |
+
pixel_values = (
|
58 |
+
processor(image, return_tensors="pt").pixel_values.to(torch_dtype).to(device)
|
59 |
+
)
|
60 |
|
61 |
# generate transcription
|
62 |
outputs = model.generate(
|
|
|
68 |
)
|
69 |
|
70 |
page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
|
|
71 |
return page_sequence
|
72 |
|
73 |
+
|
74 |
print(predict("path/to/page_image.jpg"))
|
75 |
+
|
76 |
```
|
77 |
|
78 |
|