alimotahharynia commited on
Commit
d0e68c7
·
verified ·
1 Parent(s): c577e1a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +217 -1
README.md CHANGED
@@ -13,4 +13,220 @@ tags:
13
  - chemistry
14
  - biology
15
  - medical
16
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  - chemistry
14
  - biology
15
  - medical
16
+ ---
17
+ # DrugGen: Advancing Drug Discovery with Large Language Models and Reinforcement Learning Feedback
18
+
19
+ DrugGen is a GPT-2 based model specialized for generating drug-like SMILES structures based on protein sequence. The model leverages the characteristics of approved drug targets and has been trained through both supervised fine-tuning and reinforcement learning techniques to enhance its ability to generate chemically valid, safe, and effective structures.
20
+
21
+ ## Model Details
22
+
23
+ - Model Name: DrugGen
24
+ - Training Paradigm: Supervised Fine-Tuning (SFT) + Proximal Policy Optimization (PPO)
25
+ - Input: Protein Sequence
26
+ - Output: SMILES Structure
27
+ - Training Libraries: Hugging Face’s transformers and Transformer Reinforcement Learning (TRL)
28
+ - Model Sources: liyuesen/druggpt
29
+
30
+ ## How to Get Started with the Model
31
+ ```python
32
+ import pandas as pd
33
+ from transformers import AutoTokenizer, GPT2LMHeadModel
34
+ from datasets import load_dataset
35
+
36
+ class SMILESGeneator:
37
+ def __init__(self):
38
+
39
+ # Configuration parameters
40
+ self.config = {
41
+ "model_name": "alimotahharynia/DrugGen",
42
+ "dataset_name": "alimotahharynia/approved_drug_target",
43
+ "dataset_key": "uniprot_sequence",
44
+ "generation_kwargs": {
45
+ "do_sample": True,
46
+ "top_k": 9,
47
+ "max_length": 1024,
48
+ "top_p": 0.9,
49
+ "num_return_sequences": 10
50
+ },
51
+ "max_retries": 30 # Max retry limit to avoid infinite loops
52
+ }
53
+
54
+ # Load model and tokenizer
55
+ self.model_name = self.config["model_name"]
56
+ self.model, self.tokenizer = self.load_model_and_tokenizer(self.model_name)
57
+
58
+ # Load UniProt mapping dataset
59
+ dataset_name = self.config["dataset_name"]
60
+ dataset_key = self.config["dataset_key"]
61
+ self.uniprot_to_sequence = self.load_uniprot_mapping(dataset_name, dataset_key)
62
+
63
+ # Adjust generation parameters with token IDs
64
+ self.generation_kwargs = self.config["generation_kwargs"]
65
+ self.generation_kwargs["bos_token_id"] = self.tokenizer.bos_token_id
66
+ self.generation_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
67
+ self.generation_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
68
+
69
+ def load_model_and_tokenizer(self, model_name):
70
+
71
+ print(f"Loading model and tokenizer: {model_name}")
72
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
73
+ model = GPT2LMHeadModel.from_pretrained(model_name)
74
+ return model, tokenizer
75
+
76
+ def load_uniprot_mapping(self, dataset_name, dataset_key):
77
+
78
+ print(f"Loading dataset: {dataset_name}")
79
+ try:
80
+ dataset = load_dataset(dataset_name, dataset_key)
81
+ return {row["UniProt_id"]: row["Sequence"] for row in dataset["uniprot_seq"]}
82
+ except Exception as e:
83
+ raise RuntimeError(f"Failed to load dataset {dataset_name}: {e}")
84
+
85
+ def generate_smiles(self, sequence, num_generated):
86
+ """
87
+ Generate unique SMILES with a retry limit to avoid infinite loops.
88
+ """
89
+ generated_smiles_set = set()
90
+ prompt = f"<|startoftext|><P>{sequence}<L>"
91
+ encoded_prompt = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
92
+ retries = 0
93
+
94
+ while len(generated_smiles_set) < num_generated:
95
+ if retries >= self.config["max_retries"]:
96
+ print("Max retries reached. Returning what has been generated so far.")
97
+ break
98
+
99
+ sample_outputs = self.model.generate(encoded_prompt, **self.generation_kwargs)
100
+ for sample_output in sample_outputs:
101
+ output_decode = self.tokenizer.decode(sample_output, skip_special_tokens=False)
102
+ try:
103
+ generated_smiles = output_decode.split("<L>")[1].split("<|endoftext|>")[0]
104
+ if generated_smiles not in generated_smiles_set:
105
+ generated_smiles_set.add(generated_smiles)
106
+ except IndexError:
107
+ continue
108
+
109
+ retries += 1
110
+
111
+ return list(generated_smiles_set)
112
+
113
+ def generate_smiles_data(self, list_of_sequences=None, list_of_uniprot_ids=None, num_generated=10):
114
+ """
115
+ Generate SMILES data for sequences or UniProt IDs.
116
+ """
117
+ if not list_of_sequences and not list_of_uniprot_ids:
118
+ raise ValueError("Either `list_of_sequences` or `list_of_uniprot_ids` must be provided.")
119
+
120
+ # Prepare sequences input
121
+ if list_of_sequences:
122
+ sequences_input = list_of_sequences
123
+ else:
124
+ sequences_input = [
125
+ self.uniprot_to_sequence[uid]
126
+ for uid in list_of_uniprot_ids
127
+ if uid in self.uniprot_to_sequence
128
+ ]
129
+
130
+ data = []
131
+ for sequence in sequences_input:
132
+ smiles = self.generate_smiles(sequence, num_generated)
133
+ uniprot_id = next((uid for uid, seq in self.uniprot_to_sequence.items() if seq == sequence), None)
134
+ data.append({"UniProt_id": uniprot_id, "sequence": sequence, "smiles": smiles})
135
+
136
+ return pd.DataFrame(data)
137
+ ```
138
+ Below is an example of how to use DrugGen for generating SMILES. Adjust the `num_generated` parameter to specify the number of unique protein SMILES you wish to generate.
139
+ ```python
140
+ if __name__ == "__main__":
141
+ # Initialize the generator
142
+ generator = SMILESGeneator()
143
+
144
+ # Example input (use either list_of_sequences or list_of_uniprot_ids)
145
+ list_of_sequences = [
146
+ "MGAASGRRGPGLLLPLPLLLLLPPQPALALDPGLQPGNFSADEAGAQLFAQSYNSSAEQVLFQSVAASWAHDTNITAENARRQEEAALLSQEFAEAWGQKAKELYEPIWQNFTDPQLRRIIGAVRTLGSANLPLAKRQQYNALLSNMSRIYSTAKVCLPNKTATCWSLDPDLTNILASSRSYAMLLFAWEGWHNAAGIPLKPLYEDFTALSNEAYKQDGFTDTGAYWRSWYNSPTFEDDLEHLYQQLEPLYLNLHAFVRRALHRRYGDRYINLRGPIPAHLLGDMWAQSWENIYDMVVPFPDKPNLDVTSTMLQQGWNATHMFRVAEEFFTSLELSPMPPEFWEGSMLEKPADGREVVCHASAWDFYNRKDFRIKQCTRVTMDQLSTVHHEMGHIQYYLQYKDLPVSLRRGANPGFHEAIGDVLALSVSTPEHLHKIGLLDRVTNDTESDINYLLKMALEKIAFLPFGYLVDQWRWGVFSGRTPPSRYNFDWWYLRTKYQGICPPVTRNETHFDAGAKFHVPNVTPYIRYFVSFVLQFQFHEALCKEAGYEGPLHQCDIYRSTKAGAKLRKVLQAGSSRPWQEVLKDMVGLDALDAQPLLKYFQPVTQWLQEQNQQNGEVLGWPEYQWHPPLPDNYPEGIDLVTDEAEASKFVEEYDRTSQVVWNEYAEANWNYNTNITTETSKILLQKNMQIANHTLKYGTQARKFDVNQLQNTTIKRIIKKVQDLERAALPAQELEEYNKILLDMETTYSVATVCHPNGSCLQLEPDLTNVMATSRKYEDLLWAWEGWRDKAGRAILQFYPKYVELINQAARLNGYVDAGDSWRSMYETPSLEQDLERLFQELQPLYLNLHAYVRRALHRHYGAQHINLEGPIPAHLLGNMWAQTWSNIYDLVVPFPSAPSMDTTEAMLKQGWTPRRMFKEADDFFTSLGLLPVPPEFWNKSMLEKPTDGREVVCHASAWDFYNGKDFRIKQCTTVNLEDLVVAHHEMGHIQYFMQYKDLPVALREGANPGFHEAIGDVLALSVSTPKHLHSLNLLSSEGGSDEHDINFLMKMALDKIAFIPFSYLVDQWRWRVFDGSITKENYNQEWWSLRLKYQGLCPPVPRTQGDFDPGAKFHIPSSVPYIRYFVSFIIQFQFHEALCQAAGHTGPLHKCDIYQSKEAGQRLATAMKLGFSRPWPEAMQLITGQPNMSASAMLSYFKPLLDWLRTENELHGEKLGWPQYNWTPNSARSEGPLPDSGRVSFLGLDLDAQQARVGQWLLLFLGIALLVATLGLSQRLFSIRHRSLHRHSHGPQFGSEVELRHS"
147
+ ]
148
+ list_of_uniprot_ids = ["P12821", "P37231"]
149
+
150
+ # Generate SMILES data for sequences
151
+ # df = generator.generate_smiles_data(list_of_sequences=list_of_sequences, num_generated=2)
152
+
153
+ # Generate SMILES data for UniProt IDs
154
+ df = generator.generate_smiles_data(list_of_uniprot_ids=list_of_uniprot_ids, num_generated=2)
155
+
156
+ # Save the output
157
+ output_file = "seq_SMILES.txt"
158
+ df.to_csv(output_file, sep="\t", index=False)
159
+ print(f"Generated SMILES saved to {output_file}")
160
+ print(df)
161
+ ```
162
+
163
+ ## Training Details
164
+ ### Training Data
165
+ [alimotahharynia/approved_drug_target](https://huggingface.co/datasets/alimotahharynia/approved_drug_target)
166
+ - This dataset contains approved SMILES-protein sequences pairs data. It was used to train the model for generating SMILES strings.
167
+
168
+ ### Training Procedure
169
+ - **Training regime:** fp32
170
+
171
+ #### Supervised Fine-Tuning
172
+ DrugGen was initially trained using supervised fine-tuning on a curated dataset of approved drug targets.
173
+ - **Training: validation sets** (ratio of 8:2)
174
+ - **sft_config**
175
+ - `num_train_epochs= 5`
176
+ - `per_device_train_batch_size= 8`
177
+ - `per_device_eval_batch_size= 8`
178
+ - `evaluation_strategy="steps"`
179
+ - `save_strategy="epoch"`
180
+ - `eval_steps=50`
181
+ - `logging_steps=25`
182
+ - `logging_strategy="steps"`
183
+ - `do_eval=True`
184
+ - `do_train=True`
185
+ - `learning_rate=5e-4`
186
+ - `adam_epsilon=1e-08`
187
+ - `warmup_steps=100`
188
+ - `eval steps=50`
189
+ - `dataloader_drop_last=True`
190
+ - `save_safetensors=False`
191
+ - `max_seq_length=768`
192
+
193
+ - **AdamW optimizer**
194
+ - `lr=5e-4`
195
+ - `eps=1e-08`
196
+
197
+ - **scheduler**
198
+ - get_linear_schedule_with_warmup
199
+ #### Proximal Policy Optimization
200
+
201
+ - **Rollout:** Generates a response based on an input query. Generation parameters include:
202
+
203
+ - `do_sample=True`
204
+ - `top_k=9`
205
+ - `max_length=1024`
206
+ - `top_p=0.9`
207
+ - `bos_token_id=tokenizer.bos_token_id`
208
+ - `eos_token_id=tokenizer.eos_token_id`
209
+ - `pad_token_id=tokenizer.pad_token_id`
210
+ - `num_return_sequences=10`
211
+
212
+ In each epoch, generation continued until 30 unique small molecules were generated for each target.
213
+
214
+ - **Evaluation:** A reward function include:
215
+
216
+ - Binding affinity predictor: "Protein-Ligand Binding Affinity Prediction Using Pretrained Transformerswas (PLAPT)"
217
+ - Customized invalid structure assessor: Based on RDKit library
218
+ - A multiplicative penalty of "0.7" when a generated SMILES matched a molecule present in the approved SMILES dataset.
219
+
220
+ - **Optimization:**
221
+
222
+ - **ppo_config**
223
+ - `mini_batch_size=8`
224
+ - `batch_size=240`
225
+ - `learning_rate=1.41e-5`
226
+ - `use_score_scaling=True`
227
+ - `use_score_norm=True`
228
+
229
+ Prompts with a tensor size greater than 768 were omitted, resulting in 2053 sequences (98.09% of the initial dataset).
230
+
231
+ ## Citation
232
+ If you use this model in your research or projects, please cite it as: