Ulian7 commited on
Commit
ec3de94
·
verified ·
1 Parent(s): f664fd4

modify model load process

Browse files
Files changed (1) hide show
  1. README.md +22 -3
README.md CHANGED
@@ -9,12 +9,31 @@ lon_std = 0.0006184167829766685
9
  # TO RUN:
10
 
11
  from huggingface_hub import hf_hub_download
 
12
  import torch
 
13
 
14
- repo_id = "cis-5190-final-fall24/ImageToGPSproject_model"
 
15
  filename = "final_model.pth"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
18
 
19
- model_test = torch.load(model_path)
20
- model_test.eval()
 
 
 
9
  # TO RUN:
10
 
11
  from huggingface_hub import hf_hub_download
12
+ import torchvision.models as models
13
  import torch
14
+ import torch.nn as nn
15
 
16
+ # Specify the repository and the filename of the model you want to load
17
+ repo_id = "cis-5190-final-fall24/ImageToGPSproject_model" # Replace with your repo name
18
  filename = "final_model.pth"
19
 
20
+ class ResNetGPSModel(nn.Module):
21
+ def __init__(self):
22
+ super(ResNetGPSModel, self).__init__()
23
+ self.resnet = models.resnet101() # Updated for PyTorch >=0.13
24
+ self.resnet.fc = nn.Sequential(
25
+ nn.Dropout(0.4), # Dropout for regularization
26
+ nn.Linear(self.resnet.fc.in_features, 2) # Latitude and Longitude
27
+ )
28
+
29
+ def forward(self, x):
30
+ return self.resnet(x)
31
+
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ model = ResNetGPSModel().to(device)
34
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
35
 
36
+ # Load the model using torch
37
+ state_dict = torch.load(model_path)
38
+ model.load_state_dict(state_dict)
39
+ model.eval() # Set the model to evaluation mode