shredder-31 commited on
Commit
583ff20
·
verified ·
1 Parent(s): ea5ee6e

Create utils

Browse files
Files changed (1) hide show
  1. utils +42 -0
utils ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def load_checkpoint(checkpoint_path, model, remove_prefix=True ,device='cpu', optimizer=None):
4
+ """
5
+ Loads a model and optimizer state from a checkpoint file, removing '_orig_mod.' prefix if present.
6
+
7
+ Parameters:
8
+ - checkpoint_path (str): Path to the checkpoint file.
9
+ - model (torch.nn.Module): The model instance to load the state dict into.
10
+ - optimizer (torch.optim.Optimizer, optional): The optimizer instance to load the state dict into.
11
+
12
+ Returns:
13
+ - model (torch.nn.Module): The model with loaded state dict.
14
+ - optimizer (torch.optim.Optimizer, optional): The optimizer with loaded state dict (if provided).
15
+ - epoch (int): The epoch number saved in the checkpoint.
16
+ - train_loss (float): The training loss saved in the checkpoint.
17
+ - val_loss (float): The validation loss saved in the checkpoint (if available).
18
+ - bleu_score (float): The BLEU score saved in the checkpoint (if available).
19
+ - cider_score (float): The CIDEr score saved in the checkpoint (if available).
20
+ """
21
+
22
+ # Load the checkpoint
23
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
24
+
25
+ # Modify the state_dict to remove the `_orig_mod.` prefix, if it exists
26
+ new_state_dict = {}
27
+ for key, value in checkpoint['model_state_dict'].items():
28
+ new_key = key.replace('_orig_mod.', '') # Remove the prefix if present
29
+ new_state_dict[new_key] = value
30
+
31
+ model.load_state_dict(new_state_dict)
32
+
33
+ if optimizer is not None:
34
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
35
+
36
+ epoch = checkpoint.get('epoch', None)
37
+ train_loss = checkpoint.get('train_loss', None)
38
+ val_loss = checkpoint.get('val_loss', None)
39
+ bleu_score = checkpoint.get('bleu_score', None)
40
+ cider_score = checkpoint.get('cider_score', None)
41
+
42
+ return model, optimizer, epoch, train_loss, val_loss, bleu_score, cider_score