TeacherPuffy commited on
Commit
e592507
·
verified ·
1 Parent(s): 150c211

Update train_mlp_batches.py

Browse files
Files changed (1) hide show
  1. train_mlp_batches.py +18 -11
train_mlp_batches.py CHANGED
@@ -82,8 +82,13 @@ def main():
82
  parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
83
  parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
84
  parser.add_argument('--access_token', type=str, required=True, help='ModelScope SDK access token')
 
 
85
  args = parser.parse_args()
86
 
 
 
 
87
  # Load the zh-plus/tiny-imagenet dataset
88
  dataset = load_dataset('zh-plus/tiny-imagenet')
89
 
@@ -125,7 +130,7 @@ def main():
125
  val_loop=MLPValLoop,
126
  val_interval=1,
127
  default_hooks=dict(
128
- checkpoint=dict(type=CheckpointHook, interval=1, save_best='auto'),
129
  logger=dict(type=LoggerHook, interval=10)
130
  )
131
  )
@@ -156,17 +161,19 @@ def main():
156
  with open(duplicate_result_path, 'w') as f:
157
  f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}\n')
158
 
159
- # Upload the model to ModelScope
160
- api = HubApi()
161
- api.login(args.access_token)
162
- api.push_model(
163
- model_id="puffy310/MLPScaling",
164
- model_dir=model_folder # Local model directory, the directory must contain configuration.json
165
- )
 
166
 
167
- # Delete the local model directory
168
- import shutil
169
- shutil.rmtree(model_folder)
 
170
 
171
  if __name__ == '__main__':
172
  main()
 
82
  parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
83
  parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
84
  parser.add_argument('--access_token', type=str, required=True, help='ModelScope SDK access token')
85
+ parser.add_argument('--upload_checkpoint', action='store_true', help='Upload checkpoint to ModelScope')
86
+ parser.add_argument('--delete_checkpoint', action='store_true', help='Delete local checkpoint after uploading')
87
  args = parser.parse_args()
88
 
89
+ # Set up Git to use hf-mirror as a proxy
90
+ os.environ['GIT_PROXY_COMMAND'] = 'proxychains4 git'
91
+
92
  # Load the zh-plus/tiny-imagenet dataset
93
  dataset = load_dataset('zh-plus/tiny-imagenet')
94
 
 
130
  val_loop=MLPValLoop,
131
  val_interval=1,
132
  default_hooks=dict(
133
+ checkpoint=dict(type=CheckpointHook, interval=1, save_best='auto') if not args.delete_checkpoint else None,
134
  logger=dict(type=LoggerHook, interval=10)
135
  )
136
  )
 
161
  with open(duplicate_result_path, 'w') as f:
162
  f.write(f'Layer Count: {args.layer_count}, Width: {args.width}, Parameter Count: {param_count}\n')
163
 
164
+ # Upload the model to ModelScope if specified
165
+ if args.upload_checkpoint:
166
+ api = HubApi()
167
+ api.login(args.access_token)
168
+ api.push_model(
169
+ model_id="puffy310/MLPScaling",
170
+ model_dir=model_folder # Local model directory, the directory must contain configuration.json
171
+ )
172
 
173
+ # Delete the local model directory if specified
174
+ if args.delete_checkpoint:
175
+ import shutil
176
+ shutil.rmtree(model_folder)
177
 
178
  if __name__ == '__main__':
179
  main()