File size: 1,190 Bytes
92263a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os.path
import subprocess
import torch

from src.visualizer import save_xyz_file

N_SAMPLES = 5


def generate_linkers(ddpm, data, sample_fn, name):
    chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
    print('Generated linker')
    x = chain[0][:, :, :ddpm.n_dims]
    h = chain[0][:, :, ddpm.n_dims:]

    # Put the molecule back to the initial orientation
    pos_masked = data['positions'] * data['fragment_mask']
    N = data['fragment_mask'].sum(1, keepdims=True)
    mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
    x = x + mean * node_mask

    names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
    save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
    print('Saved XYZ files')


def try_to_convert_to_sdf(name):
    out_files = []
    for i in range(N_SAMPLES):
        out_xyz = f'results/output_{i + 1}_{name}_.xyz'
        out_sdf = f'results/output_{i + 1}_{name}_.sdf'
        subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
        if os.path.exists(out_sdf):
            out_files.append(out_sdf)
        else:
            out_files.append(out_xyz)

    return out_files