TeacherPuffy
commited on
Update train_mlp_batches.py
Browse files- train_mlp_batches.py +18 -11
train_mlp_batches.py
CHANGED
@@ -82,8 +82,13 @@ def main():
|
|
82 |
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
|
83 |
parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
|
84 |
parser.add_argument('--access_token', type=str, required=True, help='ModelScope SDK access token')
|
|
|
|
|
85 |
args = parser.parse_args()
|
86 |
|
|
|
|
|
|
|
87 |
# Load the zh-plus/tiny-imagenet dataset
|
88 |
dataset = load_dataset('zh-plus/tiny-imagenet')
|
89 |
|
@@ -125,7 +130,7 @@ def main():
|
|
125 |
val_loop=MLPValLoop,
|
126 |
val_interval=1,
|
127 |
default_hooks=dict(
|
128 |
-
checkpoint=dict(type=CheckpointHook, interval=1, save_best='auto'),
|
129 |
logger=dict(type=LoggerHook, interval=10)
|
130 |
)
|
131 |
)
|
@@ -156,17 +161,19 @@ def main():
|
|
156 |
with open(duplicate_result_path, 'w') as f:
|
157 |
f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}\n')
|
158 |
|
159 |
-
# Upload the model to ModelScope
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
166 |
|
167 |
-
# Delete the local model directory
|
168 |
-
|
169 |
-
|
|
|
170 |
|
171 |
if __name__ == '__main__':
|
172 |
main()
|
|
|
82 |
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
|
83 |
parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
|
84 |
parser.add_argument('--access_token', type=str, required=True, help='ModelScope SDK access token')
|
85 |
+
parser.add_argument('--upload_checkpoint', action='store_true', help='Upload checkpoint to ModelScope')
|
86 |
+
parser.add_argument('--delete_checkpoint', action='store_true', help='Delete local checkpoint after uploading')
|
87 |
args = parser.parse_args()
|
88 |
|
89 |
+
# Set up Git to use hf-mirror as a proxy
|
90 |
+
os.environ['GIT_PROXY_COMMAND'] = 'proxychains4 git'
|
91 |
+
|
92 |
# Load the zh-plus/tiny-imagenet dataset
|
93 |
dataset = load_dataset('zh-plus/tiny-imagenet')
|
94 |
|
|
|
130 |
val_loop=MLPValLoop,
|
131 |
val_interval=1,
|
132 |
default_hooks=dict(
|
133 |
+
checkpoint=dict(type=CheckpointHook, interval=1, save_best='auto') if not args.delete_checkpoint else None,
|
134 |
logger=dict(type=LoggerHook, interval=10)
|
135 |
)
|
136 |
)
|
|
|
161 |
with open(duplicate_result_path, 'w') as f:
|
162 |
f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}\n')
|
163 |
|
164 |
+
# Upload the model to ModelScope if specified
|
165 |
+
if args.upload_checkpoint:
|
166 |
+
api = HubApi()
|
167 |
+
api.login(args.access_token)
|
168 |
+
api.push_model(
|
169 |
+
model_id="puffy310/MLPScaling",
|
170 |
+
model_dir=model_folder # Local model directory, the directory must contain configuration.json
|
171 |
+
)
|
172 |
|
173 |
+
# Delete the local model directory if specified
|
174 |
+
if args.delete_checkpoint:
|
175 |
+
import shutil
|
176 |
+
shutil.rmtree(model_folder)
|
177 |
|
178 |
if __name__ == '__main__':
|
179 |
main()
|