import os.path import subprocess import torch from src.visualizer import save_xyz_file N_SAMPLES = 5 def generate_linkers(ddpm, data, sample_fn, name): chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) print('Generated linker') x = chain[0][:, :, :ddpm.n_dims] h = chain[0][:, :, ddpm.n_dims:] # Put the molecule back to the initial orientation pos_masked = data['positions'] * data['fragment_mask'] N = data['fragment_mask'].sum(1, keepdims=True) mean = torch.sum(pos_masked, dim=1, keepdim=True) / N x = x + mean * node_mask names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)] save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='') print('Saved XYZ files') def try_to_convert_to_sdf(name): out_files = [] for i in range(N_SAMPLES): out_xyz = f'results/output_{i + 1}_{name}_.xyz' out_sdf = f'results/output_{i + 1}_{name}_.sdf' subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True) if os.path.exists(out_sdf): out_files.append(out_sdf) else: out_files.append(out_xyz) return out_files