uartimcs's picture
Upload 22 files
7cd9ba4 verified
raw
history blame
3.16 kB
"""
Donut
Copyright (c) 2022-present NAVER Corp.
MIT License
"""
import argparse
import json
import os
import re
from pathlib import Path
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm
from donut import DonutModel, JSONParseEvaluator, load_json, save_json
def test(args):
pretrained_model = DonutModel.from_pretrained(args.pretrained_model_name_or_path)
if torch.cuda.is_available():
pretrained_model.half()
pretrained_model.to("cuda")
pretrained_model.eval()
if args.save_path:
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
predictions = []
ground_truths = []
accs = []
evaluator = JSONParseEvaluator()
dataset = load_dataset(args.dataset_name_or_path, split=args.split)
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
ground_truth = json.loads(sample["ground_truth"])
if args.task_name == "docvqa":
output = pretrained_model.inference(
image=sample["image"],
prompt=f"<s_{args.task_name}><s_question>{ground_truth['gt_parses'][0]['question'].lower()}</s_question><s_answer>",
)["predictions"][0]
else:
output = pretrained_model.inference(image=sample["image"], prompt=f"<s_{args.task_name}>")["predictions"][0]
if args.task_name == "rvlcdip":
gt = ground_truth["gt_parse"]
score = float(output["class"] == gt["class"])
elif args.task_name == "docvqa":
# Note: we evaluated the model on the official website.
# In this script, an exact-match based score will be returned instead
gt = ground_truth["gt_parses"]
answers = set([qa_parse["answer"] for qa_parse in gt])
score = float(output["answer"] in answers)
else:
gt = ground_truth["gt_parse"]
score = evaluator.cal_acc(output, gt)
accs.append(score)
predictions.append(output)
ground_truths.append(gt)
scores = {
"ted_accuracies": accs,
"ted_accuracy": np.mean(accs),
"f1_accuracy": evaluator.cal_f1(predictions, ground_truths),
}
print(
f"Total number of samples: {len(accs)}, Tree Edit Distance (TED) based accuracy score: {scores['ted_accuracy']}, F1 accuracy score: {scores['f1_accuracy']}"
)
if args.save_path:
scores["predictions"] = predictions
scores["ground_truths"] = ground_truths
save_json(args.save_path, scores)
return predictions
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_name_or_path", type=str)
parser.add_argument("--dataset_name_or_path", type=str)
parser.add_argument("--split", type=str, default="test")
parser.add_argument("--task_name", type=str, default=None)
parser.add_argument("--save_path", type=str, default=None)
args, left_argv = parser.parse_known_args()
if args.task_name is None:
args.task_name = os.path.basename(args.dataset_name_or_path)
predictions = test(args)