File size: 3,607 Bytes
55d9b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from abc import ABC
from dataclasses import dataclass
from typing import List, Union
import numpy as np
from rdkit import Chem
from rdkit.Chem.BRICS import BRICSDecompose
from rdkit.Chem.Recap import RecapDecompose

import random


@dataclass
class Fragment:
    smiles: Union[str, None]
    tokens: Union[List[int], None]


class BaseFragmentCreator(ABC):
    """
    Is the base class for all fragment creator and does nothing to the smiles
    """

    def __init__(self) -> None:
        pass

    def create_fragment(self, frag: Fragment) -> Fragment:
        return ""


# This is the method used in the paper
class RandomSubsliceFragmentCreator(BaseFragmentCreator):
    def __init__(self, max_fragment_size=50) -> None:
        super().__init__()
        self.max_fragment_size = max_fragment_size

    def create_fragment(self, frag: Fragment) -> Fragment:
        """
        Creates the random sub slice fragments from the tokens
        """
        tokens = frag.tokens

        startIdx = np.random.randint(0, len(tokens) - 1)

        endIdx = np.random.randint(
            startIdx + 1, min(len(tokens), startIdx + self.max_fragment_size)
        )
        return Fragment(smiles=None, tokens=tokens[startIdx:endIdx])


class BricksFragmentCreator(BaseFragmentCreator):
    def __init__(self) -> None:
        super().__init__()

    def create_fragment(self, frag: Fragment) -> Fragment:
        """
        Creates the Bricks fragments and takes one randomly
        """
        smiles = frag.smiles
        m = Chem.MolFromSmiles(smiles)
        if m is None:
            return ""

        res = list(BRICSDecompose(m, minFragmentSize=3))
        # print(res)
        return random.choice(res)


class RecapFragmentCreator(BaseFragmentCreator):
    def __init__(self) -> None:
        super().__init__()

    def create_fragment(self, frag: Fragment) -> Fragment:
        """
        Creates the Recap fragments and takes one randomly
        """
        smiles = frag.smiles
        m = Chem.MolFromSmiles(smiles)
        if m is None:
            return ""

        res = RecapDecompose(m, minFragmentSize=3).GetAllChildren()
        # print(res)
        return random.choice(res)


class MolFragsFragmentCreator(BaseFragmentCreator):
    def __init__(self) -> None:
        super().__init__()

    def create_fragment(self, frag: Fragment) -> Fragment:
        """
        Creates the Bricks fragments and takes one randomly
        """
        smiles = frag.smiles
        m = Chem.MolFromSmiles(smiles)
        if m is None:
            return ""

        res = list(Chem.rdmolops.GetMolFrags(m, asMols=True))
        res = [Chem.MolToSmiles(m) for m in res]
        # print(res)
        return random.choice(res)


def fragment_creator_factory(key: Union[str, None]):
    if key is None:
        return None

    if key == "mol_frags":
        return MolFragsFragmentCreator()
    elif key == "recap":
        return RecapFragmentCreator()
    elif key == "bricks":
        return BricksFragmentCreator()
    elif key == "rss":
        return RandomSubsliceFragmentCreator()
    else:
        raise ValueError(f"Do not have factory for the given key: {key}")


if __name__ == "__main__":
    from tokenizer import SmilesTokenizer

    tokenizer = SmilesTokenizer()

    creator = BricksFragmentCreator()
    # creator = MolFragsFragmentCreator()

    # creator = RecapFragmentCreator()

    frag = creator.create_fragment("CC(=O)NC1=CC=C(C=C1)O")

    print(frag)
    tokens = tokenizer.encode(frag)
    print(tokens)
    print([tokenizer._convert_id_to_token(t) for t in tokens])