inoid commited on
Commit
d9f1440
·
verified ·
1 Parent(s): cc2101c

Create spanish_medica_llm.py

Browse files
Files changed (1) hide show
  1. spanish_medica_llm.py +250 -0
spanish_medica_llm.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Optional
7
+ import subprocess
8
+ import sys
9
+ import torch
10
+ import transformers
11
+
12
+ ef parse_args():
13
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
14
+ parser.add_argument(
15
+ "--pretrained_model_name_or_path",
16
+ type=str,
17
+ default=None,
18
+ #required=True,
19
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
20
+ )
21
+ parser.add_argument(
22
+ "--tokenizer_name",
23
+ type=str,
24
+ default=None,
25
+ help="Pretrained tokenizer name or path if not the same as model_name",
26
+ )
27
+ parser.add_argument(
28
+ "--instance_data_dir",
29
+ type=str,
30
+ default=None,
31
+ #required=True,
32
+ help="A folder containing the training data of instance images.",
33
+ )
34
+ parser.add_argument(
35
+ "--class_data_dir",
36
+ type=str,
37
+ default=None,
38
+ required=False,
39
+ help="A folder containing the training data of class images.",
40
+ )
41
+ parser.add_argument(
42
+ "--instance_prompt",
43
+ type=str,
44
+ default=None,
45
+ help="The prompt with identifier specifying the instance",
46
+ )
47
+ parser.add_argument(
48
+ "--class_prompt",
49
+ type=str,
50
+ default="",
51
+ help="The prompt to specify images in the same class as provided instance images.",
52
+ )
53
+ parser.add_argument(
54
+ "--with_prior_preservation",
55
+ default=False,
56
+ action="store_true",
57
+ help="Flag to add prior preservation loss.",
58
+ )
59
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
60
+ parser.add_argument(
61
+ "--num_class_images",
62
+ type=int,
63
+ default=100,
64
+ help=(
65
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
66
+ " sampled with class_prompt."
67
+ ),
68
+ )
69
+ parser.add_argument(
70
+ "--output_dir",
71
+ type=str,
72
+ default="",
73
+ help="The output directory where the model predictions and checkpoints will be written.",
74
+ )
75
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
76
+ parser.add_argument(
77
+ "--resolution",
78
+ type=int,
79
+ default=512,
80
+ help=(
81
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
82
+ " resolution"
83
+ ),
84
+ )
85
+ parser.add_argument(
86
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
87
+ )
88
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
89
+ parser.add_argument(
90
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
91
+ )
92
+ parser.add_argument(
93
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
94
+ )
95
+ parser.add_argument("--num_train_epochs", type=int, default=1)
96
+ parser.add_argument(
97
+ "--max_train_steps",
98
+ type=int,
99
+ default=None,
100
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
101
+ )
102
+ parser.add_argument(
103
+ "--gradient_accumulation_steps",
104
+ type=int,
105
+ default=1,
106
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
107
+ )
108
+ parser.add_argument(
109
+ "--gradient_checkpointing",
110
+ action="store_true",
111
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
112
+ )
113
+ parser.add_argument(
114
+ "--learning_rate",
115
+ type=float,
116
+ default=5e-6,
117
+ help="Initial learning rate (after the potential warmup period) to use.",
118
+ )
119
+ parser.add_argument(
120
+ "--scale_lr",
121
+ action="store_true",
122
+ default=False,
123
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
124
+ )
125
+ parser.add_argument(
126
+ "--lr_scheduler",
127
+ type=str,
128
+ default="constant",
129
+ help=(
130
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
131
+ ' "constant", "constant_with_warmup"]'
132
+ ),
133
+ )
134
+ parser.add_argument(
135
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
136
+ )
137
+ parser.add_argument(
138
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
139
+ )
140
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
141
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
142
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
143
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
144
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
145
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
146
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
147
+ parser.add_argument(
148
+ "--hub_model_id",
149
+ type=str,
150
+ default=None,
151
+ help="The name of the repository to keep in sync with the local `output_dir`.",
152
+ )
153
+ parser.add_argument(
154
+ "--logging_dir",
155
+ type=str,
156
+ default="logs",
157
+ help=(
158
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
159
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
160
+ ),
161
+ )
162
+ parser.add_argument(
163
+ "--mixed_precision",
164
+ type=str,
165
+ default="no",
166
+ choices=["no", "fp16", "bf16"],
167
+ help=(
168
+ "Whether to use mixed precision. Choose"
169
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
170
+ "and an Nvidia Ampere GPU."
171
+ ),
172
+ )
173
+
174
+ parser.add_argument(
175
+ "--save_n_steps",
176
+ type=int,
177
+ default=1,
178
+ help=("Save the model every n global_steps"),
179
+ )
180
+
181
+
182
+ parser.add_argument(
183
+ "--save_starting_step",
184
+ type=int,
185
+ default=1,
186
+ help=("The step from which it starts saving intermediary checkpoints"),
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--stop_text_encoder_training",
191
+ type=int,
192
+ default=1000000,
193
+ help=("The step at which the text_encoder is no longer trained"),
194
+ )
195
+
196
+
197
+ parser.add_argument(
198
+ "--image_captions_filename",
199
+ action="store_true",
200
+ help="Get captions from filename",
201
+ )
202
+
203
+
204
+ parser.add_argument(
205
+ "--dump_only_text_encoder",
206
+ action="store_true",
207
+ default=False,
208
+ help="Dump only text encoder",
209
+ )
210
+
211
+ parser.add_argument(
212
+ "--train_only_unet",
213
+ action="store_true",
214
+ default=False,
215
+ help="Train only the unet",
216
+ )
217
+
218
+ parser.add_argument(
219
+ "--Session_dir",
220
+ type=str,
221
+ default="",
222
+ help="Current session directory",
223
+ )
224
+
225
+
226
+
227
+
228
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
229
+
230
+ args = parser.parse_args()
231
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
232
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
233
+ args.local_rank = env_local_rank
234
+
235
+ #if args.instance_data_dir is None:
236
+ # raise ValueError("You must specify a train data directory.")
237
+
238
+ #if args.with_prior_preservation:
239
+ # if args.class_data_dir is None:
240
+ # raise ValueError("You must specify a data directory for class images.")
241
+ # if args.class_prompt is None:
242
+ # raise ValueError("You must specify prompt for class images.")
243
+
244
+ return args
245
+
246
+
247
+ def run_training(args_imported):
248
+ args_default = parse_args()
249
+ args = merge_args(args_default, args_imported)
250
+ return(args)