# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT import random import torch from torch.utils.data import Dataset from torch.utils.data import sampler #import lmdb import torchvision.transforms as transforms import six import sys from PIL import Image import numpy as np import os import sys import pickle import numpy as np from params import * import glob, cv2 import torchvision.transforms as transforms def crop_(input): image = Image.fromarray(input) image = image.convert('L') binary_image = image.point(lambda x: 0 if x > 127 else 255, '1') bbox = binary_image.getbbox() cropped_image = image.crop(bbox) return np.array(cropped_image) def get_transform(grayscale=False, convert=True): transform_list = [] if grayscale: transform_list.append(transforms.Grayscale(1)) if convert: transform_list += [transforms.ToTensor()] if grayscale: transform_list += [transforms.Normalize((0.5,), (0.5,))] else: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def load_itw_samples(folder_path, num_samples = 15): if isinstance(folder_path, str): paths = glob.glob(f'{folder_path}/*') else: paths = folder_path paths = np.random.choice(paths, num_samples, replace = len(paths)<=num_samples) words = [os.path.basename(path_i)[:-4] for path_i in paths] imgs = [np.array(Image.open(i).convert('L')) for i in paths] imgs = [crop_(im) for im in imgs] imgs = [cv2.resize(imgs_i, (int(32*(imgs_i.shape[1]/imgs_i.shape[0])), 32)) for imgs_i in imgs] max_width = 192 imgs_pad = [] imgs_wids = [] trans_fn = get_transform(grayscale=True) for img in imgs: img = 255 - img img_height, img_width = img.shape[0], img.shape[1] outImg = np.zeros(( img_height, max_width), dtype='float32') outImg[:, :img_width] = img[:, :max_width] img = 255 - outImg imgs_pad.append(trans_fn((Image.fromarray(img)))) imgs_wids.append(img_width) imgs_pad = torch.cat(imgs_pad, 0) return imgs_pad.unsqueeze(0), torch.Tensor(imgs_wids).unsqueeze(0) class TextDataset(): def __init__(self, base_path = DATASET_PATHS, num_examples = 15, target_transform=None): self.NUM_EXAMPLES = num_examples #base_path = DATASET_PATHS file_to_store = open(base_path, "rb") self.IMG_DATA = pickle.load(file_to_store)['train'] self.IMG_DATA = dict(list( self.IMG_DATA.items())) #[:NUM_WRITERS]) if 'None' in self.IMG_DATA.keys(): del self.IMG_DATA['None'] self.author_id = list(self.IMG_DATA.keys()) self.transform = get_transform(grayscale=True) self.target_transform = target_transform self.collate_fn = TextCollator() def __len__(self): return len(self.author_id) def __getitem__(self, index): NUM_SAMPLES = self.NUM_EXAMPLES author_id = self.author_id[index] self.IMG_DATA_AUTHOR = self.IMG_DATA[author_id] random_idxs = np.random.choice(len(self.IMG_DATA_AUTHOR), NUM_SAMPLES, replace = True) rand_id_real = np.random.choice(len(self.IMG_DATA_AUTHOR)) real_img = self.transform(self.IMG_DATA_AUTHOR[rand_id_real]['img'].convert('L')) real_labels = self.IMG_DATA_AUTHOR[rand_id_real]['label'].encode() imgs = [np.array(self.IMG_DATA_AUTHOR[idx]['img'].convert('L')) for idx in random_idxs] labels = [self.IMG_DATA_AUTHOR[idx]['label'].encode() for idx in random_idxs] max_width = 192 #[img.shape[1] for img in imgs] imgs_pad = [] imgs_wids = [] for img in imgs: img = 255 - img img_height, img_width = img.shape[0], img.shape[1] outImg = np.zeros(( img_height, max_width), dtype='float32') outImg[:, :img_width] = img[:, :max_width] img = 255 - outImg imgs_pad.append(self.transform((Image.fromarray(img)))) imgs_wids.append(img_width) imgs_pad = torch.cat(imgs_pad, 0) item = {'simg': imgs_pad, 'swids':imgs_wids, 'img' : real_img, 'label':real_labels,'img_path':'img_path', 'idx':'indexes', 'wcl':index} return item class TextDatasetval(): def __init__(self, base_path = DATASET_PATHS, num_examples = 15, target_transform=None): self.NUM_EXAMPLES = num_examples #base_path = DATASET_PATHS file_to_store = open(base_path, "rb") self.IMG_DATA = pickle.load(file_to_store)['test'] self.IMG_DATA = dict(list( self.IMG_DATA.items()))#[NUM_WRITERS:]) if 'None' in self.IMG_DATA.keys(): del self.IMG_DATA['None'] self.author_id = list(self.IMG_DATA.keys()) self.transform = get_transform(grayscale=True) self.target_transform = target_transform self.collate_fn = TextCollator() def __len__(self): return len(self.author_id) def __getitem__(self, index): NUM_SAMPLES = self.NUM_EXAMPLES author_id = self.author_id[index] self.IMG_DATA_AUTHOR = self.IMG_DATA[author_id] random_idxs = np.random.choice(len(self.IMG_DATA_AUTHOR), NUM_SAMPLES, replace = True) rand_id_real = np.random.choice(len(self.IMG_DATA_AUTHOR)) real_img = self.transform(self.IMG_DATA_AUTHOR[rand_id_real]['img'].convert('L')) real_labels = self.IMG_DATA_AUTHOR[rand_id_real]['label'].encode() imgs = [np.array(self.IMG_DATA_AUTHOR[idx]['img'].convert('L')) for idx in random_idxs] labels = [self.IMG_DATA_AUTHOR[idx]['label'].encode() for idx in random_idxs] max_width = 192 #[img.shape[1] for img in imgs] imgs_pad = [] imgs_wids = [] for img in imgs: img = 255 - img img_height, img_width = img.shape[0], img.shape[1] outImg = np.zeros(( img_height, max_width), dtype='float32') outImg[:, :img_width] = img[:, :max_width] img = 255 - outImg imgs_pad.append(self.transform((Image.fromarray(img)))) imgs_wids.append(img_width) imgs_pad = torch.cat(imgs_pad, 0) item = {'simg': imgs_pad, 'swids':imgs_wids, 'img' : real_img, 'label':real_labels,'img_path':'img_path', 'idx':'indexes', 'wcl':index} return item class TextCollator(object): def __init__(self): self.resolution = resolution def __call__(self, batch): img_path = [item['img_path'] for item in batch] width = [item['img'].shape[2] for item in batch] indexes = [item['idx'] for item in batch] simgs = torch.stack([item['simg'] for item in batch], 0) wcls = torch.Tensor([item['wcl'] for item in batch]) swids = torch.Tensor([item['swids'] for item in batch]) imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], max(width)], dtype=torch.float32) for idx, item in enumerate(batch): try: imgs[idx, :, :, 0:item['img'].shape[2]] = item['img'] except: print(imgs.shape) item = {'img': imgs, 'img_path':img_path, 'idx':indexes, 'simg': simgs, 'swids': swids, 'wcl':wcls} if 'label' in batch[0].keys(): labels = [item['label'] for item in batch] item['label'] = labels if 'z' in batch[0].keys(): z = torch.stack([item['z'] for item in batch]) item['z'] = z return item