TeacherPuffy commited on
Commit
048c839
·
verified ·
1 Parent(s): e592507

Update train_mlp_batches.py

Browse files
Files changed (1) hide show
  1. train_mlp_batches.py +3 -1
train_mlp_batches.py CHANGED
@@ -81,7 +81,7 @@ def main():
81
  parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
82
  parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
83
  parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
84
- parser.add_argument('--access_token', type=str, required=True, help='ModelScope SDK access token')
85
  parser.add_argument('--upload_checkpoint', action='store_true', help='Upload checkpoint to ModelScope')
86
  parser.add_argument('--delete_checkpoint', action='store_true', help='Delete local checkpoint after uploading')
87
  args = parser.parse_args()
@@ -163,6 +163,8 @@ def main():
163
 
164
  # Upload the model to ModelScope if specified
165
  if args.upload_checkpoint:
 
 
166
  api = HubApi()
167
  api.login(args.access_token)
168
  api.push_model(
 
81
  parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
82
  parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
83
  parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
84
+ parser.add_argument('--access_token', type=str, help='ModelScope SDK access token (optional)')
85
  parser.add_argument('--upload_checkpoint', action='store_true', help='Upload checkpoint to ModelScope')
86
  parser.add_argument('--delete_checkpoint', action='store_true', help='Delete local checkpoint after uploading')
87
  args = parser.parse_args()
 
163
 
164
  # Upload the model to ModelScope if specified
165
  if args.upload_checkpoint:
166
+ if not args.access_token:
167
+ raise ValueError("Access token is required for uploading to ModelScope.")
168
  api = HubApi()
169
  api.login(args.access_token)
170
  api.push_model(