from transformers import VisionEncoderDecoderConfig
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import re
import requests
from PIL import Image
from io import BytesIO
url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRCeH216oW6FXeTpN4ijvakW8_frP3vnCBIKQ&s"
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img.show()
config = VisionEncoderDecoderConfig.from_pretrained('jjjlangem/He-Donut')
processor = DonutProcessor.from_pretrained('jjjlangem/He-Donut')
model = VisionEncoderDecoderModel.from_pretrained('jjjlangem/He-Donut')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
with torch.no_grad():
pixel_values = processor(img, random_padding=False, return_tensors="pt").pixel_values
batch_size = pixel_values.shape[0]
decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id,
device=device)
outputs = model.generate(pixel_values.to(device),
decoder_input_ids=decoder_input_ids,
max_length= 768,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True)
predictions = []
for seq in processor.tokenizer.batch_decode(outputs.sequences):
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "").replace(processor.tokenizer.bos_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip()
predictions.append(seq)
print(predictions)
- Downloads last month
- 5
Inference API (serverless) does not yet support transformers models for this pipeline type.
Model tree for jjjlangem/He-Donut
Base model
naver-clova-ix/donut-base