File size: 2,374 Bytes
5c31d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import math

import torch.nn.functional as F

from .utils import *


def get_schedule(timesteps, schedule):
    end = round(len(timesteps) * schedule)
    timesteps = timesteps[:end]
    return timesteps


def get_elem(l, i, default=0.0):
    if i >= len(l):
        return default
    return l[i]


def pad_list(l_1, l_2, pad=0.0):
    max_len = max(len(l_1), len(l_2))
    l_1 = l_1 + [pad] * (max_len - len(l_1))
    l_2 = l_2 + [pad] * (max_len - len(l_2))
    return l_1, l_2
    
    
def normalize(x, dim):
    x_mean = x.mean(dim=dim, keepdim=True)
    x_std = x.std(dim=dim, keepdim=True)
    x_normalized = (x - x_mean) / x_std
    return x_normalized


# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
def appearance_mean_std(q_c_normed, k_s_normed, v_s):  # c: content, s: style
    q_c = q_c_normed  # q_c and k_s must be projected from normalized features
    k_s = k_s_normed
    mean = F.scaled_dot_product_attention(q_c, k_s, v_s)  # Use scaled_dot_product_attention for efficiency
    std = (F.scaled_dot_product_attention(q_c, k_s, v_s.square()) - mean.square()).relu().sqrt()
    
    return mean, std
    

def feature_injection(features, batch_order):
    assert features.shape[0] % len(batch_order) == 0
    features_dict = batch_tensor_to_dict(features, batch_order)
    features_dict["cond"] = features_dict["structure_cond"]
    features = batch_dict_to_tensor(features_dict, batch_order)
    return features


def appearance_transfer(features, q_normed, k_normed, batch_order, v=None, reshape_fn=None):
    assert features.shape[0] % len(batch_order) == 0

    features_dict = batch_tensor_to_dict(features, batch_order)
    q_normed_dict = batch_tensor_to_dict(q_normed, batch_order)
    k_normed_dict = batch_tensor_to_dict(k_normed, batch_order)
    v_dict = features_dict
    if v is not None:
        v_dict = batch_tensor_to_dict(v, batch_order)
    
    mean_cond, std_cond = appearance_mean_std(
        q_normed_dict["cond"], k_normed_dict["appearance_cond"], v_dict["appearance_cond"],
    )

    if reshape_fn is not None:
        mean_cond = reshape_fn(mean_cond)
        std_cond = reshape_fn(std_cond)

    features_dict["cond"] = std_cond * normalize(features_dict["cond"], dim=-2) + mean_cond
    
    features = batch_dict_to_tensor(features_dict, batch_order)
    return features