Sadjad Alikhani
commited on
Upload 3 files
Browse files- inference.py +1 -3
- lwm_model.py +3 -4
inference.py
CHANGED
@@ -30,10 +30,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
|
30 |
if torch.cuda.is_available():
|
31 |
torch.cuda.empty_cache()
|
32 |
|
33 |
-
# Folders
|
34 |
-
# MODELS_FOLDER = 'models/'
|
35 |
|
36 |
-
def
|
37 |
|
38 |
if input_type in ['cls_emb', 'channel_emb']:
|
39 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
|
|
30 |
if torch.cuda.is_available():
|
31 |
torch.cuda.empty_cache()
|
32 |
|
|
|
|
|
33 |
|
34 |
+
def lwm_inference(preprocessed_chs, input_type, lwm_model):
|
35 |
|
36 |
if input_type in ['cls_emb', 'channel_emb']:
|
37 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
lwm_model.py
CHANGED
@@ -11,9 +11,7 @@ import torch.nn as nn
|
|
11 |
import torch.nn.functional as F
|
12 |
import numpy as np
|
13 |
from inference import *
|
14 |
-
from load_data import load_DeepMIMO_data
|
15 |
from input_preprocess import *
|
16 |
-
from huggingface_hub import hf_hub_download
|
17 |
|
18 |
|
19 |
ELEMENT_LENGTH = 16
|
@@ -131,8 +129,9 @@ class LWM(torch.nn.Module):
|
|
131 |
model = cls().to(device)
|
132 |
|
133 |
# Download model weights using Hugging Face Hub
|
134 |
-
ckpt_path = hf_hub_download(repo_id="sadjadalikhani/LWM", filename=ckpt_name, use_auth_token=use_auth_token)
|
135 |
-
|
|
|
136 |
# Load the model weights
|
137 |
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
138 |
print(f"Model loaded successfully from {ckpt_path} to {device}")
|
|
|
11 |
import torch.nn.functional as F
|
12 |
import numpy as np
|
13 |
from inference import *
|
|
|
14 |
from input_preprocess import *
|
|
|
15 |
|
16 |
|
17 |
ELEMENT_LENGTH = 16
|
|
|
129 |
model = cls().to(device)
|
130 |
|
131 |
# Download model weights using Hugging Face Hub
|
132 |
+
# ckpt_path = hf_hub_download(repo_id="sadjadalikhani/LWM", filename=ckpt_name, use_auth_token=use_auth_token)
|
133 |
+
ckpt_path = ckpt_name
|
134 |
+
|
135 |
# Load the model weights
|
136 |
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
137 |
print(f"Model loaded successfully from {ckpt_path} to {device}")
|