Spaces:
Sleeping
Sleeping
Create utils
Browse files
utils
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def load_checkpoint(checkpoint_path, model, remove_prefix=True ,device='cpu', optimizer=None):
|
4 |
+
"""
|
5 |
+
Loads a model and optimizer state from a checkpoint file, removing '_orig_mod.' prefix if present.
|
6 |
+
|
7 |
+
Parameters:
|
8 |
+
- checkpoint_path (str): Path to the checkpoint file.
|
9 |
+
- model (torch.nn.Module): The model instance to load the state dict into.
|
10 |
+
- optimizer (torch.optim.Optimizer, optional): The optimizer instance to load the state dict into.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
- model (torch.nn.Module): The model with loaded state dict.
|
14 |
+
- optimizer (torch.optim.Optimizer, optional): The optimizer with loaded state dict (if provided).
|
15 |
+
- epoch (int): The epoch number saved in the checkpoint.
|
16 |
+
- train_loss (float): The training loss saved in the checkpoint.
|
17 |
+
- val_loss (float): The validation loss saved in the checkpoint (if available).
|
18 |
+
- bleu_score (float): The BLEU score saved in the checkpoint (if available).
|
19 |
+
- cider_score (float): The CIDEr score saved in the checkpoint (if available).
|
20 |
+
"""
|
21 |
+
|
22 |
+
# Load the checkpoint
|
23 |
+
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
24 |
+
|
25 |
+
# Modify the state_dict to remove the `_orig_mod.` prefix, if it exists
|
26 |
+
new_state_dict = {}
|
27 |
+
for key, value in checkpoint['model_state_dict'].items():
|
28 |
+
new_key = key.replace('_orig_mod.', '') # Remove the prefix if present
|
29 |
+
new_state_dict[new_key] = value
|
30 |
+
|
31 |
+
model.load_state_dict(new_state_dict)
|
32 |
+
|
33 |
+
if optimizer is not None:
|
34 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
35 |
+
|
36 |
+
epoch = checkpoint.get('epoch', None)
|
37 |
+
train_loss = checkpoint.get('train_loss', None)
|
38 |
+
val_loss = checkpoint.get('val_loss', None)
|
39 |
+
bleu_score = checkpoint.get('bleu_score', None)
|
40 |
+
cider_score = checkpoint.get('cider_score', None)
|
41 |
+
|
42 |
+
return model, optimizer, epoch, train_loss, val_loss, bleu_score, cider_score
|