Spaces:
Runtime error
Runtime error
File size: 6,712 Bytes
028951c 6a43216 028951c 6a43216 028951c 6a43216 028951c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import collections
import dataclasses
import types
import pytorch_lightning as pl
import torch.utils.data
import transformers
from data import (
generate_annotated_images,
get_annotation_ground_truth_str,
DataItem,
get_extra_tokens,
Batch,
Split,
BatchCollateFunction,
)
from utils import load_pickle_or_build_object_and_save
@dataclasses.dataclass
class Model:
processor: transformers.models.donut.processing_donut.DonutProcessor
tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast
encoder_decoder: transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder.VisionEncoderDecoderModel
batch_collate_function: BatchCollateFunction
config: types.SimpleNamespace
def add_unknown_tokens_to_tokenizer(
tokenizer, encoder_decoder, unknown_tokens: list[str]
):
tokenizer.add_tokens(unknown_tokens)
encoder_decoder.decoder.resize_token_embeddings(len(tokenizer))
def find_unknown_tokens_for_tokenizer(tokenizer) -> collections.Counter:
unknown_tokens_counter = collections.Counter()
for annotated_image in generate_annotated_images():
ground_truth = get_annotation_ground_truth_str(annotated_image.annotation)
input_ids = tokenizer(ground_truth).input_ids
tokens = tokenizer.tokenize(ground_truth, add_special_tokens=True)
for token_id, token in zip(input_ids, tokens, strict=True):
if token_id == tokenizer.unk_token_id:
unknown_tokens_counter.update([token])
return unknown_tokens_counter
def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
tokenizer, token_ids
):
token_ids[token_ids == tokenizer.pad_token_id] = -100
return token_ids
@dataclasses.dataclass
class BatchCollateFunction:
processor: transformers.models.donut.processing_donut.DonutProcessor
tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast
decoder_sequence_max_length: int
def __call__(self, batch: list[DataItem], split: Split) -> Batch:
images = [di.image for di in batch]
images = self.processor(
images, random_padding=split == Split.train, return_tensors="pt"
).pixel_values
target_token_ids = self.tokenizer(
[di.target_string for di in batch],
add_special_tokens=False,
max_length=self.decoder_sequence_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
labels = replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
self.tokenizer, target_token_ids
)
data_indices = [di.data_index for di in batch]
return Batch(images=images, labels=labels, data_indices=data_indices)
def build_model(config: types.SimpleNamespace or object) -> Model:
donut_processor = transformers.DonutProcessor.from_pretrained(
config.pretrained_model_name
)
donut_processor.image_processor.size = dict(
width=config.image_width, height=config.image_height
)
donut_processor.image_processor.do_align_long_axis = False
tokenizer = donut_processor.tokenizer
encoder_decoder_config = transformers.VisionEncoderDecoderConfig.from_pretrained(
config.pretrained_model_name
)
encoder_decoder_config.encoder.image_size = (
config.image_width,
config.image_height,
)
encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(
config.pretrained_model_name, config=encoder_decoder_config
)
encoder_decoder_config.pad_token_id = tokenizer.pad_token_id
encoder_decoder_config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(
get_extra_tokens().benetech_prompt
)
encoder_decoder_config.bos_token_id = encoder_decoder_config.decoder_start_token_id
encoder_decoder_config.eos_token_id = tokenizer.convert_tokens_to_ids(
get_extra_tokens().benetech_prompt_end
)
extra_tokens = list(get_extra_tokens().__dict__.values())
add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, extra_tokens)
unknown_dataset_tokens = load_pickle_or_build_object_and_save(
config.unknown_tokens_for_tokenizer_path,
lambda: list(find_unknown_tokens_for_tokenizer(tokenizer).keys()),
)
add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, unknown_dataset_tokens)
tokenizer.eos_token_id = encoder_decoder_config.eos_token_id
batch_collate_function = BatchCollateFunction(
processor=donut_processor,
tokenizer=tokenizer,
decoder_sequence_max_length=config.decoder_sequence_max_length,
)
return Model(
processor=donut_processor,
tokenizer=tokenizer,
encoder_decoder=encoder_decoder,
batch_collate_function=batch_collate_function,
config=config,
)
def generate_token_strings(
model: Model, images: torch.Tensor, skip_special_tokens=True
) -> list[str]:
decoder_output = model.encoder_decoder.generate(
images,
max_length=10
if model.config.debug
else model.config.decoder_sequence_max_length,
eos_token_id=model.tokenizer.eos_token_id,
return_dict_in_generate=True,
)
return model.tokenizer.batch_decode(
decoder_output.sequences, skip_special_tokens=skip_special_tokens
)
def predict_string(image, model: Model):
image = model.processor(
image, random_padding=False, return_tensors="pt"
).pixel_values
string = generate_token_strings(model, image)[0]
return string
class LightningModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
self.model = build_model(config)
self.encoder_decoder = self.model.encoder_decoder
def training_step(self, batch: Batch, batch_idx: int) -> torch.Tensor:
loss = self.compute_loss(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch: Batch, batch_idx: int, dataset_idx: int = 0):
loss = self.compute_loss(batch)
self.log("val_loss", loss)
def compute_loss(self, batch: Batch) -> torch.Tensor:
outputs = self.encoder_decoder(pixel_values=batch.images, labels=batch.labels)
loss = outputs.loss
return loss
def configure_optimizers(self) -> torch.optim.Optimizer:
optimizer = torch.optim.Adam(
self.parameters(), lr=self.hparams["config"].learning_rate
)
return optimizer
|