Spaces:
Build error
Build error
import torch | |
class StyleTransferController(torch.nn.Module): | |
def __init__( | |
self, | |
num_control_params, | |
edim, | |
hidden_dim=256, | |
agg_method="mlp", | |
): | |
"""Plugin parameter controller module to map from input to target style. | |
Args: | |
num_control_params (int): Number of plugin parameters to predicted. | |
edim (int): Size of the encoder representations. | |
hidden_dim (int, optional): Hidden size of the 3-layer parameter predictor MLP. Default: 256 | |
agg_method (str, optional): Input/reference embed aggregation method ["conv" or "linear", "mlp"]. Default: "mlp" | |
""" | |
super().__init__() | |
self.num_control_params = num_control_params | |
self.edim = edim | |
self.hidden_dim = hidden_dim | |
self.agg_method = agg_method | |
if agg_method == "conv": | |
self.agg = torch.nn.Conv1d( | |
2, | |
1, | |
kernel_size=129, | |
stride=1, | |
padding="same", | |
bias=False, | |
) | |
mlp_in_dim = edim | |
elif agg_method == "linear": | |
self.agg = torch.nn.Linear(edim * 2, edim) | |
elif agg_method == "mlp": | |
self.agg = None | |
mlp_in_dim = edim * 2 | |
else: | |
raise ValueError(f"Invalid agg_method = {self.agg_method}.") | |
self.mlp = torch.nn.Sequential( | |
torch.nn.Linear(mlp_in_dim, hidden_dim), | |
torch.nn.LeakyReLU(0.01), | |
torch.nn.Linear(hidden_dim, hidden_dim), | |
torch.nn.LeakyReLU(0.01), | |
torch.nn.Linear(hidden_dim, num_control_params), | |
torch.nn.Sigmoid(), # normalize between 0 and 1 | |
) | |
def forward(self, e_x, e_y, z=None): | |
"""Forward pass to generate plugin parameters. | |
Args: | |
e_x (tensor): Input signal embedding of shape (batch, edim) | |
e_y (tensor): Target signal embedding of shape (batch, edim) | |
Returns: | |
p (tensor): Estimated control parameters of shape (batch, num_control_params) | |
""" | |
# use learnable projection | |
if self.agg_method == "conv": | |
e_xy = torch.stack((e_x, e_y), dim=1) # concat on channel dim | |
e_xy = self.agg(e_xy) | |
elif self.agg_method == "linear": | |
e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim | |
e_xy = self.agg(e_xy) | |
else: | |
e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim | |
# pass through MLP to project to control parametesr | |
p = self.mlp(e_xy.squeeze(1)) | |
return p | |