Sadjad Alikhani
commited on
Update lwm_model.py
Browse files- lwm_model.py +1 -19
lwm_model.py
CHANGED
@@ -2,21 +2,6 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
import numpy as np
|
5 |
-
import random
|
6 |
-
|
7 |
-
# Set manual seed for reproducibility
|
8 |
-
def set_random_seed(seed=42):
|
9 |
-
torch.manual_seed(seed)
|
10 |
-
np.random.seed(seed)
|
11 |
-
random.seed(seed)
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
torch.cuda.manual_seed_all(seed)
|
14 |
-
# Ensures deterministic behavior
|
15 |
-
torch.backends.cudnn.deterministic = True
|
16 |
-
torch.backends.cudnn.benchmark = False
|
17 |
-
|
18 |
-
# Call the seed function
|
19 |
-
set_random_seed()
|
20 |
|
21 |
ELEMENT_LENGTH = 16
|
22 |
D_MODEL = 64
|
@@ -53,7 +38,7 @@ class Embedding(nn.Module):
|
|
53 |
seq_len = x.size(1)
|
54 |
pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
|
55 |
pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
|
56 |
-
tok_emb = self.proj(x.float())
|
57 |
embedding = tok_emb + self.pos_embed(pos)
|
58 |
return self.norm(embedding)
|
59 |
|
@@ -124,15 +109,12 @@ class LWM(torch.nn.Module):
|
|
124 |
embed_weight = self.embedding.proj.weight
|
125 |
d_model, n_dim = embed_weight.size()
|
126 |
self.decoder = nn.Linear(d_model, n_dim, bias=False)
|
127 |
-
self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
|
128 |
self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
|
129 |
|
130 |
@classmethod
|
131 |
def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda', use_auth_token=None):
|
132 |
-
# Define model
|
133 |
model = cls().to(device)
|
134 |
|
135 |
-
# Load model weights
|
136 |
ckpt_path = ckpt_name
|
137 |
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
138 |
print(f"Model loaded successfully from {ckpt_path} to {device}")
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
ELEMENT_LENGTH = 16
|
7 |
D_MODEL = 64
|
|
|
38 |
seq_len = x.size(1)
|
39 |
pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
|
40 |
pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
|
41 |
+
tok_emb = self.proj(x.float())
|
42 |
embedding = tok_emb + self.pos_embed(pos)
|
43 |
return self.norm(embedding)
|
44 |
|
|
|
109 |
embed_weight = self.embedding.proj.weight
|
110 |
d_model, n_dim = embed_weight.size()
|
111 |
self.decoder = nn.Linear(d_model, n_dim, bias=False)
|
|
|
112 |
self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
|
113 |
|
114 |
@classmethod
|
115 |
def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda', use_auth_token=None):
|
|
|
116 |
model = cls().to(device)
|
117 |
|
|
|
118 |
ckpt_path = ckpt_name
|
119 |
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
120 |
print(f"Model loaded successfully from {ckpt_path} to {device}")
|