TusharGoel's picture
Update README.md
74b7d9d
|
raw
history blame
1.42 kB
metadata
license: mit
language:
  - en
pipeline_tag: document-question-answering

This model trained on DocVQA Dataset on sample of 15000 questions

from transformers import AutoTokenizer, AutoModelForDocumentQuestionAnswering
from datasets import load_dataset

model_checkpoint = "TusharGoel/LayoutLM-Finetuned-DocVQA"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
model_predict = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint)

model_predict.eval()
dataset = load_dataset("nielsr/funsd", split="train")
example = dataset[0]

question = "What's Licensee Number?"

words = example["words"]
boxes = example["bboxes"]

encoding = tokenizer(question.split(), words,
                            is_split_into_words=True, return_token_type_ids=True, return_tensors="pt")

bbox = []
for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
    if s == 1:
        bbox.append(boxes[w])
    elif i == tokenizer.sep_token_id:
        bbox.append([1000] * 4)
    else:
        bbox.append([0] * 4)
encoding["bbox"] = torch.tensor([bbox])

word_ids = encoding.word_ids(0)
outputs = model_predict(**encoding)

loss = outputs.loss
start_scores = outputs.start_logits
end_scores = outputs.end_logits

start, end = word_ids[start_scores.argmax(-1).item()], word_ids[end_scores.argmax(-1).item()]
print(" ".join(words[start : end + 1]))