|
from abc import ABC |
|
from typing import Optional |
|
|
|
import math |
|
import torch |
|
from torch import _dynamo |
|
_dynamo.config.suppress_errors = True |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.nn.functional import mse_loss, l1_loss, binary_cross_entropy, cross_entropy, kl_div, nll_loss |
|
from pyro.distributions.conjugate import BetaBinomial |
|
from pyro.distributions import Normal |
|
from torch_geometric.nn import MessagePassing |
|
|
|
|
|
class NeighborEmbedding(MessagePassing, ABC): |
|
def __init__(self, hidden_channels, num_rbf, cutoff_lower, cutoff_upper): |
|
super(NeighborEmbedding, self).__init__(aggr="add") |
|
self.distance_proj = nn.Linear(num_rbf, hidden_channels) |
|
self.combine = nn.Linear(hidden_channels * 2, hidden_channels) |
|
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.distance_proj.weight) |
|
nn.init.xavier_uniform_(self.combine.weight) |
|
self.distance_proj.bias.data.fill_(0) |
|
self.combine.bias.data.fill_(0) |
|
|
|
def forward(self, x, edge_index, edge_weight, edge_attr): |
|
|
|
mask = edge_index[0] != edge_index[1] |
|
if not mask.all(): |
|
edge_index = edge_index[:, mask] |
|
edge_weight = edge_weight[mask] |
|
edge_attr = edge_attr[mask] |
|
|
|
C = self.cutoff(edge_weight) |
|
W = self.distance_proj(edge_attr) * C.view(-1, 1) |
|
|
|
x_neighbors = x |
|
|
|
x_neighbors = self.propagate(edge_index, x=x_neighbors, W=W, size=None) |
|
x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1)) |
|
return x_neighbors |
|
|
|
def message(self, x_j, W): |
|
return x_j * W |
|
|
|
|
|
class GaussianSmearing(nn.Module): |
|
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): |
|
super(GaussianSmearing, self).__init__() |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
self.num_rbf = num_rbf |
|
self.trainable = trainable |
|
|
|
offset, coeff = self._initial_params() |
|
if trainable: |
|
self.register_parameter("coeff", nn.Parameter(coeff)) |
|
self.register_parameter("offset", nn.Parameter(offset)) |
|
else: |
|
self.register_buffer("coeff", coeff) |
|
self.register_buffer("offset", offset) |
|
|
|
def _initial_params(self): |
|
offset = torch.linspace(self.cutoff_lower, self.cutoff_upper, self.num_rbf) |
|
coeff = -0.5 / (offset[1] - offset[0]) ** 2 |
|
return offset, coeff |
|
|
|
def reset_parameters(self): |
|
offset, coeff = self._initial_params() |
|
self.offset.data.copy_(offset) |
|
self.coeff.data.copy_(coeff) |
|
|
|
def forward(self, dist): |
|
dist = dist.unsqueeze(-1) - self.offset |
|
return torch.exp(self.coeff * torch.pow(dist, 2)) |
|
|
|
|
|
class ExpNormalSmearing(nn.Module): |
|
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): |
|
super(ExpNormalSmearing, self).__init__() |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
self.num_rbf = num_rbf |
|
self.trainable = trainable |
|
|
|
self.cutoff_fn = CosineCutoff(0, cutoff_upper) |
|
self.alpha = 5.0 / (cutoff_upper - cutoff_lower) |
|
|
|
means, betas = self._initial_params() |
|
if trainable: |
|
self.register_parameter("means", nn.Parameter(means)) |
|
self.register_parameter("betas", nn.Parameter(betas)) |
|
else: |
|
self.register_buffer("means", means) |
|
self.register_buffer("betas", betas) |
|
|
|
def _initial_params(self): |
|
|
|
|
|
start_value = torch.exp( |
|
torch.scalar_tensor(-self.cutoff_upper + self.cutoff_lower) |
|
) |
|
means = torch.linspace(start_value, 1, self.num_rbf) |
|
betas = torch.tensor( |
|
[(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf |
|
) |
|
return means, betas |
|
|
|
def reset_parameters(self): |
|
means, betas = self._initial_params() |
|
self.means.data.copy_(means) |
|
self.betas.data.copy_(betas) |
|
|
|
def forward(self, dist): |
|
dist = dist.unsqueeze(-1) |
|
return self.cutoff_fn(dist) * torch.exp( |
|
-self.betas |
|
* (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2 |
|
) |
|
|
|
|
|
class ExpNormalSmearingUnlimited(nn.Module): |
|
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): |
|
super(ExpNormalSmearingUnlimited, self).__init__() |
|
self.num_rbf = num_rbf |
|
self.trainable = trainable |
|
|
|
self.alpha = 1 / 20 |
|
|
|
means, betas = self._initial_params() |
|
if trainable: |
|
self.register_parameter("means", nn.Parameter(means)) |
|
self.register_parameter("betas", nn.Parameter(betas)) |
|
else: |
|
self.register_buffer("means", means) |
|
self.register_buffer("betas", betas) |
|
|
|
def _initial_params(self): |
|
|
|
|
|
start_value = 0.1 |
|
means = torch.linspace(start_value, 1, self.num_rbf) |
|
betas = torch.tensor( |
|
[(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf |
|
) |
|
return means, betas |
|
|
|
def reset_parameters(self): |
|
means, betas = self._initial_params() |
|
self.means.data.copy_(means) |
|
self.betas.data.copy_(betas) |
|
|
|
def forward(self, dist): |
|
dist = dist.unsqueeze(-1) |
|
return torch.exp( |
|
-self.betas * (torch.exp(self.alpha * (-dist)) - self.means) ** 2 |
|
) |
|
|
|
|
|
class ShiftedSoftplus(nn.Module): |
|
def __init__(self): |
|
super(ShiftedSoftplus, self).__init__() |
|
self.shift = torch.log(torch.tensor(2.0)).item() |
|
|
|
def forward(self, x): |
|
return F.softplus(x) - self.shift |
|
|
|
|
|
class CosineCutoff(nn.Module): |
|
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0): |
|
super(CosineCutoff, self).__init__() |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
|
|
def forward(self, distances): |
|
if self.cutoff_lower > 0: |
|
cutoffs = 0.5 * ( |
|
torch.cos( |
|
math.pi |
|
* ( |
|
2 |
|
* (distances - self.cutoff_lower) |
|
/ (self.cutoff_upper - self.cutoff_lower) |
|
+ 1.0 |
|
) |
|
) |
|
+ 1.0 |
|
) |
|
|
|
cutoffs = cutoffs * (distances < self.cutoff_upper).float() |
|
cutoffs = cutoffs * (distances > self.cutoff_lower).float() |
|
return cutoffs |
|
else: |
|
cutoffs = 0.5 * (torch.cos(distances * math.pi / self.cutoff_upper) + 1.0) |
|
|
|
cutoffs = cutoffs * (distances < self.cutoff_upper).float() |
|
return cutoffs |
|
|
|
|
|
class Distance(nn.Module): |
|
def __init__( |
|
self, |
|
cutoff_lower, |
|
cutoff_upper, |
|
return_vecs=False, |
|
loop=False, |
|
): |
|
super(Distance, self).__init__() |
|
self.cutoff_lower = cutoff_lower |
|
self.cutoff_upper = cutoff_upper |
|
self.return_vecs = return_vecs |
|
self.loop = loop |
|
|
|
def forward(self, pos, edge_index): |
|
edge_vec = pos[edge_index[0]] - pos[edge_index[1]] |
|
|
|
mask: Optional[torch.Tensor] = None |
|
if self.loop: |
|
|
|
|
|
|
|
mask = edge_index[0] != edge_index[1] |
|
edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device, dtype=edge_vec.dtype) |
|
edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) |
|
else: |
|
edge_weight = torch.norm(edge_vec, dim=-1) |
|
|
|
lower_mask = edge_weight >= self.cutoff_lower |
|
if self.loop and mask is not None: |
|
|
|
lower_mask = lower_mask | ~mask |
|
edge_index = edge_index[:, lower_mask] |
|
edge_weight = edge_weight[lower_mask] |
|
|
|
if self.return_vecs: |
|
edge_vec = edge_vec[lower_mask] |
|
return edge_index, edge_weight, edge_vec |
|
|
|
|
|
return edge_index, edge_weight, None |
|
|
|
|
|
class DistanceV2(nn.Module): |
|
def __init__( |
|
self, |
|
return_vecs=True, |
|
loop=False, |
|
): |
|
super(DistanceV2, self).__init__() |
|
self.return_vecs = return_vecs |
|
self.loop = loop |
|
|
|
def forward(self, pos, coords, edge_index): |
|
|
|
|
|
ca_ca = pos[edge_index[1]] - pos[edge_index[0]] |
|
cb_cb = coords[edge_index[1], :, [0]] - coords[edge_index[0], :, [0]] |
|
cb_N = coords[edge_index[1], :, [2]] - coords[edge_index[0], :, [0]] |
|
cb_O = coords[edge_index[1], :, [3]] - coords[edge_index[0], :, [0]] |
|
edge_vec = torch.cat([ca_ca.unsqueeze(-1), |
|
cb_cb.unsqueeze(-1), |
|
cb_N.unsqueeze(-1), |
|
cb_O.unsqueeze(-1)], dim=-1) |
|
mask: Optional[torch.Tensor] = None |
|
if self.loop: |
|
mask = edge_index[0] != edge_index[1] |
|
edge_weight = torch.zeros(ca_ca.size(0), device=ca_ca.device, dtype=ca_ca.dtype) |
|
edge_weight[mask] = torch.norm(ca_ca[mask], dim=-1) |
|
else: |
|
edge_weight = torch.norm(ca_ca, dim=-1) |
|
|
|
return edge_index, edge_weight, edge_vec |
|
|
|
|
|
|
|
rbf_class_mapping = {"gauss": GaussianSmearing, "expnorm": ExpNormalSmearing, "expnormunlim": ExpNormalSmearingUnlimited} |
|
|
|
|
|
class AbsTanh(nn.Module): |
|
def __init__(self): |
|
super(AbsTanh, self).__init__() |
|
|
|
@staticmethod |
|
def forward(x: torch.Tensor) -> torch.Tensor: |
|
return torch.abs(torch.tanh(x)) |
|
|
|
|
|
class Tanh2(nn.Module): |
|
def __init__(self): |
|
super(Tanh2, self).__init__() |
|
|
|
@staticmethod |
|
def forward(x: torch.Tensor) -> torch.Tensor: |
|
return torch.square(torch.tanh(x)) |
|
|
|
def gelu(x): |
|
"""Implementation of the gelu activation function. |
|
|
|
For information: OpenAI GPT's gelu is slightly different |
|
(and gives slightly different results): |
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
""" |
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
act_class_mapping = { |
|
"ssp": ShiftedSoftplus, |
|
"softplus": nn.Softplus, |
|
"silu": nn.SiLU, |
|
"leaky_relu": nn.LeakyReLU, |
|
"tanh": nn.Tanh, |
|
"sigmoid": nn.Sigmoid, |
|
"pass": nn.Identity, |
|
"abs_tanh": AbsTanh, |
|
"tanh2": Tanh2, |
|
"softmax": nn.Softmax, |
|
"gelu": nn.GELU, |
|
} |
|
|
|
|
|
def cosin_contrastive_loss(input, target, margin=0): |
|
if target.ndim == 1: |
|
target = target.unsqueeze(1) |
|
if input.shape[0] == 1: |
|
return torch.tensor(0, dtype=input.dtype, device=input.device) |
|
|
|
dist = F.cosine_similarity(input.unsqueeze(1), input.unsqueeze(0), dim=2) |
|
|
|
sim = torch.eq(target, target.T) |
|
|
|
sim = sim.float() * 2 - 1 |
|
|
|
loss = - dist * sim + (sim + 1) / 2 + (sim - 1) * margin / 2 |
|
|
|
loss = torch.clamp(loss.triu(diagonal=1), min=0).sum() / (target.shape[0] * (target.shape[0] - 1) / 2) |
|
return loss |
|
|
|
|
|
def euclid_contrastive_loss(input, target): |
|
if target.ndim == 1: |
|
target = target.unsqueeze(1) |
|
if input.shape[0] == 1: |
|
return torch.tensor(0, dtype=input.dtype, device=input.device) |
|
|
|
margin = 10 * input.shape[1] |
|
|
|
dist = torch.cdist(input, input) |
|
|
|
sim = torch.eq(target, target.T) |
|
|
|
sim = sim.float() * 2 - 1 |
|
|
|
mask = (dist > margin).float() * (sim == -1).float() |
|
loss = dist * sim * (1 - mask) |
|
|
|
loss = loss.triu(diagonal=1).sum() / (target.shape[0] * (target.shape[0] - 1) / 2) |
|
return loss |
|
|
|
|
|
class WeightedCombinedLoss(nn.modules.loss._WeightedLoss): |
|
""" |
|
Weighted combined loss function. |
|
Input weight should be a tensor of shape (5,). |
|
The first 2 weights are for the patho/beni loss |
|
The last 3 weights are for the beni/gof/lof loss |
|
""" |
|
def __init__(self, weight: Optional[torch.Tensor] = None, |
|
task_weight: float = 10.0, |
|
size_average=None, ignore_index: int = -100, |
|
reduce=None, reduction: str = 'mean') -> None: |
|
super().__init__(weight, size_average, reduce, reduction) |
|
self.ignore_index = ignore_index |
|
self.task_weight = task_weight |
|
|
|
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
return combined_loss(input, target, |
|
weight_1=self.weight[:2], |
|
weight_2=self.weight[2:], |
|
weight=self.task_weight, |
|
reduction=self.reduction) |
|
|
|
|
|
class WeightedLoss1(nn.modules.loss._WeightedLoss): |
|
""" |
|
Weighted combined loss function. |
|
Input weight should be a tensor of shape (5,). |
|
The first 2 weights are for the patho/beni loss |
|
The last 3 weights are for the beni/gof/lof loss |
|
""" |
|
def __init__(self, weight: Optional[torch.Tensor] = None, |
|
task_weight: float = 10.0, |
|
size_average=None, ignore_index: int = -100, |
|
reduce=None, reduction: str = 'mean') -> None: |
|
super().__init__(weight, size_average, reduce, reduction) |
|
self.ignore_index = ignore_index |
|
self.task_weight = task_weight |
|
|
|
def forward(self, input: torch.Tensor, target: torch.Tensor, weight: torch.Tensor=None, reduce=True, reduction=None) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if reduction is None: |
|
reduction = self.reduction |
|
weight_1 = self.weight[:2] |
|
target_1 = (target).float() |
|
weight_loss_1 = torch.ones_like(target_1, dtype=input.dtype, device=input.device) |
|
weight_loss_1[target == 1] = weight_1[1] / weight_1[0] |
|
if weight is not None: |
|
weight_loss_1 *= weight |
|
loss_1 = binary_cross_entropy(input=input, |
|
target=target_1, |
|
weight=weight_loss_1, |
|
reduce=reduce, |
|
reduction=reduction) |
|
return loss_1 |
|
|
|
|
|
class WeightedLoss2(nn.modules.loss._WeightedLoss): |
|
""" |
|
Weighted combined loss function. |
|
Input weight should be a tensor of shape (5,). |
|
The first 2 weights are for the patho/beni loss |
|
The last 3 weights are for the beni/gof/lof loss |
|
""" |
|
def __init__(self, weight: Optional[torch.Tensor] = None, |
|
task_weight: float = 10.0, |
|
size_average=None, ignore_index: int = -100, |
|
reduce=None, reduction: str = 'mean') -> None: |
|
super().__init__(weight, size_average, reduce, reduction) |
|
self.ignore_index = ignore_index |
|
self.task_weight = task_weight |
|
|
|
def forward(self, input: torch.Tensor, target: torch.Tensor, weight: torch.Tensor=None, reduce=True, reduction=None) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_1 = (-1/3 * target**3 + target**2 + 1/3 * target).float() |
|
if reduction is None: |
|
reduction = self.reduction |
|
loss_1 = binary_cross_entropy(input=input, |
|
target=target_1, |
|
weight=weight, |
|
reduce=reduce, |
|
reduction=reduction) |
|
|
|
filter = (target == -1) | (target == 1) |
|
|
|
if not filter.any(): |
|
return 0 * loss_1 |
|
|
|
weight_2 = self.weight[2:] |
|
|
|
target_2 = (1/2 * (-target + 1)).float() |
|
|
|
weight_loss_2 = torch.ones_like(target_2, dtype=input.dtype, device=input.device) |
|
weight_loss_2[target == 1] = weight_2[1] / weight_2[0] |
|
if weight is not None: |
|
weight_loss_2 *= weight |
|
loss_2 = binary_cross_entropy(input=input[filter], |
|
target=target_2[filter], |
|
weight=weight_loss_2[filter], |
|
reduce=reduce, |
|
reduction=reduction) |
|
return loss_2 |
|
|
|
|
|
class WeightedLoss3(nn.modules.loss._WeightedLoss): |
|
""" |
|
Weighted combined loss function. |
|
Input weight should be a tensor of shape (5,). |
|
The first 2 weights are for the patho/beni loss |
|
The last 3 weights are for the beni/gof/lof loss |
|
""" |
|
def __init__(self, weight: Optional[torch.Tensor] = None, |
|
task_weight: float = 10.0, |
|
size_average=None, ignore_index: int = -100, |
|
reduce=None, reduction: str = 'mean') -> None: |
|
super().__init__(weight, size_average, reduce, reduction) |
|
self.ignore_index = ignore_index |
|
self.task_weight = task_weight |
|
|
|
def forward(self, input: torch.Tensor, target: torch.Tensor, weight: torch.Tensor=None, reduce=True, reduction=None) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_1 = (-1/3 * target**3 + target**2 + 1/3 * target).float() |
|
if reduction is None: |
|
reduction = self.reduction |
|
loss_1 = binary_cross_entropy(input=input[:, 0]/(input[:, 0] + input[:, 1]), |
|
target=target_1, |
|
weight=weight, |
|
reduce=reduce, |
|
reduction=reduction) |
|
|
|
filter = (target == -1) | (target == 1) |
|
|
|
if not filter.any(): |
|
return 0 * loss_1 |
|
|
|
weight_2 = self.weight[2:] |
|
|
|
target_2 = (1/2 * (-target + 1)).float() |
|
|
|
weight_loss_2 = torch.ones_like(target_2, dtype=input.dtype, device=input.device) |
|
weight_loss_2[target == 1] = weight_2[1] / weight_2[0] |
|
if weight is not None: |
|
weight_loss_2 *= weight |
|
loss_2 = -BetaBinomial( |
|
concentration1=input[:, 0][filter], |
|
concentration0=input[:, 1][filter], |
|
total_count=1 |
|
).log_prob(target_2[filter]) |
|
|
|
loss_2 *= weight_loss_2[filter] |
|
|
|
loss_2 = loss_2.mean() |
|
return loss_2 |
|
|
|
|
|
class RegressionWeightedLoss(nn.modules.loss._WeightedLoss): |
|
""" |
|
Weighted combined loss function. |
|
Input weight should be a tensor of shape (5,). |
|
The first 2 weights are for the patho/beni loss |
|
The last 3 weights are for the beni/gof/lof loss |
|
""" |
|
def __init__(self, weight: Optional[torch.Tensor] = None, |
|
task_weight: float = 10.0, |
|
size_average=None, ignore_index: int = -100, |
|
reduce=None, reduction: str = 'mean') -> None: |
|
super().__init__(weight, size_average, reduce, reduction) |
|
self.ignore_index = ignore_index |
|
self.task_weight = task_weight |
|
|
|
def forward(self, input, target) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
regression_target = target[:, 1:] |
|
regression_input = input[:, 1:] |
|
regression_loss = mse_loss(input=regression_input, |
|
target=regression_target, |
|
reduction=self.reduction) |
|
target = target[:, [0]] |
|
input = input[:, [0]] |
|
target_1 = (-1/3 * target**3 + target**2 + 1/3 * target).float() |
|
loss_1 = binary_cross_entropy(input=input, |
|
target=target_1, |
|
reduction=self.reduction) |
|
|
|
filter = (target == -1) | (target == 1) |
|
|
|
if not filter.any(): |
|
return 0 * loss_1 |
|
|
|
weight_2 = self.weight[2:] |
|
|
|
target_2 = (1/2 * (-target + 1)).float() |
|
|
|
weight_loss_2 = torch.ones_like(target_2, dtype=input.dtype, device=input.device) |
|
weight_loss_2[target == 1] = weight_2[1] / weight_2[0] |
|
loss_2 = binary_cross_entropy(input=input[filter], |
|
target=target_2[filter], |
|
weight=weight_loss_2[filter], |
|
reduction=self.reduction) |
|
return loss_2 + regression_loss |
|
|
|
|
|
class GPLoss(nn.modules.loss._WeightedLoss): |
|
def __init__(): |
|
super().__init__() |
|
|
|
|
|
def combined_loss(input: torch.Tensor, target: torch.Tensor, |
|
weight: float=10.0, |
|
weight_1: Optional[torch.Tensor]=None, |
|
weight_2: Optional[torch.Tensor]=None, |
|
reduction: str = 'mean') -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if target.ndim == 2: |
|
target = target.squeeze(1) |
|
|
|
target_1 = (-1/3 * target**3 + target**2 + 1/3 * target).float() |
|
|
|
target_2 = (1/2 * (target + 1)).float() |
|
|
|
|
|
weight_loss_1 = torch.ones_like(target_1, dtype=input.dtype, device=input.device) |
|
weight_loss_1[target_1 == 1] = weight_1[0] / weight_1[1] |
|
loss_1 = binary_cross_entropy(input=input[:, 0], |
|
target=target_1, |
|
weight=weight_loss_1, |
|
reduction=reduction) |
|
|
|
|
|
filter = (target == -1) | (target == 1) |
|
|
|
if not filter.any(): |
|
return loss_1 |
|
weight_loss_2 = torch.ones_like(target_2, dtype=input.dtype, device=input.device) |
|
weight_loss_2[target_2 == 1] = weight_2[0] / weight_2[1] |
|
loss_2 = binary_cross_entropy(input=input[filter, 1], |
|
target=target_2[filter], |
|
weight=weight_loss_2[filter], |
|
reduction=reduction) |
|
|
|
|
|
if not (target == 0).any(): |
|
loss = loss_2 |
|
else: |
|
loss = loss_1 + weight * loss_2 |
|
return loss |
|
|
|
|
|
def gaussian_loss(input: torch.Tensor, target: torch.Tensor): |
|
|
|
|
|
|
|
|
|
loss = -Normal(loc=input[:, 0], scale=torch.nn.functional.softplus(input[:, 1])).log_prob(target).mean() |
|
loss += torch.nn.functional.softplus(input[:, 1]).mean() |
|
return loss |
|
|
|
|
|
def mse_loss_weighted(input: torch.Tensor, target: torch.Tensor, weight: torch.Tensor=None, reduce=True, reduction=None) -> torch.Tensor: |
|
|
|
mse = (input - target).pow(2) |
|
if weight is not None: |
|
mse *= weight |
|
if reduce: |
|
return mse.mean() |
|
else: |
|
return mse |
|
|
|
|
|
loss_fn_mapping = { |
|
"mse_loss": mse_loss, |
|
"mse_loss_weighted": mse_loss_weighted, |
|
"l1_loss": l1_loss, |
|
"binary_cross_entropy": binary_cross_entropy, |
|
"cross_entropy": cross_entropy, |
|
"kl_div": kl_div, |
|
"cosin_contrastive_loss": cosin_contrastive_loss, |
|
"euclid_contrastive_loss": euclid_contrastive_loss, |
|
"combined_loss": combined_loss, |
|
"weighted_combined_loss": WeightedCombinedLoss, |
|
"weighted_loss": WeightedLoss2, |
|
"weighted_loss_betabinomial": WeightedLoss3, |
|
"gaussian_loss": gaussian_loss, |
|
"weighted_loss_pretrain": WeightedLoss1, |
|
"regression_weighted_loss": RegressionWeightedLoss, |
|
"GP_loss": GPLoss, |
|
} |
|
|
|
|
|
def get_template_fn(template): |
|
if template == 'plain-distance': |
|
return plain_distance, 1 |
|
elif template == 'exp-normal-smearing-distance': |
|
return exp_normal_smearing_distance, 50 |
|
|
|
def plain_distance(pos): |
|
eps=1e-10 |
|
CA = pos[..., 3, :] |
|
d = (eps + (CA[..., None, :, :] - CA[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)) ** 0.5 |
|
return d |
|
|
|
def exp_normal_smearing_distance(pos, cutoff_upper=100, cutoff_lower=0, num_rbf=50): |
|
alpha = 5.0 / (cutoff_upper - cutoff_lower) |
|
start_value = torch.exp( |
|
torch.scalar_tensor(-cutoff_upper + cutoff_lower) |
|
).to(pos.device) |
|
means = torch.linspace(start_value, 1, num_rbf).to(pos.device) |
|
betas = torch.tensor( |
|
[(2 / num_rbf * (1 - start_value)) ** -2] * num_rbf |
|
).to(pos.device) |
|
dist = plain_distance(pos) |
|
cutoffs = 0.5 * (torch.cos(dist * math.pi / cutoff_upper).to(pos.device) + 1.0) |
|
|
|
cutoffs = cutoffs * (dist < cutoff_upper).float() |
|
return cutoffs * torch.exp( |
|
-betas * (torch.exp(alpha * (-dist + cutoff_lower)) - means) ** 2 |
|
) |
|
|
|
|