ameerazam08's picture
Upload folder using huggingface_hub (#1)
9e7a39a
raw
history blame
1.74 kB
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
class ParseqPredictor(nn.Module):
def __init__(self, ckpt_path=None, freeze=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.parseq = torch.hub.load('./src/parseq', 'parseq', source='local').eval()
self.parseq.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
self.parseq_transform = transforms.Compose([
transforms.Resize(self.parseq.hparams.img_size, transforms.InterpolationMode.BICUBIC, antialias=True),
transforms.Normalize(0.5, 0.5)
])
if freeze:
self.freeze()
def freeze(self):
for param in self.parseq.parameters():
param.requires_grad_(False)
def forward(self, x):
x = torch.cat([self.parseq_transform(t[None]) for t in x])
logits = self.parseq(x.to(next(self.parameters()).device))
return logits
def img2txt(self, x):
pred = self(x)
label, confidence = self.parseq.tokenizer.decode(pred)
return label
def calc_loss(self, x, label):
preds = self(x) # (B, l, C) l=26, C=95
gt_ids = self.parseq.tokenizer.encode(label).to(preds.device) # (B, l_trun)
losses = []
for pred, gt_id in zip(preds, gt_ids):
eos_id = (gt_id == 0).nonzero().item()
gt_id = gt_id[1: eos_id]
pred = pred[:eos_id-1, :]
ce_loss = nn.functional.cross_entropy(pred.permute(1, 0)[None], gt_id[None])
ce_loss = torch.clamp(ce_loss, max = 1.0)
losses.append(ce_loss[None])
loss = torch.cat(losses)
return loss