TeacherPuffy
commited on
Update train_mlp_batches.py
Browse files- train_mlp_batches.py +3 -1
train_mlp_batches.py
CHANGED
@@ -81,7 +81,7 @@ def main():
|
|
81 |
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
|
82 |
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
|
83 |
parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
|
84 |
-
parser.add_argument('--access_token', type=str,
|
85 |
parser.add_argument('--upload_checkpoint', action='store_true', help='Upload checkpoint to ModelScope')
|
86 |
parser.add_argument('--delete_checkpoint', action='store_true', help='Delete local checkpoint after uploading')
|
87 |
args = parser.parse_args()
|
@@ -163,6 +163,8 @@ def main():
|
|
163 |
|
164 |
# Upload the model to ModelScope if specified
|
165 |
if args.upload_checkpoint:
|
|
|
|
|
166 |
api = HubApi()
|
167 |
api.login(args.access_token)
|
168 |
api.push_model(
|
|
|
81 |
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
|
82 |
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
|
83 |
parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
|
84 |
+
parser.add_argument('--access_token', type=str, help='ModelScope SDK access token (optional)')
|
85 |
parser.add_argument('--upload_checkpoint', action='store_true', help='Upload checkpoint to ModelScope')
|
86 |
parser.add_argument('--delete_checkpoint', action='store_true', help='Delete local checkpoint after uploading')
|
87 |
args = parser.parse_args()
|
|
|
163 |
|
164 |
# Upload the model to ModelScope if specified
|
165 |
if args.upload_checkpoint:
|
166 |
+
if not args.access_token:
|
167 |
+
raise ValueError("Access token is required for uploading to ModelScope.")
|
168 |
api = HubApi()
|
169 |
api.login(args.access_token)
|
170 |
api.push_model(
|