Spaces:
Running
Running
File size: 1,822 Bytes
e17e8cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
"""This code is taken from <https://github.com/alexandre01/deepsvg>
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
"""
import torch
def linear(a, b, x, min_x, max_x):
"""
b ___________
/|
/ |
a _______/ |
| |
min_x max_x
"""
return a + min(max((x - min_x) / (max_x - min_x), 0), 1) * (b - a)
def batchify(data, device):
return (d.unsqueeze(0).to(device) for d in data)
def _make_seq_first(*args):
# N, G, S, ... -> S, G, N, ...
if len(args) == 1:
arg, = args
return arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None
return (*(arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None for arg in args),)
def _make_batch_first(*args):
# S, G, N, ... -> N, G, S, ...
if len(args) == 1:
arg, = args
return arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None
return (*(arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None for arg in args),)
def _pack_group_batch(*args):
# S, G, N, ... -> S, G * N, ...
if len(args) == 1:
arg, = args
return arg.reshape(arg.size(0), arg.size(1) * arg.size(2), *arg.shape[3:]) if arg is not None else None
return (*(arg.reshape(arg.size(0), arg.size(1) * arg.size(2), *arg.shape[3:]) if arg is not None else None for arg in args),)
def _unpack_group_batch(N, *args):
# S, G * N, ... -> S, G, N, ...
if len(args) == 1:
arg, = args
return arg.reshape(arg.size(0), -1, N, *arg.shape[2:]) if arg is not None else None
return (*(arg.reshape(arg.size(0), -1, N, *arg.shape[2:]) if arg is not None else None for arg in args),)
|