from datasets import load_dataset from transformers import TrOCRProcessor, VisionEncoderDecoderModel import torch from PIL import Image import requests # Load dataset dataset = load_dataset("nielsr/funsd") # Load pre-trained model and processor processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") # Preprocess the dataset def preprocess_images(examples): images = [Image.open(img).convert("RGB") for img in examples['image_path']] pixel_values = processor(images=images, return_tensors="pt").pixel_values return {"pixel_values": pixel_values} encoded_dataset = dataset.map(preprocess_images, batched=True) # Preprocess the labels max_length = 64 def preprocess_labels(examples): labels = processor.tokenizer(examples['words'], is_split_into_words=True, padding="max_length", max_length=max_length, truncation=True) return labels encoded_dataset = encoded_dataset.map(preprocess_labels, batched=True) # Prepare for training model.config.decoder_start_token_id = processor.tokenizer.cls_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id # Define training arguments from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments training_args = Seq2SeqTrainingArguments( output_dir="./trocr-finetuned-funsd", per_device_train_batch_size=8, per_device_eval_batch_size=8, learning_rate=5e-5, num_train_epochs=3, weight_decay=0.01, logging_dir="./trocr-finetuned-funsd/logs", logging_steps=10, evaluation_strategy="epoch", save_strategy="epoch", ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=encoded_dataset["train"], eval_dataset=encoded_dataset["test"], tokenizer=processor.tokenizer, ) # Train the model trainer.train()