alikayh commited on
Commit
ea4688b
·
verified ·
1 Parent(s): c40d74b

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +59 -0
train.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ import torch
4
+ from PIL import Image
5
+ import requests
6
+
7
+ # Load dataset
8
+ dataset = load_dataset("nielsr/funsd")
9
+
10
+ # Load pre-trained model and processor
11
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
12
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
13
+
14
+ # Preprocess the dataset
15
+ def preprocess_images(examples):
16
+ images = [Image.open(img).convert("RGB") for img in examples['image_path']]
17
+ pixel_values = processor(images=images, return_tensors="pt").pixel_values
18
+ return {"pixel_values": pixel_values}
19
+
20
+ encoded_dataset = dataset.map(preprocess_images, batched=True)
21
+
22
+ # Preprocess the labels
23
+ max_length = 64
24
+
25
+ def preprocess_labels(examples):
26
+ labels = processor.tokenizer(examples['words'], is_split_into_words=True, padding="max_length", max_length=max_length, truncation=True)
27
+ return labels
28
+
29
+ encoded_dataset = encoded_dataset.map(preprocess_labels, batched=True)
30
+
31
+ # Prepare for training
32
+ model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
33
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
34
+
35
+ # Define training arguments
36
+ from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
37
+
38
+ training_args = Seq2SeqTrainingArguments(
39
+ output_dir="./trocr-finetuned-funsd",
40
+ per_device_train_batch_size=8,
41
+ per_device_eval_batch_size=8,
42
+ learning_rate=5e-5,
43
+ num_train_epochs=3,
44
+ weight_decay=0.01,
45
+ logging_dir="./trocr-finetuned-funsd/logs",
46
+ logging_steps=10,
47
+ evaluation_strategy="epoch",
48
+ save_strategy="epoch",
49
+ )
50
+ trainer = Seq2SeqTrainer(
51
+ model=model,
52
+ args=training_args,
53
+ train_dataset=encoded_dataset["train"],
54
+ eval_dataset=encoded_dataset["test"],
55
+ tokenizer=processor.tokenizer,
56
+ )
57
+
58
+ # Train the model
59
+ trainer.train()