Sadjad Alikhani commited on
Commit
78d5782
·
verified ·
1 Parent(s): 60ef271

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +11 -11
inference.py CHANGED
@@ -23,19 +23,18 @@ import numpy as np
23
  import warnings
24
  warnings.filterwarnings('ignore')
25
 
26
- # Set random seeds for reproducibility across CPU and GPU
27
- def set_random_seed(seed=42):
28
  torch.manual_seed(seed)
29
  np.random.seed(seed)
30
- #random.seed(seed)
31
- if torch.cuda.is_available():
32
- torch.cuda.manual_seed_all(seed)
33
- # Ensures deterministic behavior
34
- torch.backends.cudnn.deterministic = True
35
- torch.backends.cudnn.benchmark = False
36
 
37
- # Apply random seed
38
- set_random_seed()
 
 
 
 
 
 
39
 
40
  # Device configuration
41
  device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
@@ -45,7 +44,8 @@ if torch.cuda.is_available():
45
  def lwm_inference(preprocessed_chs, input_type, lwm_model):
46
 
47
  dataset = prepare_for_LWM(preprocessed_chs, device)
48
-
 
49
  # Process data through LWM
50
  lwm_loss, embedding_data = evaluate(lwm_model, dataset)
51
  print(f'LWM loss: {lwm_loss:.4f}')
 
23
  import warnings
24
  warnings.filterwarnings('ignore')
25
 
26
+ def set_seed(seed=42):
 
27
  torch.manual_seed(seed)
28
  np.random.seed(seed)
 
 
 
 
 
 
29
 
30
+ # Use this function at the start of your code
31
+ set_seed(42)
32
+
33
+ # Force model weights and data to float32 precision
34
+ def cast_model_weights_to_float32(model):
35
+ for param in model.parameters():
36
+ param.data = param.data.float() # Cast all weights to float32
37
+ return model
38
 
39
  # Device configuration
40
  device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
 
44
  def lwm_inference(preprocessed_chs, input_type, lwm_model):
45
 
46
  dataset = prepare_for_LWM(preprocessed_chs, device)
47
+
48
+ lwm_model = cast_model_weights_to_float32(lwm_model)
49
  # Process data through LWM
50
  lwm_loss, embedding_data = evaluate(lwm_model, dataset)
51
  print(f'LWM loss: {lwm_loss:.4f}')