Capx
/

File size: 987 Bytes
5e83696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from location_encoder import get_neural_network, get_positional_encoding, LocationEncoder


def get_satclip_loc_encoder(ckpt_path, device):
    ckpt = torch.load(ckpt_path,map_location=device)
    hp = ckpt['hyper_parameters']

    posenc = get_positional_encoding(
        hp['le_type'],
        hp['legendre_polys'],
        hp['harmonics_calculation'],
        hp['min_radius'],
        hp['max_radius'],
        hp['frequency_num']
    )
    
    nnet = get_neural_network(
        hp['pe_type'],
        posenc.embedding_dim,
        hp['embed_dim'],
        hp['capacity'],
        hp['num_hidden_layers']
    )

    # only load nnet params from state dict
    state_dict = ckpt['state_dict']
    state_dict = {k[k.index('nnet'):]:state_dict[k] 
                  for k in state_dict.keys() if 'nnet' in k}
    
    loc_encoder = LocationEncoder(posenc, nnet).double()
    loc_encoder.load_state_dict(state_dict)
    loc_encoder.eval()

    return loc_encoder