Spaces:
Runtime error
Runtime error
| import sys | |
| from datetime import datetime | |
| import torch | |
| import numpy as np | |
| class Logger(object): | |
| def __init__(self, logpath, syspart=sys.stdout): | |
| self.terminal = syspart | |
| self.log = open(logpath, "a") | |
| def write(self, message): | |
| self.terminal.write(message) | |
| self.log.write(message) | |
| self.log.flush() | |
| def flush(self): | |
| # this flush method is needed for python 3 compatibility. | |
| # this handles the flush command by doing nothing. | |
| # you might want to specify some extra behavior here. | |
| pass | |
| def log(*args): | |
| print(f'[{datetime.now()}]', *args) | |
| class EMA: | |
| def __init__(self, beta): | |
| super().__init__() | |
| self.beta = beta | |
| def update_model_average(self, ma_model, current_model): | |
| for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): | |
| old_weight, up_weight = ma_params.data, current_params.data | |
| ma_params.data = self.update_average(old_weight, up_weight) | |
| def update_average(self, old, new): | |
| if old is None: | |
| return new | |
| return old * self.beta + (1 - self.beta) * new | |
| def sum_except_batch(x): | |
| return x.reshape(x.size(0), -1).sum(dim=-1) | |
| def remove_mean(x): | |
| mean = torch.mean(x, dim=1, keepdim=True) | |
| x = x - mean | |
| return x | |
| def remove_mean_with_mask(x, node_mask): | |
| masked_max_abs_value = (x * (1 - node_mask)).abs().sum().item() | |
| assert masked_max_abs_value < 1e-5, f'Error {masked_max_abs_value} too high' | |
| N = node_mask.sum(1, keepdims=True) | |
| mean = torch.sum(x, dim=1, keepdim=True) / N | |
| x = x - mean * node_mask | |
| return x | |
| def remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask): | |
| """ | |
| Subtract center of mass of fragments from coordinates of all atoms | |
| """ | |
| x_masked = x * center_of_mass_mask | |
| N = center_of_mass_mask.sum(1, keepdims=True) | |
| mean = torch.sum(x_masked, dim=1, keepdim=True) / N | |
| x = x - mean * node_mask | |
| return x | |
| def assert_mean_zero(x): | |
| mean = torch.mean(x, dim=1, keepdim=True) | |
| assert mean.abs().max().item() < 1e-4 | |
| def assert_mean_zero_with_mask(x, node_mask, eps=1e-10): | |
| assert_correctly_masked(x, node_mask) | |
| largest_value = x.abs().max().item() | |
| error = torch.sum(x, dim=1, keepdim=True).abs().max().item() | |
| rel_error = error / (largest_value + eps) | |
| assert rel_error < 1e-2, f'Mean is not zero, relative_error {rel_error}' | |
| def assert_partial_mean_zero_with_mask(x, node_mask, center_of_mass_mask, eps=1e-10): | |
| assert_correctly_masked(x, node_mask) | |
| x_masked = x * center_of_mass_mask | |
| largest_value = x_masked.abs().max().item() | |
| error = torch.sum(x_masked, dim=1, keepdim=True).abs().max().item() | |
| rel_error = error / (largest_value + eps) | |
| assert rel_error < 1e-2, f'Partial mean is not zero, relative_error {rel_error}' | |
| def assert_correctly_masked(variable, node_mask): | |
| assert (variable * (1 - node_mask)).abs().max().item() < 1e-4, \ | |
| 'Variables not masked properly.' | |
| def check_mask_correct(variables, node_mask): | |
| for i, variable in enumerate(variables): | |
| if len(variable) > 0: | |
| assert_correctly_masked(variable, node_mask) | |
| def center_gravity_zero_gaussian_log_likelihood(x): | |
| assert len(x.size()) == 3 | |
| B, N, D = x.size() | |
| assert_mean_zero(x) | |
| # r is invariant to a basis change in the relevant hyperplane. | |
| r2 = sum_except_batch(x.pow(2)) | |
| # The relevant hyperplane is (N-1) * D dimensional. | |
| degrees_of_freedom = (N-1) * D | |
| # Normalizing constant and logpx are computed: | |
| log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi) | |
| log_px = -0.5 * r2 + log_normalizing_constant | |
| return log_px | |
| def sample_center_gravity_zero_gaussian(size, device): | |
| assert len(size) == 3 | |
| x = torch.randn(size, device=device) | |
| # This projection only works because Gaussian is rotation invariant around | |
| # zero and samples are independent! | |
| x_projected = remove_mean(x) | |
| return x_projected | |
| def center_gravity_zero_gaussian_log_likelihood_with_mask(x, node_mask): | |
| assert len(x.size()) == 3 | |
| B, N_embedded, D = x.size() | |
| assert_mean_zero_with_mask(x, node_mask) | |
| # r is invariant to a basis change in the relevant hyperplane, the masked | |
| # out values will have zero contribution. | |
| r2 = sum_except_batch(x.pow(2)) | |
| # The relevant hyperplane is (N-1) * D dimensional. | |
| N = node_mask.squeeze(2).sum(1) # N has shape [B] | |
| degrees_of_freedom = (N-1) * D | |
| # Normalizing constant and logpx are computed: | |
| log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi) | |
| log_px = -0.5 * r2 + log_normalizing_constant | |
| return log_px | |
| def sample_center_gravity_zero_gaussian_with_mask(size, device, node_mask): | |
| assert len(size) == 3 | |
| x = torch.randn(size, device=device) | |
| x_masked = x * node_mask | |
| # This projection only works because Gaussian is rotation invariant around | |
| # zero and samples are independent! | |
| # TODO: check it | |
| x_projected = remove_mean_with_mask(x_masked, node_mask) | |
| return x_projected | |
| def standard_gaussian_log_likelihood(x): | |
| # Normalizing constant and logpx are computed: | |
| log_px = sum_except_batch(-0.5 * x * x - 0.5 * np.log(2*np.pi)) | |
| return log_px | |
| def sample_gaussian(size, device): | |
| x = torch.randn(size, device=device) | |
| return x | |
| def standard_gaussian_log_likelihood_with_mask(x, node_mask): | |
| # Normalizing constant and logpx are computed: | |
| log_px_elementwise = -0.5 * x * x - 0.5 * np.log(2*np.pi) | |
| log_px = sum_except_batch(log_px_elementwise * node_mask) | |
| return log_px | |
| def sample_gaussian_with_mask(size, device, node_mask): | |
| x = torch.randn(size, device=device) | |
| x_masked = x * node_mask | |
| return x_masked | |
| def concatenate_features(x, h): | |
| xh = torch.cat([x, h['categorical']], dim=2) | |
| if 'integer' in h: | |
| xh = torch.cat([xh, h['integer']], dim=2) | |
| return xh | |
| def split_features(z, n_dims, num_classes, include_charges): | |
| assert z.size(2) == n_dims + num_classes + include_charges | |
| x = z[:, :, 0:n_dims] | |
| h = {'categorical': z[:, :, n_dims:n_dims+num_classes]} | |
| if include_charges: | |
| h['integer'] = z[:, :, n_dims+num_classes:n_dims+num_classes+1] | |
| return x, h | |
| # For gradient clipping | |
| class Queue: | |
| def __init__(self, max_len=50): | |
| self.items = [] | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.items) | |
| def add(self, item): | |
| self.items.insert(0, item) | |
| if len(self) > self.max_len: | |
| self.items.pop() | |
| def mean(self): | |
| return np.mean(self.items) | |
| def std(self): | |
| return np.std(self.items) | |
| def gradient_clipping(flow, gradnorm_queue): | |
| # Allow gradient norm to be 150% + 2 * stdev of the recent history. | |
| max_grad_norm = 1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std() | |
| # Clips gradient and returns the norm | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| flow.parameters(), max_norm=max_grad_norm, norm_type=2.0) | |
| if float(grad_norm) > max_grad_norm: | |
| gradnorm_queue.add(float(max_grad_norm)) | |
| else: | |
| gradnorm_queue.add(float(grad_norm)) | |
| if float(grad_norm) > max_grad_norm: | |
| print(f'Clipped gradient with value {grad_norm:.1f} while allowed {max_grad_norm:.1f}') | |
| return grad_norm | |
| def disable_rdkit_logging(): | |
| """ | |
| Disables RDKit whiny logging. | |
| """ | |
| import rdkit.rdBase as rkrb | |
| import rdkit.RDLogger as rkl | |
| logger = rkl.logger() | |
| logger.setLevel(rkl.ERROR) | |
| rkrb.DisableLog('rdApp.error') | |
| class FoundNaNException(Exception): | |
| def __init__(self, x, h): | |
| x_nan_idx = self.find_nan_idx(x) | |
| h_nan_idx = self.find_nan_idx(h) | |
| self.x_h_nan_idx = x_nan_idx & h_nan_idx | |
| self.only_x_nan_idx = x_nan_idx.difference(h_nan_idx) | |
| self.only_h_nan_idx = h_nan_idx.difference(x_nan_idx) | |
| def find_nan_idx(z): | |
| idx = set() | |
| for i in range(z.shape[0]): | |
| if torch.any(torch.isnan(z[i])): | |
| idx.add(i) | |
| return idx | |
| def get_batch_idx_for_animation(batch_size, batch_idx): | |
| batch_indices = [] | |
| mol_indices = [] | |
| for idx in [0, 110, 360]: | |
| if idx // batch_size == batch_idx: | |
| batch_indices.append(idx % batch_size) | |
| mol_indices.append(idx) | |
| return batch_indices, mol_indices | |
| # Rotation data augmntation | |
| def random_rotation(x): | |
| bs, n_nodes, n_dims = x.size() | |
| device = x.device | |
| angle_range = np.pi * 2 | |
| if n_dims == 2: | |
| theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi | |
| cos_theta = torch.cos(theta) | |
| sin_theta = torch.sin(theta) | |
| R_row0 = torch.cat([cos_theta, -sin_theta], dim=2) | |
| R_row1 = torch.cat([sin_theta, cos_theta], dim=2) | |
| R = torch.cat([R_row0, R_row1], dim=1) | |
| x = x.transpose(1, 2) | |
| x = torch.matmul(R, x) | |
| x = x.transpose(1, 2) | |
| elif n_dims == 3: | |
| # Build Rx | |
| Rx = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device) | |
| theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi | |
| cos = torch.cos(theta) | |
| sin = torch.sin(theta) | |
| Rx[:, 1:2, 1:2] = cos | |
| Rx[:, 1:2, 2:3] = sin | |
| Rx[:, 2:3, 1:2] = - sin | |
| Rx[:, 2:3, 2:3] = cos | |
| # Build Ry | |
| Ry = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device) | |
| theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi | |
| cos = torch.cos(theta) | |
| sin = torch.sin(theta) | |
| Ry[:, 0:1, 0:1] = cos | |
| Ry[:, 0:1, 2:3] = -sin | |
| Ry[:, 2:3, 0:1] = sin | |
| Ry[:, 2:3, 2:3] = cos | |
| # Build Rz | |
| Rz = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device) | |
| theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi | |
| cos = torch.cos(theta) | |
| sin = torch.sin(theta) | |
| Rz[:, 0:1, 0:1] = cos | |
| Rz[:, 0:1, 1:2] = sin | |
| Rz[:, 1:2, 0:1] = -sin | |
| Rz[:, 1:2, 1:2] = cos | |
| x = x.transpose(1, 2) | |
| x = torch.matmul(Rx, x) | |
| #x = torch.matmul(Rx.transpose(1, 2), x) | |
| x = torch.matmul(Ry, x) | |
| #x = torch.matmul(Ry.transpose(1, 2), x) | |
| x = torch.matmul(Rz, x) | |
| #x = torch.matmul(Rz.transpose(1, 2), x) | |
| x = x.transpose(1, 2) | |
| else: | |
| raise Exception("Not implemented Error") | |
| return x.contiguous() |