Sadjad Alikhani commited on
Commit
1d6fec0
·
verified ·
1 Parent(s): 43db74c

Upload 3 files

Browse files
Files changed (2) hide show
  1. inference.py +1 -3
  2. lwm_model.py +3 -4
inference.py CHANGED
@@ -30,10 +30,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
30
  if torch.cuda.is_available():
31
  torch.cuda.empty_cache()
32
 
33
- # Folders
34
- # MODELS_FOLDER = 'models/'
35
 
36
- def dataset_gen(preprocessed_chs, input_type, lwm_model):
37
 
38
  if input_type in ['cls_emb', 'channel_emb']:
39
  dataset = prepare_for_LWM(preprocessed_chs, device)
 
30
  if torch.cuda.is_available():
31
  torch.cuda.empty_cache()
32
 
 
 
33
 
34
+ def lwm_inference(preprocessed_chs, input_type, lwm_model):
35
 
36
  if input_type in ['cls_emb', 'channel_emb']:
37
  dataset = prepare_for_LWM(preprocessed_chs, device)
lwm_model.py CHANGED
@@ -11,9 +11,7 @@ import torch.nn as nn
11
  import torch.nn.functional as F
12
  import numpy as np
13
  from inference import *
14
- from load_data import load_DeepMIMO_data
15
  from input_preprocess import *
16
- from huggingface_hub import hf_hub_download
17
 
18
 
19
  ELEMENT_LENGTH = 16
@@ -131,8 +129,9 @@ class LWM(torch.nn.Module):
131
  model = cls().to(device)
132
 
133
  # Download model weights using Hugging Face Hub
134
- ckpt_path = hf_hub_download(repo_id="sadjadalikhani/LWM", filename=ckpt_name, use_auth_token=use_auth_token)
135
-
 
136
  # Load the model weights
137
  model.load_state_dict(torch.load(ckpt_path, map_location=device))
138
  print(f"Model loaded successfully from {ckpt_path} to {device}")
 
11
  import torch.nn.functional as F
12
  import numpy as np
13
  from inference import *
 
14
  from input_preprocess import *
 
15
 
16
 
17
  ELEMENT_LENGTH = 16
 
129
  model = cls().to(device)
130
 
131
  # Download model weights using Hugging Face Hub
132
+ # ckpt_path = hf_hub_download(repo_id="sadjadalikhani/LWM", filename=ckpt_name, use_auth_token=use_auth_token)
133
+ ckpt_path = ckpt_name
134
+
135
  # Load the model weights
136
  model.load_state_dict(torch.load(ckpt_path, map_location=device))
137
  print(f"Model loaded successfully from {ckpt_path} to {device}")