alimotahharynia
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -13,4 +13,220 @@ tags:
|
|
13 |
- chemistry
|
14 |
- biology
|
15 |
- medical
|
16 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
- chemistry
|
14 |
- biology
|
15 |
- medical
|
16 |
+
---
|
17 |
+
# DrugGen: Advancing Drug Discovery with Large Language Models and Reinforcement Learning Feedback
|
18 |
+
|
19 |
+
DrugGen is a GPT-2 based model specialized for generating drug-like SMILES structures based on protein sequence. The model leverages the characteristics of approved drug targets and has been trained through both supervised fine-tuning and reinforcement learning techniques to enhance its ability to generate chemically valid, safe, and effective structures.
|
20 |
+
|
21 |
+
## Model Details
|
22 |
+
|
23 |
+
- Model Name: DrugGen
|
24 |
+
- Training Paradigm: Supervised Fine-Tuning (SFT) + Proximal Policy Optimization (PPO)
|
25 |
+
- Input: Protein Sequence
|
26 |
+
- Output: SMILES Structure
|
27 |
+
- Training Libraries: Hugging Face’s transformers and Transformer Reinforcement Learning (TRL)
|
28 |
+
- Model Sources: liyuesen/druggpt
|
29 |
+
|
30 |
+
## How to Get Started with the Model
|
31 |
+
```python
|
32 |
+
import pandas as pd
|
33 |
+
from transformers import AutoTokenizer, GPT2LMHeadModel
|
34 |
+
from datasets import load_dataset
|
35 |
+
|
36 |
+
class SMILESGeneator:
|
37 |
+
def __init__(self):
|
38 |
+
|
39 |
+
# Configuration parameters
|
40 |
+
self.config = {
|
41 |
+
"model_name": "alimotahharynia/DrugGen",
|
42 |
+
"dataset_name": "alimotahharynia/approved_drug_target",
|
43 |
+
"dataset_key": "uniprot_sequence",
|
44 |
+
"generation_kwargs": {
|
45 |
+
"do_sample": True,
|
46 |
+
"top_k": 9,
|
47 |
+
"max_length": 1024,
|
48 |
+
"top_p": 0.9,
|
49 |
+
"num_return_sequences": 10
|
50 |
+
},
|
51 |
+
"max_retries": 30 # Max retry limit to avoid infinite loops
|
52 |
+
}
|
53 |
+
|
54 |
+
# Load model and tokenizer
|
55 |
+
self.model_name = self.config["model_name"]
|
56 |
+
self.model, self.tokenizer = self.load_model_and_tokenizer(self.model_name)
|
57 |
+
|
58 |
+
# Load UniProt mapping dataset
|
59 |
+
dataset_name = self.config["dataset_name"]
|
60 |
+
dataset_key = self.config["dataset_key"]
|
61 |
+
self.uniprot_to_sequence = self.load_uniprot_mapping(dataset_name, dataset_key)
|
62 |
+
|
63 |
+
# Adjust generation parameters with token IDs
|
64 |
+
self.generation_kwargs = self.config["generation_kwargs"]
|
65 |
+
self.generation_kwargs["bos_token_id"] = self.tokenizer.bos_token_id
|
66 |
+
self.generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
|
67 |
+
self.generation_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
68 |
+
|
69 |
+
def load_model_and_tokenizer(self, model_name):
|
70 |
+
|
71 |
+
print(f"Loading model and tokenizer: {model_name}")
|
72 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
73 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
74 |
+
return model, tokenizer
|
75 |
+
|
76 |
+
def load_uniprot_mapping(self, dataset_name, dataset_key):
|
77 |
+
|
78 |
+
print(f"Loading dataset: {dataset_name}")
|
79 |
+
try:
|
80 |
+
dataset = load_dataset(dataset_name, dataset_key)
|
81 |
+
return {row["UniProt_id"]: row["Sequence"] for row in dataset["uniprot_seq"]}
|
82 |
+
except Exception as e:
|
83 |
+
raise RuntimeError(f"Failed to load dataset {dataset_name}: {e}")
|
84 |
+
|
85 |
+
def generate_smiles(self, sequence, num_generated):
|
86 |
+
"""
|
87 |
+
Generate unique SMILES with a retry limit to avoid infinite loops.
|
88 |
+
"""
|
89 |
+
generated_smiles_set = set()
|
90 |
+
prompt = f"<|startoftext|><P>{sequence}<L>"
|
91 |
+
encoded_prompt = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
|
92 |
+
retries = 0
|
93 |
+
|
94 |
+
while len(generated_smiles_set) < num_generated:
|
95 |
+
if retries >= self.config["max_retries"]:
|
96 |
+
print("Max retries reached. Returning what has been generated so far.")
|
97 |
+
break
|
98 |
+
|
99 |
+
sample_outputs = self.model.generate(encoded_prompt, **self.generation_kwargs)
|
100 |
+
for sample_output in sample_outputs:
|
101 |
+
output_decode = self.tokenizer.decode(sample_output, skip_special_tokens=False)
|
102 |
+
try:
|
103 |
+
generated_smiles = output_decode.split("<L>")[1].split("<|endoftext|>")[0]
|
104 |
+
if generated_smiles not in generated_smiles_set:
|
105 |
+
generated_smiles_set.add(generated_smiles)
|
106 |
+
except IndexError:
|
107 |
+
continue
|
108 |
+
|
109 |
+
retries += 1
|
110 |
+
|
111 |
+
return list(generated_smiles_set)
|
112 |
+
|
113 |
+
def generate_smiles_data(self, list_of_sequences=None, list_of_uniprot_ids=None, num_generated=10):
|
114 |
+
"""
|
115 |
+
Generate SMILES data for sequences or UniProt IDs.
|
116 |
+
"""
|
117 |
+
if not list_of_sequences and not list_of_uniprot_ids:
|
118 |
+
raise ValueError("Either `list_of_sequences` or `list_of_uniprot_ids` must be provided.")
|
119 |
+
|
120 |
+
# Prepare sequences input
|
121 |
+
if list_of_sequences:
|
122 |
+
sequences_input = list_of_sequences
|
123 |
+
else:
|
124 |
+
sequences_input = [
|
125 |
+
self.uniprot_to_sequence[uid]
|
126 |
+
for uid in list_of_uniprot_ids
|
127 |
+
if uid in self.uniprot_to_sequence
|
128 |
+
]
|
129 |
+
|
130 |
+
data = []
|
131 |
+
for sequence in sequences_input:
|
132 |
+
smiles = self.generate_smiles(sequence, num_generated)
|
133 |
+
uniprot_id = next((uid for uid, seq in self.uniprot_to_sequence.items() if seq == sequence), None)
|
134 |
+
data.append({"UniProt_id": uniprot_id, "sequence": sequence, "smiles": smiles})
|
135 |
+
|
136 |
+
return pd.DataFrame(data)
|
137 |
+
```
|
138 |
+
Below is an example of how to use DrugGen for generating SMILES. Adjust the `num_generated` parameter to specify the number of unique protein SMILES you wish to generate.
|
139 |
+
```python
|
140 |
+
if __name__ == "__main__":
|
141 |
+
# Initialize the generator
|
142 |
+
generator = SMILESGeneator()
|
143 |
+
|
144 |
+
# Example input (use either list_of_sequences or list_of_uniprot_ids)
|
145 |
+
list_of_sequences = [
|
146 |
+
"MGAASGRRGPGLLLPLPLLLLLPPQPALALDPGLQPGNFSADEAGAQLFAQSYNSSAEQVLFQSVAASWAHDTNITAENARRQEEAALLSQEFAEAWGQKAKELYEPIWQNFTDPQLRRIIGAVRTLGSANLPLAKRQQYNALLSNMSRIYSTAKVCLPNKTATCWSLDPDLTNILASSRSYAMLLFAWEGWHNAAGIPLKPLYEDFTALSNEAYKQDGFTDTGAYWRSWYNSPTFEDDLEHLYQQLEPLYLNLHAFVRRALHRRYGDRYINLRGPIPAHLLGDMWAQSWENIYDMVVPFPDKPNLDVTSTMLQQGWNATHMFRVAEEFFTSLELSPMPPEFWEGSMLEKPADGREVVCHASAWDFYNRKDFRIKQCTRVTMDQLSTVHHEMGHIQYYLQYKDLPVSLRRGANPGFHEAIGDVLALSVSTPEHLHKIGLLDRVTNDTESDINYLLKMALEKIAFLPFGYLVDQWRWGVFSGRTPPSRYNFDWWYLRTKYQGICPPVTRNETHFDAGAKFHVPNVTPYIRYFVSFVLQFQFHEALCKEAGYEGPLHQCDIYRSTKAGAKLRKVLQAGSSRPWQEVLKDMVGLDALDAQPLLKYFQPVTQWLQEQNQQNGEVLGWPEYQWHPPLPDNYPEGIDLVTDEAEASKFVEEYDRTSQVVWNEYAEANWNYNTNITTETSKILLQKNMQIANHTLKYGTQARKFDVNQLQNTTIKRIIKKVQDLERAALPAQELEEYNKILLDMETTYSVATVCHPNGSCLQLEPDLTNVMATSRKYEDLLWAWEGWRDKAGRAILQFYPKYVELINQAARLNGYVDAGDSWRSMYETPSLEQDLERLFQELQPLYLNLHAYVRRALHRHYGAQHINLEGPIPAHLLGNMWAQTWSNIYDLVVPFPSAPSMDTTEAMLKQGWTPRRMFKEADDFFTSLGLLPVPPEFWNKSMLEKPTDGREVVCHASAWDFYNGKDFRIKQCTTVNLEDLVVAHHEMGHIQYFMQYKDLPVALREGANPGFHEAIGDVLALSVSTPKHLHSLNLLSSEGGSDEHDINFLMKMALDKIAFIPFSYLVDQWRWRVFDGSITKENYNQEWWSLRLKYQGLCPPVPRTQGDFDPGAKFHIPSSVPYIRYFVSFIIQFQFHEALCQAAGHTGPLHKCDIYQSKEAGQRLATAMKLGFSRPWPEAMQLITGQPNMSASAMLSYFKPLLDWLRTENELHGEKLGWPQYNWTPNSARSEGPLPDSGRVSFLGLDLDAQQARVGQWLLLFLGIALLVATLGLSQRLFSIRHRSLHRHSHGPQFGSEVELRHS"
|
147 |
+
]
|
148 |
+
list_of_uniprot_ids = ["P12821", "P37231"]
|
149 |
+
|
150 |
+
# Generate SMILES data for sequences
|
151 |
+
# df = generator.generate_smiles_data(list_of_sequences=list_of_sequences, num_generated=2)
|
152 |
+
|
153 |
+
# Generate SMILES data for UniProt IDs
|
154 |
+
df = generator.generate_smiles_data(list_of_uniprot_ids=list_of_uniprot_ids, num_generated=2)
|
155 |
+
|
156 |
+
# Save the output
|
157 |
+
output_file = "seq_SMILES.txt"
|
158 |
+
df.to_csv(output_file, sep="\t", index=False)
|
159 |
+
print(f"Generated SMILES saved to {output_file}")
|
160 |
+
print(df)
|
161 |
+
```
|
162 |
+
|
163 |
+
## Training Details
|
164 |
+
### Training Data
|
165 |
+
[alimotahharynia/approved_drug_target](https://huggingface.co/datasets/alimotahharynia/approved_drug_target)
|
166 |
+
- This dataset contains approved SMILES-protein sequences pairs data. It was used to train the model for generating SMILES strings.
|
167 |
+
|
168 |
+
### Training Procedure
|
169 |
+
- **Training regime:** fp32
|
170 |
+
|
171 |
+
#### Supervised Fine-Tuning
|
172 |
+
DrugGen was initially trained using supervised fine-tuning on a curated dataset of approved drug targets.
|
173 |
+
- **Training: validation sets** (ratio of 8:2)
|
174 |
+
- **sft_config**
|
175 |
+
- `num_train_epochs= 5`
|
176 |
+
- `per_device_train_batch_size= 8`
|
177 |
+
- `per_device_eval_batch_size= 8`
|
178 |
+
- `evaluation_strategy="steps"`
|
179 |
+
- `save_strategy="epoch"`
|
180 |
+
- `eval_steps=50`
|
181 |
+
- `logging_steps=25`
|
182 |
+
- `logging_strategy="steps"`
|
183 |
+
- `do_eval=True`
|
184 |
+
- `do_train=True`
|
185 |
+
- `learning_rate=5e-4`
|
186 |
+
- `adam_epsilon=1e-08`
|
187 |
+
- `warmup_steps=100`
|
188 |
+
- `eval steps=50`
|
189 |
+
- `dataloader_drop_last=True`
|
190 |
+
- `save_safetensors=False`
|
191 |
+
- `max_seq_length=768`
|
192 |
+
|
193 |
+
- **AdamW optimizer**
|
194 |
+
- `lr=5e-4`
|
195 |
+
- `eps=1e-08`
|
196 |
+
|
197 |
+
- **scheduler**
|
198 |
+
- get_linear_schedule_with_warmup
|
199 |
+
#### Proximal Policy Optimization
|
200 |
+
|
201 |
+
- **Rollout:** Generates a response based on an input query. Generation parameters include:
|
202 |
+
|
203 |
+
- `do_sample=True`
|
204 |
+
- `top_k=9`
|
205 |
+
- `max_length=1024`
|
206 |
+
- `top_p=0.9`
|
207 |
+
- `bos_token_id=tokenizer.bos_token_id`
|
208 |
+
- `eos_token_id=tokenizer.eos_token_id`
|
209 |
+
- `pad_token_id=tokenizer.pad_token_id`
|
210 |
+
- `num_return_sequences=10`
|
211 |
+
|
212 |
+
In each epoch, generation continued until 30 unique small molecules were generated for each target.
|
213 |
+
|
214 |
+
- **Evaluation:** A reward function include:
|
215 |
+
|
216 |
+
- Binding affinity predictor: "Protein-Ligand Binding Affinity Prediction Using Pretrained Transformerswas (PLAPT)"
|
217 |
+
- Customized invalid structure assessor: Based on RDKit library
|
218 |
+
- A multiplicative penalty of "0.7" when a generated SMILES matched a molecule present in the approved SMILES dataset.
|
219 |
+
|
220 |
+
- **Optimization:**
|
221 |
+
|
222 |
+
- **ppo_config**
|
223 |
+
- `mini_batch_size=8`
|
224 |
+
- `batch_size=240`
|
225 |
+
- `learning_rate=1.41e-5`
|
226 |
+
- `use_score_scaling=True`
|
227 |
+
- `use_score_norm=True`
|
228 |
+
|
229 |
+
Prompts with a tensor size greater than 768 were omitted, resulting in 2053 sequences (98.09% of the initial dataset).
|
230 |
+
|
231 |
+
## Citation
|
232 |
+
If you use this model in your research or projects, please cite it as:
|