ravi.naik commited on
Commit
2bae27d
·
1 Parent(s): a9a979d

Updated app.py to resolve device mapping issues

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -21,10 +21,11 @@ dropout = 0.2
21
 
22
 
23
  def load_model():
24
- if device != "cpu":
25
- model = torch.load("checkpoints/model.pth", map_location={"cpu": device})
26
- else:
27
- model = torch.load("checkpoints/model.pth", map_location=device)
 
28
  return model
29
 
30
 
 
21
 
22
 
23
  def load_model():
24
+ model_ckpt = torch.load("checkpoints/model.pth", map_location=device)
25
+ model = GPTModel(
26
+ vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device
27
+ )
28
+ model.load_state_dict(model_ckpt.state_dict())
29
  return model
30
 
31