kgauvin603 commited on
Commit
eacf9f9
·
verified ·
1 Parent(s): b0a8caf

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +10 -5
train.py CHANGED
@@ -25,7 +25,7 @@ dst_path = "./creditcard.Rdata"
25
  wget.download(url, dst_path)
26
 
27
  # Define the directory to save the model (same as dataset directory)
28
- dataset_dir = os.path.dirname(dst_path) # Make sure this is defined before saving the model
29
 
30
  # Load the dataset
31
  parsed_res = rdata.parser.parse_file(dst_path)
@@ -82,16 +82,21 @@ joblib.dump(model_pipeline, saved_model_path)
82
 
83
  print(f"Model Serialized and Saved to {saved_model_path}")
84
 
85
- # Login to Hugging Face using the API token
 
 
 
 
 
 
86
  api = HfApi()
87
- api.set_access_token(FraudDemoWrite)
88
 
89
  # Create or use an existing repository on Hugging Face Hub
90
  repo_name = "kgauvin603/creditcard-fraud-detection" # Replace with your desired repo name
91
- repo_url = api.create_repo(repo_id=repo_name, exist_ok=True)
92
 
93
  # Initialize the repository
94
- repo = Repository(local_dir="hf_model_repo", clone_from=repo_url)
95
 
96
  # Move the model to the repository folder
97
  os.rename(saved_model_path, os.path.join("hf_model_repo", "model.joblib"))
 
25
  wget.download(url, dst_path)
26
 
27
  # Define the directory to save the model (same as dataset directory)
28
+ dataset_dir = os.path.dirname(dst_path)
29
 
30
  # Load the dataset
31
  parsed_res = rdata.parser.parse_file(dst_path)
 
82
 
83
  print(f"Model Serialized and Saved to {saved_model_path}")
84
 
85
+ # Get the Hugging Face API token securely from the secret environment variable
86
+ api_token = os.getenv('FraudDemoWrite') # Access the secret directly using its name
87
+
88
+ if not api_token:
89
+ raise ValueError("Hugging Face API token not found. Ensure 'FraudDemoWrite' is set as a secret in your environment.")
90
+
91
+ # Initialize Hugging Face API
92
  api = HfApi()
 
93
 
94
  # Create or use an existing repository on Hugging Face Hub
95
  repo_name = "kgauvin603/creditcard-fraud-detection" # Replace with your desired repo name
96
+ repo_url = api.create_repo(repo_id=repo_name, token=api_token, exist_ok=True)
97
 
98
  # Initialize the repository
99
+ repo = Repository(local_dir="hf_model_repo", clone_from=repo_url, use_auth_token=api_token)
100
 
101
  # Move the model to the repository folder
102
  os.rename(saved_model_path, os.path.join("hf_model_repo", "model.joblib"))