Sadjad Alikhani commited on
Commit
c0addc2
·
verified ·
1 Parent(s): 5d9edcb

Update lwm_model.py

Browse files
Files changed (1) hide show
  1. lwm_model.py +1 -19
lwm_model.py CHANGED
@@ -2,21 +2,6 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import numpy as np
5
- import random
6
-
7
- # Set manual seed for reproducibility
8
- def set_random_seed(seed=42):
9
- torch.manual_seed(seed)
10
- np.random.seed(seed)
11
- random.seed(seed)
12
- if torch.cuda.is_available():
13
- torch.cuda.manual_seed_all(seed)
14
- # Ensures deterministic behavior
15
- torch.backends.cudnn.deterministic = True
16
- torch.backends.cudnn.benchmark = False
17
-
18
- # Call the seed function
19
- set_random_seed()
20
 
21
  ELEMENT_LENGTH = 16
22
  D_MODEL = 64
@@ -53,7 +38,7 @@ class Embedding(nn.Module):
53
  seq_len = x.size(1)
54
  pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
55
  pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
56
- tok_emb = self.proj(x.float()) # Ensure consistency in floating-point precision
57
  embedding = tok_emb + self.pos_embed(pos)
58
  return self.norm(embedding)
59
 
@@ -124,15 +109,12 @@ class LWM(torch.nn.Module):
124
  embed_weight = self.embedding.proj.weight
125
  d_model, n_dim = embed_weight.size()
126
  self.decoder = nn.Linear(d_model, n_dim, bias=False)
127
- self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
128
  self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
129
 
130
  @classmethod
131
  def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda', use_auth_token=None):
132
- # Define model
133
  model = cls().to(device)
134
 
135
- # Load model weights
136
  ckpt_path = ckpt_name
137
  model.load_state_dict(torch.load(ckpt_path, map_location=device))
138
  print(f"Model loaded successfully from {ckpt_path} to {device}")
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  ELEMENT_LENGTH = 16
7
  D_MODEL = 64
 
38
  seq_len = x.size(1)
39
  pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
40
  pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
41
+ tok_emb = self.proj(x.float())
42
  embedding = tok_emb + self.pos_embed(pos)
43
  return self.norm(embedding)
44
 
 
109
  embed_weight = self.embedding.proj.weight
110
  d_model, n_dim = embed_weight.size()
111
  self.decoder = nn.Linear(d_model, n_dim, bias=False)
 
112
  self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
113
 
114
  @classmethod
115
  def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda', use_auth_token=None):
 
116
  model = cls().to(device)
117
 
 
118
  ckpt_path = ckpt_name
119
  model.load_state_dict(torch.load(ckpt_path, map_location=device))
120
  print(f"Model loaded successfully from {ckpt_path} to {device}")