Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import cairosvg | |
| from data_utils.common_utils import trans2_white_bg | |
| from PIL import Image | |
| import numpy as np | |
| def select_imgs(images_of_onefont, selected_cls, opts): | |
| # given selected char classes, return selected imgs | |
| # images_of_onefont: [bs, 52, opts.img_size, opts.img_size] | |
| # selected_cls: [bs, nshot] | |
| nums = selected_cls.size(1) | |
| selected_cls_ = selected_cls.unsqueeze(2) | |
| selected_cls_ = selected_cls_.unsqueeze(3) | |
| selected_cls_ = selected_cls_.expand(images_of_onefont.size(0), nums, opts.img_size, opts.img_size) | |
| selected_img = torch.gather(images_of_onefont, 1, selected_cls_) | |
| return selected_img | |
| def select_seqs(seqs_of_onefont, selected_cls, opts, seq_dim): | |
| nums = selected_cls.size(1) | |
| selected_cls_ = selected_cls.unsqueeze(2) | |
| selected_cls_ = selected_cls_.unsqueeze(3) | |
| selected_cls_ = selected_cls_.expand(seqs_of_onefont.size(0), nums, opts.max_seq_len, seq_dim) | |
| selected_seqs = torch.gather(seqs_of_onefont, 1, selected_cls_) | |
| return selected_seqs | |
| def select_seqlens(seqlens_of_onefont, selected_cls, opts): | |
| nums = selected_cls.size(1) | |
| selected_cls_ = selected_cls.unsqueeze(2) | |
| selected_cls_ = selected_cls_.expand(seqlens_of_onefont.size(0), nums, 1) # 64, nums, 1 | |
| selected_seqlens = torch.gather(seqlens_of_onefont, 1, selected_cls_) | |
| return selected_seqlens | |
| def trgcls_to_onehot(trg_cls, opts): | |
| trg_char = F.one_hot(trg_cls, num_classes=opts.char_num).squeeze(dim=1) | |
| return trg_char | |
| def shift_right(x, pad_value=None): | |
| if pad_value is None: | |
| shifted = F.pad(x, (0, 0, 0, 0, 1, 0))[:-1, :, :] | |
| else: | |
| shifted = torch.cat([pad_value, x], axis=0)[:-1, :, :] | |
| return shifted | |
| def length_form_embedding(emb): | |
| """Compute the length of each sequence in the batch | |
| Args: | |
| emb: [seq_len, batch, depth] | |
| Returns: | |
| a 0/1 tensor: [batch] | |
| """ | |
| absed = torch.abs(emb) | |
| sum_last = torch.sum(absed, dim=2, keepdim=True) | |
| mask = sum_last != 0 | |
| sum_except_batch = torch.sum(mask, dim=(0, 2), dtype=torch.long) | |
| return sum_except_batch | |
| def lognormal(y, mean, logstd, logsqrttwopi): | |
| y_mean = y - mean # NOTE y:[b*51*6, 1] mean: [b*51*6, 50] | |
| logstd_exp = logstd.exp() # NOTE [b*51*6, 50] | |
| y_mean_divide_exp = y_mean / logstd_exp | |
| return -0.5 * (y_mean_divide_exp) ** 2 - logstd - logsqrttwopi | |
| def sequence_mask(lengths, max_len=None): | |
| batch_size=lengths.numel() | |
| max_len=max_len or lengths.max() | |
| return (torch.arange(0, max_len, device=lengths.device) | |
| .type_as(lengths) | |
| .unsqueeze(0).expand(batch_size,max_len) | |
| .lt(lengths.unsqueeze(1))) | |
| def svg2img(path_svg, path_img, img_size): | |
| cairosvg.svg2png(url=path_svg, write_to=path_img, output_width=img_size, output_height=img_size) | |
| img_arr = trans2_white_bg(path_img) | |
| return img_arr | |
| def cal_img_l1_dist(path_img1, path_img2): | |
| img1 = np.array(Image.open(path_img1)) | |
| img2 = np.array(Image.open(path_img2)) | |
| dist = np.mean(np.abs(img1 - img2[:, :, 0])) | |
| return dist | |
| def cal_iou(path_img1, path_img2): | |
| img1 = np.array(Image.open(path_img1)) | |
| img2 = np.array(Image.open(path_img2))[:, :, 0] | |
| mask_img1 = img1 < (255 * 3 / 4) | |
| mask_img2 = img2 < (255 * 3 / 4) | |
| iou = np.sum(mask_img1 * mask_img2) / (np.sum(mask_img1 + mask_img2)) | |
| l1_dist = np.mean(np.abs(mask_img1.astype(float) - mask_img2.astype(float))) | |
| return iou, l1_dist |